mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-03 05:52:42 +00:00
Compare commits
4 Commits
cli-agent-
...
temp/pr-98
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
569fef22b4 | ||
|
|
517d1035e0 | ||
|
|
ad108dd573 | ||
|
|
2a5810a44a |
285
.github/workflows/deployment.yml
vendored
285
.github/workflows/deployment.yml
vendored
@@ -1509,105 +1509,232 @@ jobs:
|
||||
$(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \
|
||||
$IMAGES
|
||||
|
||||
trivy-scan:
|
||||
trivy-scan-web:
|
||||
needs:
|
||||
- determine-builds
|
||||
- merge-web
|
||||
- merge-web-cloud
|
||||
- merge-backend
|
||||
- merge-model-server
|
||||
if: >-
|
||||
always() && !cancelled() &&
|
||||
(needs.merge-web.result == 'success' ||
|
||||
needs.merge-web-cloud.result == 'success' ||
|
||||
needs.merge-backend.result == 'success' ||
|
||||
needs.merge-model-server.result == 'success')
|
||||
if: needs.merge-web.result == 'success'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-trivy-scan-${{ matrix.component }}
|
||||
- run-id=${{ github.run_id }}-trivy-scan-web
|
||||
- extras=ecr-cache
|
||||
permissions:
|
||||
security-events: write # needed for SARIF uploads
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- component: web
|
||||
registry-image: onyxdotapp/onyx-web-server
|
||||
- component: web-cloud
|
||||
registry-image: onyxdotapp/onyx-web-server-cloud
|
||||
- component: backend
|
||||
registry-image: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
trivyignore: backend/.trivyignore
|
||||
- component: model-server
|
||||
registry-image: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
timeout_minutes: 30
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
if [ "${{ needs.determine-builds.outputs.is-test-run }}" == "true" ]; then
|
||||
SCAN_IMAGE="${{ env.RUNS_ON_ECR_CACHE }}:web-${{ needs.determine-builds.outputs.sanitized-tag }}"
|
||||
else
|
||||
SCAN_IMAGE="docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}"
|
||||
fi
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
--timeout 20m \
|
||||
--severity CRITICAL,HIGH \
|
||||
${SCAN_IMAGE}
|
||||
|
||||
trivy-scan-web-cloud:
|
||||
needs:
|
||||
- determine-builds
|
||||
- merge-web-cloud
|
||||
if: needs.merge-web-cloud.result == 'success'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-trivy-scan-web-cloud
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
timeout_minutes: 30
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
if [ "${{ needs.determine-builds.outputs.is-test-run }}" == "true" ]; then
|
||||
SCAN_IMAGE="${{ env.RUNS_ON_ECR_CACHE }}:web-cloud-${{ needs.determine-builds.outputs.sanitized-tag }}"
|
||||
else
|
||||
SCAN_IMAGE="docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}"
|
||||
fi
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
--timeout 20m \
|
||||
--severity CRITICAL,HIGH \
|
||||
${SCAN_IMAGE}
|
||||
|
||||
trivy-scan-backend:
|
||||
needs:
|
||||
- determine-builds
|
||||
- merge-backend
|
||||
if: needs.merge-backend.result == 'success'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-trivy-scan-backend
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
steps:
|
||||
- name: Check if this scan should run
|
||||
id: should-run
|
||||
run: |
|
||||
case "$COMPONENT" in
|
||||
web) RESULT="$MERGE_WEB" ;;
|
||||
web-cloud) RESULT="$MERGE_WEB_CLOUD" ;;
|
||||
backend) RESULT="$MERGE_BACKEND" ;;
|
||||
model-server) RESULT="$MERGE_MODEL_SERVER" ;;
|
||||
esac
|
||||
if [ "$RESULT" == "success" ]; then
|
||||
echo "run=true" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "run=false" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
env:
|
||||
COMPONENT: ${{ matrix.component }}
|
||||
MERGE_WEB: ${{ needs.merge-web.result }}
|
||||
MERGE_WEB_CLOUD: ${{ needs.merge-web-cloud.result }}
|
||||
MERGE_BACKEND: ${{ needs.merge-backend.result }}
|
||||
MERGE_MODEL_SERVER: ${{ needs.merge-model-server.result }}
|
||||
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
if: steps.should-run.outputs.run == 'true'
|
||||
|
||||
- name: Checkout
|
||||
if: steps.should-run.outputs.run == 'true' && matrix.trivyignore != ''
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Determine scan image
|
||||
if: steps.should-run.outputs.run == 'true'
|
||||
id: scan-image
|
||||
run: |
|
||||
if [ "$IS_TEST_RUN" == "true" ]; then
|
||||
echo "image=${RUNS_ON_ECR_CACHE}:${TAG_PREFIX}-${SANITIZED_TAG}" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "image=docker.io/${REGISTRY_IMAGE}:${REF_NAME}" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
env:
|
||||
IS_TEST_RUN: ${{ needs.determine-builds.outputs.is-test-run }}
|
||||
TAG_PREFIX: ${{ matrix.component }}
|
||||
SANITIZED_TAG: ${{ needs.determine-builds.outputs.sanitized-tag }}
|
||||
REGISTRY_IMAGE: ${{ matrix.registry-image }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
if: steps.should-run.outputs.run == 'true'
|
||||
uses: aquasecurity/trivy-action@57a97c7e7821a5776cebc9bb87c984fa69cba8f1 # ratchet:aquasecurity/trivy-action@v0.35.0
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
image-ref: ${{ steps.scan-image.outputs.image }}
|
||||
severity: CRITICAL,HIGH
|
||||
format: "sarif"
|
||||
output: "trivy-results.sarif"
|
||||
trivyignores: ${{ matrix.trivyignore }}
|
||||
env:
|
||||
TRIVY_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
TRIVY_PASSWORD: ${{ secrets.DOCKER_TOKEN }}
|
||||
timeout_minutes: 30
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
if [ "${{ needs.determine-builds.outputs.is-test-run }}" == "true" ]; then
|
||||
SCAN_IMAGE="${{ env.RUNS_ON_ECR_CACHE }}:backend-${{ needs.determine-builds.outputs.sanitized-tag }}"
|
||||
else
|
||||
SCAN_IMAGE="docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}"
|
||||
fi
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-v ${{ github.workspace }}/backend/.trivyignore:/tmp/.trivyignore:ro \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
--timeout 20m \
|
||||
--severity CRITICAL,HIGH \
|
||||
--ignorefile /tmp/.trivyignore \
|
||||
${SCAN_IMAGE}
|
||||
|
||||
- name: Upload Trivy scan results to GitHub Security tab
|
||||
if: steps.should-run.outputs.run == 'true'
|
||||
uses: github/codeql-action/upload-sarif@ba454b8ab46733eb6145342877cd148270bb77ab
|
||||
trivy-scan-model-server:
|
||||
needs:
|
||||
- determine-builds
|
||||
- merge-model-server
|
||||
if: needs.merge-model-server.result == 'success'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-trivy-scan-model-server
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
sarif_file: "trivy-results.sarif"
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
timeout_minutes: 30
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
if [ "${{ needs.determine-builds.outputs.is-test-run }}" == "true" ]; then
|
||||
SCAN_IMAGE="${{ env.RUNS_ON_ECR_CACHE }}:model-server-${{ needs.determine-builds.outputs.sanitized-tag }}"
|
||||
else
|
||||
SCAN_IMAGE="docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}"
|
||||
fi
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
--timeout 20m \
|
||||
--severity CRITICAL,HIGH \
|
||||
${SCAN_IMAGE}
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs:
|
||||
|
||||
174
.github/workflows/pr-python-connector-tests.yml
vendored
174
.github/workflows/pr-python-connector-tests.yml
vendored
@@ -22,40 +22,132 @@ on:
|
||||
- cron: "0 16 * * *"
|
||||
|
||||
permissions:
|
||||
id-token: write # Required for OIDC-based AWS credential exchange
|
||||
contents: read
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
DISABLE_TELEMETRY: "true"
|
||||
# AWS
|
||||
AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
|
||||
AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS }}
|
||||
|
||||
# Cloudflare R2
|
||||
R2_ACCOUNT_ID_DAILY_CONNECTOR_TESTS: ${{ vars.R2_ACCOUNT_ID_DAILY_CONNECTOR_TESTS }}
|
||||
R2_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.R2_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
|
||||
R2_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS: ${{ secrets.R2_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS }}
|
||||
|
||||
# Google Cloud Storage
|
||||
GCS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.GCS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
|
||||
GCS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS: ${{ secrets.GCS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS }}
|
||||
|
||||
# Confluence
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_TEST_SPACE: ${{ vars.CONFLUENCE_TEST_SPACE }}
|
||||
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
|
||||
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
|
||||
|
||||
# Jira
|
||||
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
|
||||
|
||||
# Gong
|
||||
GONG_ACCESS_KEY: ${{ secrets.GONG_ACCESS_KEY }}
|
||||
GONG_ACCESS_KEY_SECRET: ${{ secrets.GONG_ACCESS_KEY_SECRET }}
|
||||
|
||||
# Google
|
||||
GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR }}
|
||||
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1 }}
|
||||
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }}
|
||||
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
|
||||
# Slab
|
||||
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
|
||||
|
||||
# Zendesk
|
||||
ZENDESK_SUBDOMAIN: ${{ secrets.ZENDESK_SUBDOMAIN }}
|
||||
ZENDESK_EMAIL: ${{ secrets.ZENDESK_EMAIL }}
|
||||
ZENDESK_TOKEN: ${{ secrets.ZENDESK_TOKEN }}
|
||||
|
||||
# Salesforce
|
||||
SF_USERNAME: ${{ vars.SF_USERNAME }}
|
||||
SF_PASSWORD: ${{ secrets.SF_PASSWORD }}
|
||||
SF_SECURITY_TOKEN: ${{ secrets.SF_SECURITY_TOKEN }}
|
||||
|
||||
# Hubspot
|
||||
HUBSPOT_ACCESS_TOKEN: ${{ secrets.HUBSPOT_ACCESS_TOKEN }}
|
||||
|
||||
# IMAP
|
||||
IMAP_HOST: ${{ vars.IMAP_HOST }}
|
||||
IMAP_USERNAME: ${{ vars.IMAP_USERNAME }}
|
||||
IMAP_PASSWORD: ${{ secrets.IMAP_PASSWORD }}
|
||||
IMAP_MAILBOXES: ${{ vars.IMAP_MAILBOXES }}
|
||||
|
||||
# Airtable
|
||||
AIRTABLE_TEST_BASE_ID: ${{ vars.AIRTABLE_TEST_BASE_ID }}
|
||||
AIRTABLE_TEST_TABLE_ID: ${{ vars.AIRTABLE_TEST_TABLE_ID }}
|
||||
AIRTABLE_TEST_TABLE_NAME: ${{ vars.AIRTABLE_TEST_TABLE_NAME }}
|
||||
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
|
||||
|
||||
# Sharepoint
|
||||
SHAREPOINT_CLIENT_ID: ${{ vars.SHAREPOINT_CLIENT_ID }}
|
||||
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
||||
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ vars.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||
SHAREPOINT_SITE: ${{ vars.SHAREPOINT_SITE }}
|
||||
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
|
||||
|
||||
# Github
|
||||
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}
|
||||
|
||||
# Gitlab
|
||||
GITLAB_ACCESS_TOKEN: ${{ secrets.GITLAB_ACCESS_TOKEN }}
|
||||
|
||||
# Gitbook
|
||||
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
|
||||
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}
|
||||
|
||||
# Notion
|
||||
NOTION_INTEGRATION_TOKEN: ${{ secrets.NOTION_INTEGRATION_TOKEN }}
|
||||
|
||||
# Highspot
|
||||
HIGHSPOT_KEY: ${{ secrets.HIGHSPOT_KEY }}
|
||||
HIGHSPOT_SECRET: ${{ secrets.HIGHSPOT_SECRET }}
|
||||
|
||||
# Slack
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
|
||||
# Discord
|
||||
DISCORD_CONNECTOR_BOT_TOKEN: ${{ secrets.DISCORD_CONNECTOR_BOT_TOKEN }}
|
||||
|
||||
# Teams
|
||||
TEAMS_APPLICATION_ID: ${{ secrets.TEAMS_APPLICATION_ID }}
|
||||
TEAMS_DIRECTORY_ID: ${{ secrets.TEAMS_DIRECTORY_ID }}
|
||||
TEAMS_SECRET: ${{ secrets.TEAMS_SECRET }}
|
||||
|
||||
# Bitbucket
|
||||
BITBUCKET_WORKSPACE: ${{ secrets.BITBUCKET_WORKSPACE }}
|
||||
BITBUCKET_REPOSITORIES: ${{ secrets.BITBUCKET_REPOSITORIES }}
|
||||
BITBUCKET_PROJECTS: ${{ secrets.BITBUCKET_PROJECTS }}
|
||||
BITBUCKET_EMAIL: ${{ vars.BITBUCKET_EMAIL }}
|
||||
BITBUCKET_API_TOKEN: ${{ secrets.BITBUCKET_API_TOKEN }}
|
||||
|
||||
# Fireflies
|
||||
FIREFLIES_API_KEY: ${{ secrets.FIREFLIES_API_KEY }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=8cpu-linux-x64,
|
||||
"run-id=${{ github.run_id }}-connectors-check",
|
||||
"extras=s3-cache",
|
||||
]
|
||||
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-connectors-check", "extras=s3-cache"]
|
||||
timeout-minutes: 45
|
||||
environment: ci-protected
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
DISABLE_TELEMETRY: "true"
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
@@ -96,66 +188,6 @@ jobs:
|
||||
- 'backend/onyx/file_processing/**'
|
||||
- 'uv.lock'
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v4
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get connector test secrets from AWS Secrets Manager
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2
|
||||
with:
|
||||
parse-json-secrets: false
|
||||
secret-ids: |
|
||||
AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS, test/aws-access-key-id
|
||||
AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS, test/aws-secret-access-key
|
||||
R2_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS, test/r2-access-key-id
|
||||
R2_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS, test/r2-secret-access-key
|
||||
GCS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS, test/gcs-access-key-id
|
||||
GCS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS, test/gcs-secret-access-key
|
||||
CONFLUENCE_ACCESS_TOKEN, test/confluence-access-token
|
||||
CONFLUENCE_ACCESS_TOKEN_SCOPED, test/confluence-access-token-scoped
|
||||
JIRA_BASE_URL, test/jira-base-url
|
||||
JIRA_USER_EMAIL, test/jira-user-email
|
||||
JIRA_API_TOKEN, test/jira-api-token
|
||||
JIRA_API_TOKEN_SCOPED, test/jira-api-token-scoped
|
||||
GONG_ACCESS_KEY, test/gong-access-key
|
||||
GONG_ACCESS_KEY_SECRET, test/gong-access-key-secret
|
||||
GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR, test/google-drive-service-account-json
|
||||
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1, test/google-drive-oauth-creds-test-user-1
|
||||
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR, test/google-drive-oauth-creds
|
||||
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR, test/google-gmail-service-account-json
|
||||
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR, test/google-gmail-oauth-creds
|
||||
SLAB_BOT_TOKEN, test/slab-bot-token
|
||||
ZENDESK_SUBDOMAIN, test/zendesk-subdomain
|
||||
ZENDESK_EMAIL, test/zendesk-email
|
||||
ZENDESK_TOKEN, test/zendesk-token
|
||||
SF_PASSWORD, test/sf-password
|
||||
SF_SECURITY_TOKEN, test/sf-security-token
|
||||
HUBSPOT_ACCESS_TOKEN, test/hubspot-access-token
|
||||
IMAP_PASSWORD, test/imap-password
|
||||
AIRTABLE_ACCESS_TOKEN, test/airtable-access-token
|
||||
SHAREPOINT_CLIENT_SECRET, test/sharepoint-client-secret
|
||||
PERM_SYNC_SHAREPOINT_CLIENT_ID, test/perm-sync-sharepoint-client-id
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY, test/perm-sync-sharepoint-private-key
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD, test/perm-sync-sharepoint-cert-password
|
||||
PERM_SYNC_SHAREPOINT_DIRECTORY_ID, test/perm-sync-sharepoint-directory-id
|
||||
ACCESS_TOKEN_GITHUB, test/github-access-token
|
||||
GITLAB_ACCESS_TOKEN, test/gitlab-access-token
|
||||
GITBOOK_SPACE_ID, test/gitbook-space-id
|
||||
GITBOOK_API_KEY, test/gitbook-api-key
|
||||
NOTION_INTEGRATION_TOKEN, test/notion-integration-token
|
||||
HIGHSPOT_KEY, test/highspot-key
|
||||
HIGHSPOT_SECRET, test/highspot-secret
|
||||
SLACK_BOT_TOKEN, test/slack-bot-token
|
||||
DISCORD_CONNECTOR_BOT_TOKEN, test/discord-bot-token
|
||||
TEAMS_APPLICATION_ID, test/teams-application-id
|
||||
TEAMS_DIRECTORY_ID, test/teams-directory-id
|
||||
TEAMS_SECRET, test/teams-secret
|
||||
BITBUCKET_WORKSPACE, test/bitbucket-workspace
|
||||
BITBUCKET_API_TOKEN, test/bitbucket-api-token
|
||||
FIREFLIES_API_KEY, test/fireflies-api-key
|
||||
|
||||
- name: Run Tests (excluding HubSpot, Salesforce, GitHub, and Coda)
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: |
|
||||
|
||||
1
.github/workflows/preview.yml
vendored
1
.github/workflows/preview.yml
vendored
@@ -15,6 +15,7 @@ permissions:
|
||||
jobs:
|
||||
Deploy-Preview:
|
||||
runs-on: ubuntu-latest
|
||||
environment: ci-protected
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd
|
||||
|
||||
@@ -6,7 +6,7 @@ Use explicit type annotations for variables to enhance code clarity, especially
|
||||
|
||||
## Best Practices
|
||||
|
||||
Use the "Engineering Best Practices" section of `CONTRIBUTING.md` as core review context. Prefer consistency with existing patterns, fix issues in code you touch, avoid tacking new features onto muddy interfaces, fail loudly instead of silently swallowing errors, keep code strictly typed, preserve clear state boundaries, remove duplicate or dead logic, break up overly long functions, avoid hidden import-time side effects, respect module boundaries, and favor correctness-by-construction over relying on callers to use an API correctly.
|
||||
Use `contributing_guides/best_practices.md` as core review context. Prefer consistency with existing patterns, fix issues in code you touch, avoid tacking new features onto muddy interfaces, fail loudly instead of silently swallowing errors, keep code strictly typed, preserve clear state boundaries, remove duplicate or dead logic, break up overly long functions, avoid hidden import-time side effects, respect module boundaries, and favor correctness-by-construction over relying on callers to use an API correctly.
|
||||
|
||||
## TODOs
|
||||
|
||||
@@ -27,7 +27,6 @@ Code changes must consider both multi-tenant and single-tenant deployments. In m
|
||||
## Nginx Routing — New Backend Routes
|
||||
|
||||
Whenever a new backend route is added that does NOT start with `/api`, it must also be explicitly added to ALL nginx configs:
|
||||
|
||||
- `deployment/helm/charts/onyx/templates/nginx-conf.yaml` (Helm/k8s)
|
||||
- `deployment/data/nginx/app.conf.template` (docker-compose dev)
|
||||
- `deployment/data/nginx/app.conf.template.prod` (docker-compose prod)
|
||||
@@ -38,7 +37,3 @@ Routes not starting with `/api` are not caught by the existing `^/(api|openapi\.
|
||||
## Full vs Lite Deployments
|
||||
|
||||
Code changes must consider both regular Onyx deployments and Onyx lite deployments. Lite deployments disable the vector DB, Redis, model servers, and background workers by default, use PostgreSQL-backed cache/auth/file storage, and rely on the API server to handle background work. Do not assume those services are available unless the code path is explicitly limited to full deployments.
|
||||
|
||||
## SWR Cache Keys — Always Use SWR_KEYS Registry
|
||||
|
||||
All `useSWR()` calls and `mutate()` calls in the frontend must reference the centralized `SWR_KEYS` registry in `web/src/lib/swr-keys.ts` instead of inline endpoint strings or local string constants. Never write `useSWR("/api/some/endpoint", ...)` or `mutate("/api/some/endpoint")` — always use the corresponding `SWR_KEYS.someEndpoint` constant. If the endpoint does not yet exist in the registry, add it there first. This applies to all variants of an endpoint (e.g. query-string variants like `?get_editable=true` must also be registered as their own key).
|
||||
|
||||
@@ -357,5 +357,5 @@ raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=e.respon
|
||||
## Best Practices
|
||||
|
||||
In addition to the other content in this file, best practices for contributing
|
||||
to the codebase can be found in the "Engineering Best Practices" section of
|
||||
`CONTRIBUTING.md`. Understand its contents and follow them.
|
||||
to the codebase can be found at `contributing_guides/best_practices.md`.
|
||||
Understand its contents and follow them.
|
||||
|
||||
481
CONTRIBUTING.md
481
CONTRIBUTING.md
@@ -1,487 +1,32 @@
|
||||
# Contributing to Onyx
|
||||
|
||||
Hey there! We are so excited that you're interested in Onyx.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Contribution Opportunities](#contribution-opportunities)
|
||||
- [Contribution Process](#contribution-process)
|
||||
- [Development Setup](#development-setup)
|
||||
- [Prerequisites](#prerequisites)
|
||||
- [Backend: Python Requirements](#backend-python-requirements)
|
||||
- [Frontend: Node Dependencies](#frontend-node-dependencies)
|
||||
- [Formatting and Linting](#formatting-and-linting)
|
||||
- [Running the Application](#running-the-application)
|
||||
- [VSCode Debugger (Recommended)](#vscode-debugger-recommended)
|
||||
- [Manually Running for Development](#manually-running-for-development)
|
||||
- [Running in Docker](#running-in-docker)
|
||||
- [macOS-Specific Notes](#macos-specific-notes)
|
||||
- [Engineering Best Practices](#engineering-best-practices)
|
||||
- [Principles and Collaboration](#principles-and-collaboration)
|
||||
- [Style and Maintainability](#style-and-maintainability)
|
||||
- [Performance and Correctness](#performance-and-correctness)
|
||||
- [Repository Conventions](#repository-conventions)
|
||||
- [Release Process](#release-process)
|
||||
- [Getting Help](#getting-help)
|
||||
- [Enterprise Edition Contributions](#enterprise-edition-contributions)
|
||||
|
||||
---
|
||||
|
||||
## Contribution Opportunities
|
||||
|
||||
The [GitHub Issues](https://github.com/onyx-dot-app/onyx/issues) page is a great place to look for and share contribution ideas.
|
||||
|
||||
If you have your own feature that you would like to build, please create an issue and community members can provide feedback and upvote if they feel a common need.
|
||||
If you have your own feature that you would like to build please create an issue and community members can provide feedback and
|
||||
thumb it up if they feel a common need.
|
||||
|
||||
---
|
||||
|
||||
## Contribution Process
|
||||
## Contributing Code
|
||||
Please reference the documents in contributing_guides folder to ensure that the code base is kept to a high standard.
|
||||
1. dev_setup.md (start here): gives you a guide to setting up a local development environment.
|
||||
2. contribution_process.md: how to ensure you are building valuable features that will get reviewed and merged.
|
||||
3. best_practices.md: before asking for reviews, ensure your changes meet the repo code quality standards.
|
||||
|
||||
To contribute, please follow the
|
||||
["fork and pull request"](https://docs.github.com/en/get-started/quickstart/contributing-to-projects) workflow.
|
||||
|
||||
### 1. Get the feature or enhancement approved
|
||||
|
||||
Create a GitHub issue and see if there are upvotes. If you feel the feature is sufficiently value-additive and you would like approval to contribute it to the repo, tag [Yuhong](https://github.com/yuhongsun96) to review.
|
||||
|
||||
If you do not get a response within a week, feel free to email yuhong@onyx.app and include the issue in the message.
|
||||
|
||||
Not all small features and enhancements will be accepted as there is a balance between feature richness and bloat. We strive to provide the best user experience possible so we have to be intentional about what we include in the app.
|
||||
|
||||
### 2. Get the design approved
|
||||
|
||||
The Onyx team will either provide a design doc and PRD for the feature or request one from you, the contributor. The scope and detail of the design will depend on the individual feature.
|
||||
|
||||
### 3. IP attribution for EE contributions
|
||||
|
||||
If you are contributing features to Onyx Enterprise Edition, you are required to sign the [IP Assignment Agreement](contributor_ip_assignment/EE_Contributor_IP_Assignment_Agreement.md).
|
||||
|
||||
### 4. Review and testing
|
||||
|
||||
Your features must pass all tests and all comments must be addressed prior to merging.
|
||||
|
||||
### Implicit agreements
|
||||
|
||||
If we approve an issue, we are promising you the following:
|
||||
- Your work will receive timely attention and we will put aside other important items to ensure you are not blocked.
|
||||
- You will receive necessary coaching on eng quality, system design, etc. to ensure the feature is completed well.
|
||||
- The Onyx team will pull resources and bandwidth from design, PM, and engineering to ensure that you have all the resources to build the feature to the quality required for merging.
|
||||
|
||||
Because this is a large investment from our team, we ask that you:
|
||||
- Thoroughly read all the requirements of the design docs, engineering best practices, and try to minimize overhead for the Onyx team.
|
||||
- Complete the feature in a timely manner to reduce context switching and an ongoing resource pull from the Onyx team.
|
||||
|
||||
---
|
||||
|
||||
## Development Setup
|
||||
|
||||
Onyx being a fully functional app, relies on some external software, specifically:
|
||||
|
||||
- [Postgres](https://www.postgresql.org/) (Relational DB)
|
||||
- [OpenSearch](https://opensearch.org/) (Vector DB/Search Engine)
|
||||
- [Redis](https://redis.io/) (Cache)
|
||||
- [MinIO](https://min.io/) (File Store)
|
||||
- [Nginx](https://nginx.org/) (Not needed for development flows generally)
|
||||
|
||||
> **Note:**
|
||||
> This guide provides instructions to build and run Onyx locally from source with Docker containers providing the above external software.
|
||||
> We believe this combination is easier for development purposes. If you prefer to use pre-built container images, see [Running in Docker](#running-in-docker) below.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- **Python 3.11** — If using a lower version, modifications will have to be made to the code. Higher versions may have library compatibility issues.
|
||||
- **Docker** — Required for running external services (Postgres, OpenSearch, Redis, MinIO).
|
||||
- **Node.js v22** — We recommend using [nvm](https://github.com/nvm-sh/nvm) to manage Node installations.
|
||||
|
||||
### Backend: Python Requirements
|
||||
|
||||
We use [uv](https://docs.astral.sh/uv/) and recommend creating a [virtual environment](https://docs.astral.sh/uv/pip/environments/#using-a-virtual-environment).
|
||||
|
||||
```bash
|
||||
uv venv .venv --python 3.11
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
_For Windows, activate the virtual environment using Command Prompt:_
|
||||
|
||||
```bash
|
||||
.venv\Scripts\activate
|
||||
```
|
||||
|
||||
If using PowerShell, the command slightly differs:
|
||||
|
||||
```powershell
|
||||
.venv\Scripts\Activate.ps1
|
||||
```
|
||||
|
||||
Install the required Python dependencies:
|
||||
|
||||
```bash
|
||||
uv sync --all-extras
|
||||
```
|
||||
|
||||
Install Playwright for Python (headless browser required by the Web Connector):
|
||||
|
||||
```bash
|
||||
uv run playwright install
|
||||
```
|
||||
|
||||
### Frontend: Node Dependencies
|
||||
|
||||
```bash
|
||||
nvm install 22 && nvm use 22
|
||||
node -v # verify your active version
|
||||
```
|
||||
|
||||
Navigate to `onyx/web` and run:
|
||||
|
||||
```bash
|
||||
npm i
|
||||
```
|
||||
|
||||
### Formatting and Linting
|
||||
|
||||
#### Backend
|
||||
|
||||
Set up pre-commit hooks (black / reorder-python-imports):
|
||||
|
||||
```bash
|
||||
uv run pre-commit install
|
||||
```
|
||||
|
||||
We also use `mypy` for static type checking. Onyx is fully type-annotated, and we want to keep it that way! To run the mypy checks manually:
|
||||
|
||||
```bash
|
||||
uv run mypy . # from onyx/backend
|
||||
```
|
||||
|
||||
#### Frontend
|
||||
|
||||
We use `prettier` for formatting. The desired version will be installed via `npm i` from the `onyx/web` directory. To run the formatter:
|
||||
|
||||
```bash
|
||||
npx prettier --write . # from onyx/web
|
||||
```
|
||||
|
||||
Pre-commit will also run prettier automatically on files you've recently touched. If re-formatted, your commit will fail. Re-stage your changes and commit again.
|
||||
|
||||
---
|
||||
|
||||
## Running the Application
|
||||
|
||||
### VSCode Debugger (Recommended)
|
||||
|
||||
We highly recommend using VSCode's debugger for development.
|
||||
|
||||
#### Initial Setup
|
||||
|
||||
1. Copy `.vscode/env_template.txt` to `.vscode/.env`
|
||||
2. Fill in the necessary environment variables in `.vscode/.env`
|
||||
|
||||
#### Using the Debugger
|
||||
|
||||
Before starting, make sure the Docker Daemon is running.
|
||||
|
||||
1. Open the Debug view in VSCode (Cmd+Shift+D on macOS)
|
||||
2. From the dropdown at the top, select "Clear and Restart External Volumes and Containers" and press the green play button
|
||||
3. From the dropdown at the top, select "Run All Onyx Services" and press the green play button
|
||||
4. Navigate to http://localhost:3000 in your browser to start using the app
|
||||
5. Set breakpoints by clicking to the left of line numbers to help debug while the app is running
|
||||
6. Use the debug toolbar to step through code, inspect variables, etc.
|
||||
|
||||
> **Note:** "Clear and Restart External Volumes and Containers" will reset your Postgres and OpenSearch (relational-db and index). Only run this if you are okay with wiping your data.
|
||||
|
||||
**Features:**
|
||||
- Hot reload is enabled for the web server and API servers
|
||||
- Python debugging is configured with debugpy
|
||||
- Environment variables are loaded from `.vscode/.env`
|
||||
- Console output is organized in the integrated terminal with labeled tabs
|
||||
|
||||
### Manually Running for Development
|
||||
|
||||
#### Docker containers for external software
|
||||
|
||||
You will need Docker installed to run these containers.
|
||||
|
||||
Navigate to `onyx/deployment/docker_compose`, then start up Postgres/OpenSearch/Redis/MinIO with:
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d index relational_db cache minio
|
||||
```
|
||||
|
||||
(index refers to OpenSearch, relational_db refers to Postgres, and cache refers to Redis)
|
||||
|
||||
#### Running Onyx locally
|
||||
|
||||
To start the frontend, navigate to `onyx/web` and run:
|
||||
|
||||
```bash
|
||||
npm run dev
|
||||
```
|
||||
|
||||
Next, start the model server which runs the local NLP models. Navigate to `onyx/backend` and run:
|
||||
|
||||
```bash
|
||||
uvicorn model_server.main:app --reload --port 9000
|
||||
```
|
||||
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
|
||||
```bash
|
||||
powershell -Command "uvicorn model_server.main:app --reload --port 9000"
|
||||
```
|
||||
|
||||
The first time running Onyx, you will need to run the DB migrations for Postgres. After the first time, this is no longer required unless the DB models change.
|
||||
|
||||
Navigate to `onyx/backend` and with the venv active, run:
|
||||
|
||||
```bash
|
||||
alembic upgrade head
|
||||
```
|
||||
|
||||
Next, start the task queue which orchestrates the background jobs. Still in `onyx/backend`, run:
|
||||
|
||||
```bash
|
||||
python ./scripts/dev_run_background_jobs.py
|
||||
```
|
||||
|
||||
To run the backend API server, navigate back to `onyx/backend` and run:
|
||||
|
||||
```bash
|
||||
AUTH_TYPE=basic uvicorn onyx.main:app --reload --port 8080
|
||||
```
|
||||
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
|
||||
```bash
|
||||
powershell -Command "
|
||||
$env:AUTH_TYPE='basic'
|
||||
uvicorn onyx.main:app --reload --port 8080
|
||||
"
|
||||
```
|
||||
|
||||
> **Note:** If you need finer logging, add the additional environment variable `LOG_LEVEL=DEBUG` to the relevant services.
|
||||
|
||||
#### Wrapping up
|
||||
|
||||
You should now have 4 servers running:
|
||||
|
||||
- Web server
|
||||
- Backend API
|
||||
- Model server
|
||||
- Background jobs
|
||||
|
||||
Now, visit http://localhost:3000 in your browser. You should see the Onyx onboarding wizard where you can connect your external LLM provider to Onyx.
|
||||
|
||||
You've successfully set up a local Onyx instance!
|
||||
|
||||
### Running in Docker
|
||||
|
||||
You can run the full Onyx application stack from pre-built images including all external software dependencies.
|
||||
|
||||
Navigate to `onyx/deployment/docker_compose` and run:
|
||||
|
||||
```bash
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
After Docker pulls and starts these containers, navigate to http://localhost:3000 to use Onyx.
|
||||
|
||||
If you want to make changes to Onyx and run those changes in Docker, you can also build a local version of the Onyx container images that incorporates your changes:
|
||||
|
||||
```bash
|
||||
docker compose up -d --build
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## macOS-Specific Notes
|
||||
|
||||
### Setting up Python
|
||||
|
||||
Ensure [Homebrew](https://brew.sh/) is already set up, then install Python 3.11:
|
||||
|
||||
```bash
|
||||
brew install python@3.11
|
||||
```
|
||||
|
||||
Add Python 3.11 to your path by adding the following line to `~/.zshrc`:
|
||||
|
||||
```
|
||||
export PATH="$(brew --prefix)/opt/python@3.11/libexec/bin:$PATH"
|
||||
```
|
||||
|
||||
> **Note:** You will need to open a new terminal for the path change above to take effect.
|
||||
|
||||
### Setting up Docker
|
||||
|
||||
On macOS, you will need to install [Docker Desktop](https://www.docker.com/products/docker-desktop/) and ensure it is running before continuing with the docker commands.
|
||||
|
||||
### Formatting and Linting
|
||||
|
||||
macOS will likely require you to remove some quarantine attributes on some of the hooks for them to execute properly. After installing pre-commit, run the following command:
|
||||
|
||||
```bash
|
||||
sudo xattr -r -d com.apple.quarantine ~/.cache/pre-commit
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Engineering Best Practices
|
||||
|
||||
> These are also what we adhere to as a team internally, we love to build in the open and to uplevel our community and each other through being transparent.
|
||||
|
||||
### Principles and Collaboration
|
||||
|
||||
- **Use 1-way vs 2-way doors.** For 2-way doors, move faster and iterate. For 1-way doors, be more deliberate.
|
||||
- **Consistency > being "right."** Prefer consistent patterns across the codebase. If something is truly bad, fix it everywhere.
|
||||
- **Fix what you touch (selectively).**
|
||||
- Don't feel obligated to fix every best-practice issue you notice.
|
||||
- Don't introduce new bad practices.
|
||||
- If your change touches code that violates best practices, fix it as part of the change.
|
||||
- **Don't tack features on.** When adding functionality, restructure logically as needed to avoid muddying interfaces and accumulating tech debt.
|
||||
|
||||
### Style and Maintainability
|
||||
|
||||
#### Comments and readability
|
||||
Add clear comments:
|
||||
- At logical boundaries (e.g., interfaces) so the reader doesn't need to dig 10 layers deeper.
|
||||
- Wherever assumptions are made or something non-obvious/unexpected is done.
|
||||
- For complicated flows/functions.
|
||||
- Wherever it saves time (e.g., nontrivial regex patterns).
|
||||
|
||||
#### Errors and exceptions
|
||||
- **Fail loudly** rather than silently skipping work.
|
||||
- Example: raise and let exceptions propagate instead of silently dropping a document.
|
||||
- **Don't overuse `try/except`.**
|
||||
- Put `try/except` at the correct logical level.
|
||||
- Do not mask exceptions unless it is clearly appropriate.
|
||||
|
||||
#### Typing
|
||||
- Everything should be **as strictly typed as possible**.
|
||||
- Use `cast` for annoying/loose-typed interfaces (e.g., results of `run_functions_tuples_in_parallel`).
|
||||
- Only `cast` when the type checker sees `Any` or types are too loose.
|
||||
- Prefer types that are easy to read.
|
||||
- Avoid dense types like `dict[tuple[str, str], list[list[float]]]`.
|
||||
- Prefer domain models, e.g.:
|
||||
- `EmbeddingModel(provider_name, model_name)` as a Pydantic model
|
||||
- `dict[EmbeddingModel, list[EmbeddingVector]]`
|
||||
|
||||
#### State, objects, and boundaries
|
||||
- Keep **clear logical boundaries** for state containers and objects.
|
||||
- A **config** object should never contain things like a `db_session`.
|
||||
- Avoid state containers that are overly nested, or huge + flat (use judgment).
|
||||
- Prefer **composition and functional style** over inheritance/OOP.
|
||||
- Prefer **no mutation** unless there's a strong reason.
|
||||
- State objects should be **intentional and explicit**, ideally nonmutating.
|
||||
- Use interfaces/objects to create clear separation of responsibility.
|
||||
- Prefer simplicity when there's no clear gain.
|
||||
- Avoid overcomplicated mechanisms like semaphores.
|
||||
- Prefer **hash maps (dicts)** over tree structures unless there's a strong reason.
|
||||
|
||||
#### Naming
|
||||
- Name variables carefully and intentionally.
|
||||
- Prefer long, explicit names when undecided.
|
||||
- Avoid single-character variables except for small, self-contained utilities (or not at all).
|
||||
- Keep the same object/name consistent through the call stack and within functions when reasonable.
|
||||
- Good: `for token in tokens:`
|
||||
- Bad: `for msg in tokens:` (if iterating tokens)
|
||||
- Function names should bias toward **long + descriptive** for codebase search.
|
||||
- IntelliSense can miss call sites; search works best with unique names.
|
||||
|
||||
#### Correctness by construction
|
||||
- Prefer self-contained correctness — don't rely on callers to "use it right" if you can make misuse hard.
|
||||
- Avoid redundancies: if a function takes an arg, it shouldn't also take a state object that contains that same arg.
|
||||
- No dead code (unless there's a very good reason).
|
||||
- No commented-out code in main or feature branches (unless there's a very good reason).
|
||||
- No duplicate logic:
|
||||
- Don't copy/paste into branches when shared logic can live above the conditional.
|
||||
- If you're afraid to touch the original, you don't understand it well enough.
|
||||
- LLMs often create subtle duplicate logic — review carefully and remove it.
|
||||
- Avoid "nearly identical" objects that confuse when to use which.
|
||||
- Avoid extremely long functions with chained logic:
|
||||
- Encapsulate steps into helpers for readability, even if not reused.
|
||||
- "Pythonic" multi-step expressions are OK in moderation; don't trade clarity for cleverness.
|
||||
|
||||
### Performance and Correctness
|
||||
|
||||
- Avoid holding resources for extended periods (DB sessions, locks/semaphores).
|
||||
- Validate objects on creation and right before use.
|
||||
- Connector code (data to Onyx documents):
|
||||
- Any in-memory structure that can grow without bound based on input must be periodically size-checked.
|
||||
- If a connector is OOMing (often shows up as "missing celery tasks"), this is a top thing to check retroactively.
|
||||
- Async and event loops:
|
||||
- Never introduce new async/event loop Python code, and try to make existing async code synchronous when possible if it makes sense.
|
||||
- Writing async code without 100% understanding the code and having a concrete reason to do so is likely to introduce bugs and not add any meaningful performance gains.
|
||||
|
||||
### Repository Conventions
|
||||
|
||||
#### Where code lives
|
||||
- Pydantic + data models: `models.py` files.
|
||||
- DB interface functions (excluding lazy loading): `db/` directory.
|
||||
- LLM prompts: `prompts/` directory, roughly mirroring the code layout that uses them.
|
||||
- API routes: `server/` directory.
|
||||
|
||||
#### Pydantic and modeling
|
||||
- Prefer **Pydantic** over dataclasses.
|
||||
- If absolutely required, use `allow_arbitrary_types`.
|
||||
|
||||
#### Data conventions
|
||||
- Prefer explicit `None` over sentinel empty strings (usually; depends on intent).
|
||||
- Prefer explicit identifiers: use string enums instead of integer codes.
|
||||
- Avoid magic numbers (co-location is good when necessary). **Always avoid magic strings.**
|
||||
|
||||
#### Logging
|
||||
- Log messages where they are created.
|
||||
- Don't propagate log messages around just to log them elsewhere.
|
||||
|
||||
#### Encapsulation
|
||||
- Don't use private attributes/methods/properties from other classes/modules.
|
||||
- "Private" is private — respect that boundary.
|
||||
|
||||
#### SQLAlchemy guidance
|
||||
- Lazy loading is often bad at scale, especially across multiple list relationships.
|
||||
- Be careful when accessing SQLAlchemy object attributes:
|
||||
- It can help avoid redundant DB queries,
|
||||
- but it can also fail if accessed outside an active session,
|
||||
- and lazy loading can add hidden DB dependencies to otherwise "simple" functions.
|
||||
- Reference: https://www.reddit.com/r/SQLAlchemy/comments/138f248/joinedload_vs_selectinload/
|
||||
|
||||
#### Trunk-based development and feature flags
|
||||
- **PRs should contain no more than 500 lines of real change.**
|
||||
- **Merge to main frequently.** Avoid long-lived feature branches — they create merge conflicts and integration pain.
|
||||
- **Use feature flags for incremental rollout.**
|
||||
- Large features should be merged in small, shippable increments behind a flag.
|
||||
- This allows continuous integration without exposing incomplete functionality.
|
||||
- **Keep flags short-lived.** Once a feature is fully rolled out, remove the flag and dead code paths promptly.
|
||||
- **Flag at the right level.** Prefer flagging at API/UI entry points rather than deep in business logic.
|
||||
- **Test both flag states.** Ensure the codebase works correctly with the flag on and off.
|
||||
|
||||
#### Miscellaneous
|
||||
- Any TODOs you add in the code must be accompanied by either the name/username of the owner of that TODO, or an issue number for an issue referencing that piece of work.
|
||||
- Avoid module-level logic that runs on import, which leads to import-time side effects. Essentially every piece of meaningful logic should exist within some function that has to be explicitly invoked. Acceptable exceptions may include loading environment variables or setting up loggers.
|
||||
- If you find yourself needing something like this, you may want that logic to exist in a file dedicated for manual execution (contains `if __name__ == "__main__":`) which should not be imported by anything else.
|
||||
- Do not conflate Python scripts you intend to run from the command line (contains `if __name__ == "__main__":`) with modules you intend to import from elsewhere. If for some unlikely reason they have to be the same file, any logic specific to executing the file (including imports) should be contained in the `if __name__ == "__main__":` block.
|
||||
- Generally these executable files exist in `backend/scripts/`.
|
||||
|
||||
---
|
||||
|
||||
## Release Process
|
||||
|
||||
Onyx loosely follows the SemVer versioning standard.
|
||||
A set of Docker containers will be pushed automatically to DockerHub with every tag.
|
||||
You can see the containers [here](https://hub.docker.com/search?q=onyx%2F).
|
||||
|
||||
---
|
||||
|
||||
## Getting Help
|
||||
|
||||
## Getting Help 🙋
|
||||
We have support channels and generally interesting discussions on our [Discord](https://discord.gg/4NA5SbzrWb).
|
||||
|
||||
See you there!
|
||||
|
||||
---
|
||||
|
||||
## Enterprise Edition Contributions
|
||||
|
||||
If you are contributing features to Onyx Enterprise Edition (code under any `ee/` directory), you are required to sign the [IP Assignment Agreement](contributor_ip_assignment/EE_Contributor_IP_Assignment_Agreement.md) ([PDF version](contributor_ip_assignment/EE_Contributor_IP_Assignment_Agreement.pdf)).
|
||||
## Release Process
|
||||
Onyx loosely follows the SemVer versioning standard.
|
||||
Major changes are released with a "minor" version bump. Currently we use patch release versions to indicate small feature changes.
|
||||
A set of Docker containers will be pushed automatically to DockerHub with every tag.
|
||||
You can see the containers [here](https://hub.docker.com/search?q=onyx%2F).
|
||||
|
||||
102
README.md
102
README.md
@@ -4,6 +4,8 @@
|
||||
<a href="https://www.onyx.app/?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/OnyxLogoCropped.jpg?raw=true" /></a>
|
||||
</h2>
|
||||
|
||||
<p align="center">Open Source AI Platform</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
|
||||
<img src="https://img.shields.io/badge/discord-join-blue.svg?logo=discord&logoColor=white" alt="Discord" />
|
||||
@@ -25,94 +27,82 @@
|
||||
</a>
|
||||
</p>
|
||||
|
||||
# Onyx - The Open Source AI Platform
|
||||
|
||||
**[Onyx](https://www.onyx.app/?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme)** is the application layer for LLMs - bringing a feature-rich interface that can be easily hosted by anyone.
|
||||
Onyx enables LLMs through advanced capabilities like RAG, web search, code execution, file creation, deep research and more.
|
||||
**[Onyx](https://www.onyx.app/?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme)** is a feature-rich, self-hostable Chat UI that works with any LLM. It is easy to deploy and can run in a completely airgapped environment.
|
||||
|
||||
Connect your applications with over 50+ indexing based connectors provided out of the box or via MCP.
|
||||
Onyx comes loaded with advanced features like Agents, Web Search, RAG, MCP, Deep Research, Connectors to 40+ knowledge sources, and more.
|
||||
|
||||
> [!TIP]
|
||||
> Deploy with a single command:
|
||||
> Run Onyx with one command (or see deployment section below):
|
||||
> ```
|
||||
> curl -fsSL https://onyx.app/install_onyx.sh | bash
|
||||
> ```
|
||||
|
||||

|
||||
****
|
||||
|
||||

|
||||
|
||||
|
||||
---
|
||||
|
||||
## ⭐ Features
|
||||
|
||||
- **🔍 Agentic RAG:** Get best in class search and answer quality based on hybrid index + AI Agents for information retrieval
|
||||
- Benchmark to release soon!
|
||||
- **🔬 Deep Research:** Get in depth reports with a multi-step research flow.
|
||||
- Top of [leaderboard](https://github.com/onyx-dot-app/onyx_deep_research_bench) as of Feb 2026.
|
||||
- **🤖 Custom Agents:** Build AI Agents with unique instructions, knowledge, and actions.
|
||||
- **🌍 Web Search:** Browse the web to get up to date information.
|
||||
- Supports Serper, Google PSE, Brave, SearXNG, and others.
|
||||
- Comes with an in house web crawler and support for Firecrawl/Exa.
|
||||
- **📄 Artifacts:** Generate documents, graphics, and other downloadable artifacts.
|
||||
- **▶️ Actions & MCP:** Let Onyx agents interact with external applications, comes with flexible Auth options.
|
||||
- **💻 Code Execution:** Execute code in a sandbox to analyze data, render graphs, or modify files.
|
||||
- **🎙️ Voice Mode:** Chat with Onyx via text-to-speech and speech-to-text.
|
||||
- **🤖 Custom Agents:** Build AI Agents with unique instructions, knowledge and actions.
|
||||
- **🌍 Web Search:** Browse the web with Google PSE, Exa, and Serper as well as an in-house scraper or Firecrawl.
|
||||
- **🔍 RAG:** Best in class hybrid-search + knowledge graph for uploaded files and ingested documents from connectors.
|
||||
- **🔄 Connectors:** Pull knowledge, metadata, and access information from over 40 applications.
|
||||
- **🔬 Deep Research:** Get in depth answers with an agentic multi-step search.
|
||||
- **▶️ Actions & MCP:** Give AI Agents the ability to interact with external systems.
|
||||
- **💻 Code Interpreter:** Execute code to analyze data, render graphs and create files.
|
||||
- **🎨 Image Generation:** Generate images based on user prompts.
|
||||
- **👥 Collaboration:** Chat sharing, feedback gathering, user management, usage analytics, and more.
|
||||
|
||||
Onyx supports all major LLM providers, both self-hosted (like Ollama, LiteLLM, vLLM, etc.) and proprietary (like Anthropic, OpenAI, Gemini, etc.).
|
||||
Onyx works with all LLMs (like OpenAI, Anthropic, Gemini, etc.) and self-hosted LLMs (like Ollama, vLLM, etc.)
|
||||
|
||||
To learn more - check out our [docs](https://docs.onyx.app/welcome?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme)!
|
||||
To learn more about the features, check out our [documentation](https://docs.onyx.app/welcome?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme)!
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Deployment Modes
|
||||
|
||||
> Onyx supports deployments in Docker, Kubernetes, Helm/Terraform and provides guides for major cloud providers.
|
||||
> Detailed deployment guides found [here](https://docs.onyx.app/deployment/overview).
|
||||
## 🚀 Deployment
|
||||
Onyx supports deployments in Docker, Kubernetes, Terraform, along with guides for major cloud providers.
|
||||
|
||||
Onyx supports two separate deployment options: standard and lite.
|
||||
|
||||
#### Onyx Lite
|
||||
|
||||
The Lite mode can be thought of as a lightweight Chat UI. It requires less resources (under 1GB memory) and runs a less complex stack.
|
||||
It is great for users who want to test out Onyx quickly or for teams who are only interested in the Chat UI and Agents functionalities.
|
||||
|
||||
#### Standard Onyx
|
||||
|
||||
The complete feature set of Onyx which is recommended for serious users and larger teams. Additional components not included in Lite mode:
|
||||
- Vector + Keyword index for RAG.
|
||||
- Background containers to run job queues and workers for syncing knowledge from connectors.
|
||||
- AI model inference servers to run deep learning models used during indexing and inference.
|
||||
- Performance optimizations for large scale use via in memory cache (Redis) and blob store (MinIO).
|
||||
See guides below:
|
||||
- [Docker](https://docs.onyx.app/deployment/local/docker?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme) or [Quickstart](https://docs.onyx.app/deployment/getting_started/quickstart?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme) (best for most users)
|
||||
- [Kubernetes](https://docs.onyx.app/deployment/local/kubernetes?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme) (best for large teams)
|
||||
- [Terraform](https://docs.onyx.app/deployment/local/terraform?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme) (best for teams already using Terraform)
|
||||
- Cloud specific guides (best if specifically using [AWS EKS](https://docs.onyx.app/deployment/cloud/aws/eks?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme), [Azure VMs](https://docs.onyx.app/deployment/cloud/azure?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme), etc.)
|
||||
|
||||
> [!TIP]
|
||||
> **To try Onyx for free without deploying, visit [Onyx Cloud](https://cloud.onyx.app/signup?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme)**.
|
||||
> **To try Onyx for free without deploying, check out [Onyx Cloud](https://cloud.onyx.app/signup?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme)**.
|
||||
|
||||
---
|
||||
|
||||
## 🏢 Onyx for Enterprise
|
||||
|
||||
Onyx is built for teams of all sizes, from individual users to the largest global enterprises:
|
||||
- 👥 Collaboration: Share chats and agents with other members of your organization.
|
||||
- 🔐 Single Sign On: SSO via Google OAuth, OIDC, or SAML. Group syncing and user provisioning via SCIM.
|
||||
- 🛡️ Role Based Access Control: RBAC for sensitive resources like access to agents, actions, etc.
|
||||
- 📊 Analytics: Usage graphs broken down by teams, LLMs, or agents.
|
||||
- 🕵️ Query History: Audit usage to ensure safe adoption of AI in your organization.
|
||||
- 💻 Custom code: Run custom code to remove PII, reject sensitive queries, or to run custom analysis.
|
||||
- 🎨 Whitelabeling: Customize the look and feel of Onyx with custom naming, icons, banners, and more.
|
||||
## 🔍 Other Notable Benefits
|
||||
Onyx is built for teams of all sizes, from individual users to the largest global enterprises.
|
||||
|
||||
- **Enterprise Search**: far more than simple RAG, Onyx has custom indexing and retrieval that remains performant and accurate for scales of up to tens of millions of documents.
|
||||
- **Security**: SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc.
|
||||
- **Management UI**: different user roles such as basic, curator, and admin.
|
||||
- **Document Permissioning**: mirrors user access from external apps for RAG use cases.
|
||||
|
||||
|
||||
|
||||
## 🚧 Roadmap
|
||||
To see ongoing and upcoming projects, check out our [roadmap](https://github.com/orgs/onyx-dot-app/projects/2)!
|
||||
|
||||
|
||||
|
||||
## 📚 Licensing
|
||||
|
||||
There are two editions of Onyx:
|
||||
|
||||
- Onyx Community Edition (CE) is available freely under the MIT license and covers all of the core features for Chat, RAG, Agents, and Actions.
|
||||
- Onyx Community Edition (CE) is available freely under the MIT license.
|
||||
- Onyx Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations.
|
||||
|
||||
For feature details, check out [our website](https://www.onyx.app/pricing?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme).
|
||||
|
||||
## 👪 Community
|
||||
|
||||
|
||||
## 👪 Community
|
||||
Join our open source community on **[Discord](https://discord.gg/TDJ59cGV2X)**!
|
||||
|
||||
## 💡 Contributing
|
||||
|
||||
|
||||
## 💡 Contributing
|
||||
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
|
||||
|
||||
@@ -1,108 +0,0 @@
|
||||
"""backfill_account_type
|
||||
|
||||
Revision ID: 03d085c5c38d
|
||||
Revises: 977e834c1427
|
||||
Create Date: 2026-03-25 16:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "03d085c5c38d"
|
||||
down_revision = "977e834c1427"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
_STANDARD = "STANDARD"
|
||||
_BOT = "BOT"
|
||||
_EXT_PERM_USER = "EXT_PERM_USER"
|
||||
_SERVICE_ACCOUNT = "SERVICE_ACCOUNT"
|
||||
_ANONYMOUS = "ANONYMOUS"
|
||||
|
||||
# Well-known anonymous user UUID
|
||||
ANONYMOUS_USER_ID = "00000000-0000-0000-0000-000000000002"
|
||||
|
||||
# Email pattern for API key virtual users
|
||||
API_KEY_EMAIL_PATTERN = r"API\_KEY\_\_%"
|
||||
|
||||
# Reflect the table structure for use in DML
|
||||
user_table = sa.table(
|
||||
"user",
|
||||
sa.column("id", sa.Uuid),
|
||||
sa.column("email", sa.String),
|
||||
sa.column("role", sa.String),
|
||||
sa.column("account_type", sa.String),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ------------------------------------------------------------------
|
||||
# Step 1: Backfill account_type from role.
|
||||
# Order matters — most-specific matches first so the final catch-all
|
||||
# only touches rows that haven't been classified yet.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# 1a. API key virtual users → SERVICE_ACCOUNT
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.email.ilike(API_KEY_EMAIL_PATTERN),
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_SERVICE_ACCOUNT)
|
||||
)
|
||||
|
||||
# 1b. Anonymous user → ANONYMOUS
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.id == ANONYMOUS_USER_ID,
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_ANONYMOUS)
|
||||
)
|
||||
|
||||
# 1c. SLACK_USER role → BOT
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.role == "SLACK_USER",
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_BOT)
|
||||
)
|
||||
|
||||
# 1d. EXT_PERM_USER role → EXT_PERM_USER
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.role == "EXT_PERM_USER",
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_EXT_PERM_USER)
|
||||
)
|
||||
|
||||
# 1e. Everything else → STANDARD
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(user_table.c.account_type.is_(None))
|
||||
.values(account_type=_STANDARD)
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 2: Set account_type to NOT NULL now that every row is filled.
|
||||
# ------------------------------------------------------------------
|
||||
op.alter_column(
|
||||
"user",
|
||||
"account_type",
|
||||
nullable=False,
|
||||
server_default="STANDARD",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column("user", "account_type", nullable=True, server_default=None)
|
||||
op.execute(sa.update(user_table).values(account_type=None))
|
||||
@@ -1,104 +0,0 @@
|
||||
"""add_effective_permissions
|
||||
|
||||
Adds a JSONB column `effective_permissions` to the user table to store
|
||||
directly granted permissions (e.g. ["admin"] or ["basic"]). Implied
|
||||
permissions are expanded at read time, not stored.
|
||||
|
||||
Backfill: joins user__user_group → permission_grant to collect each
|
||||
user's granted permissions into a JSON array. Users without group
|
||||
memberships keep the default [].
|
||||
|
||||
Revision ID: 503883791c39
|
||||
Revises: b4b7e1028dfd
|
||||
Create Date: 2026-03-30 14:49:22.261748
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "503883791c39"
|
||||
down_revision = "b4b7e1028dfd"
|
||||
branch_labels: str | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
user_table = sa.table(
|
||||
"user",
|
||||
sa.column("id", sa.Uuid),
|
||||
sa.column("effective_permissions", postgresql.JSONB),
|
||||
)
|
||||
|
||||
user_user_group = sa.table(
|
||||
"user__user_group",
|
||||
sa.column("user_id", sa.Uuid),
|
||||
sa.column("user_group_id", sa.Integer),
|
||||
)
|
||||
|
||||
permission_grant = sa.table(
|
||||
"permission_grant",
|
||||
sa.column("group_id", sa.Integer),
|
||||
sa.column("permission", sa.String),
|
||||
sa.column("is_deleted", sa.Boolean),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"effective_permissions",
|
||||
postgresql.JSONB(),
|
||||
nullable=False,
|
||||
server_default=sa.text("'[]'::jsonb"),
|
||||
),
|
||||
)
|
||||
|
||||
conn = op.get_bind()
|
||||
|
||||
# Deduplicated permissions per user
|
||||
deduped = (
|
||||
sa.select(
|
||||
user_user_group.c.user_id,
|
||||
permission_grant.c.permission,
|
||||
)
|
||||
.select_from(
|
||||
user_user_group.join(
|
||||
permission_grant,
|
||||
sa.and_(
|
||||
permission_grant.c.group_id == user_user_group.c.user_group_id,
|
||||
permission_grant.c.is_deleted == sa.false(),
|
||||
),
|
||||
)
|
||||
)
|
||||
.distinct()
|
||||
.subquery("deduped")
|
||||
)
|
||||
|
||||
# Aggregate into JSONB array per user (order is not guaranteed;
|
||||
# consumers read this as a set so ordering does not matter)
|
||||
perms_per_user = (
|
||||
sa.select(
|
||||
deduped.c.user_id,
|
||||
sa.func.jsonb_agg(
|
||||
deduped.c.permission,
|
||||
type_=postgresql.JSONB,
|
||||
).label("perms"),
|
||||
)
|
||||
.group_by(deduped.c.user_id)
|
||||
.subquery("sub")
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
user_table.update()
|
||||
.where(user_table.c.id == perms_per_user.c.user_id)
|
||||
.values(effective_permissions=perms_per_user.c.perms)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "effective_permissions")
|
||||
@@ -1,54 +0,0 @@
|
||||
"""csv to tabular chat file type
|
||||
|
||||
Revision ID: 8188861f4e92
|
||||
Revises: d8cdfee5df80
|
||||
Create Date: 2026-03-31 19:23:05.753184
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8188861f4e92"
|
||||
down_revision = "d8cdfee5df80"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE chat_message
|
||||
SET files = (
|
||||
SELECT jsonb_agg(
|
||||
CASE
|
||||
WHEN elem->>'type' = 'csv'
|
||||
THEN jsonb_set(elem, '{type}', '"tabular"')
|
||||
ELSE elem
|
||||
END
|
||||
)
|
||||
FROM jsonb_array_elements(files) AS elem
|
||||
)
|
||||
WHERE files::text LIKE '%"type": "csv"%'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE chat_message
|
||||
SET files = (
|
||||
SELECT jsonb_agg(
|
||||
CASE
|
||||
WHEN elem->>'type' = 'tabular'
|
||||
THEN jsonb_set(elem, '{type}', '"csv"')
|
||||
ELSE elem
|
||||
END
|
||||
)
|
||||
FROM jsonb_array_elements(files) AS elem
|
||||
)
|
||||
WHERE files::text LIKE '%"type": "tabular"%'
|
||||
"""
|
||||
)
|
||||
@@ -1,139 +0,0 @@
|
||||
"""seed_default_groups
|
||||
|
||||
Revision ID: 977e834c1427
|
||||
Revises: 8188861f4e92
|
||||
Create Date: 2026-03-25 14:59:41.313091
|
||||
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "977e834c1427"
|
||||
down_revision = "8188861f4e92"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# (group_name, permission_value)
|
||||
DEFAULT_GROUPS = [
|
||||
("Admin", "admin"),
|
||||
("Basic", "basic"),
|
||||
]
|
||||
|
||||
CUSTOM_SUFFIX = "(Custom)"
|
||||
|
||||
MAX_RENAME_ATTEMPTS = 100
|
||||
|
||||
# Reflect table structures for use in DML
|
||||
user_group_table = sa.table(
|
||||
"user_group",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("name", sa.String),
|
||||
sa.column("is_up_to_date", sa.Boolean),
|
||||
sa.column("is_up_for_deletion", sa.Boolean),
|
||||
sa.column("is_default", sa.Boolean),
|
||||
)
|
||||
|
||||
permission_grant_table = sa.table(
|
||||
"permission_grant",
|
||||
sa.column("group_id", sa.Integer),
|
||||
sa.column("permission", sa.String),
|
||||
sa.column("grant_source", sa.String),
|
||||
)
|
||||
|
||||
user__user_group_table = sa.table(
|
||||
"user__user_group",
|
||||
sa.column("user_group_id", sa.Integer),
|
||||
sa.column("user_id", sa.Uuid),
|
||||
)
|
||||
|
||||
|
||||
def _find_available_name(conn: sa.engine.Connection, base: str) -> str:
|
||||
"""Return a name like 'Admin (Custom)' or 'Admin (Custom 2)' that is not taken."""
|
||||
candidate = f"{base} {CUSTOM_SUFFIX}"
|
||||
attempt = 1
|
||||
while attempt <= MAX_RENAME_ATTEMPTS:
|
||||
exists: Any = conn.execute(
|
||||
sa.select(sa.literal(1))
|
||||
.select_from(user_group_table)
|
||||
.where(user_group_table.c.name == candidate)
|
||||
.limit(1)
|
||||
).fetchone()
|
||||
if exists is None:
|
||||
return candidate
|
||||
attempt += 1
|
||||
candidate = f"{base} (Custom {attempt})"
|
||||
raise RuntimeError(
|
||||
f"Could not find an available name for group '{base}' "
|
||||
f"after {MAX_RENAME_ATTEMPTS} attempts"
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
for group_name, permission_value in DEFAULT_GROUPS:
|
||||
# Step 1: Rename ALL existing groups that clash with the canonical name.
|
||||
conflicting = conn.execute(
|
||||
sa.select(user_group_table.c.id, user_group_table.c.name).where(
|
||||
user_group_table.c.name == group_name
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for row_id, row_name in conflicting:
|
||||
new_name = _find_available_name(conn, row_name)
|
||||
op.execute(
|
||||
sa.update(user_group_table)
|
||||
.where(user_group_table.c.id == row_id)
|
||||
.values(name=new_name, is_up_to_date=False)
|
||||
)
|
||||
|
||||
# Step 2: Create a fresh default group.
|
||||
result = conn.execute(
|
||||
user_group_table.insert()
|
||||
.values(
|
||||
name=group_name,
|
||||
is_up_to_date=True,
|
||||
is_up_for_deletion=False,
|
||||
is_default=True,
|
||||
)
|
||||
.returning(user_group_table.c.id)
|
||||
).fetchone()
|
||||
assert result is not None
|
||||
group_id = result[0]
|
||||
|
||||
# Step 3: Upsert permission grant.
|
||||
op.execute(
|
||||
pg_insert(permission_grant_table)
|
||||
.values(
|
||||
group_id=group_id,
|
||||
permission=permission_value,
|
||||
grant_source="SYSTEM",
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=["group_id", "permission"])
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the default groups created by this migration.
|
||||
# First remove user-group memberships that reference default groups
|
||||
# to avoid FK violations, then delete the groups themselves.
|
||||
default_group_ids = sa.select(user_group_table.c.id).where(
|
||||
user_group_table.c.is_default == True # noqa: E712
|
||||
)
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.delete(user__user_group_table).where(
|
||||
user__user_group_table.c.user_group_id.in_(default_group_ids)
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
sa.delete(user_group_table).where(
|
||||
user_group_table.c.is_default == True # noqa: E712
|
||||
)
|
||||
)
|
||||
@@ -1,84 +0,0 @@
|
||||
"""grant_basic_to_existing_groups
|
||||
|
||||
Grants the "basic" permission to all existing groups that don't already
|
||||
have it. Every group should have at least "basic" so that its members
|
||||
get basic access when effective_permissions is backfilled.
|
||||
|
||||
Revision ID: b4b7e1028dfd
|
||||
Revises: b7bcc991d722
|
||||
Create Date: 2026-03-30 16:15:17.093498
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b4b7e1028dfd"
|
||||
down_revision = "b7bcc991d722"
|
||||
branch_labels: str | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
user_group = sa.table(
|
||||
"user_group",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("is_default", sa.Boolean),
|
||||
)
|
||||
|
||||
permission_grant = sa.table(
|
||||
"permission_grant",
|
||||
sa.column("group_id", sa.Integer),
|
||||
sa.column("permission", sa.String),
|
||||
sa.column("grant_source", sa.String),
|
||||
sa.column("is_deleted", sa.Boolean),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
already_has_basic = (
|
||||
sa.select(sa.literal(1))
|
||||
.select_from(permission_grant)
|
||||
.where(
|
||||
permission_grant.c.group_id == user_group.c.id,
|
||||
permission_grant.c.permission == "basic",
|
||||
)
|
||||
.exists()
|
||||
)
|
||||
|
||||
groups_needing_basic = sa.select(
|
||||
user_group.c.id,
|
||||
sa.literal("basic").label("permission"),
|
||||
sa.literal("SYSTEM").label("grant_source"),
|
||||
sa.literal(False).label("is_deleted"),
|
||||
).where(
|
||||
user_group.c.is_default == sa.false(),
|
||||
~already_has_basic,
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
permission_grant.insert().from_select(
|
||||
["group_id", "permission", "grant_source", "is_deleted"],
|
||||
groups_needing_basic,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
non_default_group_ids = sa.select(user_group.c.id).where(
|
||||
user_group.c.is_default == sa.false()
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
permission_grant.delete().where(
|
||||
permission_grant.c.permission == "basic",
|
||||
permission_grant.c.grant_source == "SYSTEM",
|
||||
permission_grant.c.group_id.in_(non_default_group_ids),
|
||||
)
|
||||
)
|
||||
@@ -1,125 +0,0 @@
|
||||
"""assign_users_to_default_groups
|
||||
|
||||
Revision ID: b7bcc991d722
|
||||
Revises: 03d085c5c38d
|
||||
Create Date: 2026-03-25 16:30:39.529301
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b7bcc991d722"
|
||||
down_revision = "03d085c5c38d"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# The no-auth placeholder user must NOT be assigned to default groups.
|
||||
# A database trigger (migrate_no_auth_data_to_user) will try to DELETE this
|
||||
# user when the first real user registers; group membership rows would cause
|
||||
# an FK violation on that DELETE.
|
||||
NO_AUTH_PLACEHOLDER_USER_UUID = "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
# Reflect table structures for use in DML
|
||||
user_group_table = sa.table(
|
||||
"user_group",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("name", sa.String),
|
||||
sa.column("is_default", sa.Boolean),
|
||||
)
|
||||
|
||||
user_table = sa.table(
|
||||
"user",
|
||||
sa.column("id", sa.Uuid),
|
||||
sa.column("role", sa.String),
|
||||
sa.column("account_type", sa.String),
|
||||
sa.column("is_active", sa.Boolean),
|
||||
)
|
||||
|
||||
user__user_group_table = sa.table(
|
||||
"user__user_group",
|
||||
sa.column("user_group_id", sa.Integer),
|
||||
sa.column("user_id", sa.Uuid),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Look up default group IDs
|
||||
admin_row = conn.execute(
|
||||
sa.select(user_group_table.c.id).where(
|
||||
user_group_table.c.name == "Admin",
|
||||
user_group_table.c.is_default == True, # noqa: E712
|
||||
)
|
||||
).fetchone()
|
||||
|
||||
basic_row = conn.execute(
|
||||
sa.select(user_group_table.c.id).where(
|
||||
user_group_table.c.name == "Basic",
|
||||
user_group_table.c.is_default == True, # noqa: E712
|
||||
)
|
||||
).fetchone()
|
||||
|
||||
if admin_row is None:
|
||||
raise RuntimeError(
|
||||
"Default 'Admin' group not found. "
|
||||
"Ensure migration 977e834c1427 (seed_default_groups) ran successfully."
|
||||
)
|
||||
|
||||
if basic_row is None:
|
||||
raise RuntimeError(
|
||||
"Default 'Basic' group not found. "
|
||||
"Ensure migration 977e834c1427 (seed_default_groups) ran successfully."
|
||||
)
|
||||
|
||||
# Users with role=admin → Admin group
|
||||
# Include inactive users so reactivation doesn't require reconciliation.
|
||||
# Exclude non-human account types (mirrors assign_user_to_default_groups logic).
|
||||
admin_users = sa.select(
|
||||
sa.literal(admin_row[0]).label("user_group_id"),
|
||||
user_table.c.id.label("user_id"),
|
||||
).where(
|
||||
user_table.c.role == "ADMIN",
|
||||
user_table.c.account_type.notin_(["BOT", "EXT_PERM_USER", "ANONYMOUS"]),
|
||||
user_table.c.id != NO_AUTH_PLACEHOLDER_USER_UUID,
|
||||
)
|
||||
op.execute(
|
||||
pg_insert(user__user_group_table)
|
||||
.from_select(["user_group_id", "user_id"], admin_users)
|
||||
.on_conflict_do_nothing(index_elements=["user_group_id", "user_id"])
|
||||
)
|
||||
|
||||
# STANDARD users (non-admin) and SERVICE_ACCOUNT users (role=basic) → Basic group
|
||||
# Include inactive users so reactivation doesn't require reconciliation.
|
||||
basic_users = sa.select(
|
||||
sa.literal(basic_row[0]).label("user_group_id"),
|
||||
user_table.c.id.label("user_id"),
|
||||
).where(
|
||||
user_table.c.account_type.notin_(["BOT", "EXT_PERM_USER", "ANONYMOUS"]),
|
||||
user_table.c.id != NO_AUTH_PLACEHOLDER_USER_UUID,
|
||||
sa.or_(
|
||||
sa.and_(
|
||||
user_table.c.account_type == "STANDARD",
|
||||
user_table.c.role != "ADMIN",
|
||||
),
|
||||
sa.and_(
|
||||
user_table.c.account_type == "SERVICE_ACCOUNT",
|
||||
user_table.c.role == "BASIC",
|
||||
),
|
||||
),
|
||||
)
|
||||
op.execute(
|
||||
pg_insert(user__user_group_table)
|
||||
.from_select(["user_group_id", "user_id"], basic_users)
|
||||
.on_conflict_do_nothing(index_elements=["user_group_id", "user_id"])
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Group memberships are left in place — removing them risks
|
||||
# deleting memberships that existed before this migration.
|
||||
pass
|
||||
@@ -1,55 +0,0 @@
|
||||
"""add skipped to userfilestatus
|
||||
|
||||
Revision ID: d8cdfee5df80
|
||||
Revises: 1d78c0ca7853
|
||||
Create Date: 2026-04-01 10:47:12.593950
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d8cdfee5df80"
|
||||
down_revision = "1d78c0ca7853"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
TABLE = "user_file"
|
||||
COLUMN = "status"
|
||||
CONSTRAINT_NAME = "ck_user_file_status"
|
||||
|
||||
OLD_VALUES = ("PROCESSING", "INDEXING", "COMPLETED", "FAILED", "CANCELED", "DELETING")
|
||||
NEW_VALUES = (
|
||||
"PROCESSING",
|
||||
"INDEXING",
|
||||
"COMPLETED",
|
||||
"SKIPPED",
|
||||
"FAILED",
|
||||
"CANCELED",
|
||||
"DELETING",
|
||||
)
|
||||
|
||||
|
||||
def _drop_status_check_constraint() -> None:
|
||||
inspector = sa.inspect(op.get_bind())
|
||||
for constraint in inspector.get_check_constraints(TABLE):
|
||||
if COLUMN in constraint.get("sqltext", ""):
|
||||
constraint_name = constraint["name"]
|
||||
if constraint_name is not None:
|
||||
op.drop_constraint(constraint_name, TABLE, type_="check")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
_drop_status_check_constraint()
|
||||
in_clause = ", ".join(f"'{v}'" for v in NEW_VALUES)
|
||||
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(f"UPDATE {TABLE} SET {COLUMN} = 'COMPLETED' WHERE {COLUMN} = 'SKIPPED'")
|
||||
_drop_status_check_constraint()
|
||||
in_clause = ", ".join(f"'{v}'" for v in OLD_VALUES)
|
||||
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")
|
||||
@@ -5,7 +5,6 @@ from onyx.background.celery.apps.primary import celery_app
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.hooks",
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cloud",
|
||||
|
||||
@@ -55,15 +55,6 @@ ee_tasks_to_schedule: list[dict] = []
|
||||
|
||||
if not MULTI_TENANT:
|
||||
ee_tasks_to_schedule = [
|
||||
{
|
||||
"name": "hook-execution-log-cleanup",
|
||||
"task": OnyxCeleryTask.HOOK_EXECUTION_LOG_CLEANUP_TASK,
|
||||
"schedule": timedelta(days=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.GENERATE_USAGE_REPORT_TASK,
|
||||
|
||||
@@ -54,27 +54,35 @@ def perform_ttl_management_task(
|
||||
retention_limit_days, db_session
|
||||
)
|
||||
|
||||
failures = 0
|
||||
for user_id, session_id in old_chat_sessions:
|
||||
# one session per delete so that we don't blow up if a deletion fails.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
delete_chat_session(
|
||||
user_id,
|
||||
session_id,
|
||||
db_session,
|
||||
include_deleted=True,
|
||||
hard_delete=True,
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
delete_chat_session(
|
||||
user_id,
|
||||
session_id,
|
||||
db_session,
|
||||
include_deleted=True,
|
||||
hard_delete=True,
|
||||
)
|
||||
except Exception:
|
||||
failures += 1
|
||||
logger.exception(
|
||||
"Failed to delete chat session "
|
||||
f"user_id={user_id} session_id={session_id}, "
|
||||
"continuing with remaining sessions"
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=True,
|
||||
success=failures == 0,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"delete_chat_session exceptioned. user_id={user_id} session_id={session_id}"
|
||||
f"TTL management task failed. user_id={user_id} session_id={session_id}"
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_task_as_finished_with_id(
|
||||
|
||||
@@ -69,7 +69,5 @@ EE_ONLY_PATH_PREFIXES: frozenset[str] = frozenset(
|
||||
"/admin/token-rate-limits",
|
||||
# Evals
|
||||
"/evals",
|
||||
# Hook extensions
|
||||
"/admin/hooks",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -36,16 +36,13 @@ from ee.onyx.server.scim.filtering import ScimFilter
|
||||
from ee.onyx.server.scim.filtering import ScimFilterOperator
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from onyx.db.dal import DAL
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import GrantSource
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import PermissionGrant
|
||||
from onyx.db.models import ScimGroupMapping
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -283,9 +280,7 @@ class ScimDAL(DAL):
|
||||
query = (
|
||||
select(User)
|
||||
.join(ScimUserMapping, ScimUserMapping.user_id == User.id)
|
||||
.where(
|
||||
User.account_type.notin_([AccountType.BOT, AccountType.EXT_PERM_USER])
|
||||
)
|
||||
.where(User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]))
|
||||
)
|
||||
|
||||
if scim_filter:
|
||||
@@ -526,22 +521,6 @@ class ScimDAL(DAL):
|
||||
self._session.add(group)
|
||||
self._session.flush()
|
||||
|
||||
def add_permission_grant_to_group(
|
||||
self,
|
||||
group_id: int,
|
||||
permission: Permission,
|
||||
grant_source: GrantSource,
|
||||
) -> None:
|
||||
"""Grant a permission to a group and flush."""
|
||||
self._session.add(
|
||||
PermissionGrant(
|
||||
group_id=group_id,
|
||||
permission=permission,
|
||||
grant_source=grant_source,
|
||||
)
|
||||
)
|
||||
self._session.flush()
|
||||
|
||||
def update_group(
|
||||
self,
|
||||
group: UserGroup,
|
||||
|
||||
@@ -19,8 +19,6 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import GrantSource
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import Credential__UserGroup
|
||||
@@ -30,7 +28,6 @@ from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__UserGroup
|
||||
from onyx.db.models import FederatedConnector__DocumentSet
|
||||
from onyx.db.models import LLMProvider__UserGroup
|
||||
from onyx.db.models import PermissionGrant
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.models import TokenRateLimit__UserGroup
|
||||
@@ -39,7 +36,6 @@ from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserGroup__ConnectorCredentialPair
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import fetch_user_by_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -259,7 +255,6 @@ def fetch_user_groups(
|
||||
db_session: Session,
|
||||
only_up_to_date: bool = True,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
include_default: bool = True,
|
||||
) -> Sequence[UserGroup]:
|
||||
"""
|
||||
Fetches user groups from the database.
|
||||
@@ -274,7 +269,6 @@ def fetch_user_groups(
|
||||
to include only up to date user groups. Defaults to `True`.
|
||||
eager_load_for_snapshot: If True, adds eager loading for all relationships
|
||||
needed by UserGroup.from_model snapshot creation.
|
||||
include_default: If False, excludes system default groups (is_default=True).
|
||||
|
||||
Returns:
|
||||
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
|
||||
@@ -282,8 +276,6 @@ def fetch_user_groups(
|
||||
stmt = select(UserGroup)
|
||||
if only_up_to_date:
|
||||
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
|
||||
if not include_default:
|
||||
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
@@ -294,7 +286,6 @@ def fetch_user_groups_for_user(
|
||||
user_id: UUID,
|
||||
only_curator_groups: bool = False,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
include_default: bool = True,
|
||||
) -> Sequence[UserGroup]:
|
||||
stmt = (
|
||||
select(UserGroup)
|
||||
@@ -304,8 +295,6 @@ def fetch_user_groups_for_user(
|
||||
)
|
||||
if only_curator_groups:
|
||||
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
|
||||
if not include_default:
|
||||
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
@@ -489,16 +478,6 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
|
||||
db_session.add(db_user_group)
|
||||
db_session.flush() # give the group an ID
|
||||
|
||||
# Every group gets the "basic" permission by default
|
||||
db_session.add(
|
||||
PermissionGrant(
|
||||
group_id=db_user_group.id,
|
||||
permission=Permission.BASIC_ACCESS,
|
||||
grant_source=GrantSource.SYSTEM,
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
_add_user__user_group_relationships__no_commit(
|
||||
db_session=db_session,
|
||||
user_group_id=db_user_group.id,
|
||||
@@ -510,8 +489,6 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
|
||||
cc_pair_ids=user_group.cc_pair_ids,
|
||||
)
|
||||
|
||||
recompute_user_permissions__no_commit(user_group.user_ids, db_session)
|
||||
|
||||
db_session.commit()
|
||||
return db_user_group
|
||||
|
||||
@@ -819,10 +796,6 @@ def update_user_group(
|
||||
# update "time_updated" to now
|
||||
db_user_group.time_last_modified_by_user = func.now()
|
||||
|
||||
recompute_user_permissions__no_commit(
|
||||
list(set(added_user_ids) | set(removed_user_ids)), db_session
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
return db_user_group
|
||||
|
||||
@@ -862,19 +835,6 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
|
||||
|
||||
_check_user_group_is_modifiable(db_user_group)
|
||||
|
||||
# Collect affected user IDs before cleanup deletes the relationships
|
||||
affected_user_ids: list[UUID] = [
|
||||
uid
|
||||
for uid in db_session.execute(
|
||||
select(User__UserGroup.user_id).where(
|
||||
User__UserGroup.user_group_id == user_group_id
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
if uid is not None
|
||||
]
|
||||
|
||||
_mark_user_group__cc_pair_relationships_outdated__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
@@ -903,10 +863,6 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
|
||||
# Recompute permissions for affected users now that their
|
||||
# membership in this group has been removed
|
||||
recompute_user_permissions__no_commit(affected_user_ids, db_session)
|
||||
|
||||
db_user_group.is_up_to_date = False
|
||||
db_user_group.is_up_for_deletion = True
|
||||
db_session.commit()
|
||||
|
||||
@@ -1,385 +0,0 @@
|
||||
"""Hook executor — calls a customer's external HTTP endpoint for a given hook point.
|
||||
|
||||
Usage (Celery tasks and FastAPI handlers):
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
|
||||
if isinstance(result, HookSkipped):
|
||||
# no active hook configured — continue with original behavior
|
||||
...
|
||||
elif isinstance(result, HookSoftFailed):
|
||||
# hook failed but fail strategy is SOFT — continue with original behavior
|
||||
...
|
||||
else:
|
||||
# result is a validated Pydantic model instance (response_type)
|
||||
...
|
||||
|
||||
is_reachable update policy
|
||||
--------------------------
|
||||
``is_reachable`` on the Hook row is updated selectively — only when the outcome
|
||||
carries meaningful signal about physical reachability:
|
||||
|
||||
NetworkError (DNS, connection refused) → False (cannot reach the server)
|
||||
HTTP 401 / 403 → False (api_key revoked or invalid)
|
||||
TimeoutException → None (server may be slow, skip write)
|
||||
Other HTTP errors (4xx / 5xx) → None (server responded, skip write)
|
||||
Unknown exception → None (no signal, skip write)
|
||||
Non-JSON / non-dict response → None (server responded, skip write)
|
||||
Success (2xx, valid dict) → True (confirmed reachable)
|
||||
|
||||
None means "leave the current value unchanged" — no DB round-trip is made.
|
||||
|
||||
DB session design
|
||||
-----------------
|
||||
The executor uses three sessions:
|
||||
|
||||
1. Caller's session (db_session) — used only for the hook lookup read. All
|
||||
needed fields are extracted from the Hook object before the HTTP call, so
|
||||
the caller's session is not held open during the external HTTP request.
|
||||
|
||||
2. Log session — a separate short-lived session opened after the HTTP call
|
||||
completes to write the HookExecutionLog row on failure. Success runs are
|
||||
not recorded. Committed independently of everything else.
|
||||
|
||||
3. Reachable session — a second short-lived session to update is_reachable on
|
||||
the Hook. Kept separate from the log session so a concurrent hook deletion
|
||||
(which causes update_hook__no_commit to raise OnyxError(NOT_FOUND)) cannot
|
||||
prevent the execution log from being written. This update is best-effort.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.hook import create_hook_execution_log__no_commit
|
||||
from onyx.db.hook import get_non_deleted_hook_by_hook_point
|
||||
from onyx.db.hook import update_hook__no_commit
|
||||
from onyx.db.models import Hook
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.executor import HookSkipped
|
||||
from onyx.hooks.executor import HookSoftFailed
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _HttpOutcome(BaseModel):
|
||||
"""Structured result of an HTTP hook call, returned by _process_response."""
|
||||
|
||||
is_success: bool
|
||||
updated_is_reachable: (
|
||||
bool | None
|
||||
) # True/False = write to DB, None = unchanged (skip write)
|
||||
status_code: int | None
|
||||
error_message: str | None
|
||||
response_payload: dict[str, Any] | None
|
||||
|
||||
|
||||
def _lookup_hook(
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
) -> Hook | HookSkipped:
|
||||
"""Return the active Hook or HookSkipped if hooks are unavailable/unconfigured.
|
||||
|
||||
No HTTP call is made and no DB writes are performed for any HookSkipped path.
|
||||
There is nothing to log and no reachability information to update.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
return HookSkipped()
|
||||
hook = get_non_deleted_hook_by_hook_point(
|
||||
db_session=db_session, hook_point=hook_point
|
||||
)
|
||||
if hook is None or not hook.is_active:
|
||||
return HookSkipped()
|
||||
if not hook.endpoint_url:
|
||||
return HookSkipped()
|
||||
return hook
|
||||
|
||||
|
||||
def _process_response(
|
||||
*,
|
||||
response: httpx.Response | None,
|
||||
exc: Exception | None,
|
||||
timeout: float,
|
||||
) -> _HttpOutcome:
|
||||
"""Process the result of an HTTP call and return a structured outcome.
|
||||
|
||||
Called after the client.post() try/except. If post() raised, exc is set and
|
||||
response is None. Otherwise response is set and exc is None. Handles
|
||||
raise_for_status(), JSON decoding, and the dict shape check.
|
||||
"""
|
||||
if exc is not None:
|
||||
if isinstance(exc, httpx.NetworkError):
|
||||
msg = f"Hook network error (endpoint unreachable): {exc}"
|
||||
logger.warning(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=False,
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
if isinstance(exc, httpx.TimeoutException):
|
||||
msg = f"Hook timed out after {timeout}s: {exc}"
|
||||
logger.warning(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # timeout doesn't indicate unreachability
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
msg = f"Hook call failed: {exc}"
|
||||
logger.exception(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # unknown error — don't make assumptions
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
if response is None:
|
||||
raise ValueError(
|
||||
"exactly one of response or exc must be non-None; both are None"
|
||||
)
|
||||
status_code = response.status_code
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
msg = f"Hook returned HTTP {e.response.status_code}: {e.response.text}"
|
||||
logger.warning(msg, exc_info=e)
|
||||
# 401/403 means the api_key has been revoked or is invalid — mark unreachable
|
||||
# so the operator knows to update it. All other HTTP errors keep is_reachable
|
||||
# as-is (server is up, the request just failed for application reasons).
|
||||
auth_failed = e.response.status_code in (401, 403)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=False if auth_failed else None,
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
try:
|
||||
response_payload = response.json()
|
||||
except (json.JSONDecodeError, httpx.DecodingError) as e:
|
||||
msg = f"Hook returned non-JSON response: {e}"
|
||||
logger.warning(msg, exc_info=e)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
if not isinstance(response_payload, dict):
|
||||
msg = f"Hook returned non-dict JSON (got {type(response_payload).__name__})"
|
||||
logger.warning(msg)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
return _HttpOutcome(
|
||||
is_success=True,
|
||||
updated_is_reachable=True,
|
||||
status_code=status_code,
|
||||
error_message=None,
|
||||
response_payload=response_payload,
|
||||
)
|
||||
|
||||
|
||||
def _persist_result(
|
||||
*,
|
||||
hook_id: int,
|
||||
outcome: _HttpOutcome,
|
||||
duration_ms: int,
|
||||
) -> None:
|
||||
"""Write the execution log on failure and optionally update is_reachable, each
|
||||
in its own session so a failure in one does not affect the other."""
|
||||
# Only write the execution log on failure — success runs are not recorded.
|
||||
# Must not be skipped if the is_reachable update fails (e.g. hook concurrently
|
||||
# deleted between the initial lookup and here).
|
||||
if not outcome.is_success:
|
||||
try:
|
||||
with get_session_with_current_tenant() as log_session:
|
||||
create_hook_execution_log__no_commit(
|
||||
db_session=log_session,
|
||||
hook_id=hook_id,
|
||||
is_success=False,
|
||||
error_message=outcome.error_message,
|
||||
status_code=outcome.status_code,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
log_session.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to persist hook execution log for hook_id={hook_id}"
|
||||
)
|
||||
|
||||
# Update is_reachable separately — best-effort, non-critical.
|
||||
# None means the value is unchanged (set by the caller to skip the no-op write).
|
||||
# update_hook__no_commit can raise OnyxError(NOT_FOUND) if the hook was
|
||||
# concurrently deleted, so keep this isolated from the log write above.
|
||||
if outcome.updated_is_reachable is not None:
|
||||
try:
|
||||
with get_session_with_current_tenant() as reachable_session:
|
||||
update_hook__no_commit(
|
||||
db_session=reachable_session,
|
||||
hook_id=hook_id,
|
||||
is_reachable=outcome.updated_is_reachable,
|
||||
)
|
||||
reachable_session.commit()
|
||||
except Exception:
|
||||
logger.warning(f"Failed to update is_reachable for hook_id={hook_id}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _execute_hook_inner(
|
||||
hook: Hook,
|
||||
payload: dict[str, Any],
|
||||
response_type: type[T],
|
||||
) -> T | HookSoftFailed:
|
||||
"""Make the HTTP call, validate the response, and return a typed model.
|
||||
|
||||
Raises OnyxError on HARD failure. Returns HookSoftFailed on SOFT failure.
|
||||
"""
|
||||
timeout = hook.timeout_seconds
|
||||
hook_id = hook.id
|
||||
fail_strategy = hook.fail_strategy
|
||||
endpoint_url = hook.endpoint_url
|
||||
current_is_reachable: bool | None = hook.is_reachable
|
||||
|
||||
if not endpoint_url:
|
||||
raise ValueError(
|
||||
f"hook_id={hook_id} is active but has no endpoint_url — "
|
||||
"active hooks without an endpoint_url must be rejected by _lookup_hook"
|
||||
)
|
||||
|
||||
start = time.monotonic()
|
||||
response: httpx.Response | None = None
|
||||
exc: Exception | None = None
|
||||
try:
|
||||
api_key: str | None = (
|
||||
hook.api_key.get_value(apply_mask=False) if hook.api_key else None
|
||||
)
|
||||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
with httpx.Client(
|
||||
timeout=timeout, follow_redirects=False
|
||||
) as client: # SSRF guard: never follow redirects
|
||||
response = client.post(endpoint_url, json=payload, headers=headers)
|
||||
except Exception as e:
|
||||
exc = e
|
||||
duration_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
outcome = _process_response(response=response, exc=exc, timeout=timeout)
|
||||
|
||||
# Validate the response payload against response_type.
|
||||
# A validation failure downgrades the outcome to a failure so it is logged,
|
||||
# is_reachable is left unchanged (server responded — just a bad payload),
|
||||
# and fail_strategy is respected below.
|
||||
validated_model: T | None = None
|
||||
if outcome.is_success and outcome.response_payload is not None:
|
||||
try:
|
||||
validated_model = response_type.model_validate(outcome.response_payload)
|
||||
except ValidationError as e:
|
||||
msg = (
|
||||
f"Hook response failed validation against {response_type.__name__}: {e}"
|
||||
)
|
||||
outcome = _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=outcome.status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
# Skip the is_reachable write when the value would not change — avoids a
|
||||
# no-op DB round-trip on every call when the hook is already in the expected state.
|
||||
if outcome.updated_is_reachable == current_is_reachable:
|
||||
outcome = outcome.model_copy(update={"updated_is_reachable": None})
|
||||
_persist_result(hook_id=hook_id, outcome=outcome, duration_ms=duration_ms)
|
||||
|
||||
if not outcome.is_success:
|
||||
if fail_strategy == HookFailStrategy.HARD:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.HOOK_EXECUTION_FAILED,
|
||||
outcome.error_message or "Hook execution failed.",
|
||||
)
|
||||
logger.warning(
|
||||
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
|
||||
)
|
||||
return HookSoftFailed()
|
||||
|
||||
if validated_model is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
f"validated_model is None for successful hook call (hook_id={hook_id})",
|
||||
)
|
||||
return validated_model
|
||||
|
||||
|
||||
def _execute_hook_impl(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
payload: dict[str, Any],
|
||||
response_type: type[T],
|
||||
) -> T | HookSkipped | HookSoftFailed:
|
||||
"""EE implementation — loaded by CE's execute_hook via fetch_versioned_implementation.
|
||||
|
||||
Returns HookSkipped if no active hook is configured, HookSoftFailed if the
|
||||
hook failed with SOFT fail strategy, or a validated response model on success.
|
||||
Raises OnyxError on HARD failure or if the hook is misconfigured.
|
||||
"""
|
||||
hook = _lookup_hook(db_session, hook_point)
|
||||
if isinstance(hook, HookSkipped):
|
||||
return hook
|
||||
|
||||
fail_strategy = hook.fail_strategy
|
||||
hook_id = hook.id
|
||||
|
||||
try:
|
||||
return _execute_hook_inner(hook, payload, response_type)
|
||||
except Exception:
|
||||
if fail_strategy == HookFailStrategy.SOFT:
|
||||
logger.exception(
|
||||
f"Unexpected error in hook execution (soft fail) for hook_id={hook_id}"
|
||||
)
|
||||
return HookSoftFailed()
|
||||
raise
|
||||
@@ -15,7 +15,6 @@ from ee.onyx.server.enterprise_settings.api import (
|
||||
basic_router as enterprise_settings_router,
|
||||
)
|
||||
from ee.onyx.server.evals.api import router as evals_router
|
||||
from ee.onyx.server.features.hooks.api import router as hook_router
|
||||
from ee.onyx.server.license.api import router as license_router
|
||||
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
|
||||
from ee.onyx.server.middleware.license_enforcement import (
|
||||
@@ -139,7 +138,6 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, ee_oauth_router)
|
||||
include_router_with_global_prefix_prepended(application, ee_document_cc_pair_router)
|
||||
include_router_with_global_prefix_prepended(application, evals_router)
|
||||
include_router_with_global_prefix_prepended(application, hook_router)
|
||||
|
||||
# Enterprise-only global settings
|
||||
include_router_with_global_prefix_prepended(
|
||||
|
||||
@@ -52,25 +52,16 @@ from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
|
||||
from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import GrantSource
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.permissions import recompute_permissions_for_group__no_commit
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Group names reserved for system default groups (seeded by migration).
|
||||
_RESERVED_GROUP_NAMES = frozenset({"Admin", "Basic"})
|
||||
|
||||
|
||||
class ScimJSONResponse(JSONResponse):
|
||||
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
|
||||
@@ -495,7 +486,6 @@ def create_user(
|
||||
email=email,
|
||||
hashed_password=_pw_helper.hash(_pw_helper.generate()),
|
||||
role=UserRole.BASIC,
|
||||
account_type=AccountType.STANDARD,
|
||||
is_active=user_resource.active,
|
||||
is_verified=True,
|
||||
personal_name=personal_name,
|
||||
@@ -516,25 +506,13 @@ def create_user(
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
dal.commit()
|
||||
except IntegrityError:
|
||||
dal.rollback()
|
||||
return _scim_error_response(
|
||||
409, f"User with email {email} already has a SCIM mapping"
|
||||
)
|
||||
|
||||
# Assign user to default group BEFORE commit so everything is atomic.
|
||||
# If this fails, the entire user creation rolls back and IdP can retry.
|
||||
try:
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
except Exception:
|
||||
dal.rollback()
|
||||
logger.exception(f"Failed to assign SCIM user {email} to default groups")
|
||||
return _scim_error_response(
|
||||
500, f"Failed to assign user {email} to default group"
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
@@ -564,8 +542,7 @@ def replace_user(
|
||||
user = result
|
||||
|
||||
# Handle activation (need seat check) / deactivation
|
||||
is_reactivation = user_resource.active and not user.is_active
|
||||
if is_reactivation:
|
||||
if user_resource.active and not user.is_active:
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
@@ -579,12 +556,6 @@ def replace_user(
|
||||
personal_name=personal_name,
|
||||
)
|
||||
|
||||
# Reconcile default-group membership on reactivation
|
||||
if is_reactivation:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session, user, is_admin=(user.role == UserRole.ADMIN)
|
||||
)
|
||||
|
||||
new_external_id = user_resource.externalId
|
||||
scim_username = user_resource.userName.strip()
|
||||
fields = _fields_from_resource(user_resource)
|
||||
@@ -650,7 +621,6 @@ def patch_user(
|
||||
return _scim_error_response(e.status, e.detail)
|
||||
|
||||
# Apply changes back to the DB model
|
||||
is_reactivation = patched.active and not user.is_active
|
||||
if patched.active != user.is_active:
|
||||
if patched.active:
|
||||
seat_error = _check_seat_availability(dal)
|
||||
@@ -679,12 +649,6 @@ def patch_user(
|
||||
personal_name=personal_name,
|
||||
)
|
||||
|
||||
# Reconcile default-group membership on reactivation
|
||||
if is_reactivation:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session, user, is_admin=(user.role == UserRole.ADMIN)
|
||||
)
|
||||
|
||||
# Build updated fields by merging PATCH enterprise data with current values
|
||||
cf = current_fields or ScimMappingFields()
|
||||
fields = ScimMappingFields(
|
||||
@@ -893,11 +857,6 @@ def create_group(
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
if group_resource.displayName in _RESERVED_GROUP_NAMES:
|
||||
return _scim_error_response(
|
||||
409, f"'{group_resource.displayName}' is a reserved group name."
|
||||
)
|
||||
|
||||
if dal.get_group_by_name(group_resource.displayName):
|
||||
return _scim_error_response(
|
||||
409, f"Group with name '{group_resource.displayName}' already exists"
|
||||
@@ -920,18 +879,8 @@ def create_group(
|
||||
409, f"Group with name '{group_resource.displayName}' already exists"
|
||||
)
|
||||
|
||||
# Every group gets the "basic" permission by default.
|
||||
dal.add_permission_grant_to_group(
|
||||
group_id=db_group.id,
|
||||
permission=Permission.BASIC_ACCESS,
|
||||
grant_source=GrantSource.SYSTEM,
|
||||
)
|
||||
|
||||
dal.upsert_group_members(db_group.id, member_uuids)
|
||||
|
||||
# Recompute permissions for initial members.
|
||||
recompute_user_permissions__no_commit(member_uuids, db_session)
|
||||
|
||||
external_id = group_resource.externalId
|
||||
if external_id:
|
||||
dal.create_group_mapping(external_id=external_id, user_group_id=db_group.id)
|
||||
@@ -962,36 +911,14 @@ def replace_group(
|
||||
return result
|
||||
group = result
|
||||
|
||||
if group.name in _RESERVED_GROUP_NAMES and group_resource.displayName != group.name:
|
||||
return _scim_error_response(
|
||||
409, f"'{group.name}' is a reserved group name and cannot be renamed."
|
||||
)
|
||||
|
||||
if (
|
||||
group_resource.displayName in _RESERVED_GROUP_NAMES
|
||||
and group_resource.displayName != group.name
|
||||
):
|
||||
return _scim_error_response(
|
||||
409, f"'{group_resource.displayName}' is a reserved group name."
|
||||
)
|
||||
|
||||
member_uuids, err = _validate_and_parse_members(group_resource.members, dal)
|
||||
if err:
|
||||
return _scim_error_response(400, err)
|
||||
|
||||
# Capture old member IDs before replacing so we can recompute their
|
||||
# permissions after they are removed from the group.
|
||||
old_member_ids = {uid for uid, _ in dal.get_group_members(group.id)}
|
||||
|
||||
dal.update_group(group, name=group_resource.displayName)
|
||||
dal.replace_group_members(group.id, member_uuids)
|
||||
dal.sync_group_external_id(group.id, group_resource.externalId)
|
||||
|
||||
# Recompute permissions for current members (batch) and removed members.
|
||||
recompute_permissions_for_group__no_commit(group.id, db_session)
|
||||
removed_ids = list(old_member_ids - set(member_uuids))
|
||||
recompute_user_permissions__no_commit(removed_ids, db_session)
|
||||
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
@@ -1034,19 +961,8 @@ def patch_group(
|
||||
return _scim_error_response(e.status, e.detail)
|
||||
|
||||
new_name = patched.displayName if patched.displayName != group.name else None
|
||||
|
||||
if group.name in _RESERVED_GROUP_NAMES and new_name:
|
||||
return _scim_error_response(
|
||||
409, f"'{group.name}' is a reserved group name and cannot be renamed."
|
||||
)
|
||||
|
||||
if new_name and new_name in _RESERVED_GROUP_NAMES:
|
||||
return _scim_error_response(409, f"'{new_name}' is a reserved group name.")
|
||||
|
||||
dal.update_group(group, name=new_name)
|
||||
|
||||
affected_uuids: list[UUID] = []
|
||||
|
||||
if added_ids:
|
||||
add_uuids = [UUID(mid) for mid in added_ids if _is_valid_uuid(mid)]
|
||||
if add_uuids:
|
||||
@@ -1057,15 +973,10 @@ def patch_group(
|
||||
f"Member(s) not found: {', '.join(str(u) for u in missing)}",
|
||||
)
|
||||
dal.upsert_group_members(group.id, add_uuids)
|
||||
affected_uuids.extend(add_uuids)
|
||||
|
||||
if removed_ids:
|
||||
remove_uuids = [UUID(mid) for mid in removed_ids if _is_valid_uuid(mid)]
|
||||
dal.remove_group_members(group.id, remove_uuids)
|
||||
affected_uuids.extend(remove_uuids)
|
||||
|
||||
# Recompute permissions for all users whose group membership changed.
|
||||
recompute_user_permissions__no_commit(affected_uuids, db_session)
|
||||
|
||||
dal.sync_group_external_id(group.id, patched.externalId)
|
||||
dal.commit()
|
||||
@@ -1091,21 +1002,11 @@ def delete_group(
|
||||
return result
|
||||
group = result
|
||||
|
||||
if group.name in _RESERVED_GROUP_NAMES:
|
||||
return _scim_error_response(409, f"'{group.name}' is a reserved group name.")
|
||||
|
||||
# Capture member IDs before deletion so we can recompute their permissions.
|
||||
affected_user_ids = [uid for uid, _ in dal.get_group_members(group.id)]
|
||||
|
||||
mapping = dal.get_group_mapping_by_group_id(group.id)
|
||||
if mapping:
|
||||
dal.delete_group_mapping(mapping.id)
|
||||
|
||||
dal.delete_group_with_members(group)
|
||||
|
||||
# Recompute permissions for users who lost this group membership.
|
||||
recompute_user_permissions__no_commit(affected_user_ids, db_session)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
@@ -43,16 +43,12 @@ router = APIRouter(prefix="/manage", tags=PUBLIC_API_TAGS)
|
||||
|
||||
@router.get("/admin/user-group")
|
||||
def list_user_groups(
|
||||
include_default: bool = False,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UserGroup]:
|
||||
if user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(
|
||||
db_session,
|
||||
only_up_to_date=False,
|
||||
eager_load_for_snapshot=True,
|
||||
include_default=include_default,
|
||||
db_session, only_up_to_date=False, eager_load_for_snapshot=True
|
||||
)
|
||||
else:
|
||||
user_groups = fetch_user_groups_for_user(
|
||||
@@ -60,50 +56,27 @@ def list_user_groups(
|
||||
user_id=user.id,
|
||||
only_curator_groups=user.role == UserRole.CURATOR,
|
||||
eager_load_for_snapshot=True,
|
||||
include_default=include_default,
|
||||
)
|
||||
return [UserGroup.from_model(user_group) for user_group in user_groups]
|
||||
|
||||
|
||||
@router.get("/user-groups/minimal")
|
||||
def list_minimal_user_groups(
|
||||
include_default: bool = False,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[MinimalUserGroupSnapshot]:
|
||||
if user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(
|
||||
db_session,
|
||||
only_up_to_date=False,
|
||||
include_default=include_default,
|
||||
)
|
||||
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
|
||||
else:
|
||||
user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
include_default=include_default,
|
||||
)
|
||||
return [
|
||||
MinimalUserGroupSnapshot.from_model(user_group) for user_group in user_groups
|
||||
]
|
||||
|
||||
|
||||
@router.get("/admin/user-group/{user_group_id}/permissions")
|
||||
def get_user_group_permissions(
|
||||
user_group_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[str]:
|
||||
group = fetch_user_group(db_session, user_group_id)
|
||||
if group is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "User group not found")
|
||||
return [
|
||||
grant.permission.value
|
||||
for grant in group.permission_grants
|
||||
if not grant.is_deleted
|
||||
]
|
||||
|
||||
|
||||
@router.post("/admin/user-group")
|
||||
def create_user_group(
|
||||
user_group: UserGroupCreate,
|
||||
@@ -127,9 +100,6 @@ def rename_user_group_endpoint(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserGroup:
|
||||
group = fetch_user_group(db_session, rename_request.id)
|
||||
if group and group.is_default:
|
||||
raise OnyxError(OnyxErrorCode.CONFLICT, "Cannot rename a default system group.")
|
||||
try:
|
||||
return UserGroup.from_model(
|
||||
rename_user_group(
|
||||
@@ -215,9 +185,6 @@ def delete_user_group(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
group = fetch_user_group(db_session, user_group_id)
|
||||
if group and group.is_default:
|
||||
raise OnyxError(OnyxErrorCode.CONFLICT, "Cannot delete a default system group.")
|
||||
try:
|
||||
prepare_user_group_for_deletion(db_session, user_group_id)
|
||||
except ValueError as e:
|
||||
|
||||
@@ -22,7 +22,6 @@ class UserGroup(BaseModel):
|
||||
personas: list[PersonaSnapshot]
|
||||
is_up_to_date: bool
|
||||
is_up_for_deletion: bool
|
||||
is_default: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, user_group_model: UserGroupModel) -> "UserGroup":
|
||||
@@ -75,21 +74,18 @@ class UserGroup(BaseModel):
|
||||
],
|
||||
is_up_to_date=user_group_model.is_up_to_date,
|
||||
is_up_for_deletion=user_group_model.is_up_for_deletion,
|
||||
is_default=user_group_model.is_default,
|
||||
)
|
||||
|
||||
|
||||
class MinimalUserGroupSnapshot(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
is_default: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, user_group_model: UserGroupModel) -> "MinimalUserGroupSnapshot":
|
||||
return cls(
|
||||
id=user_group_model.id,
|
||||
name=user_group_model.name,
|
||||
is_default=user_group_model.is_default,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,110 +0,0 @@
|
||||
"""
|
||||
Permission resolution for group-based authorization.
|
||||
|
||||
Granted permissions are stored as a JSONB column on the User table and
|
||||
loaded for free with every auth query. Implied permissions are expanded
|
||||
at read time — only directly granted permissions are persisted.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Coroutine
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Depends
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
ALL_PERMISSIONS: frozenset[str] = frozenset(p.value for p in Permission)
|
||||
|
||||
# Implication map: granted permission -> set of permissions it implies.
|
||||
IMPLIED_PERMISSIONS: dict[str, set[str]] = {
|
||||
Permission.ADD_AGENTS.value: {Permission.READ_AGENTS.value},
|
||||
Permission.MANAGE_AGENTS.value: {
|
||||
Permission.ADD_AGENTS.value,
|
||||
Permission.READ_AGENTS.value,
|
||||
},
|
||||
Permission.MANAGE_DOCUMENT_SETS.value: {
|
||||
Permission.READ_DOCUMENT_SETS.value,
|
||||
Permission.READ_CONNECTORS.value,
|
||||
},
|
||||
Permission.ADD_CONNECTORS.value: {Permission.READ_CONNECTORS.value},
|
||||
Permission.MANAGE_CONNECTORS.value: {
|
||||
Permission.ADD_CONNECTORS.value,
|
||||
Permission.READ_CONNECTORS.value,
|
||||
},
|
||||
Permission.MANAGE_USER_GROUPS.value: {
|
||||
Permission.READ_CONNECTORS.value,
|
||||
Permission.READ_DOCUMENT_SETS.value,
|
||||
Permission.READ_AGENTS.value,
|
||||
Permission.READ_USERS.value,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def resolve_effective_permissions(granted: set[str]) -> set[str]:
|
||||
"""Expand granted permissions with their implied permissions.
|
||||
|
||||
If "admin" is present, returns all 19 permissions.
|
||||
"""
|
||||
if Permission.FULL_ADMIN_PANEL_ACCESS.value in granted:
|
||||
return set(ALL_PERMISSIONS)
|
||||
|
||||
effective = set(granted)
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
for perm in list(effective):
|
||||
implied = IMPLIED_PERMISSIONS.get(perm)
|
||||
if implied and not implied.issubset(effective):
|
||||
effective |= implied
|
||||
changed = True
|
||||
return effective
|
||||
|
||||
|
||||
def get_effective_permissions(user: User) -> set[Permission]:
|
||||
"""Read granted permissions from the column and expand implied permissions."""
|
||||
granted: set[Permission] = set()
|
||||
for p in user.effective_permissions:
|
||||
try:
|
||||
granted.add(Permission(p))
|
||||
except ValueError:
|
||||
logger.warning(f"Skipping unknown permission '{p}' for user {user.id}")
|
||||
if Permission.FULL_ADMIN_PANEL_ACCESS in granted:
|
||||
return set(Permission)
|
||||
expanded = resolve_effective_permissions({p.value for p in granted})
|
||||
return {Permission(p) for p in expanded}
|
||||
|
||||
|
||||
def require_permission(
|
||||
required: Permission,
|
||||
) -> Callable[..., Coroutine[Any, Any, User]]:
|
||||
"""FastAPI dependency factory for permission-based access control.
|
||||
|
||||
Usage:
|
||||
@router.get("/endpoint")
|
||||
def endpoint(user: User = Depends(require_permission(Permission.MANAGE_CONNECTORS))):
|
||||
...
|
||||
"""
|
||||
|
||||
async def dependency(user: User = Depends(current_user)) -> User:
|
||||
effective = get_effective_permissions(user)
|
||||
|
||||
if Permission.FULL_ADMIN_PANEL_ACCESS in effective:
|
||||
return user
|
||||
|
||||
if required not in effective:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
"You do not have the required permissions for this action.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
return dependency
|
||||
@@ -5,8 +5,6 @@ from typing import Any
|
||||
from fastapi_users import schemas
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
|
||||
|
||||
class UserRole(str, Enum):
|
||||
"""
|
||||
@@ -43,7 +41,6 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
|
||||
class UserCreate(schemas.BaseUserCreate):
|
||||
role: UserRole = UserRole.BASIC
|
||||
account_type: AccountType = AccountType.STANDARD
|
||||
tenant_id: str | None = None
|
||||
# Captcha token for cloud signup protection (optional, only used when captcha is enabled)
|
||||
# Excluded from create_update_dict so it never reaches the DB layer
|
||||
@@ -53,19 +50,19 @@ class UserCreate(schemas.BaseUserCreate):
|
||||
def create_update_dict(self) -> dict[str, Any]:
|
||||
d = super().create_update_dict()
|
||||
d.pop("captcha_token", None)
|
||||
# Force STANDARD for self-registration; only trusted paths
|
||||
# (SCIM, API key creation) supply a different account_type directly.
|
||||
d["account_type"] = AccountType.STANDARD
|
||||
return d
|
||||
|
||||
@override
|
||||
def create_update_dict_superuser(self) -> dict[str, Any]:
|
||||
d = super().create_update_dict_superuser()
|
||||
d.pop("captcha_token", None)
|
||||
d.setdefault("account_type", self.account_type)
|
||||
return d
|
||||
|
||||
|
||||
class UserUpdateWithRole(schemas.BaseUserUpdate):
|
||||
role: UserRole
|
||||
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
"""
|
||||
Role updates are not allowed through the user update endpoint for security reasons
|
||||
|
||||
@@ -80,6 +80,7 @@ from onyx.auth.pat import get_hashed_pat_from_request
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.auth.schemas import UserCreate
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.schemas import UserUpdateWithRole
|
||||
from onyx.configs.app_configs import AUTH_BACKEND
|
||||
from onyx.configs.app_configs import AUTH_COOKIE_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
@@ -119,13 +120,11 @@ from onyx.db.engine.async_sql_engine import get_async_session
|
||||
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.pat import fetch_user_for_pat
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import log_onyx_error
|
||||
@@ -501,21 +500,18 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user = user_by_session
|
||||
|
||||
if (
|
||||
user.account_type.is_web_login()
|
||||
user.role.is_web_login()
|
||||
or not isinstance(user_create, UserCreate)
|
||||
or not user_create.account_type.is_web_login()
|
||||
or not user_create.role.is_web_login()
|
||||
):
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
# Cache id before expire — accessing attrs on an expired
|
||||
# object triggers a sync lazy-load which raises MissingGreenlet
|
||||
# in this async context.
|
||||
user_id = user.id
|
||||
self._upgrade_user_to_standard__sync(user_id, user_create)
|
||||
# Expire so the async session re-fetches the row updated by
|
||||
# the sync session above.
|
||||
self.user_db.session.expire(user)
|
||||
user = await self.user_db.get(user_id) # type: ignore[assignment]
|
||||
user_update = UserUpdateWithRole(
|
||||
password=user_create.password,
|
||||
is_verified=user_create.is_verified,
|
||||
role=user_create.role,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
|
||||
@@ -529,21 +525,18 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if (
|
||||
user.account_type.is_web_login()
|
||||
user.role.is_web_login()
|
||||
or not isinstance(user_create, UserCreate)
|
||||
or not user_create.account_type.is_web_login()
|
||||
or not user_create.role.is_web_login()
|
||||
):
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
# Cache id before expire — accessing attrs on an expired
|
||||
# object triggers a sync lazy-load which raises MissingGreenlet
|
||||
# in this async context.
|
||||
user_id = user.id
|
||||
self._upgrade_user_to_standard__sync(user_id, user_create)
|
||||
# Expire so the async session re-fetches the row updated by
|
||||
# the sync session above.
|
||||
self.user_db.session.expire(user)
|
||||
user = await self.user_db.get(user_id) # type: ignore[assignment]
|
||||
user_update = UserUpdateWithRole(
|
||||
password=user_create.password,
|
||||
is_verified=user_create.is_verified,
|
||||
role=user_create.role,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
if user_created:
|
||||
await self._assign_default_pinned_assistants(user, db_session)
|
||||
remove_user_from_invited_users(user_create.email)
|
||||
@@ -580,38 +573,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
)
|
||||
user.pinned_assistants = default_persona_ids
|
||||
|
||||
def _upgrade_user_to_standard__sync(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
user_create: UserCreate,
|
||||
) -> None:
|
||||
"""Upgrade a non-web user to STANDARD and assign default groups atomically.
|
||||
|
||||
All writes happen in a single sync transaction so neither the field
|
||||
update nor the group assignment is visible without the other.
|
||||
"""
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
sync_user = sync_db.query(User).filter(User.id == user_id).first() # type: ignore[arg-type]
|
||||
if sync_user:
|
||||
sync_user.hashed_password = self.password_helper.hash(
|
||||
user_create.password
|
||||
)
|
||||
sync_user.is_verified = user_create.is_verified or False
|
||||
sync_user.role = user_create.role
|
||||
sync_user.account_type = AccountType.STANDARD
|
||||
assign_user_to_default_groups__no_commit(
|
||||
sync_db,
|
||||
sync_user,
|
||||
is_admin=(user_create.role == UserRole.ADMIN),
|
||||
)
|
||||
sync_db.commit()
|
||||
else:
|
||||
logger.warning(
|
||||
"User %s not found in sync session during upgrade to standard; "
|
||||
"skipping upgrade",
|
||||
user_id,
|
||||
)
|
||||
|
||||
async def validate_password(self, password: str, _: schemas.UC | models.UP) -> None:
|
||||
# Validate password according to configurable security policy (defined via environment variables)
|
||||
if len(password) < PASSWORD_MIN_LENGTH:
|
||||
@@ -733,7 +694,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
"email": account_email,
|
||||
"hashed_password": self.password_helper.hash(password),
|
||||
"is_verified": is_verified_by_default,
|
||||
"account_type": AccountType.STANDARD,
|
||||
}
|
||||
|
||||
user = await self.user_db.create(user_dict)
|
||||
@@ -766,7 +726,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
)
|
||||
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if not user.account_type.is_web_login():
|
||||
if not user.role.is_web_login():
|
||||
# We must use the existing user in the session if it matches
|
||||
# the user we just got by email/oauth. Note that this only applies
|
||||
# to multi-tenant, due to the overwriting of the user_db
|
||||
@@ -783,25 +743,14 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
enforce_seat_limit(sync_db)
|
||||
|
||||
# Upgrade the user and assign default groups in a single
|
||||
# transaction so neither change is visible without the other.
|
||||
was_inactive = not user.is_active
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
sync_user = sync_db.query(User).filter(User.id == user.id).first() # type: ignore[arg-type]
|
||||
if sync_user:
|
||||
sync_user.is_verified = is_verified_by_default
|
||||
sync_user.role = UserRole.BASIC
|
||||
sync_user.account_type = AccountType.STANDARD
|
||||
if was_inactive:
|
||||
sync_user.is_active = True
|
||||
assign_user_to_default_groups__no_commit(sync_db, sync_user)
|
||||
sync_db.commit()
|
||||
|
||||
# Refresh the async user object so downstream code
|
||||
# (e.g. oidc_expiry check) sees the updated fields.
|
||||
self.user_db.session.expire(user)
|
||||
user = await self.user_db.get(user.id)
|
||||
assert user is not None
|
||||
await self.user_db.update(
|
||||
user,
|
||||
{
|
||||
"is_verified": is_verified_by_default,
|
||||
"role": UserRole.BASIC,
|
||||
**({"is_active": True} if not user.is_active else {}),
|
||||
},
|
||||
)
|
||||
|
||||
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
|
||||
# otherwise, the oidc expiry will always be old, and the user will never be able to login
|
||||
@@ -887,16 +836,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
event=MilestoneRecordType.TENANT_CREATED,
|
||||
)
|
||||
|
||||
# Assign user to the appropriate default group (Admin or Basic).
|
||||
# Must happen inside the try block while tenant context is active,
|
||||
# otherwise get_session_with_current_tenant() targets the wrong schema.
|
||||
is_admin = user_count == 1 or user.email in get_default_admin_user_emails()
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session, user, is_admin=is_admin
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
@@ -1036,7 +975,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
self.password_helper.hash(credentials.password)
|
||||
return None
|
||||
|
||||
if not user.account_type.is_web_login():
|
||||
if not user.role.is_web_login():
|
||||
raise BasicAuthenticationError(
|
||||
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
|
||||
)
|
||||
@@ -1532,7 +1471,7 @@ async def _get_or_create_user_from_jwt(
|
||||
if not user.is_active:
|
||||
logger.warning("Inactive user %s attempted JWT login; skipping", email)
|
||||
return None
|
||||
if not user.account_type.is_web_login():
|
||||
if not user.role.is_web_login():
|
||||
raise exceptions.UserNotExists()
|
||||
except exceptions.UserNotExists:
|
||||
logger.info("Provisioning user %s from JWT login", email)
|
||||
@@ -1553,7 +1492,7 @@ async def _get_or_create_user_from_jwt(
|
||||
email,
|
||||
)
|
||||
return None
|
||||
if not user.account_type.is_web_login():
|
||||
if not user.role.is_web_login():
|
||||
logger.warning(
|
||||
"Non-web-login user %s attempted JWT login during provisioning race; skipping",
|
||||
email,
|
||||
@@ -1615,7 +1554,6 @@ def get_anonymous_user() -> User:
|
||||
is_verified=True,
|
||||
is_superuser=False,
|
||||
role=UserRole.LIMITED,
|
||||
account_type=AccountType.ANONYMOUS,
|
||||
use_memories=False,
|
||||
enable_memory_tool=False,
|
||||
)
|
||||
|
||||
@@ -317,6 +317,7 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
"onyx.background.celery.tasks.evals",
|
||||
"onyx.background.celery.tasks.hierarchyfetching",
|
||||
"onyx.background.celery.tasks.hooks",
|
||||
"onyx.background.celery.tasks.periodic",
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.shared",
|
||||
|
||||
@@ -14,6 +14,7 @@ from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
# choosing 15 minutes because it roughly gives us enough time to process many tasks
|
||||
@@ -361,6 +362,19 @@ if not MULTI_TENANT:
|
||||
|
||||
tasks_to_schedule.extend(beat_task_templates)
|
||||
|
||||
if HOOKS_AVAILABLE:
|
||||
tasks_to_schedule.append(
|
||||
{
|
||||
"name": "hook-execution-log-cleanup",
|
||||
"task": OnyxCeleryTask.HOOK_EXECUTION_LOG_CLEANUP_TASK,
|
||||
"schedule": timedelta(days=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def generate_cloud_tasks(
|
||||
beat_tasks: list[dict], beat_templates: list[dict], beat_multiplier: float
|
||||
|
||||
@@ -319,11 +319,6 @@ def monitor_indexing_attempt_progress(
|
||||
)
|
||||
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
total_batches: int | str = (
|
||||
coordination_status.total_batches
|
||||
if coordination_status.total_batches is not None
|
||||
else "?"
|
||||
)
|
||||
if coordination_status.found:
|
||||
task_logger.info(
|
||||
f"Indexing attempt progress: "
|
||||
@@ -331,7 +326,7 @@ def monitor_indexing_attempt_progress(
|
||||
f"cc_pair={attempt.connector_credential_pair_id} "
|
||||
f"search_settings={attempt.search_settings_id} "
|
||||
f"completed_batches={coordination_status.completed_batches} "
|
||||
f"total_batches={total_batches} "
|
||||
f"total_batches={coordination_status.total_batches or '?'} "
|
||||
f"total_docs={coordination_status.total_docs} "
|
||||
f"total_failures={coordination_status.total_failures}"
|
||||
f"elapsed={(current_db_time - attempt.time_created).seconds}"
|
||||
@@ -415,7 +410,7 @@ def check_indexing_completion(
|
||||
logger.info(
|
||||
f"Indexing status: "
|
||||
f"indexing_completed={indexing_completed} "
|
||||
f"batches_processed={batches_processed}/{batches_total if batches_total is not None else '?'} "
|
||||
f"batches_processed={batches_processed}/{batches_total or '?'} "
|
||||
f"total_docs={coordination_status.total_docs} "
|
||||
f"total_chunks={coordination_status.total_chunks} "
|
||||
f"total_failures={coordination_status.total_failures}"
|
||||
|
||||
@@ -1,33 +1,25 @@
|
||||
# Overview of Context Management
|
||||
|
||||
This document reviews some design decisions around the main agent-loop powering Onyx's chat flow.
|
||||
It is highly recommended for all engineers contributing to this flow to be familiar with the concepts here.
|
||||
|
||||
> Note: it is assumed the reader is familiar with the Onyx product and features such as Projects, User files, Citations, etc.
|
||||
|
||||
## System Prompt
|
||||
|
||||
The system prompt is a default prompt that comes packaged with the system. Users can edit the default prompt and it will be persisted in the database.
|
||||
|
||||
Some parts of the system prompt are dynamically updated / inserted:
|
||||
|
||||
- Datetime of the message sent
|
||||
- Tools description of when to use certain tools depending on if the tool is available in that cycle
|
||||
- If the user has just called a search related tool, then a section about citations is included
|
||||
|
||||
## Custom Agent Prompt
|
||||
|
||||
## Custom Agent Prompt
|
||||
The custom agent is inserted as a user message above the most recent user message, it is dynamically moved in the history as the user sends more messages.
|
||||
If the user has opted to completely replace the System Prompt, then this Custom Agent prompt replaces the system prompt and does not move along the history.
|
||||
|
||||
|
||||
## How Files are handled
|
||||
|
||||
On upload, Files are processed for tokens, if too many tokens to fit in the context, it’s considered a failed inclusion. This is done using the LLM tokenizer.
|
||||
|
||||
- In many cases, there is not a known tokenizer for each LLM so there is a default tokenizer used as a catchall.
|
||||
- File upload happens in 2 parts - the actual upload + token counting.
|
||||
- Files are added into chat context as a “point in time” inclusion and move up the context window as the conversation progresses.
|
||||
Every file knows how many tokens it is (model agnostic), image files have some assumed number of tokens.
|
||||
Every file knows how many tokens it is (model agnostic), image files have some assumed number of tokens.
|
||||
|
||||
Image files are attached to User Messages also as point in time inclusions.
|
||||
|
||||
@@ -35,8 +27,8 @@ Image files are attached to User Messages also as point in time inclusions.
|
||||
Files selected from the search results are also counted as “point in time” inclusions. Files that are too large cannot be selected.
|
||||
For these files, the "entire file" does not exist for most connectors, it's pieced back together from the search engine.
|
||||
|
||||
## Projects
|
||||
|
||||
## Projects
|
||||
If a Project contains few enough files that it all fits in the model context, we keep it close enough in the history to ensure it is easy for the LLM to
|
||||
access. Note that the project documents are assumed to be quite useful and that they should 1. never be dropped from context, 2. is not just a needle in
|
||||
a haystack type search with a strong keyword to make the LLM attend to it.
|
||||
@@ -44,12 +36,11 @@ a haystack type search with a strong keyword to make the LLM attend to it.
|
||||
Project files are vectorized and stored in the Search Engine so that if the user chooses a model with less context than the number of tokens in the project,
|
||||
the system can RAG over the project files.
|
||||
|
||||
|
||||
## How documents are represented
|
||||
|
||||
Documents from search or uploaded Project files are represented as a json so that the LLM can easily understand it. It is represented with a prefix string to
|
||||
make the context clearer to the LLM. Note that for search results (whether web or internal, it will just be the json) and it will be a Tool Call type of
|
||||
message rather than a user message.
|
||||
|
||||
Documents from search or uploaded Project files are represented as a json so that the LLM can easily understand it. It is represented with a prefix to make the
|
||||
context clearer to the LLM. Note that for search results (whether web or internal, it will just be the json) and it will be a Tool Call type of message
|
||||
rather than a user message.
|
||||
```
|
||||
Here are some documents provided for context, they may not all be relevant:
|
||||
{
|
||||
@@ -59,37 +50,33 @@ Here are some documents provided for context, they may not all be relevant:
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Documents are represented with the `document` key so that the LLM can easily cite them with a single number. The tool returns have to be richer to be able to
|
||||
Documents are represented with document so that the LLM can easily cite them with a single number. The tool returns have to be richer to be able to
|
||||
translate this into links and other UI elements. What the LLM sees is far simpler to reduce noise/hallucinations.
|
||||
|
||||
Note that documents included in a single turn should be collapsed into a single user message.
|
||||
|
||||
Search tools also give URLs to the LLM so that open_url (a separate tool) can be called on them.
|
||||
Search tools give URLs to the LLM though so that open_url (a separate tool) can be called on them.
|
||||
|
||||
|
||||
## Reminders
|
||||
|
||||
To ensure the LLM follows certain specific instructions, instructions are added at the very end of the chat context as a user message. If a search related
|
||||
tool is used, a citation reminder is always added. Otherwise, by default there is no reminder. If the user configures reminders, those are added to the
|
||||
final message. If a search related tool just ran and the user has reminders, both appear in a single message.
|
||||
|
||||
If a search related tool is called at any point during the turn, the reminder will remain at the end until the turn is over and the agent has responded.
|
||||
|
||||
## Tool Calls
|
||||
|
||||
As tool call responses can get very long (like an internal search can be many thousands of tokens), tool responses are current replaced with a hardcoded
|
||||
## Tool Calls
|
||||
As tool call responses can get very long (like an internal search can be many thousands of tokens), tool responses are today replaced with a hardcoded
|
||||
string saying it is no longer available. Tool Call details like the search query and other arguments are kept in the history as this is information
|
||||
rich and generally very few tokens.
|
||||
|
||||
> Note: in the Internal Search flow with query expansion, the Tool Call which was actually run differs from what the LLM provided as arguments.
|
||||
> What the LLM sees in the history (to be most informative for future calls) is the full set of expanded queries.
|
||||
|
||||
**Possible Future Extension**:
|
||||
Instead of dropping the Tool Call response, we might summarize it using an LLM so that it is just 1-2 sentences and captures the main points. That said,
|
||||
this is questionable value add because anything relevant and useful should be already captured in the Agent response.
|
||||
|
||||
## Examples
|
||||
|
||||
## Examples
|
||||
```
|
||||
S -> System Message
|
||||
CA -> Custom Agent as a User Message
|
||||
@@ -111,15 +98,15 @@ Flow with Project and File Upload
|
||||
S, CA, P, F, U1, A1 -- user sends another message -> S, F, U1, A1, CA, P, U2, A2
|
||||
- File stays in place, above the user message
|
||||
- Project files move along the chain as new messages are sent
|
||||
- Custom Agent prompt comes before project files which come before user uploaded files in each turn
|
||||
- Custom Agent prompt comes before project files which comes before user uploaded files in each turn
|
||||
|
||||
Reminders during a single Turn
|
||||
S, U1, TC, TR, R -- agent calls another tool -> S, U1, TC, TR, TC, TR, R, A1
|
||||
- Reminder moved to the end
|
||||
```
|
||||
|
||||
## Product considerations
|
||||
|
||||
## Product considerations
|
||||
Project files are important to the entire duration of the chat session. If the user has uploaded project files, they are likely very intent on working with
|
||||
those files. The LLM is much better at referencing documents close to the end of the context window so keeping it there for ease of access.
|
||||
|
||||
@@ -130,9 +117,9 @@ User Message further away. This tradeoff is accepted for Projects because of the
|
||||
Reminder are absolutely necessary to ensure 1-2 specific instructions get followed with a very high probability. It is less detailed than the system prompt
|
||||
and should be very targetted for it to work reliably and also not interfere with the last user message.
|
||||
|
||||
## Reasons / Experiments
|
||||
|
||||
Custom Agent instructions being placed in the system prompt is poorly followed. It also degrades performance of the system especially when the instructions
|
||||
## Reasons / Experiments
|
||||
Custom Agent instructions being placed in the system prompt is poorly followed. It also degrade performance of the system especially when the instructions
|
||||
are orthogonal (or even possibly contradictory) to the system prompt. For weaker models, it causes strange artifacts in tool calls and final responses
|
||||
that completely ruins the user experience. Empirically, this way works better across a range of models especially when the history gets longer.
|
||||
Having the Custom Agent instructions not move means it fades more as the chat gets long which is also not ok from a UX perspective.
|
||||
@@ -159,10 +146,10 @@ In a similar concept, LLM instructions in the system prompt are structured speci
|
||||
fairly surprising actually but if there is a line of instructions effectively saying "If you try to use some tools and find that you need more information or
|
||||
need to call additional tools, you are encouraged to do this", having this in the Tool section of the System prompt makes all the LLMs follow it well but if it's
|
||||
even just a paragraph away like near the beginning of the prompt, it is often ignored. The difference is as drastic as a 30% follow rate to a 90% follow
|
||||
rate by even just moving the same statement a few sentences.
|
||||
rate even just moving the same statement a few sentences.
|
||||
|
||||
|
||||
## Other related pointers
|
||||
|
||||
- How messages, files, images are stored can be found in backend/onyx/db/models.py, there is also a README.md under that directory that may be helpful.
|
||||
|
||||
---
|
||||
@@ -173,38 +160,32 @@ rate by even just moving the same statement a few sentences.
|
||||
Turn: User sends a message and AI does some set of things and responds
|
||||
Step/Cycle: 1 single LLM inference given some context and some tools
|
||||
|
||||
## 1. Top Level (process_message function):
|
||||
|
||||
## 1. Top Level (process_message function):
|
||||
This function can be thought of as the set-up and validation layer. It ensures that the database is in a valid state, reads the
|
||||
messages in the session and sets up all the necessary items to run the chat loop and state containers. The major things it does
|
||||
are:
|
||||
|
||||
- Validates the request
|
||||
- Builds the chat history for the session
|
||||
- Fetches any additional context such as files and images
|
||||
- Prepares all of the tools for the LLM
|
||||
- Creates the state container objects for use in the loop
|
||||
|
||||
### Execution (`_run_models` function):
|
||||
|
||||
Each model runs in its own worker thread inside a `ThreadPoolExecutor`. Workers write packets to a shared
|
||||
`merged_queue` via an `Emitter`; the main thread drains the queue and yields packets in arrival order. This
|
||||
means the top level is isolated from the LLM flow and can yield packets as soon as they are produced. If a
|
||||
worker fails, the main thread yields a `StreamingError` for that model and keeps the other models running.
|
||||
All saving and database operations are handled by the main thread after the workers complete (or by the
|
||||
workers themselves via self-completion if the drain loop exits early).
|
||||
### Wrapper (run_chat_loop_with_state_containers function):
|
||||
This wrapper is used to run the LLM flow in a background thread and monitor the emitter for stop signals. This means the top
|
||||
level is as isolated from the LLM flow as possible and can continue to yield packets as soon as they are available from the lower
|
||||
levels. This also means that if the lower levels fail, the top level will still guarantee a reasonable response to the user.
|
||||
All of the saving and database operations are abstracted away from the lower levels.
|
||||
|
||||
### Emitter
|
||||
|
||||
The emitter is an object that lower levels use to send packets without needing to yield them all the way back
|
||||
up the call stack. Each `Emitter` tags every packet with a `model_index` and places it on the shared
|
||||
`merged_queue` as a `(model_idx, packet)` tuple. The drain loop in `_run_models` consumes these tuples and
|
||||
yields the packets to the caller. Both the emitter and the state container are mutating state objects used
|
||||
only to accumulate state. There should be no logic dependent on the states of these objects, especially in
|
||||
the lower levels. The emitter should only take packets and should not be used for other things.
|
||||
The emitter is designed to be an object queue so that lower levels do not need to yield objects all the way back to the top.
|
||||
This way the functions can be better designed (not everything as a generator) and more easily tested. The wrapper around the
|
||||
LLM flow (run_chat_loop_with_state_containers) is used to monitor the emitter and handle packets as soon as they are available
|
||||
from the lower levels. Both the emitter and the state container are mutating state objects and only used to accumulate state.
|
||||
There should be no logic dependent on the states of these objects, especially in the lower levels. The emitter should only take
|
||||
packets and should not be used for other things.
|
||||
|
||||
### State Container
|
||||
|
||||
The state container is used to accumulate state during the LLM flow. Similar to the emitter, it should not be used for logic,
|
||||
only for accumulating state. It is used to gather all of the necessary information for saving the chat turn into the database.
|
||||
So it will accumulate answer tokens, reasoning tokens, tool calls, citation info, etc. This is used at the end of the flow once
|
||||
@@ -212,40 +193,35 @@ the lower level is completed whether on its own or stopped by the user. At that
|
||||
the database. The state container can be added to by any of the underlying layers, this is fine.
|
||||
|
||||
### Stopping Generation
|
||||
A stop signal is checked every 300ms by the wrapper around the LLM flow. The signal itself
|
||||
is stored in Redis and is set by the user calling the stop endpoint. The wrapper ensures that no matter what the lower level is
|
||||
doing at the time, the thread can be killed by the top level. It does not require a cooperative cancellation from the lower level
|
||||
and in fact the lower level does not know about the stop signal at all.
|
||||
|
||||
The drain loop in `_run_models` checks `check_is_connected()` every 50 ms (on queue timeout). The signal itself
|
||||
is stored in Redis and is set by the user calling the stop endpoint. On disconnect, the drain loop saves
|
||||
partial state for every model, yields an `OverallStop(stop_reason="user_cancelled")` packet, and returns.
|
||||
A `drain_done` event signals emitters to stop blocking so worker threads can exit quickly. Workers that
|
||||
already completed successfully will self-complete (persist their response) if the drain loop exited before
|
||||
reaching the normal completion path.
|
||||
|
||||
## 2. LLM Loop (run_llm_loop function)
|
||||
|
||||
This function handles the logic of the Turn. It's essentially a while loop where context is added and modified (according what
|
||||
is outlined in the first half of this doc). Its main functionality is:
|
||||
|
||||
- Translate and truncate the context for the LLM inference
|
||||
- Add context modifiers like reminders, updates to the system prompts, etc.
|
||||
- Run tool calls and gather results
|
||||
- Build some of the objects stored in the state container.
|
||||
|
||||
## 3. LLM Step (run_llm_step function)
|
||||
|
||||
## 3. LLM Step (run_llm_step function)
|
||||
This function is a single inference of the LLM. It's a wrapper around the LLM stream function which handles packet translations
|
||||
so that the Emitter can emit individual tokens as soon as they arrive. It also keeps track of the different sections since they
|
||||
do not all come at once (reasoning, answers, tool calls are all built up token by token). This layer also tracks the different
|
||||
tool calls and returns that to the LLM Loop to execute.
|
||||
|
||||
|
||||
## Things to know
|
||||
|
||||
- Packets are labeled with a "turn_index" field as part of the Placement of the packet. This is not the same as the backend
|
||||
concept of a turn. The turn_index for the frontend is which block does this packet belong to. So while a reasoning + tool call
|
||||
comes from the same LLM inference (same backend LLM step), they are 2 turns to the frontend because that's how it's rendered.
|
||||
concept of a turn. The turn_index for the frontend is which block does this packet belong to. So while a reasoning + tool call
|
||||
comes from the same LLM inference (same backend LLM step), they are 2 turns to the frontend because that's how it's rendered.
|
||||
|
||||
- There are 3 representations of a message, each scoped to a different layer:
|
||||
1. **ChatMessage** — The database model. Should be converted into ChatMessageSimple early and never passed deep into the flow.
|
||||
2. **ChatMessageSimple** — The canonical data model used throughout the codebase. This is the rich, full-featured representation
|
||||
of a message. Any modifications or additions to message structure should be made here.
|
||||
3. **LanguageModelInput** — The LLM-facing representation. Intentionally minimal so the LLM interface layer stays clean and
|
||||
easy to maintain/extend.
|
||||
- There are 3 representations of "message". The first is the database model ChatMessage, this one should be translated away and
|
||||
not used deep into the flow. The second is ChatMessageSimple which is the data model which should be used throughout the code
|
||||
as much as possible. If modifications/additions are needed, it should be to this object. This is the rich representation of a
|
||||
message for the code. Finally there is the LanguageModelInput representation of a message. This one is for the LLM interface
|
||||
layer and is as stripped down as possible so that the LLM interface can be clean and easy to maintain/extend.
|
||||
|
||||
@@ -1,28 +1,19 @@
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
from collections.abc import Generator
|
||||
from queue import Empty
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ExtractedContextFiles
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import SearchParams
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.db.memory import UserMemoryContext
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import Persona
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.tools.models import ChatFile
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PacketException
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
|
||||
# Type alias for search doc deduplication key
|
||||
# Simple key: just document_id (str)
|
||||
@@ -170,45 +161,112 @@ class ChatStateContainer:
|
||||
return self._emitted_citations.copy()
|
||||
|
||||
|
||||
class AvailableFiles(BaseModel):
|
||||
"""Separated file IDs for the FileReaderTool so it knows which loader to use."""
|
||||
def run_chat_loop_with_state_containers(
|
||||
chat_loop_func: Callable[[Emitter, ChatStateContainer], None],
|
||||
completion_callback: Callable[[ChatStateContainer], None],
|
||||
is_connected: Callable[[], bool],
|
||||
emitter: Emitter,
|
||||
state_container: ChatStateContainer,
|
||||
) -> Generator[Packet, None]:
|
||||
"""
|
||||
Explicit wrapper function that runs a function in a background thread
|
||||
with event streaming capabilities.
|
||||
|
||||
# IDs from the ``user_file`` table (project / persona-attached files).
|
||||
user_file_ids: list[UUID] = []
|
||||
# IDs from the ``file_record`` table (chat-attached files).
|
||||
chat_file_ids: list[UUID] = []
|
||||
The wrapped function should accept emitter as first arg and use it to emit
|
||||
Packet objects. This wrapper polls every 300ms to check if stop signal is set.
|
||||
|
||||
Args:
|
||||
func: The function to wrap (should accept emitter and state_container as first and second args)
|
||||
completion_callback: Callback function to call when the function completes
|
||||
emitter: Emitter instance for sending packets
|
||||
state_container: ChatStateContainer instance for accumulating state
|
||||
is_connected: Callable that returns False when stop signal is set
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ChatTurnSetup:
|
||||
"""Immutable context produced by ``build_chat_turn`` and consumed by ``_run_models``."""
|
||||
Usage:
|
||||
packets = run_chat_loop_with_state_containers(
|
||||
my_func,
|
||||
completion_callback=completion_callback,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
is_connected=check_func,
|
||||
)
|
||||
for packet in packets:
|
||||
# Process packets
|
||||
pass
|
||||
"""
|
||||
|
||||
new_msg_req: SendMessageRequest
|
||||
chat_session: ChatSession
|
||||
persona: Persona
|
||||
user_message: ChatMessage
|
||||
user_identity: LLMUserIdentity
|
||||
llms: list[LLM] # length 1 for single-model, N for multi-model
|
||||
model_display_names: list[str] # parallel to llms
|
||||
simple_chat_history: list[ChatMessageSimple]
|
||||
extracted_context_files: ExtractedContextFiles
|
||||
reserved_messages: list[ChatMessage] # length 1 for single, N for multi
|
||||
reserved_token_count: int
|
||||
search_params: SearchParams
|
||||
all_injected_file_metadata: dict[str, FileToolMetadata]
|
||||
available_files: AvailableFiles
|
||||
tool_id_to_name_map: dict[int, str]
|
||||
forced_tool_id: int | None
|
||||
files: list[ChatLoadedFile]
|
||||
chat_files_for_tools: list[ChatFile]
|
||||
custom_agent_prompt: str | None
|
||||
user_memory_context: UserMemoryContext
|
||||
# For deep research: was the last assistant message a clarification request?
|
||||
skip_clarification: bool
|
||||
check_is_connected: Callable[[], bool]
|
||||
cache: CacheBackend
|
||||
# Execution params forwarded to per-model tool construction
|
||||
bypass_acl: bool
|
||||
slack_context: SlackContext | None
|
||||
custom_tool_additional_headers: dict[str, str] | None
|
||||
mcp_headers: dict[str, str] | None
|
||||
def run_with_exception_capture() -> None:
|
||||
try:
|
||||
chat_loop_func(emitter, state_container)
|
||||
except Exception as e:
|
||||
# If execution fails, emit an exception packet
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=PacketException(type="error", exception=e),
|
||||
)
|
||||
)
|
||||
|
||||
# Run the function in a background thread
|
||||
thread = run_in_background(run_with_exception_capture)
|
||||
|
||||
pkt: Packet | None = None
|
||||
last_turn_index = 0 # Track the highest turn_index seen for stop packet
|
||||
last_cancel_check = time.monotonic()
|
||||
cancel_check_interval = 0.3 # Check for cancellation every 300ms
|
||||
try:
|
||||
while True:
|
||||
# Poll queue with 300ms timeout for natural stop signal checking
|
||||
# the 300ms timeout is to avoid busy-waiting and to allow the stop signal to be checked regularly
|
||||
try:
|
||||
pkt = emitter.bus.get(timeout=0.3)
|
||||
except Empty:
|
||||
if not is_connected():
|
||||
# Stop signal detected
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
|
||||
)
|
||||
break
|
||||
last_cancel_check = time.monotonic()
|
||||
continue
|
||||
|
||||
if pkt is not None:
|
||||
# Track the highest turn_index for the stop packet
|
||||
if pkt.placement and pkt.placement.turn_index > last_turn_index:
|
||||
last_turn_index = pkt.placement.turn_index
|
||||
|
||||
if isinstance(pkt.obj, OverallStop):
|
||||
yield pkt
|
||||
break
|
||||
elif isinstance(pkt.obj, PacketException):
|
||||
raise pkt.obj.exception
|
||||
else:
|
||||
yield pkt
|
||||
|
||||
# Check for cancellation periodically even when packets are flowing
|
||||
# This ensures stop signal is checked during active streaming
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_cancel_check >= cancel_check_interval:
|
||||
if not is_connected():
|
||||
# Stop signal detected during streaming
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
|
||||
)
|
||||
break
|
||||
last_cancel_check = current_time
|
||||
finally:
|
||||
# Wait for thread to complete on normal exit to propagate exceptions and ensure cleanup.
|
||||
# Skip waiting if user disconnected to exit quickly.
|
||||
if is_connected():
|
||||
wait_on_background(thread)
|
||||
try:
|
||||
completion_callback(state_container)
|
||||
except Exception as e:
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=PacketException(type="error", exception=e),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi.datastructures import Headers
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.models import ChatHistoryResult
|
||||
@@ -52,60 +51,6 @@ logger = setup_logger()
|
||||
IMAGE_GENERATION_TOOL_NAME = "generate_image"
|
||||
|
||||
|
||||
class FileContextResult(BaseModel):
|
||||
"""Result of building a file's LLM context representation."""
|
||||
|
||||
message: ChatMessageSimple
|
||||
tool_metadata: FileToolMetadata
|
||||
|
||||
|
||||
def build_file_context(
|
||||
tool_file_id: str,
|
||||
filename: str,
|
||||
file_type: ChatFileType,
|
||||
content_text: str | None = None,
|
||||
token_count: int = 0,
|
||||
approx_char_count: int | None = None,
|
||||
) -> FileContextResult:
|
||||
"""Build the LLM context representation for a single file.
|
||||
|
||||
Centralises how files should appear in the LLM prompt
|
||||
— the ID that FileReaderTool accepts (``UserFile.id`` for user files).
|
||||
"""
|
||||
if file_type.use_metadata_only():
|
||||
message_text = (
|
||||
f"File: {filename} (id={tool_file_id})\n"
|
||||
"Use the file_reader or python tools to access "
|
||||
"this file's contents."
|
||||
)
|
||||
message = ChatMessageSimple(
|
||||
message=message_text,
|
||||
token_count=max(1, len(message_text) // 4),
|
||||
message_type=MessageType.USER,
|
||||
file_id=tool_file_id,
|
||||
)
|
||||
else:
|
||||
message_text = f"File: {filename}\n{content_text or ''}\nEnd of File"
|
||||
message = ChatMessageSimple(
|
||||
message=message_text,
|
||||
token_count=token_count,
|
||||
message_type=MessageType.USER,
|
||||
file_id=tool_file_id,
|
||||
)
|
||||
|
||||
metadata = FileToolMetadata(
|
||||
file_id=tool_file_id,
|
||||
filename=filename,
|
||||
approx_char_count=(
|
||||
approx_char_count
|
||||
if approx_char_count is not None
|
||||
else len(content_text or "")
|
||||
),
|
||||
)
|
||||
|
||||
return FileContextResult(message=message, tool_metadata=metadata)
|
||||
|
||||
|
||||
def create_chat_session_from_request(
|
||||
chat_session_request: ChatSessionCreationRequest,
|
||||
user_id: UUID | None,
|
||||
@@ -593,7 +538,7 @@ def convert_chat_history(
|
||||
for idx, chat_message in enumerate(chat_history):
|
||||
if chat_message.message_type == MessageType.USER:
|
||||
# Process files attached to this message
|
||||
text_files: list[tuple[ChatLoadedFile, FileDescriptor]] = []
|
||||
text_files: list[ChatLoadedFile] = []
|
||||
image_files: list[ChatLoadedFile] = []
|
||||
|
||||
if chat_message.files:
|
||||
@@ -604,26 +549,34 @@ def convert_chat_history(
|
||||
if loaded_file.file_type == ChatFileType.IMAGE:
|
||||
image_files.append(loaded_file)
|
||||
else:
|
||||
# Text files (DOC, PLAIN_TEXT, TABULAR) are added as separate messages
|
||||
text_files.append((loaded_file, file_descriptor))
|
||||
# Text files (DOC, PLAIN_TEXT, CSV) are added as separate messages
|
||||
text_files.append(loaded_file)
|
||||
|
||||
# Add text files as separate messages before the user message.
|
||||
# Each message is tagged with ``file_id`` so that forgotten files
|
||||
# can be detected after context-window truncation.
|
||||
for text_file, fd in text_files:
|
||||
# Use user_file_id as the FileReaderTool accepts that.
|
||||
# Fall back to the file-store path id.
|
||||
tool_id = fd.get("user_file_id") or text_file.file_id
|
||||
filename = text_file.filename or "unknown"
|
||||
ctx = build_file_context(
|
||||
tool_file_id=tool_id,
|
||||
filename=filename,
|
||||
file_type=text_file.file_type,
|
||||
content_text=text_file.content_text,
|
||||
token_count=text_file.token_count,
|
||||
for text_file in text_files:
|
||||
file_text = text_file.content_text or ""
|
||||
filename = text_file.filename
|
||||
message = (
|
||||
f"File: {filename}\n{file_text}\nEnd of File"
|
||||
if filename
|
||||
else file_text
|
||||
)
|
||||
simple_messages.append(
|
||||
ChatMessageSimple(
|
||||
message=message,
|
||||
token_count=text_file.token_count,
|
||||
message_type=MessageType.USER,
|
||||
image_files=None,
|
||||
file_id=text_file.file_id,
|
||||
)
|
||||
)
|
||||
all_injected_file_metadata[text_file.file_id] = FileToolMetadata(
|
||||
file_id=text_file.file_id,
|
||||
filename=filename or "unknown",
|
||||
approx_char_count=len(file_text),
|
||||
)
|
||||
simple_messages.append(ctx.message)
|
||||
all_injected_file_metadata[tool_id] = ctx.tool_metadata
|
||||
|
||||
# Sum token counts from image files (excluding project image files)
|
||||
image_token_count = (
|
||||
|
||||
@@ -1,40 +1,19 @@
|
||||
import threading
|
||||
from queue import Queue
|
||||
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
|
||||
|
||||
class Emitter:
|
||||
"""Routes packets from LLM/tool execution to the ``_run_models`` drain loop.
|
||||
"""Use this inside tools to emit arbitrary UI progress."""
|
||||
|
||||
Tags every packet with ``model_index`` and places it on ``merged_queue``
|
||||
as a ``(model_idx, packet)`` tuple for ordered consumption downstream.
|
||||
|
||||
Args:
|
||||
merged_queue: Shared queue owned by ``_run_models``.
|
||||
model_idx: Index embedded in packet placements (``0`` for N=1 runs).
|
||||
drain_done: Optional event set by ``_run_models`` when the drain loop
|
||||
exits early (e.g. HTTP disconnect). When set, ``emit`` returns
|
||||
immediately so worker threads can exit fast.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
merged_queue: Queue[tuple[int, Packet | Exception | object]],
|
||||
model_idx: int = 0,
|
||||
drain_done: threading.Event | None = None,
|
||||
) -> None:
|
||||
self._model_idx = model_idx
|
||||
self._merged_queue = merged_queue
|
||||
self._drain_done = drain_done
|
||||
def __init__(self, bus: Queue):
|
||||
self.bus = bus
|
||||
|
||||
def emit(self, packet: Packet) -> None:
|
||||
if self._drain_done is not None and self._drain_done.is_set():
|
||||
return
|
||||
base = packet.placement or Placement(turn_index=0)
|
||||
tagged = Packet(
|
||||
placement=base.model_copy(update={"model_index": self._model_idx}),
|
||||
obj=packet.obj,
|
||||
)
|
||||
self._merged_queue.put((self._model_idx, tagged))
|
||||
self.bus.put(packet) # Thread-safe
|
||||
|
||||
|
||||
def get_default_emitter() -> Emitter:
|
||||
bus: Queue[Packet] = Queue()
|
||||
emitter = Emitter(bus)
|
||||
return emitter
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -286,9 +286,11 @@ USING_AWS_MANAGED_OPENSEARCH = (
|
||||
os.environ.get("USING_AWS_MANAGED_OPENSEARCH", "").lower() == "true"
|
||||
)
|
||||
# Profiling adds some overhead to OpenSearch operations. This overhead is
|
||||
# unknown right now. Defaults to True.
|
||||
# unknown right now. It is enabled by default so we can get useful logs for
|
||||
# investigating slow queries. We may never disable it if the overhead is
|
||||
# minimal.
|
||||
OPENSEARCH_PROFILING_DISABLED = (
|
||||
os.environ.get("OPENSEARCH_PROFILING_DISABLED", "true").lower() == "true"
|
||||
os.environ.get("OPENSEARCH_PROFILING_DISABLED", "").lower() == "true"
|
||||
)
|
||||
# Whether to disable match highlights for OpenSearch. Defaults to True for now
|
||||
# as we investigate query performance.
|
||||
@@ -940,20 +942,9 @@ CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads(
|
||||
)
|
||||
|
||||
VESPA_REQUEST_TIMEOUT = int(os.environ.get("VESPA_REQUEST_TIMEOUT") or "15")
|
||||
# This is the timeout for the client side of the Vespa migration task. When
|
||||
# exceeded, an exception is raised in our code. This value should be higher than
|
||||
# VESPA_MIGRATION_SERVER_SIDE_REQUEST_TIMEOUT.
|
||||
VESPA_MIGRATION_REQUEST_TIMEOUT_S = int(
|
||||
os.environ.get("VESPA_MIGRATION_REQUEST_TIMEOUT_S") or "120"
|
||||
)
|
||||
# This is the timeout Vespa uses on the server side to know when to wrap up its
|
||||
# traversal and try to report partial results. This differs from the client
|
||||
# timeout above which raises an exception in our code when exceeded. This
|
||||
# timeout allows Vespa to return gracefully. This value should be lower than
|
||||
# VESPA_MIGRATION_REQUEST_TIMEOUT_S. Formatted as <number of seconds>s.
|
||||
VESPA_MIGRATION_SERVER_SIDE_REQUEST_TIMEOUT = os.environ.get(
|
||||
"VESPA_MIGRATION_SERVER_SIDE_REQUEST_TIMEOUT", "110s"
|
||||
)
|
||||
|
||||
SYSTEM_RECURSION_LIMIT = int(os.environ.get("SYSTEM_RECURSION_LIMIT") or "1000")
|
||||
|
||||
@@ -1088,6 +1079,7 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
|
||||
|
||||
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
|
||||
|
||||
HOOK_ENABLED = os.environ.get("HOOK_ENABLED", "").lower() == "true"
|
||||
|
||||
INTEGRATION_TESTS_MODE = os.environ.get("INTEGRATION_TESTS_MODE", "").lower() == "true"
|
||||
|
||||
|
||||
@@ -11,13 +11,11 @@ from discord import Client
|
||||
from discord.channel import TextChannel
|
||||
from discord.channel import Thread
|
||||
from discord.enums import MessageType
|
||||
from discord.errors import LoginFailure
|
||||
from discord.flags import Intents
|
||||
from discord.message import Message as DiscordMessage
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import CredentialInvalidError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
@@ -211,19 +209,8 @@ def _manage_async_retrieval(
|
||||
intents = Intents.default()
|
||||
intents.message_content = True
|
||||
async with Client(intents=intents) as discord_client:
|
||||
start_task = asyncio.create_task(discord_client.start(token))
|
||||
ready_task = asyncio.create_task(discord_client.wait_until_ready())
|
||||
|
||||
done, _ = await asyncio.wait(
|
||||
{start_task, ready_task},
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
# start() runs indefinitely once connected, so it only lands
|
||||
# in `done` when login/connection failed — propagate the error.
|
||||
if start_task in done:
|
||||
ready_task.cancel()
|
||||
start_task.result()
|
||||
asyncio.create_task(discord_client.start(token))
|
||||
await discord_client.wait_until_ready()
|
||||
|
||||
filtered_channels: list[TextChannel] = await _fetch_filtered_channels(
|
||||
discord_client=discord_client,
|
||||
@@ -289,19 +276,6 @@ class DiscordConnector(PollConnector, LoadConnector):
|
||||
self._discord_bot_token = credentials["discord_bot_token"]
|
||||
return None
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
client = Client(intents=Intents.default())
|
||||
try:
|
||||
loop.run_until_complete(client.login(self.discord_bot_token))
|
||||
except LoginFailure as e:
|
||||
raise CredentialInvalidError(f"Invalid Discord bot token: {e}")
|
||||
finally:
|
||||
loop.run_until_complete(client.close())
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
def _manage_doc_batching(
|
||||
self,
|
||||
start: datetime | None = None,
|
||||
|
||||
@@ -8,6 +8,7 @@ from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Protocol
|
||||
@@ -1486,113 +1487,134 @@ class GoogleDriveConnector(
|
||||
end=end,
|
||||
)
|
||||
|
||||
def _convert_retrieved_files_to_documents(
|
||||
def _extract_docs_from_google_drive(
|
||||
self,
|
||||
drive_files_iter: Iterator[RetrievedDriveFile],
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
start: SecondsSinceUnixEpoch | None,
|
||||
end: SecondsSinceUnixEpoch | None,
|
||||
include_permissions: bool,
|
||||
) -> Iterator[Document | ConnectorFailure | HierarchyNode]:
|
||||
"""
|
||||
Converts retrieved files to documents, yielding HierarchyNode
|
||||
objects for ancestor folders before the converted documents.
|
||||
Retrieves and converts Google Drive files to documents.
|
||||
Also yields HierarchyNode objects for ancestor folders.
|
||||
"""
|
||||
permission_sync_context = (
|
||||
PermissionSyncContext(
|
||||
primary_admin_email=self.primary_admin_email,
|
||||
google_domain=self.google_domain,
|
||||
)
|
||||
if include_permissions
|
||||
else None
|
||||
field_type = (
|
||||
DriveFileFieldType.WITH_PERMISSIONS
|
||||
if include_permissions or self.exclude_domain_link_only
|
||||
else DriveFileFieldType.STANDARD
|
||||
)
|
||||
|
||||
files_batch: list[RetrievedDriveFile] = []
|
||||
for retrieved_file in drive_files_iter:
|
||||
if self.exclude_domain_link_only and has_link_only_permission(
|
||||
retrieved_file.drive_file
|
||||
):
|
||||
continue
|
||||
if retrieved_file.error is None:
|
||||
files_batch.append(retrieved_file)
|
||||
continue
|
||||
|
||||
failure_stage = retrieved_file.completion_stage.value
|
||||
failure_message = f"retrieval failure during stage: {failure_stage},"
|
||||
failure_message += f"user: {retrieved_file.user_email},"
|
||||
failure_message += f"parent drive/folder: {retrieved_file.parent_id},"
|
||||
failure_message += f"error: {retrieved_file.error}"
|
||||
logger.error(failure_message)
|
||||
yield ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=retrieved_file.drive_file.get("id", failure_stage),
|
||||
),
|
||||
failure_message=failure_message,
|
||||
exception=retrieved_file.error,
|
||||
)
|
||||
|
||||
new_ancestors = self._get_new_ancestors_for_files(
|
||||
files=files_batch,
|
||||
seen_hierarchy_node_raw_ids=checkpoint.seen_hierarchy_node_raw_ids,
|
||||
fully_walked_hierarchy_node_raw_ids=checkpoint.fully_walked_hierarchy_node_raw_ids,
|
||||
permission_sync_context=permission_sync_context,
|
||||
add_prefix=True,
|
||||
)
|
||||
if new_ancestors:
|
||||
logger.debug(f"Yielding {len(new_ancestors)} new hierarchy nodes")
|
||||
yield from new_ancestors
|
||||
|
||||
func_with_args = [
|
||||
(
|
||||
self._convert_retrieved_file_to_document,
|
||||
(retrieved_file, permission_sync_context),
|
||||
)
|
||||
for retrieved_file in files_batch
|
||||
]
|
||||
raw_results = cast(
|
||||
list[Document | ConnectorFailure | None],
|
||||
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
|
||||
)
|
||||
|
||||
results: list[Document | ConnectorFailure] = [
|
||||
r for r in raw_results if r is not None
|
||||
]
|
||||
logger.debug(f"batch has {len(results)} docs or failures")
|
||||
yield from results
|
||||
|
||||
checkpoint.retrieved_folder_and_drive_ids = self._retrieved_folder_and_drive_ids
|
||||
|
||||
def _convert_retrieved_file_to_document(
|
||||
self,
|
||||
retrieved_file: RetrievedDriveFile,
|
||||
permission_sync_context: PermissionSyncContext | None,
|
||||
) -> Document | ConnectorFailure | None:
|
||||
"""
|
||||
Converts a single retrieved file to a document.
|
||||
"""
|
||||
try:
|
||||
return convert_drive_item_to_document(
|
||||
# Build permission sync context if needed
|
||||
permission_sync_context = (
|
||||
PermissionSyncContext(
|
||||
primary_admin_email=self.primary_admin_email,
|
||||
google_domain=self.google_domain,
|
||||
)
|
||||
if include_permissions
|
||||
else None
|
||||
)
|
||||
|
||||
# Prepare a partial function with the credentials and admin email
|
||||
convert_func = partial(
|
||||
convert_drive_item_to_document,
|
||||
self.creds,
|
||||
self.allow_images,
|
||||
self.size_threshold,
|
||||
permission_sync_context,
|
||||
[retrieved_file.user_email, self.primary_admin_email]
|
||||
+ get_file_owners(retrieved_file.drive_file, self.primary_admin_email),
|
||||
retrieved_file.drive_file,
|
||||
)
|
||||
# Fetch files in batches
|
||||
batches_complete = 0
|
||||
files_batch: list[RetrievedDriveFile] = []
|
||||
|
||||
def _yield_batch(
|
||||
files_batch: list[RetrievedDriveFile],
|
||||
) -> Iterator[Document | ConnectorFailure | HierarchyNode]:
|
||||
nonlocal batches_complete
|
||||
|
||||
# First, yield any new ancestor hierarchy nodes
|
||||
new_ancestors = self._get_new_ancestors_for_files(
|
||||
files=files_batch,
|
||||
seen_hierarchy_node_raw_ids=checkpoint.seen_hierarchy_node_raw_ids,
|
||||
fully_walked_hierarchy_node_raw_ids=checkpoint.fully_walked_hierarchy_node_raw_ids,
|
||||
permission_sync_context=permission_sync_context,
|
||||
add_prefix=True, # Indexing path - prefix here
|
||||
)
|
||||
if new_ancestors:
|
||||
logger.debug(
|
||||
f"Yielding {len(new_ancestors)} new hierarchy nodes for batch {batches_complete}"
|
||||
)
|
||||
yield from new_ancestors
|
||||
|
||||
# Process the batch using run_functions_tuples_in_parallel
|
||||
func_with_args = [
|
||||
(
|
||||
convert_func,
|
||||
(
|
||||
[file.user_email, self.primary_admin_email]
|
||||
+ get_file_owners(
|
||||
file.drive_file, self.primary_admin_email
|
||||
),
|
||||
file.drive_file,
|
||||
),
|
||||
)
|
||||
for file in files_batch
|
||||
]
|
||||
results = cast(
|
||||
list[Document | ConnectorFailure | None],
|
||||
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
|
||||
)
|
||||
logger.debug(
|
||||
f"finished processing batch {batches_complete} with {len(results)} results"
|
||||
)
|
||||
|
||||
docs_and_failures = [result for result in results if result is not None]
|
||||
logger.debug(
|
||||
f"batch {batches_complete} has {len(docs_and_failures)} docs or failures"
|
||||
)
|
||||
|
||||
if docs_and_failures:
|
||||
yield from docs_and_failures
|
||||
batches_complete += 1
|
||||
logger.debug(f"finished yielding batch {batches_complete}")
|
||||
|
||||
for retrieved_file in self._fetch_drive_items(
|
||||
field_type=field_type,
|
||||
checkpoint=checkpoint,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
if self.exclude_domain_link_only and has_link_only_permission(
|
||||
retrieved_file.drive_file
|
||||
):
|
||||
continue
|
||||
if retrieved_file.error is None:
|
||||
files_batch.append(retrieved_file)
|
||||
continue
|
||||
|
||||
# handle retrieval errors
|
||||
failure_stage = retrieved_file.completion_stage.value
|
||||
failure_message = f"retrieval failure during stage: {failure_stage},"
|
||||
failure_message += f"user: {retrieved_file.user_email},"
|
||||
failure_message += f"parent drive/folder: {retrieved_file.parent_id},"
|
||||
failure_message += f"error: {retrieved_file.error}"
|
||||
logger.error(failure_message)
|
||||
yield ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=failure_stage,
|
||||
),
|
||||
failure_message=failure_message,
|
||||
exception=retrieved_file.error,
|
||||
)
|
||||
|
||||
yield from _yield_batch(files_batch)
|
||||
checkpoint.retrieved_folder_and_drive_ids = (
|
||||
self._retrieved_folder_and_drive_ids
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error extracting document: "
|
||||
f"{retrieved_file.drive_file.get('name')} from Google Drive"
|
||||
)
|
||||
return ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=retrieved_file.drive_file.get("id", "unknown"),
|
||||
),
|
||||
failure_message=(
|
||||
f"Error extracting document: "
|
||||
f"{retrieved_file.drive_file.get('name')}"
|
||||
),
|
||||
exception=e,
|
||||
)
|
||||
logger.exception(f"Error extracting documents from Google Drive: {e}")
|
||||
raise e
|
||||
|
||||
def _load_from_checkpoint(
|
||||
self,
|
||||
@@ -1616,19 +1638,8 @@ class GoogleDriveConnector(
|
||||
checkpoint = copy.deepcopy(checkpoint)
|
||||
self._retrieved_folder_and_drive_ids = checkpoint.retrieved_folder_and_drive_ids
|
||||
try:
|
||||
field_type = (
|
||||
DriveFileFieldType.WITH_PERMISSIONS
|
||||
if include_permissions or self.exclude_domain_link_only
|
||||
else DriveFileFieldType.STANDARD
|
||||
)
|
||||
drive_files_iter = self._fetch_drive_items(
|
||||
field_type=field_type,
|
||||
checkpoint=checkpoint,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
yield from self._convert_retrieved_files_to_documents(
|
||||
drive_files_iter, checkpoint, include_permissions
|
||||
yield from self._extract_docs_from_google_drive(
|
||||
checkpoint, start, end, include_permissions
|
||||
)
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
|
||||
@@ -4,8 +4,6 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
from enum import Enum
|
||||
from typing import cast
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
@@ -498,41 +496,3 @@ def get_root_folder_id(service: Resource) -> str:
|
||||
.get(fileId="root", fields=GoogleFields.ID.value)
|
||||
.execute()[GoogleFields.ID.value]
|
||||
)
|
||||
|
||||
|
||||
def _extract_file_id_from_web_view_link(web_view_link: str) -> str:
|
||||
parsed = urlparse(web_view_link)
|
||||
path_parts = [part for part in parsed.path.split("/") if part]
|
||||
|
||||
if "d" in path_parts:
|
||||
idx = path_parts.index("d")
|
||||
if idx + 1 < len(path_parts):
|
||||
return path_parts[idx + 1]
|
||||
|
||||
query_params = parse_qs(parsed.query)
|
||||
for key in ("id", "fileId"):
|
||||
value = query_params.get(key)
|
||||
if value and value[0]:
|
||||
return value[0]
|
||||
|
||||
raise ValueError(
|
||||
f"Unable to extract Drive file id from webViewLink: {web_view_link}"
|
||||
)
|
||||
|
||||
|
||||
def get_file_by_web_view_link(
|
||||
service: GoogleDriveService,
|
||||
web_view_link: str,
|
||||
fields: str,
|
||||
) -> GoogleDriveFileType:
|
||||
"""Retrieve a Google Drive file using its webViewLink."""
|
||||
file_id = _extract_file_id_from_web_view_link(web_view_link)
|
||||
return (
|
||||
service.files()
|
||||
.get(
|
||||
fileId=file_id,
|
||||
supportsAllDrives=True,
|
||||
fields=fields,
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
@@ -1,33 +1,24 @@
|
||||
import uuid
|
||||
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.api_key import ApiKeyDescriptor
|
||||
from onyx.auth.api_key import build_displayable_api_key
|
||||
from onyx.auth.api_key import generate_api_key
|
||||
from onyx.auth.api_key import hash_api_key
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
|
||||
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import ApiKey
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.server.api_key.models import APIKeyArgs
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_api_key_email_pattern() -> str:
|
||||
return DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
@@ -64,6 +55,7 @@ async def fetch_user_for_api_key(
|
||||
select(User)
|
||||
.join(ApiKey, ApiKey.user_id == User.id)
|
||||
.where(ApiKey.hashed_api_key == hashed_api_key)
|
||||
.options(selectinload(User.memories))
|
||||
)
|
||||
|
||||
|
||||
@@ -95,7 +87,6 @@ def insert_api_key(
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=api_key_args.role,
|
||||
account_type=AccountType.SERVICE_ACCOUNT,
|
||||
)
|
||||
db_session.add(api_key_user_row)
|
||||
|
||||
@@ -108,18 +99,7 @@ def insert_api_key(
|
||||
)
|
||||
db_session.add(api_key_row)
|
||||
|
||||
# Assign the API key virtual user to the appropriate default group
|
||||
# before commit so everything is atomic.
|
||||
# LIMITED role service accounts should have no group membership.
|
||||
if api_key_args.role != UserRole.LIMITED:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session,
|
||||
api_key_user_row,
|
||||
is_admin=(api_key_args.role == UserRole.ADMIN),
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return ApiKeyDescriptor(
|
||||
api_key_id=api_key_row.id,
|
||||
api_key_role=api_key_user_row.role,
|
||||
@@ -146,33 +126,7 @@ def update_api_key(
|
||||
|
||||
email_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER
|
||||
api_key_user.email = get_api_key_fake_email(email_name, str(api_key_user.id))
|
||||
|
||||
old_role = api_key_user.role
|
||||
api_key_user.role = api_key_args.role
|
||||
|
||||
# Reconcile default-group membership when the role changes.
|
||||
if old_role != api_key_args.role:
|
||||
# Remove from all default groups first.
|
||||
delete_stmt = delete(User__UserGroup).where(
|
||||
User__UserGroup.user_id == api_key_user.id,
|
||||
User__UserGroup.user_group_id.in_(
|
||||
select(UserGroup.id).where(UserGroup.is_default.is_(True))
|
||||
),
|
||||
)
|
||||
db_session.execute(delete_stmt)
|
||||
|
||||
# Re-assign to the correct default group (skip for LIMITED).
|
||||
if api_key_args.role != UserRole.LIMITED:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session,
|
||||
api_key_user,
|
||||
is_admin=(api_key_args.role == UserRole.ADMIN),
|
||||
)
|
||||
else:
|
||||
# No group assigned for LIMITED, but we still need to recompute
|
||||
# since we just removed the old default-group membership above.
|
||||
recompute_user_permissions__no_commit(api_key_user.id, db_session)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return ApiKeyDescriptor(
|
||||
|
||||
@@ -13,6 +13,7 @@ from sqlalchemy import func
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
@@ -97,6 +98,11 @@ async def get_user_count(only_admin_users: bool = False) -> int:
|
||||
|
||||
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
|
||||
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase[UP, ID]):
|
||||
async def _get_user(self, statement: Select) -> UP | None:
|
||||
statement = statement.options(selectinload(User.memories))
|
||||
results = await self.session.execute(statement)
|
||||
return results.unique().scalar_one_or_none()
|
||||
|
||||
async def create(
|
||||
self,
|
||||
create_dict: Dict[str, Any],
|
||||
|
||||
@@ -8,6 +8,7 @@ from uuid import UUID
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import exists
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import nullsfirst
|
||||
from sqlalchemy import or_
|
||||
@@ -131,47 +132,32 @@ def get_chat_sessions_by_user(
|
||||
if before is not None:
|
||||
stmt = stmt.where(ChatSession.time_updated < before)
|
||||
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
if project_id is not None:
|
||||
stmt = stmt.where(ChatSession.project_id == project_id)
|
||||
elif only_non_project_chats:
|
||||
stmt = stmt.where(ChatSession.project_id.is_(None))
|
||||
|
||||
# When filtering out failed chats, we apply the limit in Python after
|
||||
# filtering rather than in SQL, since the post-filter may remove rows.
|
||||
if limit and include_failed_chats:
|
||||
stmt = stmt.limit(limit)
|
||||
if not include_failed_chats:
|
||||
non_system_message_exists_subq = (
|
||||
exists()
|
||||
.where(ChatMessage.chat_session_id == ChatSession.id)
|
||||
.where(ChatMessage.message_type != MessageType.SYSTEM)
|
||||
.correlate(ChatSession)
|
||||
)
|
||||
|
||||
# Leeway for newly created chats that don't have messages yet
|
||||
time = datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||
recently_created = ChatSession.time_created >= time
|
||||
|
||||
stmt = stmt.where(or_(non_system_message_exists_subq, recently_created))
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
chat_sessions = list(result.scalars().all())
|
||||
chat_sessions = result.scalars().all()
|
||||
|
||||
if not include_failed_chats and chat_sessions:
|
||||
# Filter out "failed" sessions (those with only SYSTEM messages)
|
||||
# using a separate efficient query instead of a correlated EXISTS
|
||||
# subquery, which causes full sequential scans of chat_message.
|
||||
leeway = datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||
session_ids = [cs.id for cs in chat_sessions if cs.time_created < leeway]
|
||||
|
||||
if session_ids:
|
||||
valid_session_ids_stmt = (
|
||||
select(ChatMessage.chat_session_id)
|
||||
.where(ChatMessage.chat_session_id.in_(session_ids))
|
||||
.where(ChatMessage.message_type != MessageType.SYSTEM)
|
||||
.distinct()
|
||||
)
|
||||
valid_session_ids = set(
|
||||
db_session.execute(valid_session_ids_stmt).scalars().all()
|
||||
)
|
||||
|
||||
chat_sessions = [
|
||||
cs
|
||||
for cs in chat_sessions
|
||||
if cs.time_created >= leeway or cs.id in valid_session_ids
|
||||
]
|
||||
|
||||
if limit:
|
||||
chat_sessions = chat_sessions[:limit]
|
||||
|
||||
return chat_sessions
|
||||
return list(chat_sessions)
|
||||
|
||||
|
||||
def delete_orphaned_search_docs(db_session: Session) -> None:
|
||||
@@ -199,7 +185,15 @@ def delete_messages_and_files_from_chat_session(
|
||||
for _, files in messages_with_files:
|
||||
file_store = get_default_file_store()
|
||||
for file_info in files or []:
|
||||
file_store.delete_file(file_id=file_info.get("id"))
|
||||
try:
|
||||
file_store.delete_file(file_id=file_info.get("id"))
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to delete file %s from file store during "
|
||||
"chat session cleanup, skipping.",
|
||||
file_info.get("id"),
|
||||
)
|
||||
continue
|
||||
|
||||
# Delete ChatMessage records - CASCADE constraints will automatically handle:
|
||||
# - ChatMessage__StandardAnswer relationship records
|
||||
@@ -631,91 +625,6 @@ def reserve_message_id(
|
||||
return empty_message
|
||||
|
||||
|
||||
def reserve_multi_model_message_ids(
|
||||
db_session: Session,
|
||||
chat_session_id: UUID,
|
||||
parent_message_id: int,
|
||||
model_display_names: list[str],
|
||||
) -> list[ChatMessage]:
|
||||
"""Reserve N assistant message placeholders for multi-model parallel streaming.
|
||||
|
||||
All messages share the same parent (the user message). The parent's
|
||||
latest_child_message_id points to the LAST reserved message so that the
|
||||
default history-chain walker picks it up.
|
||||
"""
|
||||
reserved: list[ChatMessage] = []
|
||||
for display_name in model_display_names:
|
||||
msg = ChatMessage(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=parent_message_id,
|
||||
latest_child_message_id=None,
|
||||
message="Response was terminated prior to completion, try regenerating.",
|
||||
token_count=15, # placeholder; updated on completion by llm_loop_completion_handle
|
||||
message_type=MessageType.ASSISTANT,
|
||||
model_display_name=display_name,
|
||||
)
|
||||
db_session.add(msg)
|
||||
reserved.append(msg)
|
||||
|
||||
# Flush to assign IDs without committing yet
|
||||
db_session.flush()
|
||||
|
||||
# Point parent's latest_child to the last reserved message
|
||||
parent = (
|
||||
db_session.query(ChatMessage)
|
||||
.filter(ChatMessage.id == parent_message_id)
|
||||
.first()
|
||||
)
|
||||
if parent:
|
||||
parent.latest_child_message_id = reserved[-1].id
|
||||
|
||||
db_session.commit()
|
||||
return reserved
|
||||
|
||||
|
||||
def set_preferred_response(
|
||||
db_session: Session,
|
||||
user_message_id: int,
|
||||
preferred_assistant_message_id: int,
|
||||
) -> None:
|
||||
"""Mark one assistant response as the user's preferred choice in a multi-model turn.
|
||||
|
||||
Also advances ``latest_child_message_id`` so the preferred response becomes
|
||||
the active branch for any subsequent messages in the conversation.
|
||||
|
||||
Args:
|
||||
db_session: Active database session.
|
||||
user_message_id: Primary key of the ``USER``-type ``ChatMessage`` whose
|
||||
preferred response is being set.
|
||||
preferred_assistant_message_id: Primary key of the ``ASSISTANT``-type
|
||||
``ChatMessage`` to prefer. Must be a direct child of ``user_message_id``.
|
||||
|
||||
Raises:
|
||||
ValueError: If either message is not found, if ``user_message_id`` does not
|
||||
refer to a USER message, or if the assistant message is not a direct child
|
||||
of the user message.
|
||||
"""
|
||||
user_msg = db_session.get(ChatMessage, user_message_id)
|
||||
if user_msg is None:
|
||||
raise ValueError(f"User message {user_message_id} not found")
|
||||
if user_msg.message_type != MessageType.USER:
|
||||
raise ValueError(f"Message {user_message_id} is not a user message")
|
||||
|
||||
assistant_msg = db_session.get(ChatMessage, preferred_assistant_message_id)
|
||||
if assistant_msg is None:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} not found"
|
||||
)
|
||||
if assistant_msg.parent_message_id != user_message_id:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} is not a child of user message {user_message_id}"
|
||||
)
|
||||
|
||||
user_msg.preferred_response_id = preferred_assistant_message_id
|
||||
user_msg.latest_child_message_id = preferred_assistant_message_id
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def create_new_chat_message(
|
||||
chat_session_id: UUID,
|
||||
parent_message: ChatMessage,
|
||||
@@ -938,8 +847,6 @@ def translate_db_message_to_chat_message_detail(
|
||||
error=chat_message.error,
|
||||
current_feedback=current_feedback,
|
||||
processing_duration_seconds=chat_message.processing_duration_seconds,
|
||||
preferred_response_id=chat_message.preferred_response_id,
|
||||
model_display_name=chat_message.model_display_name,
|
||||
)
|
||||
|
||||
return chat_msg_detail
|
||||
|
||||
@@ -13,26 +13,19 @@ class AccountType(str, PyEnum):
|
||||
BOT, EXT_PERM_USER, ANONYMOUS → fixed behavior
|
||||
"""
|
||||
|
||||
STANDARD = "STANDARD"
|
||||
BOT = "BOT"
|
||||
EXT_PERM_USER = "EXT_PERM_USER"
|
||||
SERVICE_ACCOUNT = "SERVICE_ACCOUNT"
|
||||
ANONYMOUS = "ANONYMOUS"
|
||||
|
||||
def is_web_login(self) -> bool:
|
||||
"""Whether this account type supports interactive web login."""
|
||||
return self not in (
|
||||
AccountType.BOT,
|
||||
AccountType.EXT_PERM_USER,
|
||||
)
|
||||
STANDARD = "standard"
|
||||
BOT = "bot"
|
||||
EXT_PERM_USER = "ext_perm_user"
|
||||
SERVICE_ACCOUNT = "service_account"
|
||||
ANONYMOUS = "anonymous"
|
||||
|
||||
|
||||
class GrantSource(str, PyEnum):
|
||||
"""How a permission grant was created."""
|
||||
|
||||
USER = "USER"
|
||||
SCIM = "SCIM"
|
||||
SYSTEM = "SYSTEM"
|
||||
USER = "user"
|
||||
SCIM = "scim"
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class IndexingStatus(str, PyEnum):
|
||||
@@ -222,7 +215,6 @@ class UserFileStatus(str, PyEnum):
|
||||
PROCESSING = "PROCESSING"
|
||||
INDEXING = "INDEXING"
|
||||
COMPLETED = "COMPLETED"
|
||||
SKIPPED = "SKIPPED"
|
||||
FAILED = "FAILED"
|
||||
CANCELED = "CANCELED"
|
||||
DELETING = "DELETING"
|
||||
|
||||
@@ -305,11 +305,8 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
role: Mapped[UserRole] = mapped_column(
|
||||
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
|
||||
)
|
||||
account_type: Mapped[AccountType] = mapped_column(
|
||||
Enum(AccountType, native_enum=False),
|
||||
nullable=False,
|
||||
default=AccountType.STANDARD,
|
||||
server_default="STANDARD",
|
||||
account_type: Mapped[AccountType | None] = mapped_column(
|
||||
Enum(AccountType, native_enum=False), nullable=True
|
||||
)
|
||||
|
||||
"""
|
||||
@@ -356,13 +353,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
postgresql.JSONB(), nullable=True, default=None
|
||||
)
|
||||
|
||||
effective_permissions: Mapped[list[str]] = mapped_column(
|
||||
postgresql.JSONB(),
|
||||
nullable=False,
|
||||
default=list,
|
||||
server_default=text("'[]'::jsonb"),
|
||||
)
|
||||
|
||||
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
|
||||
TIMESTAMPAware(timezone=True), nullable=True
|
||||
)
|
||||
@@ -4026,12 +4016,7 @@ class PermissionGrant(Base):
|
||||
ForeignKey("user_group.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
permission: Mapped[Permission] = mapped_column(
|
||||
Enum(
|
||||
Permission,
|
||||
native_enum=False,
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
nullable=False,
|
||||
Enum(Permission, native_enum=False), nullable=False
|
||||
)
|
||||
grant_source: Mapped[GrantSource] = mapped_column(
|
||||
Enum(GrantSource, native_enum=False), nullable=False
|
||||
|
||||
@@ -8,6 +8,7 @@ from uuid import UUID
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.pat import build_displayable_pat
|
||||
@@ -46,6 +47,7 @@ async def fetch_user_for_pat(
|
||||
(PersonalAccessToken.expires_at.is_(None))
|
||||
| (PersonalAccessToken.expires_at > now)
|
||||
)
|
||||
.options(selectinload(User.memories))
|
||||
)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
"""
|
||||
DB operations for recomputing user effective_permissions.
|
||||
|
||||
These live in onyx/db/ (not onyx/auth/) because they are pure DB operations
|
||||
that query PermissionGrant rows and update the User.effective_permissions
|
||||
JSONB column. Keeping them here avoids circular imports when called from
|
||||
other onyx/db/ modules such as users.py.
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import PermissionGrant
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
|
||||
|
||||
def recompute_user_permissions__no_commit(
|
||||
user_ids: UUID | str | list[UUID] | list[str], db_session: Session
|
||||
) -> None:
|
||||
"""Recompute granted permissions for one or more users.
|
||||
|
||||
Accepts a single UUID or a list. Uses a single query regardless of
|
||||
how many users are passed, avoiding N+1 issues.
|
||||
|
||||
Stores only directly granted permissions — implication expansion
|
||||
happens at read time via get_effective_permissions().
|
||||
|
||||
Does NOT commit — caller must commit the session.
|
||||
"""
|
||||
if isinstance(user_ids, (UUID, str)):
|
||||
uid_list = [user_ids]
|
||||
else:
|
||||
uid_list = list(user_ids)
|
||||
|
||||
if not uid_list:
|
||||
return
|
||||
|
||||
# Single query to fetch ALL permissions for these users across ALL their
|
||||
# groups (a user may belong to multiple groups with different grants).
|
||||
rows = db_session.execute(
|
||||
select(User__UserGroup.user_id, PermissionGrant.permission)
|
||||
.join(
|
||||
PermissionGrant,
|
||||
PermissionGrant.group_id == User__UserGroup.user_group_id,
|
||||
)
|
||||
.where(
|
||||
User__UserGroup.user_id.in_(uid_list),
|
||||
PermissionGrant.is_deleted.is_(False),
|
||||
)
|
||||
).all()
|
||||
|
||||
# Group permissions by user; users with no grants get an empty set.
|
||||
perms_by_user: dict[UUID | str, set[str]] = defaultdict(set)
|
||||
for uid in uid_list:
|
||||
perms_by_user[uid] # ensure every user has an entry
|
||||
for uid, perm in rows:
|
||||
perms_by_user[uid].add(perm.value)
|
||||
|
||||
for uid, perms in perms_by_user.items():
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == uid) # type: ignore[arg-type]
|
||||
.values(effective_permissions=sorted(perms))
|
||||
)
|
||||
|
||||
|
||||
def recompute_permissions_for_group__no_commit(
|
||||
group_id: int, db_session: Session
|
||||
) -> None:
|
||||
"""Recompute granted permissions for all users in a group.
|
||||
|
||||
Does NOT commit — caller must commit the session.
|
||||
"""
|
||||
user_ids: list[UUID] = [
|
||||
uid
|
||||
for uid in db_session.execute(
|
||||
select(User__UserGroup.user_id).where(
|
||||
User__UserGroup.user_group_id == group_id,
|
||||
User__UserGroup.user_id.isnot(None),
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
if uid is not None
|
||||
]
|
||||
|
||||
if not user_ids:
|
||||
return
|
||||
|
||||
recompute_user_permissions__no_commit(user_ids, db_session)
|
||||
@@ -7,7 +7,6 @@ from fastapi import HTTPException
|
||||
from fastapi import UploadFile
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.background import BackgroundTasks
|
||||
@@ -18,7 +17,6 @@ from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import Project__UserFile
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
@@ -36,19 +34,9 @@ class CategorizedFilesResult(BaseModel):
|
||||
user_files: list[UserFile]
|
||||
rejected_files: list[RejectedFile]
|
||||
id_to_temp_id: dict[str, str]
|
||||
# Filenames that should be stored but not indexed.
|
||||
skip_indexing_filenames: set[str] = Field(default_factory=set)
|
||||
# Allow SQLAlchemy ORM models inside this result container
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@property
|
||||
def indexable_files(self) -> list[UserFile]:
|
||||
return [
|
||||
uf
|
||||
for uf in self.user_files
|
||||
if (uf.name or "") not in self.skip_indexing_filenames
|
||||
]
|
||||
|
||||
|
||||
def build_hashed_file_key(file: UploadFile) -> str:
|
||||
name_prefix = (file.filename or "")[:50]
|
||||
@@ -82,7 +70,6 @@ def create_user_files(
|
||||
)
|
||||
if new_temp_id is not None:
|
||||
id_to_temp_id[str(new_id)] = new_temp_id
|
||||
should_skip = (file.filename or "") in categorized_files.skip_indexing
|
||||
new_file = UserFile(
|
||||
id=new_id,
|
||||
user_id=user.id,
|
||||
@@ -94,7 +81,6 @@ def create_user_files(
|
||||
link_url=link_url,
|
||||
content_type=file.content_type,
|
||||
file_type=file.content_type,
|
||||
status=UserFileStatus.SKIPPED if should_skip else UserFileStatus.PROCESSING,
|
||||
last_accessed_at=datetime.datetime.now(datetime.timezone.utc),
|
||||
)
|
||||
# Persist the UserFile first to satisfy FK constraints for association table
|
||||
@@ -112,7 +98,6 @@ def create_user_files(
|
||||
user_files=user_files,
|
||||
rejected_files=rejected_files,
|
||||
id_to_temp_id=id_to_temp_id,
|
||||
skip_indexing_filenames=categorized_files.skip_indexing,
|
||||
)
|
||||
|
||||
|
||||
@@ -138,7 +123,6 @@ def upload_files_to_user_files_with_indexing(
|
||||
user_files = categorized_files_result.user_files
|
||||
rejected_files = categorized_files_result.rejected_files
|
||||
id_to_temp_id = categorized_files_result.id_to_temp_id
|
||||
indexable_files = categorized_files_result.indexable_files
|
||||
# Trigger per-file processing immediately for the current tenant
|
||||
tenant_id = get_current_tenant_id()
|
||||
for rejected_file in rejected_files:
|
||||
@@ -150,12 +134,12 @@ def upload_files_to_user_files_with_indexing(
|
||||
from onyx.background.task_utils import drain_processing_loop
|
||||
|
||||
background_tasks.add_task(drain_processing_loop, tenant_id)
|
||||
for user_file in indexable_files:
|
||||
for user_file in user_files:
|
||||
logger.info(f"Queued in-process processing for user_file_id={user_file.id}")
|
||||
else:
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
|
||||
for user_file in indexable_files:
|
||||
for user_file in user_files:
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
@@ -171,7 +155,6 @@ def upload_files_to_user_files_with_indexing(
|
||||
user_files=user_files,
|
||||
rejected_files=rejected_files,
|
||||
id_to_temp_id=id_to_temp_id,
|
||||
skip_indexing_filenames=categorized_files_result.skip_indexing_filenames,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -5,11 +5,11 @@ from urllib.parse import urlencode
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import INSTANCE_TYPE
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.configs.constants import ONYX_UTM_SOURCE
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import User
|
||||
from onyx.db.notification import batch_create_notifications
|
||||
from onyx.server.features.release_notes.constants import DOCS_CHANGELOG_BASE_URL
|
||||
@@ -49,7 +49,7 @@ def create_release_notifications_for_versions(
|
||||
db_session.scalars(
|
||||
select(User.id).where( # type: ignore
|
||||
User.is_active == True, # noqa: E712
|
||||
User.account_type.notin_([AccountType.BOT, AccountType.EXT_PERM_USER]),
|
||||
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]),
|
||||
User.email.endswith(DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN).is_(False), # type: ignore[attr-defined]
|
||||
)
|
||||
).all()
|
||||
|
||||
@@ -9,17 +9,12 @@ from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import DefaultAppMode
|
||||
from onyx.db.enums import ThemePreference
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import Assistant__UserSpecificConfig
|
||||
from onyx.db.models import Memory
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.server.manage.models import MemoryItem
|
||||
from onyx.server.manage.models import UserSpecificAssistantPreference
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -28,53 +23,13 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_ROLE_TO_ACCOUNT_TYPE: dict[UserRole, AccountType] = {
|
||||
UserRole.SLACK_USER: AccountType.BOT,
|
||||
UserRole.EXT_PERM_USER: AccountType.EXT_PERM_USER,
|
||||
}
|
||||
|
||||
|
||||
def update_user_role(
|
||||
user: User,
|
||||
new_role: UserRole,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Update a user's role in the database.
|
||||
Dual-writes account_type to keep it in sync with role and
|
||||
reconciles default-group membership (Admin / Basic)."""
|
||||
old_role = user.role
|
||||
"""Update a user's role in the database."""
|
||||
user.role = new_role
|
||||
# Note: setting account_type to BOT or EXT_PERM_USER causes
|
||||
# assign_user_to_default_groups__no_commit to early-return, which is
|
||||
# intentional — these account types should not be in default groups.
|
||||
if new_role in _ROLE_TO_ACCOUNT_TYPE:
|
||||
user.account_type = _ROLE_TO_ACCOUNT_TYPE[new_role]
|
||||
elif user.account_type in (AccountType.BOT, AccountType.EXT_PERM_USER):
|
||||
# Upgrading from a non-web-login account type to a web role
|
||||
user.account_type = AccountType.STANDARD
|
||||
|
||||
# Reconcile default-group membership when the role changes.
|
||||
if old_role != new_role:
|
||||
# Remove from all default groups first.
|
||||
db_session.execute(
|
||||
delete(User__UserGroup).where(
|
||||
User__UserGroup.user_id == user.id,
|
||||
User__UserGroup.user_group_id.in_(
|
||||
select(UserGroup.id).where(UserGroup.is_default.is_(True))
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Re-assign to the correct default group (skip for LIMITED).
|
||||
if new_role != UserRole.LIMITED:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session,
|
||||
user,
|
||||
is_admin=(new_role == UserRole.ADMIN),
|
||||
)
|
||||
|
||||
recompute_user_permissions__no_commit(user.id, db_session)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -92,16 +47,8 @@ def activate_user(
|
||||
user: User,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Activate a user by setting is_active to True.
|
||||
|
||||
Also reconciles default-group membership — the user may have been
|
||||
created while inactive or deactivated before the backfill migration.
|
||||
"""
|
||||
"""Activate a user by setting is_active to True."""
|
||||
user.is_active = True
|
||||
if user.role != UserRole.LIMITED:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session, user, is_admin=(user.role == UserRole.ADMIN)
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
@@ -282,9 +229,7 @@ def get_memories_for_user(
|
||||
user_id: UUID,
|
||||
db_session: Session,
|
||||
) -> Sequence[Memory]:
|
||||
return db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.desc())
|
||||
).all()
|
||||
return db_session.scalars(select(Memory).where(Memory.user_id == user_id)).all()
|
||||
|
||||
|
||||
def update_user_pinned_assistants(
|
||||
|
||||
@@ -17,9 +17,8 @@ from sqlalchemy.sql.expression import or_
|
||||
from onyx.auth.invited_users import remove_user_from_invited_users
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import NO_AUTH_PLACEHOLDER_USER_EMAIL
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.api_key import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__User
|
||||
from onyx.db.models import Persona
|
||||
@@ -28,17 +27,11 @@ from onyx.db.models import SamlAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def validate_user_role_update(
|
||||
requested_role: UserRole,
|
||||
current_role: UserRole,
|
||||
current_account_type: AccountType,
|
||||
explicit_override: bool = False,
|
||||
requested_role: UserRole, current_role: UserRole, explicit_override: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Validate that a user role update is valid.
|
||||
@@ -48,18 +41,19 @@ def validate_user_role_update(
|
||||
- requested role is a slack user
|
||||
- requested role is an external permissioned user
|
||||
- requested role is a limited user
|
||||
- current account type is BOT (slack user)
|
||||
- current account type is EXT_PERM_USER
|
||||
- current role is a slack user
|
||||
- current role is an external permissioned user
|
||||
- current role is a limited user
|
||||
"""
|
||||
|
||||
if current_account_type == AccountType.BOT:
|
||||
if current_role == UserRole.SLACK_USER:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="To change a Slack User's role, they must first login to Onyx via the web app.",
|
||||
)
|
||||
|
||||
if current_account_type == AccountType.EXT_PERM_USER:
|
||||
if current_role == UserRole.EXT_PERM_USER:
|
||||
# This shouldn't happen, but just in case
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="To change an External Permissioned User's role, they must first login to Onyx via the web app.",
|
||||
@@ -304,7 +298,6 @@ def _generate_slack_user(email: str) -> User:
|
||||
email=email,
|
||||
hashed_password=hashed_pass,
|
||||
role=UserRole.SLACK_USER,
|
||||
account_type=AccountType.BOT,
|
||||
)
|
||||
|
||||
|
||||
@@ -313,9 +306,8 @@ def add_slack_user_if_not_exists(db_session: Session, email: str) -> User:
|
||||
user = get_user_by_email(email, db_session)
|
||||
if user is not None:
|
||||
# If the user is an external permissioned user, we update it to a slack user
|
||||
if user.account_type == AccountType.EXT_PERM_USER:
|
||||
if user.role == UserRole.EXT_PERM_USER:
|
||||
user.role = UserRole.SLACK_USER
|
||||
user.account_type = AccountType.BOT
|
||||
db_session.commit()
|
||||
return user
|
||||
|
||||
@@ -352,7 +344,6 @@ def _generate_ext_permissioned_user(email: str) -> User:
|
||||
email=email,
|
||||
hashed_password=hashed_pass,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
account_type=AccountType.EXT_PERM_USER,
|
||||
)
|
||||
|
||||
|
||||
@@ -384,81 +375,6 @@ def batch_add_ext_perm_user_if_not_exists(
|
||||
return all_users
|
||||
|
||||
|
||||
def assign_user_to_default_groups__no_commit(
|
||||
db_session: Session,
|
||||
user: User,
|
||||
is_admin: bool = False,
|
||||
) -> None:
|
||||
"""Assign a newly created user to the appropriate default group.
|
||||
|
||||
Does NOT commit — callers must commit the session themselves so that
|
||||
group assignment can be part of the same transaction as user creation.
|
||||
|
||||
Args:
|
||||
is_admin: If True, assign to Admin default group; otherwise Basic.
|
||||
Callers determine this from their own context (e.g. user_count,
|
||||
admin email list, explicit choice). Defaults to False (Basic).
|
||||
"""
|
||||
if user.account_type in (
|
||||
AccountType.BOT,
|
||||
AccountType.EXT_PERM_USER,
|
||||
AccountType.ANONYMOUS,
|
||||
):
|
||||
return
|
||||
|
||||
target_group_name = "Admin" if is_admin else "Basic"
|
||||
|
||||
default_group = (
|
||||
db_session.query(UserGroup)
|
||||
.filter(
|
||||
UserGroup.name == target_group_name,
|
||||
UserGroup.is_default.is_(True),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if default_group is None:
|
||||
raise RuntimeError(
|
||||
f"Default group '{target_group_name}' not found. "
|
||||
f"Cannot assign user {user.email} to a group. "
|
||||
f"Ensure the seed_default_groups migration has run."
|
||||
)
|
||||
|
||||
# Check if the user is already in the group
|
||||
existing = (
|
||||
db_session.query(User__UserGroup)
|
||||
.filter(
|
||||
User__UserGroup.user_id == user.id,
|
||||
User__UserGroup.user_group_id == default_group.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing is not None:
|
||||
return
|
||||
|
||||
savepoint = db_session.begin_nested()
|
||||
try:
|
||||
db_session.add(
|
||||
User__UserGroup(
|
||||
user_id=user.id,
|
||||
user_group_id=default_group.id,
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
except IntegrityError:
|
||||
# Race condition: another transaction inserted this membership
|
||||
# between our SELECT and INSERT. The savepoint isolates the failure
|
||||
# so the outer transaction (user creation) stays intact.
|
||||
savepoint.rollback()
|
||||
return
|
||||
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
|
||||
recompute_user_permissions__no_commit(user.id, db_session)
|
||||
|
||||
logger.info(f"Assigned user {user.email} to default group '{default_group.name}'")
|
||||
|
||||
|
||||
def delete_user_from_db(
|
||||
user_to_delete: User,
|
||||
db_session: Session,
|
||||
@@ -505,14 +421,13 @@ def delete_user_from_db(
|
||||
def batch_get_user_groups(
|
||||
db_session: Session,
|
||||
user_ids: list[UUID],
|
||||
include_default: bool = False,
|
||||
) -> dict[UUID, list[tuple[int, str]]]:
|
||||
"""Fetch group memberships for a batch of users in a single query.
|
||||
Returns a mapping of user_id -> list of (group_id, group_name) tuples."""
|
||||
if not user_ids:
|
||||
return {}
|
||||
|
||||
stmt = (
|
||||
rows = db_session.execute(
|
||||
select(
|
||||
User__UserGroup.user_id,
|
||||
UserGroup.id,
|
||||
@@ -520,11 +435,7 @@ def batch_get_user_groups(
|
||||
)
|
||||
.join(UserGroup, UserGroup.id == User__UserGroup.user_group_id)
|
||||
.where(User__UserGroup.user_id.in_(user_ids))
|
||||
)
|
||||
if not include_default:
|
||||
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
|
||||
|
||||
rows = db_session.execute(stmt).all()
|
||||
).all()
|
||||
|
||||
result: dict[UUID, list[tuple[int, str]]] = {uid: [] for uid in user_ids}
|
||||
for user_id, group_id, group_name in rows:
|
||||
|
||||
@@ -932,7 +932,7 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
def search_for_document_ids(
|
||||
self,
|
||||
body: dict[str, Any],
|
||||
search_type: OpenSearchSearchType = OpenSearchSearchType.UNKNOWN,
|
||||
search_type: OpenSearchSearchType = OpenSearchSearchType.DOCUMENT_IDS,
|
||||
) -> list[str]:
|
||||
"""Searches the index and returns only document chunk IDs.
|
||||
|
||||
|
||||
@@ -37,10 +37,10 @@ M = 32 # Set relatively high for better accuracy.
|
||||
# we have a much higher chance of all 10 of the final desired docs showing up
|
||||
# and getting scored. In worse situations, the final 10 docs don't even show up
|
||||
# as the final 10 (worse than just a miss at the reranking step).
|
||||
# Defaults to 500 for now. Initially this defaulted to 750 but we were seeing
|
||||
# poor search performance; bumped from 100 to 500 to improve recall.
|
||||
# Defaults to 100 for now. Initially this defaulted to 750 but we were seeing
|
||||
# poor search performance.
|
||||
DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES = int(
|
||||
os.environ.get("DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES", 500)
|
||||
os.environ.get("DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES", 100)
|
||||
)
|
||||
|
||||
# Number of vectors to examine to decide the top k neighbors for the HNSW
|
||||
@@ -60,7 +60,8 @@ class OpenSearchSearchType(str, Enum):
|
||||
KEYWORD = "keyword"
|
||||
SEMANTIC = "semantic"
|
||||
RANDOM = "random"
|
||||
DOC_ID_RETRIEVAL = "doc_id_retrieval"
|
||||
ID_RETRIEVAL = "id_retrieval"
|
||||
DOCUMENT_IDS = "document_ids"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
|
||||
@@ -928,7 +928,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
search_hits = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
search_type=OpenSearchSearchType.DOC_ID_RETRIEVAL,
|
||||
search_type=OpenSearchSearchType.ID_RETRIEVAL,
|
||||
)
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
|
||||
@@ -20,7 +20,6 @@ from onyx.background.celery.tasks.opensearch_migration.transformer import (
|
||||
from onyx.configs.app_configs import LOG_VESPA_TIMING_INFORMATION
|
||||
from onyx.configs.app_configs import VESPA_LANGUAGE_OVERRIDE
|
||||
from onyx.configs.app_configs import VESPA_MIGRATION_REQUEST_TIMEOUT_S
|
||||
from onyx.configs.app_configs import VESPA_MIGRATION_SERVER_SIDE_REQUEST_TIMEOUT
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunkUncleaned
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
@@ -336,11 +335,6 @@ def get_all_chunks_paginated(
|
||||
"format.tensors": "short-value",
|
||||
"slices": total_slices,
|
||||
"sliceId": slice_id,
|
||||
# When exceeded, Vespa should return gracefully with partial
|
||||
# results. Even if no hits are returned, Vespa should still return a
|
||||
# new continuation token representing a new spot in the linear
|
||||
# traversal.
|
||||
"timeout": VESPA_MIGRATION_SERVER_SIDE_REQUEST_TIMEOUT,
|
||||
}
|
||||
if continuation_token is not None:
|
||||
params["continuation"] = continuation_token
|
||||
@@ -349,9 +343,6 @@ def get_all_chunks_paginated(
|
||||
start_time = time.monotonic()
|
||||
try:
|
||||
with get_vespa_http_client(
|
||||
# When exceeded, an exception is raised in our code. No progress
|
||||
# is saved, and the task will retry this spot in the traversal
|
||||
# later.
|
||||
timeout=VESPA_MIGRATION_REQUEST_TIMEOUT_S
|
||||
) as http_client:
|
||||
response = http_client.get(url, params=params)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import csv
|
||||
import gc
|
||||
import io
|
||||
import json
|
||||
@@ -20,7 +19,6 @@ from zipfile import BadZipFile
|
||||
|
||||
import chardet
|
||||
import openpyxl
|
||||
from openpyxl.worksheet.worksheet import Worksheet
|
||||
from PIL import Image
|
||||
|
||||
from onyx.configs.constants import ONYX_METADATA_FILENAME
|
||||
@@ -355,94 +353,6 @@ def pptx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
return presentation.markdown
|
||||
|
||||
|
||||
def _worksheet_to_matrix(
|
||||
worksheet: Worksheet,
|
||||
) -> list[list[str]]:
|
||||
"""
|
||||
Converts a singular worksheet to a matrix of values
|
||||
"""
|
||||
rows: list[list[str]] = []
|
||||
for worksheet_row in worksheet.iter_rows(min_row=1, values_only=True):
|
||||
row = ["" if cell is None else str(cell) for cell in worksheet_row]
|
||||
rows.append(row)
|
||||
|
||||
return rows
|
||||
|
||||
|
||||
def _clean_worksheet_matrix(matrix: list[list[str]]) -> list[list[str]]:
|
||||
"""
|
||||
Cleans a worksheet matrix by removing rows if there are N consecutive empty
|
||||
rows and removing cols if there are M consecutive empty columns
|
||||
"""
|
||||
MAX_EMPTY_ROWS = 2 # Runs longer than this are capped to max_empty; shorter runs are preserved as-is
|
||||
MAX_EMPTY_COLS = 2
|
||||
|
||||
# Row cleanup
|
||||
matrix = _remove_empty_runs(matrix, max_empty=MAX_EMPTY_ROWS)
|
||||
|
||||
if not matrix:
|
||||
return matrix
|
||||
|
||||
# Column cleanup — determine which columns to keep without transposing.
|
||||
num_cols = len(matrix[0])
|
||||
keep_cols = _columns_to_keep(matrix, num_cols, max_empty=MAX_EMPTY_COLS)
|
||||
if len(keep_cols) < num_cols:
|
||||
matrix = [[row[c] for c in keep_cols] for row in matrix]
|
||||
|
||||
return matrix
|
||||
|
||||
|
||||
def _columns_to_keep(
|
||||
matrix: list[list[str]], num_cols: int, max_empty: int
|
||||
) -> list[int]:
|
||||
"""Return the indices of columns to keep after removing empty-column runs.
|
||||
|
||||
Uses the same logic as ``_remove_empty_runs`` but operates on column
|
||||
indices so no transpose is needed.
|
||||
"""
|
||||
kept: list[int] = []
|
||||
empty_buffer: list[int] = []
|
||||
|
||||
for col_idx in range(num_cols):
|
||||
col_is_empty = all(not row[col_idx] for row in matrix)
|
||||
if col_is_empty:
|
||||
empty_buffer.append(col_idx)
|
||||
else:
|
||||
kept.extend(empty_buffer[:max_empty])
|
||||
kept.append(col_idx)
|
||||
empty_buffer = []
|
||||
|
||||
return kept
|
||||
|
||||
|
||||
def _remove_empty_runs(
|
||||
rows: list[list[str]],
|
||||
max_empty: int,
|
||||
) -> list[list[str]]:
|
||||
"""Removes entire runs of empty rows when the run length exceeds max_empty.
|
||||
|
||||
Leading empty runs are capped to max_empty, just like interior runs.
|
||||
Trailing empty rows are always dropped since there is no subsequent
|
||||
non-empty row to flush them.
|
||||
"""
|
||||
result: list[list[str]] = []
|
||||
empty_buffer: list[list[str]] = []
|
||||
|
||||
for row in rows:
|
||||
# Check if empty
|
||||
if not any(row):
|
||||
if len(empty_buffer) < max_empty:
|
||||
empty_buffer.append(row)
|
||||
else:
|
||||
# Add upto max empty rows onto the result - that's what we allow
|
||||
result.extend(empty_buffer[:max_empty])
|
||||
# Add the new non-empty row
|
||||
result.append(row)
|
||||
empty_buffer = []
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
# TODO: switch back to this approach in a few months when markitdown
|
||||
# fixes their handling of excel files
|
||||
@@ -481,15 +391,30 @@ def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
f"Failed to extract text from {file_name or 'xlsx file'}. This happens due to a bug in openpyxl. {e}"
|
||||
)
|
||||
return ""
|
||||
raise
|
||||
raise e
|
||||
|
||||
text_content = []
|
||||
for sheet in workbook.worksheets:
|
||||
sheet_matrix = _clean_worksheet_matrix(_worksheet_to_matrix(sheet))
|
||||
buf = io.StringIO()
|
||||
writer = csv.writer(buf, lineterminator="\n")
|
||||
writer.writerows(sheet_matrix)
|
||||
text_content.append(buf.getvalue().rstrip("\n"))
|
||||
rows = []
|
||||
num_empty_consecutive_rows = 0
|
||||
for row in sheet.iter_rows(min_row=1, values_only=True):
|
||||
row_str = ",".join(str(cell or "") for cell in row)
|
||||
|
||||
# Only add the row if there are any values in the cells
|
||||
if len(row_str) >= len(row):
|
||||
rows.append(row_str)
|
||||
num_empty_consecutive_rows = 0
|
||||
else:
|
||||
num_empty_consecutive_rows += 1
|
||||
|
||||
if num_empty_consecutive_rows > 100:
|
||||
# handle massive excel sheets with mostly empty cells
|
||||
logger.warning(
|
||||
f"Found {num_empty_consecutive_rows} empty rows in {file_name}, skipping rest of file"
|
||||
)
|
||||
break
|
||||
sheet_str = "\n".join(rows)
|
||||
text_content.append(sheet_str)
|
||||
return TEXT_SECTION_SEPARATOR.join(text_content)
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ PLAIN_TEXT_MIME_TYPE = "text/plain"
|
||||
class OnyxMimeTypes:
|
||||
IMAGE_MIME_TYPES = {"image/jpg", "image/jpeg", "image/png", "image/webp"}
|
||||
CSV_MIME_TYPES = {"text/csv"}
|
||||
TABULAR_MIME_TYPES = CSV_MIME_TYPES | {SPREADSHEET_MIME_TYPE}
|
||||
TEXT_MIME_TYPES = {
|
||||
PLAIN_TEXT_MIME_TYPE,
|
||||
"text/markdown",
|
||||
@@ -35,12 +34,13 @@ class OnyxMimeTypes:
|
||||
PDF_MIME_TYPE,
|
||||
WORD_PROCESSING_MIME_TYPE,
|
||||
PRESENTATION_MIME_TYPE,
|
||||
SPREADSHEET_MIME_TYPE,
|
||||
"message/rfc822",
|
||||
"application/epub+zip",
|
||||
}
|
||||
|
||||
ALLOWED_MIME_TYPES = IMAGE_MIME_TYPES.union(
|
||||
TEXT_MIME_TYPES, DOCUMENT_MIME_TYPES, TABULAR_MIME_TYPES
|
||||
TEXT_MIME_TYPES, DOCUMENT_MIME_TYPES, CSV_MIME_TYPES
|
||||
)
|
||||
|
||||
EXCLUDED_IMAGE_TYPES = {
|
||||
@@ -53,11 +53,6 @@ class OnyxMimeTypes:
|
||||
|
||||
|
||||
class OnyxFileExtensions:
|
||||
TABULAR_EXTENSIONS = {
|
||||
".csv",
|
||||
".tsv",
|
||||
".xlsx",
|
||||
}
|
||||
PLAIN_TEXT_EXTENSIONS = {
|
||||
".txt",
|
||||
".md",
|
||||
|
||||
@@ -13,21 +13,15 @@ class ChatFileType(str, Enum):
|
||||
DOC = "document"
|
||||
# Plain text only contain the text
|
||||
PLAIN_TEXT = "plain_text"
|
||||
# Tabular data files (CSV, XLSX)
|
||||
TABULAR = "tabular"
|
||||
CSV = "csv"
|
||||
|
||||
def is_text_file(self) -> bool:
|
||||
return self in (
|
||||
ChatFileType.PLAIN_TEXT,
|
||||
ChatFileType.DOC,
|
||||
ChatFileType.TABULAR,
|
||||
ChatFileType.CSV,
|
||||
)
|
||||
|
||||
def use_metadata_only(self) -> bool:
|
||||
"""File types where we can ignore the file content
|
||||
and only use the metadata."""
|
||||
return self in (ChatFileType.TABULAR,)
|
||||
|
||||
|
||||
class FileDescriptor(TypedDict):
|
||||
"""NOTE: is a `TypedDict` so it can be used as a type hint for a JSONB column
|
||||
|
||||
@@ -110,20 +110,16 @@ def load_user_file(file_id: UUID, db_session: Session) -> InMemoryChatFile:
|
||||
# check for plain text normalized version first, then use original file otherwise
|
||||
try:
|
||||
file_io = file_store.read_file(plaintext_file_name, mode="b")
|
||||
# Metadata-only file types preserve their original type so
|
||||
# downstream injection paths can route them correctly.
|
||||
if chat_file_type.use_metadata_only():
|
||||
plaintext_chat_file_type = chat_file_type
|
||||
elif file_io is not None:
|
||||
# if we have plaintext for image (which happens when image
|
||||
# extraction is enabled), we use PLAIN_TEXT type
|
||||
# For plaintext versions, use PLAIN_TEXT type (unless it's an image which doesn't have plaintext)
|
||||
plaintext_chat_file_type = (
|
||||
ChatFileType.PLAIN_TEXT
|
||||
if chat_file_type != ChatFileType.IMAGE
|
||||
else chat_file_type
|
||||
)
|
||||
|
||||
# if we have plaintext for image (which happens when image extraction is enabled), we use PLAIN_TEXT type
|
||||
if file_io is not None:
|
||||
plaintext_chat_file_type = ChatFileType.PLAIN_TEXT
|
||||
else:
|
||||
plaintext_chat_file_type = (
|
||||
ChatFileType.PLAIN_TEXT
|
||||
if chat_file_type != ChatFileType.IMAGE
|
||||
else chat_file_type
|
||||
)
|
||||
|
||||
chat_file = InMemoryChatFile(
|
||||
file_id=str(user_file.file_id),
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from onyx.configs.app_configs import HOOK_ENABLED
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -6,7 +7,10 @@ from shared_configs.configs import MULTI_TENANT
|
||||
def require_hook_enabled() -> None:
|
||||
"""FastAPI dependency that gates all hook management endpoints.
|
||||
|
||||
Hooks are only available in single-tenant / self-hosted EE deployments.
|
||||
Hooks are only available in single-tenant / self-hosted deployments with
|
||||
HOOK_ENABLED=true explicitly set. Two layers of protection:
|
||||
1. MULTI_TENANT check — rejects even if HOOK_ENABLED is accidentally set true
|
||||
2. HOOK_ENABLED flag — explicit opt-in by the operator
|
||||
|
||||
Use as: Depends(require_hook_enabled)
|
||||
"""
|
||||
@@ -15,3 +19,8 @@ def require_hook_enabled() -> None:
|
||||
OnyxErrorCode.SINGLE_TENANT_ONLY,
|
||||
"Hooks are not available in multi-tenant deployments",
|
||||
)
|
||||
if not HOOK_ENABLED:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.ENV_VAR_GATED,
|
||||
"Hooks are not enabled. Set HOOK_ENABLED=true to enable.",
|
||||
)
|
||||
|
||||
@@ -1,22 +1,79 @@
|
||||
"""CE hook executor.
|
||||
"""Hook executor — calls a customer's external HTTP endpoint for a given hook point.
|
||||
|
||||
HookSkipped and HookSoftFailed are real classes kept here because
|
||||
process_message.py (CE code) uses isinstance checks against them.
|
||||
Usage (Celery tasks and FastAPI handlers):
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
|
||||
response_type=QueryProcessingResponse,
|
||||
)
|
||||
|
||||
execute_hook is the public entry point. It dispatches to _execute_hook_impl
|
||||
via fetch_versioned_implementation so that:
|
||||
- CE: onyx.hooks.executor._execute_hook_impl → no-op, returns HookSkipped()
|
||||
- EE: ee.onyx.hooks.executor._execute_hook_impl → real HTTP call
|
||||
if isinstance(result, HookSkipped):
|
||||
# no active hook configured — continue with original behavior
|
||||
...
|
||||
elif isinstance(result, HookSoftFailed):
|
||||
# hook failed but fail strategy is SOFT — continue with original behavior
|
||||
...
|
||||
else:
|
||||
# result is a validated Pydantic model instance (response_type)
|
||||
...
|
||||
|
||||
is_reachable update policy
|
||||
--------------------------
|
||||
``is_reachable`` on the Hook row is updated selectively — only when the outcome
|
||||
carries meaningful signal about physical reachability:
|
||||
|
||||
NetworkError (DNS, connection refused) → False (cannot reach the server)
|
||||
HTTP 401 / 403 → False (api_key revoked or invalid)
|
||||
TimeoutException → None (server may be slow, skip write)
|
||||
Other HTTP errors (4xx / 5xx) → None (server responded, skip write)
|
||||
Unknown exception → None (no signal, skip write)
|
||||
Non-JSON / non-dict response → None (server responded, skip write)
|
||||
Success (2xx, valid dict) → True (confirmed reachable)
|
||||
|
||||
None means "leave the current value unchanged" — no DB round-trip is made.
|
||||
|
||||
DB session design
|
||||
-----------------
|
||||
The executor uses three sessions:
|
||||
|
||||
1. Caller's session (db_session) — used only for the hook lookup read. All
|
||||
needed fields are extracted from the Hook object before the HTTP call, so
|
||||
the caller's session is not held open during the external HTTP request.
|
||||
|
||||
2. Log session — a separate short-lived session opened after the HTTP call
|
||||
completes to write the HookExecutionLog row on failure. Success runs are
|
||||
not recorded. Committed independently of everything else.
|
||||
|
||||
3. Reachable session — a second short-lived session to update is_reachable on
|
||||
the Hook. Kept separate from the log session so a concurrent hook deletion
|
||||
(which causes update_hook__no_commit to raise OnyxError(NOT_FOUND)) cannot
|
||||
prevent the execution log from being written. This update is best-effort.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from onyx.db.hook import create_hook_execution_log__no_commit
|
||||
from onyx.db.hook import get_non_deleted_hook_by_hook_point
|
||||
from onyx.db.hook import update_hook__no_commit
|
||||
from onyx.db.models import Hook
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class HookSkipped:
|
||||
@@ -30,15 +87,277 @@ class HookSoftFailed:
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
def _execute_hook_impl(
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _HttpOutcome(BaseModel):
|
||||
"""Structured result of an HTTP hook call, returned by _process_response."""
|
||||
|
||||
is_success: bool
|
||||
updated_is_reachable: (
|
||||
bool | None
|
||||
) # True/False = write to DB, None = unchanged (skip write)
|
||||
status_code: int | None
|
||||
error_message: str | None
|
||||
response_payload: dict[str, Any] | None
|
||||
|
||||
|
||||
def _lookup_hook(
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
) -> Hook | HookSkipped:
|
||||
"""Return the active Hook or HookSkipped if hooks are unavailable/unconfigured.
|
||||
|
||||
No HTTP call is made and no DB writes are performed for any HookSkipped path.
|
||||
There is nothing to log and no reachability information to update.
|
||||
"""
|
||||
if not HOOKS_AVAILABLE:
|
||||
return HookSkipped()
|
||||
hook = get_non_deleted_hook_by_hook_point(
|
||||
db_session=db_session, hook_point=hook_point
|
||||
)
|
||||
if hook is None or not hook.is_active:
|
||||
return HookSkipped()
|
||||
if not hook.endpoint_url:
|
||||
return HookSkipped()
|
||||
return hook
|
||||
|
||||
|
||||
def _process_response(
|
||||
*,
|
||||
db_session: Session, # noqa: ARG001
|
||||
hook_point: HookPoint, # noqa: ARG001
|
||||
payload: dict[str, Any], # noqa: ARG001
|
||||
response_type: type[T], # noqa: ARG001
|
||||
) -> T | HookSkipped | HookSoftFailed:
|
||||
"""CE no-op — hooks are not available without EE."""
|
||||
return HookSkipped()
|
||||
response: httpx.Response | None,
|
||||
exc: Exception | None,
|
||||
timeout: float,
|
||||
) -> _HttpOutcome:
|
||||
"""Process the result of an HTTP call and return a structured outcome.
|
||||
|
||||
Called after the client.post() try/except. If post() raised, exc is set and
|
||||
response is None. Otherwise response is set and exc is None. Handles
|
||||
raise_for_status(), JSON decoding, and the dict shape check.
|
||||
"""
|
||||
if exc is not None:
|
||||
if isinstance(exc, httpx.NetworkError):
|
||||
msg = f"Hook network error (endpoint unreachable): {exc}"
|
||||
logger.warning(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=False,
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
if isinstance(exc, httpx.TimeoutException):
|
||||
msg = f"Hook timed out after {timeout}s: {exc}"
|
||||
logger.warning(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # timeout doesn't indicate unreachability
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
msg = f"Hook call failed: {exc}"
|
||||
logger.exception(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # unknown error — don't make assumptions
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
if response is None:
|
||||
raise ValueError(
|
||||
"exactly one of response or exc must be non-None; both are None"
|
||||
)
|
||||
status_code = response.status_code
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
msg = f"Hook returned HTTP {e.response.status_code}: {e.response.text}"
|
||||
logger.warning(msg, exc_info=e)
|
||||
# 401/403 means the api_key has been revoked or is invalid — mark unreachable
|
||||
# so the operator knows to update it. All other HTTP errors keep is_reachable
|
||||
# as-is (server is up, the request just failed for application reasons).
|
||||
auth_failed = e.response.status_code in (401, 403)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=False if auth_failed else None,
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
try:
|
||||
response_payload = response.json()
|
||||
except (json.JSONDecodeError, httpx.DecodingError) as e:
|
||||
msg = f"Hook returned non-JSON response: {e}"
|
||||
logger.warning(msg, exc_info=e)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
if not isinstance(response_payload, dict):
|
||||
msg = f"Hook returned non-dict JSON (got {type(response_payload).__name__})"
|
||||
logger.warning(msg)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
return _HttpOutcome(
|
||||
is_success=True,
|
||||
updated_is_reachable=True,
|
||||
status_code=status_code,
|
||||
error_message=None,
|
||||
response_payload=response_payload,
|
||||
)
|
||||
|
||||
|
||||
def _persist_result(
|
||||
*,
|
||||
hook_id: int,
|
||||
outcome: _HttpOutcome,
|
||||
duration_ms: int,
|
||||
) -> None:
|
||||
"""Write the execution log on failure and optionally update is_reachable, each
|
||||
in its own session so a failure in one does not affect the other."""
|
||||
# Only write the execution log on failure — success runs are not recorded.
|
||||
# Must not be skipped if the is_reachable update fails (e.g. hook concurrently
|
||||
# deleted between the initial lookup and here).
|
||||
if not outcome.is_success:
|
||||
try:
|
||||
with get_session_with_current_tenant() as log_session:
|
||||
create_hook_execution_log__no_commit(
|
||||
db_session=log_session,
|
||||
hook_id=hook_id,
|
||||
is_success=False,
|
||||
error_message=outcome.error_message,
|
||||
status_code=outcome.status_code,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
log_session.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to persist hook execution log for hook_id={hook_id}"
|
||||
)
|
||||
|
||||
# Update is_reachable separately — best-effort, non-critical.
|
||||
# None means the value is unchanged (set by the caller to skip the no-op write).
|
||||
# update_hook__no_commit can raise OnyxError(NOT_FOUND) if the hook was
|
||||
# concurrently deleted, so keep this isolated from the log write above.
|
||||
if outcome.updated_is_reachable is not None:
|
||||
try:
|
||||
with get_session_with_current_tenant() as reachable_session:
|
||||
update_hook__no_commit(
|
||||
db_session=reachable_session,
|
||||
hook_id=hook_id,
|
||||
is_reachable=outcome.updated_is_reachable,
|
||||
)
|
||||
reachable_session.commit()
|
||||
except Exception:
|
||||
logger.warning(f"Failed to update is_reachable for hook_id={hook_id}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _execute_hook_inner(
|
||||
hook: Hook,
|
||||
payload: dict[str, Any],
|
||||
response_type: type[T],
|
||||
) -> T | HookSoftFailed:
|
||||
"""Make the HTTP call, validate the response, and return a typed model.
|
||||
|
||||
Raises OnyxError on HARD failure. Returns HookSoftFailed on SOFT failure.
|
||||
"""
|
||||
timeout = hook.timeout_seconds
|
||||
hook_id = hook.id
|
||||
fail_strategy = hook.fail_strategy
|
||||
endpoint_url = hook.endpoint_url
|
||||
current_is_reachable: bool | None = hook.is_reachable
|
||||
|
||||
if not endpoint_url:
|
||||
raise ValueError(
|
||||
f"hook_id={hook_id} is active but has no endpoint_url — "
|
||||
"active hooks without an endpoint_url must be rejected by _lookup_hook"
|
||||
)
|
||||
|
||||
start = time.monotonic()
|
||||
response: httpx.Response | None = None
|
||||
exc: Exception | None = None
|
||||
try:
|
||||
api_key: str | None = (
|
||||
hook.api_key.get_value(apply_mask=False) if hook.api_key else None
|
||||
)
|
||||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
with httpx.Client(
|
||||
timeout=timeout, follow_redirects=False
|
||||
) as client: # SSRF guard: never follow redirects
|
||||
response = client.post(endpoint_url, json=payload, headers=headers)
|
||||
except Exception as e:
|
||||
exc = e
|
||||
duration_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
outcome = _process_response(response=response, exc=exc, timeout=timeout)
|
||||
|
||||
# Validate the response payload against response_type.
|
||||
# A validation failure downgrades the outcome to a failure so it is logged,
|
||||
# is_reachable is left unchanged (server responded — just a bad payload),
|
||||
# and fail_strategy is respected below.
|
||||
validated_model: T | None = None
|
||||
if outcome.is_success and outcome.response_payload is not None:
|
||||
try:
|
||||
validated_model = response_type.model_validate(outcome.response_payload)
|
||||
except ValidationError as e:
|
||||
msg = (
|
||||
f"Hook response failed validation against {response_type.__name__}: {e}"
|
||||
)
|
||||
outcome = _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=outcome.status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
# Skip the is_reachable write when the value would not change — avoids a
|
||||
# no-op DB round-trip on every call when the hook is already in the expected state.
|
||||
if outcome.updated_is_reachable == current_is_reachable:
|
||||
outcome = outcome.model_copy(update={"updated_is_reachable": None})
|
||||
_persist_result(hook_id=hook_id, outcome=outcome, duration_ms=duration_ms)
|
||||
|
||||
if not outcome.is_success:
|
||||
if fail_strategy == HookFailStrategy.HARD:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.HOOK_EXECUTION_FAILED,
|
||||
outcome.error_message or "Hook execution failed.",
|
||||
)
|
||||
logger.warning(
|
||||
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
|
||||
)
|
||||
return HookSoftFailed()
|
||||
|
||||
if validated_model is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
f"validated_model is None for successful hook call (hook_id={hook_id})",
|
||||
)
|
||||
return validated_model
|
||||
|
||||
|
||||
def execute_hook(
|
||||
@@ -48,15 +367,25 @@ def execute_hook(
|
||||
payload: dict[str, Any],
|
||||
response_type: type[T],
|
||||
) -> T | HookSkipped | HookSoftFailed:
|
||||
"""Execute the hook for the given hook point.
|
||||
"""Execute the hook for the given hook point synchronously.
|
||||
|
||||
Dispatches to the versioned implementation so EE gets the real executor
|
||||
and CE gets the no-op stub, without any changes at the call site.
|
||||
Returns HookSkipped if no active hook is configured, HookSoftFailed if the
|
||||
hook failed with SOFT fail strategy, or a validated response model on success.
|
||||
Raises OnyxError on HARD failure or if the hook is misconfigured.
|
||||
"""
|
||||
impl = fetch_versioned_implementation("onyx.hooks.executor", "_execute_hook_impl")
|
||||
return impl(
|
||||
db_session=db_session,
|
||||
hook_point=hook_point,
|
||||
payload=payload,
|
||||
response_type=response_type,
|
||||
)
|
||||
hook = _lookup_hook(db_session, hook_point)
|
||||
if isinstance(hook, HookSkipped):
|
||||
return hook
|
||||
|
||||
fail_strategy = hook.fail_strategy
|
||||
hook_id = hook.id
|
||||
|
||||
try:
|
||||
return _execute_hook_inner(hook, payload, response_type)
|
||||
except Exception:
|
||||
if fail_strategy == HookFailStrategy.SOFT:
|
||||
logger.exception(
|
||||
f"Unexpected error in hook execution (soft fail) for hook_id={hook_id}"
|
||||
)
|
||||
return HookSoftFailed()
|
||||
raise
|
||||
|
||||
@@ -1,114 +1,33 @@
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.points.base import HookPointSpec
|
||||
|
||||
|
||||
class DocumentIngestionSection(BaseModel):
|
||||
"""Represents a single section of a document — either text or image, not both.
|
||||
|
||||
Text section: set `text`, leave `image_file_id` null.
|
||||
Image section: set `image_file_id`, leave `text` null.
|
||||
"""
|
||||
|
||||
text: str | None = Field(
|
||||
default=None,
|
||||
description="Text content of this section. Set for text sections, null for image sections.",
|
||||
)
|
||||
link: str | None = Field(
|
||||
default=None,
|
||||
description="Optional URL associated with this section. Preserve the original link from the payload if you want it retained.",
|
||||
)
|
||||
image_file_id: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Opaque identifier for an image stored in the file store. "
|
||||
"The image content is not included — this field signals that the section is an image. "
|
||||
"Hooks can use its presence to reorder or drop image sections, but cannot read or modify the image itself."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DocumentIngestionOwner(BaseModel):
|
||||
display_name: str | None = Field(
|
||||
default=None,
|
||||
description="Human-readable name of the owner.",
|
||||
)
|
||||
email: str | None = Field(
|
||||
default=None,
|
||||
description="Email address of the owner.",
|
||||
)
|
||||
|
||||
|
||||
# TODO(@Bo-Onyx): define payload and response fields
|
||||
class DocumentIngestionPayload(BaseModel):
|
||||
document_id: str = Field(
|
||||
description="Unique identifier for the document. Read-only — changes are ignored."
|
||||
)
|
||||
title: str | None = Field(description="Title of the document.")
|
||||
semantic_identifier: str = Field(
|
||||
description="Human-readable identifier used for display (e.g. file name, page title)."
|
||||
)
|
||||
source: str = Field(
|
||||
description=(
|
||||
"Connector source type (e.g. confluence, slack, google_drive). "
|
||||
"Read-only — changes are ignored. "
|
||||
"Full list of values: https://github.com/onyx-dot-app/onyx/blob/main/backend/onyx/configs/constants.py#L195"
|
||||
)
|
||||
)
|
||||
sections: list[DocumentIngestionSection] = Field(
|
||||
description="Sections of the document. Includes both text sections (text set, image_file_id null) and image sections (image_file_id set, text null)."
|
||||
)
|
||||
metadata: dict[str, list[str]] = Field(
|
||||
description="Key-value metadata attached to the document. Values are always a list of strings."
|
||||
)
|
||||
doc_updated_at: str | None = Field(
|
||||
description="ISO 8601 UTC timestamp of the last update at the source, or null if unknown. Example: '2024-03-15T10:30:00+00:00'."
|
||||
)
|
||||
primary_owners: list[DocumentIngestionOwner] | None = Field(
|
||||
description="Primary owners of the document, or null if not available."
|
||||
)
|
||||
secondary_owners: list[DocumentIngestionOwner] | None = Field(
|
||||
description="Secondary owners of the document, or null if not available."
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
class DocumentIngestionResponse(BaseModel):
|
||||
# Intentionally permissive — customer endpoints may return extra fields.
|
||||
sections: list[DocumentIngestionSection] | None = Field(
|
||||
description="The sections to index, in the desired order. Reorder, drop, or modify sections freely. Null or empty list drops the document."
|
||||
)
|
||||
rejection_reason: str | None = Field(
|
||||
default=None,
|
||||
description="Logged when sections is null or empty. Falls back to a generic message if omitted.",
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
class DocumentIngestionSpec(HookPointSpec):
|
||||
"""Hook point that runs on every document before it enters the indexing pipeline.
|
||||
"""Hook point that runs during document ingestion.
|
||||
|
||||
Call site: immediately after Onyx's internal validation and before the
|
||||
indexing pipeline begins — no partial writes have occurred yet.
|
||||
|
||||
If a Document Ingestion hook is configured, it takes precedence —
|
||||
Document Ingestion Light will not run. Configure only one per deployment.
|
||||
|
||||
Supported use cases:
|
||||
- Document filtering: drop documents based on content or metadata
|
||||
- Content rewriting: redact PII or normalize text before indexing
|
||||
# TODO(@Bo-Onyx): define call site, input/output schema, and timeout budget.
|
||||
"""
|
||||
|
||||
hook_point = HookPoint.DOCUMENT_INGESTION
|
||||
display_name = "Document Ingestion"
|
||||
description = (
|
||||
"Runs on every document before it enters the indexing pipeline. "
|
||||
"Allows filtering, rewriting, or dropping documents."
|
||||
)
|
||||
description = "Runs during document ingestion. Allows filtering or transforming documents before indexing."
|
||||
default_timeout_seconds = 30.0
|
||||
fail_hard_description = "The document will not be indexed."
|
||||
default_fail_strategy = HookFailStrategy.HARD
|
||||
docs_url = "https://docs.onyx.app/admins/advanced_configs/hook_extensions#document-ingestion"
|
||||
# TODO(Bo-Onyx): update later
|
||||
docs_url = "https://docs.google.com/document/d/1pGhB8Wcnhhj8rS4baEJL6CX05yFhuIDNk1gbBRiWu94/edit?tab=t.ue263ual5vdi"
|
||||
|
||||
payload_model = DocumentIngestionPayload
|
||||
response_model = DocumentIngestionResponse
|
||||
|
||||
@@ -65,9 +65,8 @@ class QueryProcessingSpec(HookPointSpec):
|
||||
"The query will be blocked and the user will see an error message."
|
||||
)
|
||||
default_fail_strategy = HookFailStrategy.HARD
|
||||
docs_url = (
|
||||
"https://docs.onyx.app/admins/advanced_configs/hook_extensions#query-processing"
|
||||
)
|
||||
# TODO(Bo-Onyx): update later
|
||||
docs_url = "https://docs.google.com/document/d/1pGhB8Wcnhhj8rS4baEJL6CX05yFhuIDNk1gbBRiWu94/edit?tab=t.g2r1a1699u87"
|
||||
|
||||
payload_model = QueryProcessingPayload
|
||||
response_model = QueryProcessingResponse
|
||||
|
||||
5
backend/onyx/hooks/utils.py
Normal file
5
backend/onyx/hooks/utils.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from onyx.configs.app_configs import HOOK_ENABLED
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
# True only when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
|
||||
HOOKS_AVAILABLE: bool = HOOK_ENABLED and not MULTI_TENANT
|
||||
@@ -33,7 +33,6 @@ from onyx.connectors.models import TextSection
|
||||
from onyx.db.document import get_documents_by_ids
|
||||
from onyx.db.document import upsert_document_by_connector_credential_pair
|
||||
from onyx.db.document import upsert_documents
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
|
||||
from onyx.db.models import Document as DBDocument
|
||||
from onyx.db.models import IndexModelStatus
|
||||
@@ -48,13 +47,6 @@ from onyx.document_index.interfaces import DocumentMetadata
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.hooks.executor import execute_hook
|
||||
from onyx.hooks.executor import HookSkipped
|
||||
from onyx.hooks.executor import HookSoftFailed
|
||||
from onyx.hooks.points.document_ingestion import DocumentIngestionOwner
|
||||
from onyx.hooks.points.document_ingestion import DocumentIngestionPayload
|
||||
from onyx.hooks.points.document_ingestion import DocumentIngestionResponse
|
||||
from onyx.hooks.points.document_ingestion import DocumentIngestionSection
|
||||
from onyx.indexing.chunk_batch_store import ChunkBatchStore
|
||||
from onyx.indexing.chunker import Chunker
|
||||
from onyx.indexing.embedder import embed_chunks_with_failure_handling
|
||||
@@ -305,7 +297,6 @@ def index_doc_batch_with_handler(
|
||||
document_batch: list[Document],
|
||||
request_id: str | None,
|
||||
tenant_id: str,
|
||||
db_session: Session,
|
||||
adapter: IndexingBatchAdapter,
|
||||
ignore_time_skip: bool = False,
|
||||
enable_contextual_rag: bool = False,
|
||||
@@ -319,7 +310,6 @@ def index_doc_batch_with_handler(
|
||||
document_batch=document_batch,
|
||||
request_id=request_id,
|
||||
tenant_id=tenant_id,
|
||||
db_session=db_session,
|
||||
adapter=adapter,
|
||||
ignore_time_skip=ignore_time_skip,
|
||||
enable_contextual_rag=enable_contextual_rag,
|
||||
@@ -795,132 +785,6 @@ def _verify_indexing_completeness(
|
||||
)
|
||||
|
||||
|
||||
def _apply_document_ingestion_hook(
|
||||
documents: list[Document],
|
||||
db_session: Session,
|
||||
) -> list[Document]:
|
||||
"""Apply the Document Ingestion hook to each document in the batch.
|
||||
|
||||
- HookSkipped / HookSoftFailed → document passes through unchanged.
|
||||
- Response with sections=None → document is dropped (logged).
|
||||
- Response with sections → document sections are replaced with the hook's output.
|
||||
"""
|
||||
|
||||
def _build_payload(doc: Document) -> DocumentIngestionPayload:
|
||||
return DocumentIngestionPayload(
|
||||
document_id=doc.id or "",
|
||||
title=doc.title,
|
||||
semantic_identifier=doc.semantic_identifier,
|
||||
source=doc.source.value if doc.source is not None else "",
|
||||
sections=[
|
||||
DocumentIngestionSection(
|
||||
text=s.text if isinstance(s, TextSection) else None,
|
||||
link=s.link,
|
||||
image_file_id=(
|
||||
s.image_file_id if isinstance(s, ImageSection) else None
|
||||
),
|
||||
)
|
||||
for s in doc.sections
|
||||
],
|
||||
metadata={
|
||||
k: v if isinstance(v, list) else [v] for k, v in doc.metadata.items()
|
||||
},
|
||||
doc_updated_at=(
|
||||
doc.doc_updated_at.isoformat() if doc.doc_updated_at else None
|
||||
),
|
||||
primary_owners=(
|
||||
[
|
||||
DocumentIngestionOwner(
|
||||
display_name=o.get_semantic_name() or None,
|
||||
email=o.email,
|
||||
)
|
||||
for o in doc.primary_owners
|
||||
]
|
||||
if doc.primary_owners
|
||||
else None
|
||||
),
|
||||
secondary_owners=(
|
||||
[
|
||||
DocumentIngestionOwner(
|
||||
display_name=o.get_semantic_name() or None,
|
||||
email=o.email,
|
||||
)
|
||||
for o in doc.secondary_owners
|
||||
]
|
||||
if doc.secondary_owners
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
def _apply_result(
|
||||
doc: Document,
|
||||
hook_result: DocumentIngestionResponse | HookSkipped | HookSoftFailed,
|
||||
) -> Document | None:
|
||||
"""Return the modified doc, original doc (skip/soft-fail), or None (drop)."""
|
||||
if isinstance(hook_result, (HookSkipped, HookSoftFailed)):
|
||||
return doc
|
||||
if not hook_result.sections:
|
||||
reason = hook_result.rejection_reason or "Document rejected by hook"
|
||||
logger.info(
|
||||
f"Document ingestion hook dropped document doc_id={doc.id!r}: {reason}"
|
||||
)
|
||||
return None
|
||||
new_sections: list[TextSection | ImageSection] = []
|
||||
for s in hook_result.sections:
|
||||
if s.image_file_id is not None:
|
||||
new_sections.append(
|
||||
ImageSection(image_file_id=s.image_file_id, link=s.link)
|
||||
)
|
||||
elif s.text is not None:
|
||||
new_sections.append(TextSection(text=s.text, link=s.link))
|
||||
else:
|
||||
logger.warning(
|
||||
f"Document ingestion hook returned a section with neither text nor "
|
||||
f"image_file_id for doc_id={doc.id!r} — skipping section."
|
||||
)
|
||||
if not new_sections:
|
||||
logger.info(
|
||||
f"Document ingestion hook produced no valid sections for doc_id={doc.id!r} — dropping document."
|
||||
)
|
||||
return None
|
||||
return doc.model_copy(update={"sections": new_sections})
|
||||
|
||||
if not documents:
|
||||
return documents
|
||||
|
||||
# Run the hook for the first document. If it returns HookSkipped the hook
|
||||
# is not configured — skip the remaining N-1 DB lookups.
|
||||
first_doc = documents[0]
|
||||
first_payload = _build_payload(first_doc).model_dump()
|
||||
first_hook_result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.DOCUMENT_INGESTION,
|
||||
payload=first_payload,
|
||||
response_type=DocumentIngestionResponse,
|
||||
)
|
||||
if isinstance(first_hook_result, HookSkipped):
|
||||
return documents
|
||||
|
||||
result: list[Document] = []
|
||||
first_applied = _apply_result(first_doc, first_hook_result)
|
||||
if first_applied is not None:
|
||||
result.append(first_applied)
|
||||
|
||||
for doc in documents[1:]:
|
||||
payload = _build_payload(doc).model_dump()
|
||||
hook_result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.DOCUMENT_INGESTION,
|
||||
payload=payload,
|
||||
response_type=DocumentIngestionResponse,
|
||||
)
|
||||
applied = _apply_result(doc, hook_result)
|
||||
if applied is not None:
|
||||
result.append(applied)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@log_function_time(debug_only=True)
|
||||
def index_doc_batch(
|
||||
*,
|
||||
@@ -930,7 +794,6 @@ def index_doc_batch(
|
||||
document_indices: list[DocumentIndex],
|
||||
request_id: str | None,
|
||||
tenant_id: str,
|
||||
db_session: Session,
|
||||
adapter: IndexingBatchAdapter,
|
||||
enable_contextual_rag: bool = False,
|
||||
llm: LLM | None = None,
|
||||
@@ -955,7 +818,6 @@ def index_doc_batch(
|
||||
)
|
||||
|
||||
filtered_documents = filter_fnc(document_batch)
|
||||
filtered_documents = _apply_document_ingestion_hook(filtered_documents, db_session)
|
||||
context = adapter.prepare(filtered_documents, ignore_time_skip)
|
||||
if not context:
|
||||
return IndexingPipelineResult.empty(len(filtered_documents))
|
||||
@@ -1143,7 +1005,6 @@ def run_indexing_pipeline(
|
||||
document_batch=document_batch,
|
||||
request_id=request_id,
|
||||
tenant_id=tenant_id,
|
||||
db_session=db_session,
|
||||
adapter=adapter,
|
||||
enable_contextual_rag=enable_contextual_rag,
|
||||
llm=llm,
|
||||
|
||||
@@ -175,28 +175,6 @@ def _strip_tool_content_from_messages(
|
||||
return result
|
||||
|
||||
|
||||
def _fix_tool_user_message_ordering(
|
||||
messages: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Insert a synthetic assistant message between tool and user messages.
|
||||
|
||||
Some models (e.g. Mistral on Azure) require strict message ordering where
|
||||
a user message cannot immediately follow a tool message. This function
|
||||
inserts a minimal assistant message to bridge the gap.
|
||||
"""
|
||||
if len(messages) < 2:
|
||||
return messages
|
||||
|
||||
result: list[dict[str, Any]] = [messages[0]]
|
||||
for msg in messages[1:]:
|
||||
prev_role = result[-1].get("role")
|
||||
curr_role = msg.get("role")
|
||||
if prev_role == "tool" and curr_role == "user":
|
||||
result.append({"role": "assistant", "content": "Noted. Continuing."})
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
|
||||
def _messages_contain_tool_content(messages: list[dict[str, Any]]) -> bool:
|
||||
"""Check if any messages contain tool-related content blocks."""
|
||||
for msg in messages:
|
||||
@@ -598,18 +576,6 @@ class LitellmLLM(LLM):
|
||||
):
|
||||
messages = _strip_tool_content_from_messages(messages)
|
||||
|
||||
# Some models (e.g. Mistral) reject a user message
|
||||
# immediately after a tool message. Insert a synthetic
|
||||
# assistant bridge message to satisfy the ordering
|
||||
# constraint. Check both the provider and the deployment/
|
||||
# model name to catch Mistral hosted on Azure.
|
||||
model_or_deployment = (
|
||||
self._deployment_name or self._model_version or ""
|
||||
).lower()
|
||||
is_mistral_model = is_mistral or "mistral" in model_or_deployment
|
||||
if is_mistral_model:
|
||||
messages = _fix_tool_user_message_ordering(messages)
|
||||
|
||||
# Only pass tool_choice when tools are present — some providers (e.g. Fireworks)
|
||||
# reject requests where tool_choice is explicitly null.
|
||||
if tools and tool_choice is not None:
|
||||
|
||||
@@ -8,24 +8,6 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class LLMOverride(BaseModel):
|
||||
"""Per-request LLM settings that override persona defaults.
|
||||
|
||||
All fields are optional — only the fields that differ from the persona's
|
||||
configured LLM need to be supplied. Used both over the wire (API requests)
|
||||
and for multi-model comparison, where one override is supplied per model.
|
||||
|
||||
Attributes:
|
||||
model_provider: LLM provider slug (e.g. ``"openai"``, ``"anthropic"``).
|
||||
When ``None``, the persona's default provider is used.
|
||||
model_version: Specific model version string (e.g. ``"gpt-4o"``).
|
||||
When ``None``, the persona's default model is used.
|
||||
temperature: Sampling temperature in ``[0, 2]``. When ``None``, the
|
||||
persona's default temperature is used.
|
||||
display_name: Human-readable label shown in the UI for this model,
|
||||
e.g. ``"GPT-4 Turbo"``. Optional; falls back to ``model_version``
|
||||
when not set.
|
||||
"""
|
||||
|
||||
model_provider: str | None = None
|
||||
model_version: str | None = None
|
||||
temperature: float | None = None
|
||||
|
||||
@@ -77,6 +77,7 @@ from onyx.server.features.default_assistant.api import (
|
||||
)
|
||||
from onyx.server.features.document_set.api import router as document_set_router
|
||||
from onyx.server.features.hierarchy.api import router as hierarchy_router
|
||||
from onyx.server.features.hooks.api import router as hook_router
|
||||
from onyx.server.features.input_prompt.api import (
|
||||
admin_router as admin_input_prompt_router,
|
||||
)
|
||||
@@ -454,6 +455,7 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
|
||||
register_onyx_exception_handlers(application)
|
||||
|
||||
include_router_with_global_prefix_prepended(application, hook_router)
|
||||
include_router_with_global_prefix_prepended(application, password_router)
|
||||
include_router_with_global_prefix_prepended(application, chat_router)
|
||||
include_router_with_global_prefix_prepended(application, query_router)
|
||||
|
||||
@@ -3,10 +3,10 @@ import datetime
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_FEEDBACK_REMINDER
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.user_preferences import activate_user
|
||||
from onyx.db.users import add_slack_user_if_not_exists
|
||||
@@ -247,7 +247,7 @@ def handle_message(
|
||||
|
||||
elif (
|
||||
not existing_user.is_active
|
||||
and existing_user.account_type == AccountType.BOT
|
||||
and existing_user.role == UserRole.SLACK_USER
|
||||
):
|
||||
check_seat_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license",
|
||||
|
||||
@@ -76,18 +76,11 @@ class CategorizedFiles(BaseModel):
|
||||
acceptable: list[UploadFile] = Field(default_factory=list)
|
||||
rejected: list[RejectedFile] = Field(default_factory=list)
|
||||
acceptable_file_to_token_count: dict[str, int] = Field(default_factory=dict)
|
||||
# Filenames within `acceptable` that should be stored but not indexed.
|
||||
skip_indexing: set[str] = Field(default_factory=set)
|
||||
|
||||
# Allow FastAPI UploadFile instances
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
def _skip_token_threshold(extension: str) -> bool:
|
||||
"""Return True if this file extension should bypass the token limit."""
|
||||
return extension.lower() in OnyxFileExtensions.TABULAR_EXTENSIONS
|
||||
|
||||
|
||||
def _apply_long_side_cap(width: int, height: int, cap: int) -> tuple[int, int]:
|
||||
if max(width, height) <= cap:
|
||||
return width, height
|
||||
@@ -271,17 +264,7 @@ def categorize_uploaded_files(
|
||||
token_count = count_tokens(
|
||||
text_content, tokenizer, token_limit=token_threshold
|
||||
)
|
||||
exceeds_threshold = (
|
||||
token_threshold is not None and token_count > token_threshold
|
||||
)
|
||||
if exceeds_threshold and _skip_token_threshold(extension):
|
||||
# Exempt extensions (e.g. spreadsheets) are accepted
|
||||
# but flagged to skip indexing — only metadata is
|
||||
# injected into the LLM context.
|
||||
results.acceptable.append(upload)
|
||||
results.acceptable_file_to_token_count[filename] = token_count
|
||||
results.skip_indexing.add(filename)
|
||||
elif exceeds_threshold:
|
||||
if token_threshold is not None and token_count > token_threshold:
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
|
||||
@@ -147,7 +147,6 @@ class UserInfo(BaseModel):
|
||||
is_anonymous_user: bool | None = None,
|
||||
tenant_info: TenantInfo | None = None,
|
||||
assistant_specific_configs: UserSpecificAssistantPreferences | None = None,
|
||||
memories: list[MemoryItem] | None = None,
|
||||
) -> "UserInfo":
|
||||
return cls(
|
||||
id=str(user.id),
|
||||
@@ -192,7 +191,10 @@ class UserInfo(BaseModel):
|
||||
role=user.personal_role or "",
|
||||
use_memories=user.use_memories,
|
||||
enable_memory_tool=user.enable_memory_tool,
|
||||
memories=memories or [],
|
||||
memories=[
|
||||
MemoryItem(id=memory.id, content=memory.memory_text)
|
||||
for memory in (user.memories or [])
|
||||
],
|
||||
user_preferences=user.user_preferences or "",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -27,7 +27,6 @@ from onyx.auth.email_utils import send_user_email_invite
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.invited_users import remove_user_from_invited_users
|
||||
from onyx.auth.invited_users import write_invited_users
|
||||
from onyx.auth.permissions import get_effective_permissions
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import current_admin_user
|
||||
@@ -51,7 +50,6 @@ from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.api_key import is_api_key_email_address
|
||||
from onyx.db.auth import get_live_users_count
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
@@ -59,7 +57,6 @@ from onyx.db.user_preferences import activate_user
|
||||
from onyx.db.user_preferences import deactivate_user
|
||||
from onyx.db.user_preferences import get_all_user_assistant_specific_configs
|
||||
from onyx.db.user_preferences import get_latest_access_token_for_user
|
||||
from onyx.db.user_preferences import get_memories_for_user
|
||||
from onyx.db.user_preferences import update_assistant_preferences
|
||||
from onyx.db.user_preferences import update_user_assistant_visibility
|
||||
from onyx.db.user_preferences import update_user_auto_scroll
|
||||
@@ -144,7 +141,6 @@ def set_user_role(
|
||||
validate_user_role_update(
|
||||
requested_role=requested_role,
|
||||
current_role=current_role,
|
||||
current_account_type=user_to_update.account_type,
|
||||
explicit_override=user_role_update_request.explicit_override,
|
||||
)
|
||||
|
||||
@@ -330,8 +326,8 @@ def list_all_users(
|
||||
if (include_api_keys or not is_api_key_email_address(user.email))
|
||||
]
|
||||
|
||||
slack_users = [user for user in users if user.account_type == AccountType.BOT]
|
||||
accepted_users = [user for user in users if user.account_type != AccountType.BOT]
|
||||
slack_users = [user for user in users if user.role == UserRole.SLACK_USER]
|
||||
accepted_users = [user for user in users if user.role != UserRole.SLACK_USER]
|
||||
|
||||
accepted_emails = {user.email for user in accepted_users}
|
||||
slack_users_emails = {user.email for user in slack_users}
|
||||
@@ -674,7 +670,7 @@ def list_all_users_basic_info(
|
||||
return [
|
||||
MinimalUserSnapshot(id=user.id, email=user.email)
|
||||
for user in users
|
||||
if user.account_type != AccountType.BOT
|
||||
if user.role != UserRole.SLACK_USER
|
||||
and (include_api_keys or not is_api_key_email_address(user.email))
|
||||
]
|
||||
|
||||
@@ -777,13 +773,6 @@ def _get_token_created_at(
|
||||
return get_current_token_creation_postgres(user, db_session)
|
||||
|
||||
|
||||
@router.get("/me/permissions", tags=PUBLIC_API_TAGS)
|
||||
def get_current_user_permissions(
|
||||
user: User = Depends(current_user),
|
||||
) -> list[str]:
|
||||
return sorted(p.value for p in get_effective_permissions(user))
|
||||
|
||||
|
||||
@router.get("/me", tags=PUBLIC_API_TAGS)
|
||||
def verify_user_logged_in(
|
||||
request: Request,
|
||||
@@ -834,11 +823,6 @@ def verify_user_logged_in(
|
||||
[],
|
||||
),
|
||||
)
|
||||
memories = [
|
||||
MemoryItem(id=memory.id, content=memory.memory_text)
|
||||
for memory in get_memories_for_user(user.id, db_session)
|
||||
]
|
||||
|
||||
user_info = UserInfo.from_model(
|
||||
user,
|
||||
current_token_created_at=token_created_at,
|
||||
@@ -849,7 +833,6 @@ def verify_user_logged_in(
|
||||
new_tenant=new_tenant,
|
||||
invitation=tenant_invitation,
|
||||
),
|
||||
memories=memories,
|
||||
)
|
||||
|
||||
return user_info
|
||||
@@ -947,8 +930,7 @@ def update_user_personalization_api(
|
||||
else user.enable_memory_tool
|
||||
)
|
||||
existing_memories = [
|
||||
MemoryItem(id=memory.id, content=memory.memory_text)
|
||||
for memory in get_memories_for_user(user.id, db_session)
|
||||
MemoryItem(id=memory.id, content=memory.memory_text) for memory in user.memories
|
||||
]
|
||||
new_memories = (
|
||||
request.memories if request.memories is not None else existing_memories
|
||||
|
||||
@@ -7,7 +7,6 @@ from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
@@ -42,7 +41,6 @@ class FullUserSnapshot(BaseModel):
|
||||
id: UUID
|
||||
email: str
|
||||
role: UserRole
|
||||
account_type: AccountType
|
||||
is_active: bool
|
||||
password_configured: bool
|
||||
personal_name: str | None
|
||||
@@ -62,7 +60,6 @@ class FullUserSnapshot(BaseModel):
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
account_type=user.account_type,
|
||||
is_active=user.is_active,
|
||||
password_configured=user.password_configured,
|
||||
personal_name=user.personal_name,
|
||||
|
||||
@@ -28,7 +28,6 @@ from onyx.chat.chat_utils import extract_headers
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.process_message import gather_stream_full
|
||||
from onyx.chat.process_message import handle_multi_model_stream
|
||||
from onyx.chat.process_message import handle_stream_message_objects
|
||||
from onyx.chat.prompt_utils import get_default_base_system_prompt
|
||||
from onyx.chat.stop_signal_checker import set_fence
|
||||
@@ -47,7 +46,6 @@ from onyx.db.chat import get_chat_messages_by_session
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.chat import set_as_latest_chat_message
|
||||
from onyx.db.chat import set_preferred_response
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import update_chat_session
|
||||
from onyx.db.chat_search import search_chat_sessions
|
||||
@@ -62,8 +60,6 @@ from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.usage import increment_usage
|
||||
from onyx.db.usage import UsageType
|
||||
from onyx.db.user_file import get_file_id_by_user_file_id
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.factory import get_default_llm
|
||||
@@ -85,7 +81,6 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.server.query_and_chat.models import RenameChatSessionResponse
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.models import SetPreferredResponseRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
|
||||
from onyx.server.query_and_chat.session_loading import (
|
||||
@@ -575,46 +570,6 @@ def handle_send_chat_message(
|
||||
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
|
||||
chat_message_req.origin = MessageOrigin.API
|
||||
|
||||
# Multi-model streaming path: 2-3 LLMs in parallel (streaming only)
|
||||
is_multi_model = (
|
||||
chat_message_req.llm_overrides is not None
|
||||
and len(chat_message_req.llm_overrides) > 1
|
||||
)
|
||||
if is_multi_model and chat_message_req.stream:
|
||||
# Narrowed here; is_multi_model already checked llm_overrides is not None
|
||||
llm_overrides = chat_message_req.llm_overrides or []
|
||||
|
||||
def multi_model_stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for obj in handle_multi_model_stream(
|
||||
new_msg_req=chat_message_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
llm_overrides=llm_overrides,
|
||||
litellm_additional_headers=extract_headers(
|
||||
request.headers, LITELLM_PASS_THROUGH_HEADERS
|
||||
),
|
||||
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
|
||||
request.headers
|
||||
),
|
||||
mcp_headers=chat_message_req.mcp_headers,
|
||||
):
|
||||
yield get_json_line(obj.model_dump())
|
||||
except Exception as e:
|
||||
logger.exception("Error in multi-model streaming")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
return StreamingResponse(
|
||||
multi_model_stream_generator(), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
if is_multi_model and not chat_message_req.stream:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"Multi-model mode (llm_overrides with >1 entry) requires stream=True.",
|
||||
)
|
||||
|
||||
# Non-streaming path: consume all packets and return complete response
|
||||
if not chat_message_req.stream:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -705,30 +660,6 @@ def set_message_as_latest(
|
||||
)
|
||||
|
||||
|
||||
@router.put("/set-preferred-response")
|
||||
def set_preferred_response_endpoint(
|
||||
request_body: SetPreferredResponseRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
"""Set the preferred assistant response for a multi-model turn."""
|
||||
try:
|
||||
# Ownership check: get_chat_message raises ValueError if the message
|
||||
# doesn't belong to this user, preventing cross-user mutation.
|
||||
get_chat_message(
|
||||
chat_message_id=request_body.user_message_id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
set_preferred_response(
|
||||
db_session=db_session,
|
||||
user_message_id=request_body.user_message_id,
|
||||
preferred_assistant_message_id=request_body.preferred_response_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
|
||||
|
||||
|
||||
@router.post("/create-chat-message-feedback")
|
||||
def create_chat_feedback(
|
||||
feedback: ChatFeedbackRequest,
|
||||
|
||||
@@ -9,8 +9,8 @@ def mime_type_to_chat_file_type(mime_type: str | None) -> ChatFileType:
|
||||
if mime_type in OnyxMimeTypes.IMAGE_MIME_TYPES:
|
||||
return ChatFileType.IMAGE
|
||||
|
||||
if mime_type in OnyxMimeTypes.TABULAR_MIME_TYPES:
|
||||
return ChatFileType.TABULAR
|
||||
if mime_type in OnyxMimeTypes.CSV_MIME_TYPES:
|
||||
return ChatFileType.CSV
|
||||
|
||||
if mime_type in OnyxMimeTypes.DOCUMENT_MIME_TYPES:
|
||||
return ChatFileType.DOC
|
||||
|
||||
@@ -2,25 +2,11 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class Placement(BaseModel):
|
||||
"""Coordinates that identify where a streaming packet belongs in the UI.
|
||||
|
||||
The frontend uses these fields to route each packet to the correct turn,
|
||||
tool tab, agent sub-turn, and (in multi-model mode) response column.
|
||||
|
||||
Attributes:
|
||||
turn_index: Monotonically increasing index of the iterative reasoning block
|
||||
(e.g. tool call round) within this chat message. Lower values happened first.
|
||||
tab_index: Disambiguates parallel tool calls within the same turn so each
|
||||
tool's output can be displayed in its own tab.
|
||||
sub_turn_index: Nesting level for tools that invoke other tools. ``None`` for
|
||||
top-level packets; an integer for tool-within-tool output.
|
||||
model_index: Which model this packet belongs to. ``0`` for single-model
|
||||
responses; ``0``, ``1``, or ``2`` for multi-model comparison. ``None``
|
||||
for pre-LLM setup packets (e.g. message ID info) that are yielded
|
||||
before any Emitter runs.
|
||||
"""
|
||||
|
||||
# Which iterative block in the UI is this part of, these are ordered and smaller ones happened first
|
||||
turn_index: int
|
||||
# For parallel tool calls to preserve order of execution
|
||||
tab_index: int = 0
|
||||
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
|
||||
sub_turn_index: int | None = None
|
||||
# For multi-model streaming: identifies which model (0, 1, 2) this packet belongs to.
|
||||
model_index: int | None = None
|
||||
|
||||
@@ -70,7 +70,7 @@ async def upsert_saml_user(email: str) -> User:
|
||||
try:
|
||||
user = await user_manager.get_by_email(email)
|
||||
# If user has a non-authenticated role, treat as non-existent
|
||||
if not user.account_type.is_web_login():
|
||||
if not user.role.is_web_login():
|
||||
raise exceptions.UserNotExists()
|
||||
return user
|
||||
except exceptions.UserNotExists:
|
||||
|
||||
@@ -21,6 +21,7 @@ from onyx.db.notification import get_notifications
|
||||
from onyx.db.notification import update_notification_last_shown
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.features.build.utils import is_onyx_craft_enabled
|
||||
@@ -37,7 +38,6 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -98,7 +98,7 @@ def fetch_settings(
|
||||
needs_reindexing=needs_reindexing,
|
||||
onyx_craft_enabled=onyx_craft_enabled_for_user,
|
||||
vector_db_enabled=not DISABLE_VECTOR_DB,
|
||||
hooks_enabled=not MULTI_TENANT,
|
||||
hooks_enabled=HOOKS_AVAILABLE,
|
||||
version=onyx_version,
|
||||
max_allowed_upload_size_mb=MAX_ALLOWED_UPLOAD_SIZE_MB,
|
||||
default_user_file_max_upload_size_mb=min(
|
||||
|
||||
@@ -116,7 +116,7 @@ class UserSettings(Settings):
|
||||
# False when DISABLE_VECTOR_DB is set — connectors, RAG search, and
|
||||
# document sets are unavailable.
|
||||
vector_db_enabled: bool = True
|
||||
# True when hooks are available: single-tenant EE deployments only.
|
||||
# True when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
|
||||
hooks_enabled: bool = False
|
||||
# Application version, read from the ONYX_VERSION env var at startup.
|
||||
version: str | None = None
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import queue
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
@@ -709,6 +708,7 @@ def run_research_agent_calls(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from queue import Queue
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
@@ -744,8 +744,8 @@ if __name__ == "__main__":
|
||||
if user is None:
|
||||
raise ValueError("No users found in database. Please create a user first.")
|
||||
|
||||
emitter_queue: queue.Queue = queue.Queue()
|
||||
emitter = Emitter(merged_queue=emitter_queue)
|
||||
bus: Queue[Packet] = Queue()
|
||||
emitter = Emitter(bus)
|
||||
state_container = ChatStateContainer()
|
||||
|
||||
tool_dict = construct_tools(
|
||||
@@ -792,4 +792,4 @@ if __name__ == "__main__":
|
||||
print(result.intermediate_report)
|
||||
print("=" * 80)
|
||||
print(f"Citations: {result.citation_mapping}")
|
||||
print(f"Total packets emitted: {emitter_queue.qsize()}")
|
||||
print(f"Total packets emitted: {bus.qsize()}")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import csv
|
||||
import json
|
||||
import queue
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from io import StringIO
|
||||
@@ -12,6 +11,7 @@ import requests
|
||||
from requests import JSONDecodeError
|
||||
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
@@ -296,9 +296,9 @@ def build_custom_tools_from_openapi_schema_and_headers(
|
||||
url = openapi_to_url(openapi_schema)
|
||||
method_specs = openapi_to_method_specs(openapi_schema)
|
||||
|
||||
# Use a discard emitter if none provided (packets go nowhere)
|
||||
# Use default emitter if none provided
|
||||
if emitter is None:
|
||||
emitter = Emitter(merged_queue=queue.Queue())
|
||||
emitter = get_default_emitter()
|
||||
|
||||
return [
|
||||
CustomTool(
|
||||
@@ -367,7 +367,7 @@ if __name__ == "__main__":
|
||||
tools = build_custom_tools_from_openapi_schema_and_headers(
|
||||
tool_id=0, # dummy tool id
|
||||
openapi_schema=openapi_schema,
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
emitter=get_default_emitter(),
|
||||
dynamic_schema_info=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import io
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -10,7 +9,6 @@ from typing_extensions import override
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.file_store.utils import load_chat_file_by_id
|
||||
@@ -171,13 +169,10 @@ class FileReaderTool(Tool[FileReaderToolOverrideKwargs]):
|
||||
|
||||
chat_file = self._load_file(file_id)
|
||||
|
||||
# Only PLAIN_TEXT and TABULAR are guaranteed to contain actual text bytes.
|
||||
# Only PLAIN_TEXT and CSV are guaranteed to contain actual text bytes.
|
||||
# DOC type in a loaded file means plaintext extraction failed and the
|
||||
# content is the original binary (e.g. raw PDF/DOCX bytes).
|
||||
if chat_file.file_type not in (
|
||||
ChatFileType.PLAIN_TEXT,
|
||||
ChatFileType.TABULAR,
|
||||
):
|
||||
if chat_file.file_type not in (ChatFileType.PLAIN_TEXT, ChatFileType.CSV):
|
||||
raise ToolCallException(
|
||||
message=f"File {file_id} is not a text file (type={chat_file.file_type})",
|
||||
llm_facing_message=(
|
||||
@@ -186,19 +181,7 @@ class FileReaderTool(Tool[FileReaderToolOverrideKwargs]):
|
||||
)
|
||||
|
||||
try:
|
||||
if chat_file.file_type == ChatFileType.PLAIN_TEXT:
|
||||
full_text = chat_file.content.decode("utf-8", errors="replace")
|
||||
else:
|
||||
full_text = (
|
||||
extract_file_text(
|
||||
file=io.BytesIO(chat_file.content),
|
||||
file_name=chat_file.filename or "",
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
or ""
|
||||
)
|
||||
except ToolCallException:
|
||||
raise
|
||||
full_text = chat_file.content.decode("utf-8", errors="replace")
|
||||
except Exception:
|
||||
raise ToolCallException(
|
||||
message=f"Failed to decode file {file_id}",
|
||||
|
||||
@@ -458,27 +458,6 @@ def run_async_sync_no_cancel(coro: Awaitable[T]) -> T:
|
||||
return future.result()
|
||||
|
||||
|
||||
def run_multiple_in_background(
|
||||
funcs: list[Callable[[], None]],
|
||||
thread_name_prefix: str = "worker",
|
||||
) -> ThreadPoolExecutor:
|
||||
"""Submit multiple callables to a ``ThreadPoolExecutor`` with context propagation.
|
||||
|
||||
Copies the current ``contextvars`` context once and runs every callable
|
||||
inside that copy, which is important for preserving tenant IDs and other
|
||||
context-local state across threads.
|
||||
|
||||
Returns the executor so the caller can ``shutdown()`` when done.
|
||||
"""
|
||||
ctx = contextvars.copy_context()
|
||||
executor = ThreadPoolExecutor(
|
||||
max_workers=len(funcs), thread_name_prefix=thread_name_prefix
|
||||
)
|
||||
for func in funcs:
|
||||
executor.submit(ctx.run, func)
|
||||
return executor
|
||||
|
||||
|
||||
class TimeoutThread(threading.Thread, Generic[R]):
|
||||
def __init__(
|
||||
self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
|
||||
|
||||
@@ -14,7 +14,7 @@ aiofiles==25.1.0
|
||||
# unstructured-client
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.13.4
|
||||
aiohttp==3.13.3
|
||||
# via
|
||||
# aiobotocore
|
||||
# discord-py
|
||||
@@ -271,7 +271,7 @@ fastapi-users-db-sqlalchemy==7.0.0
|
||||
# via onyx
|
||||
fastavro==1.12.1
|
||||
# via cohere
|
||||
fastmcp==3.2.0
|
||||
fastmcp==3.0.2
|
||||
# via onyx
|
||||
fastuuid==0.14.0
|
||||
# via litellm
|
||||
@@ -1102,8 +1102,6 @@ tzdata==2025.2
|
||||
# tzlocal
|
||||
tzlocal==5.3.1
|
||||
# via dateparser
|
||||
uncalled-for==0.2.0
|
||||
# via fastmcp
|
||||
unstructured==0.18.27
|
||||
# via onyx
|
||||
unstructured-client==0.42.6
|
||||
|
||||
@@ -10,7 +10,7 @@ aiofiles==25.1.0
|
||||
# via aioboto3
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.13.4
|
||||
aiohttp==3.13.3
|
||||
# via
|
||||
# aiobotocore
|
||||
# discord-py
|
||||
|
||||
@@ -10,7 +10,7 @@ aiofiles==25.1.0
|
||||
# via aioboto3
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.13.4
|
||||
aiohttp==3.13.3
|
||||
# via
|
||||
# aiobotocore
|
||||
# discord-py
|
||||
|
||||
@@ -12,7 +12,7 @@ aiofiles==25.1.0
|
||||
# via aioboto3
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.13.4
|
||||
aiohttp==3.13.3
|
||||
# via
|
||||
# aiobotocore
|
||||
# discord-py
|
||||
|
||||
@@ -5,7 +5,6 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import asdict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
@@ -28,9 +27,6 @@ INTERNAL_SEARCH_TOOL_NAME = "internal_search"
|
||||
INTERNAL_SEARCH_IN_CODE_TOOL_ID = "SearchTool"
|
||||
MAX_REQUEST_ATTEMPTS = 5
|
||||
RETRIABLE_STATUS_CODES = {429, 500, 502, 503, 504}
|
||||
QUESTION_TIMEOUT_SECONDS = 300
|
||||
QUESTION_RETRY_PAUSE_SECONDS = 30
|
||||
MAX_QUESTION_ATTEMPTS = 3
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -113,27 +109,6 @@ def normalize_api_base(api_base: str) -> str:
|
||||
return f"{normalized}/api"
|
||||
|
||||
|
||||
def load_completed_question_ids(output_file: Path) -> set[str]:
|
||||
if not output_file.exists():
|
||||
return set()
|
||||
|
||||
completed_ids: set[str] = set()
|
||||
with output_file.open("r", encoding="utf-8") as file:
|
||||
for line in file:
|
||||
stripped = line.strip()
|
||||
if not stripped:
|
||||
continue
|
||||
try:
|
||||
record = json.loads(stripped)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
question_id = record.get("question_id")
|
||||
if isinstance(question_id, str) and question_id:
|
||||
completed_ids.add(question_id)
|
||||
|
||||
return completed_ids
|
||||
|
||||
|
||||
def load_questions(questions_file: Path) -> list[QuestionRecord]:
|
||||
if not questions_file.exists():
|
||||
raise FileNotFoundError(f"Questions file not found: {questions_file}")
|
||||
@@ -373,7 +348,6 @@ async def generate_answers(
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
parallelism: int,
|
||||
skipped: int,
|
||||
) -> None:
|
||||
if parallelism < 1:
|
||||
raise ValueError("`--parallelism` must be at least 1.")
|
||||
@@ -408,178 +382,58 @@ async def generate_answers(
|
||||
write_lock = asyncio.Lock()
|
||||
completed = 0
|
||||
successful = 0
|
||||
stuck_count = 0
|
||||
failed_questions: list[FailedQuestionRecord] = []
|
||||
remaining_count = len(questions)
|
||||
overall_total = remaining_count + skipped
|
||||
question_durations: list[float] = []
|
||||
run_start_time = time.monotonic()
|
||||
|
||||
def print_progress() -> None:
|
||||
avg_time = (
|
||||
sum(question_durations) / len(question_durations)
|
||||
if question_durations
|
||||
else 0.0
|
||||
)
|
||||
elapsed = time.monotonic() - run_start_time
|
||||
eta = avg_time * (remaining_count - completed) / max(parallelism, 1)
|
||||
|
||||
done = skipped + completed
|
||||
bar_width = 30
|
||||
filled = (
|
||||
int(bar_width * done / overall_total)
|
||||
if overall_total
|
||||
else bar_width
|
||||
)
|
||||
bar = "█" * filled + "░" * (bar_width - filled)
|
||||
pct = (done / overall_total * 100) if overall_total else 100.0
|
||||
|
||||
parts = (
|
||||
f"\r{bar} {pct:5.1f}% "
|
||||
f"[{done}/{overall_total}] "
|
||||
f"avg {avg_time:.1f}s/q "
|
||||
f"elapsed {elapsed:.0f}s "
|
||||
f"ETA {eta:.0f}s "
|
||||
f"(ok:{successful} fail:{len(failed_questions)}"
|
||||
)
|
||||
if stuck_count:
|
||||
parts += f" stuck:{stuck_count}"
|
||||
if skipped:
|
||||
parts += f" skip:{skipped}"
|
||||
parts += ")"
|
||||
|
||||
sys.stderr.write(parts)
|
||||
sys.stderr.flush()
|
||||
|
||||
print_progress()
|
||||
total = len(questions)
|
||||
|
||||
async def process_question(question_record: QuestionRecord) -> None:
|
||||
nonlocal completed
|
||||
nonlocal successful
|
||||
nonlocal stuck_count
|
||||
|
||||
last_error: Exception | None = None
|
||||
for attempt in range(1, MAX_QUESTION_ATTEMPTS + 1):
|
||||
q_start = time.monotonic()
|
||||
try:
|
||||
async with semaphore:
|
||||
result = await asyncio.wait_for(
|
||||
submit_question(
|
||||
session=session,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
internal_search_tool_id=internal_search_tool_id,
|
||||
question_record=question_record,
|
||||
),
|
||||
timeout=QUESTION_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
async with progress_lock:
|
||||
stuck_count += 1
|
||||
logger.warning(
|
||||
"Question %s timed out after %ss (attempt %s/%s, "
|
||||
"total stuck: %s) — retrying in %ss",
|
||||
question_record.question_id,
|
||||
QUESTION_TIMEOUT_SECONDS,
|
||||
attempt,
|
||||
MAX_QUESTION_ATTEMPTS,
|
||||
stuck_count,
|
||||
QUESTION_RETRY_PAUSE_SECONDS,
|
||||
)
|
||||
print_progress()
|
||||
last_error = TimeoutError(
|
||||
f"Timed out after {QUESTION_TIMEOUT_SECONDS}s "
|
||||
f"on attempt {attempt}/{MAX_QUESTION_ATTEMPTS}"
|
||||
try:
|
||||
async with semaphore:
|
||||
result = await submit_question(
|
||||
session=session,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
internal_search_tool_id=internal_search_tool_id,
|
||||
question_record=question_record,
|
||||
)
|
||||
await asyncio.sleep(QUESTION_RETRY_PAUSE_SECONDS)
|
||||
continue
|
||||
except Exception as exc:
|
||||
duration = time.monotonic() - q_start
|
||||
async with progress_lock:
|
||||
completed += 1
|
||||
question_durations.append(duration)
|
||||
failed_questions.append(
|
||||
FailedQuestionRecord(
|
||||
question_id=question_record.question_id,
|
||||
error=str(exc),
|
||||
)
|
||||
)
|
||||
logger.exception(
|
||||
"Failed question %s (%s/%s)",
|
||||
question_record.question_id,
|
||||
completed,
|
||||
remaining_count,
|
||||
)
|
||||
print_progress()
|
||||
return
|
||||
|
||||
duration = time.monotonic() - q_start
|
||||
|
||||
async with write_lock:
|
||||
file.write(json.dumps(asdict(result), ensure_ascii=False))
|
||||
file.write("\n")
|
||||
file.flush()
|
||||
|
||||
except Exception as exc:
|
||||
async with progress_lock:
|
||||
completed += 1
|
||||
successful += 1
|
||||
question_durations.append(duration)
|
||||
print_progress()
|
||||
failed_questions.append(
|
||||
FailedQuestionRecord(
|
||||
question_id=question_record.question_id,
|
||||
error=str(exc),
|
||||
)
|
||||
)
|
||||
logger.exception(
|
||||
"Failed question %s (%s/%s)",
|
||||
question_record.question_id,
|
||||
completed,
|
||||
total,
|
||||
)
|
||||
return
|
||||
|
||||
# All attempts exhausted due to timeouts
|
||||
async with write_lock:
|
||||
file.write(json.dumps(asdict(result), ensure_ascii=False))
|
||||
file.write("\n")
|
||||
file.flush()
|
||||
|
||||
async with progress_lock:
|
||||
completed += 1
|
||||
failed_questions.append(
|
||||
FailedQuestionRecord(
|
||||
question_id=question_record.question_id,
|
||||
error=str(last_error),
|
||||
)
|
||||
)
|
||||
logger.error(
|
||||
"Question %s failed after %s timeout attempts (%s/%s)",
|
||||
question_record.question_id,
|
||||
MAX_QUESTION_ATTEMPTS,
|
||||
completed,
|
||||
remaining_count,
|
||||
)
|
||||
print_progress()
|
||||
successful += 1
|
||||
logger.info("Processed %s/%s questions", completed, total)
|
||||
|
||||
await asyncio.gather(
|
||||
*(process_question(question_record) for question_record in questions)
|
||||
)
|
||||
|
||||
# Final newline after progress bar
|
||||
sys.stderr.write("\n")
|
||||
sys.stderr.flush()
|
||||
|
||||
total_elapsed = time.monotonic() - run_start_time
|
||||
avg_time = (
|
||||
sum(question_durations) / len(question_durations)
|
||||
if question_durations
|
||||
else 0.0
|
||||
)
|
||||
stuck_suffix = f", {stuck_count} stuck timeouts" if stuck_count else ""
|
||||
resume_suffix = (
|
||||
f" — {skipped} previously completed, "
|
||||
f"{skipped + successful}/{overall_total} overall"
|
||||
if skipped
|
||||
else ""
|
||||
)
|
||||
logger.info(
|
||||
"Done: %s/%s successful in %.1fs (avg %.1fs/question%s)%s",
|
||||
successful,
|
||||
remaining_count,
|
||||
total_elapsed,
|
||||
avg_time,
|
||||
stuck_suffix,
|
||||
resume_suffix,
|
||||
)
|
||||
|
||||
if failed_questions:
|
||||
logger.warning(
|
||||
"%s questions failed:",
|
||||
"Completed with %s failed questions and %s successful questions.",
|
||||
len(failed_questions),
|
||||
successful,
|
||||
)
|
||||
for failed_question in failed_questions:
|
||||
logger.warning(
|
||||
@@ -599,30 +453,7 @@ def main() -> None:
|
||||
raise ValueError("`--max-questions` must be at least 1 when provided.")
|
||||
questions = questions[: args.max_questions]
|
||||
|
||||
completed_ids = load_completed_question_ids(args.output_file)
|
||||
logger.info(
|
||||
"Found %s already-answered question IDs in %s",
|
||||
len(completed_ids),
|
||||
args.output_file,
|
||||
)
|
||||
total_before_filter = len(questions)
|
||||
questions = [q for q in questions if q.question_id not in completed_ids]
|
||||
skipped = total_before_filter - len(questions)
|
||||
|
||||
if skipped:
|
||||
logger.info(
|
||||
"Resuming: %s/%s already answered, %s remaining",
|
||||
skipped,
|
||||
total_before_filter,
|
||||
len(questions),
|
||||
)
|
||||
else:
|
||||
logger.info("Loaded %s questions from %s", len(questions), args.questions_file)
|
||||
|
||||
if not questions:
|
||||
logger.info("All questions already answered. Nothing to do.")
|
||||
return
|
||||
|
||||
logger.info("Loaded %s questions from %s", len(questions), args.questions_file)
|
||||
logger.info("Writing answers to %s", args.output_file)
|
||||
|
||||
asyncio.run(
|
||||
@@ -632,7 +463,6 @@ def main() -> None:
|
||||
api_base=api_base,
|
||||
api_key=args.api_key,
|
||||
parallelism=args.parallelism,
|
||||
skipped=skipped,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.google_drive.file_retrieval import DriveFileFieldType
|
||||
from onyx.connectors.google_drive.file_retrieval import has_link_only_permission
|
||||
from onyx.connectors.google_drive.models import DriveRetrievalStage
|
||||
from onyx.connectors.google_drive.models import RetrievedDriveFile
|
||||
@@ -73,8 +75,10 @@ def test_connector_skips_link_only_files_when_enabled() -> None:
|
||||
retrieved_file = _build_retrieved_file(
|
||||
[{"type": "domain", "allowFileDiscovery": False}]
|
||||
)
|
||||
fetch_mock = MagicMock(return_value=iter([retrieved_file]))
|
||||
|
||||
with (
|
||||
patch.object(connector, "_fetch_drive_items", fetch_mock),
|
||||
patch(
|
||||
"onyx.connectors.google_drive.connector.run_functions_tuples_in_parallel",
|
||||
side_effect=_stub_run_functions,
|
||||
@@ -89,16 +93,21 @@ def test_connector_skips_link_only_files_when_enabled() -> None:
|
||||
convert_mock.return_value = "doc"
|
||||
checkpoint = connector.build_dummy_checkpoint()
|
||||
results = list(
|
||||
connector._convert_retrieved_files_to_documents(
|
||||
drive_files_iter=iter([retrieved_file]),
|
||||
connector._extract_docs_from_google_drive(
|
||||
checkpoint=checkpoint,
|
||||
start=None,
|
||||
end=None,
|
||||
include_permissions=False,
|
||||
)
|
||||
)
|
||||
|
||||
assert results == []
|
||||
convert_mock.assert_not_called()
|
||||
fetch_mock.assert_called_once()
|
||||
get_new_ancestors_mock.assert_called_once()
|
||||
assert (
|
||||
fetch_mock.call_args.kwargs["field_type"] == DriveFileFieldType.WITH_PERMISSIONS
|
||||
)
|
||||
|
||||
|
||||
def test_connector_processes_files_when_option_disabled() -> None:
|
||||
@@ -106,8 +115,10 @@ def test_connector_processes_files_when_option_disabled() -> None:
|
||||
retrieved_file = _build_retrieved_file(
|
||||
[{"type": "domain", "allowFileDiscovery": False}]
|
||||
)
|
||||
fetch_mock = MagicMock(return_value=iter([retrieved_file]))
|
||||
|
||||
with (
|
||||
patch.object(connector, "_fetch_drive_items", fetch_mock),
|
||||
patch(
|
||||
"onyx.connectors.google_drive.connector.run_functions_tuples_in_parallel",
|
||||
side_effect=_stub_run_functions,
|
||||
@@ -122,13 +133,16 @@ def test_connector_processes_files_when_option_disabled() -> None:
|
||||
convert_mock.return_value = "doc"
|
||||
checkpoint = connector.build_dummy_checkpoint()
|
||||
results = list(
|
||||
connector._convert_retrieved_files_to_documents(
|
||||
drive_files_iter=iter([retrieved_file]),
|
||||
connector._extract_docs_from_google_drive(
|
||||
checkpoint=checkpoint,
|
||||
start=None,
|
||||
end=None,
|
||||
include_permissions=False,
|
||||
)
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
convert_mock.assert_called_once()
|
||||
fetch_mock.assert_called_once()
|
||||
get_new_ancestors_mock.assert_called_once()
|
||||
assert fetch_mock.call_args.kwargs["field_type"] == DriveFileFieldType.STANDARD
|
||||
|
||||
@@ -27,13 +27,11 @@ def create_placement(
|
||||
turn_index: int,
|
||||
tab_index: int = 0,
|
||||
sub_turn_index: int | None = None,
|
||||
model_index: int | None = 0,
|
||||
) -> Placement:
|
||||
return Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
model_index=model_index,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
@@ -53,12 +52,7 @@ def tenant_context() -> Generator[None, None, None]:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def create_test_user(
|
||||
db_session: Session,
|
||||
email_prefix: str,
|
||||
role: UserRole = UserRole.BASIC,
|
||||
account_type: AccountType = AccountType.STANDARD,
|
||||
) -> User:
|
||||
def create_test_user(db_session: Session, email_prefix: str) -> User:
|
||||
"""Helper to create a test user with a unique email"""
|
||||
# Use UUID to ensure unique email addresses
|
||||
unique_email = f"{email_prefix}_{uuid4().hex[:8]}@example.com"
|
||||
@@ -74,8 +68,7 @@ def create_test_user(
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=role,
|
||||
account_type=account_type,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
@@ -13,29 +13,16 @@ from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.models import Connector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import PublicExternalUserGroup
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__ExternalUserGroupId
|
||||
from onyx.db.models import UserRole
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
|
||||
def _create_ext_perm_user(db_session: Session, name: str) -> User:
|
||||
"""Create an external-permission user for group sync tests."""
|
||||
return create_test_user(
|
||||
db_session,
|
||||
name,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
account_type=AccountType.EXT_PERM_USER,
|
||||
)
|
||||
|
||||
|
||||
def _create_test_connector_credential_pair(
|
||||
db_session: Session, source: DocumentSource = DocumentSource.GOOGLE_DRIVE
|
||||
) -> ConnectorCredentialPair:
|
||||
@@ -113,9 +100,9 @@ class TestPerformExternalGroupSync:
|
||||
def test_initial_group_sync(self, db_session: Session) -> None:
|
||||
"""Test syncing external groups for the first time (initial sync)"""
|
||||
# Create test data
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user2 = _create_ext_perm_user(db_session, "user2")
|
||||
user3 = _create_ext_perm_user(db_session, "user3")
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user2 = create_test_user(db_session, "user2")
|
||||
user3 = create_test_user(db_session, "user3")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
# Mock external groups data as a generator that yields the expected groups
|
||||
@@ -188,9 +175,9 @@ class TestPerformExternalGroupSync:
|
||||
def test_update_existing_groups(self, db_session: Session) -> None:
|
||||
"""Test updating existing groups (adding/removing users)"""
|
||||
# Create test data
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user2 = _create_ext_perm_user(db_session, "user2")
|
||||
user3 = _create_ext_perm_user(db_session, "user3")
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user2 = create_test_user(db_session, "user2")
|
||||
user3 = create_test_user(db_session, "user3")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
# Initial sync with original groups
|
||||
@@ -285,8 +272,8 @@ class TestPerformExternalGroupSync:
|
||||
def test_remove_groups(self, db_session: Session) -> None:
|
||||
"""Test removing groups (groups that no longer exist in external system)"""
|
||||
# Create test data
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user2 = _create_ext_perm_user(db_session, "user2")
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user2 = create_test_user(db_session, "user2")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
# Initial sync with multiple groups
|
||||
@@ -370,7 +357,7 @@ class TestPerformExternalGroupSync:
|
||||
def test_empty_group_sync(self, db_session: Session) -> None:
|
||||
"""Test syncing when no groups are returned (all groups removed)"""
|
||||
# Create test data
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
# Initial sync with groups
|
||||
@@ -426,7 +413,7 @@ class TestPerformExternalGroupSync:
|
||||
# Create many test users
|
||||
users = []
|
||||
for i in range(150): # More than the batch size of 100
|
||||
users.append(_create_ext_perm_user(db_session, f"user{i}"))
|
||||
users.append(create_test_user(db_session, f"user{i}"))
|
||||
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
@@ -465,8 +452,8 @@ class TestPerformExternalGroupSync:
|
||||
def test_mixed_regular_and_public_groups(self, db_session: Session) -> None:
|
||||
"""Test syncing a mix of regular and public groups"""
|
||||
# Create test data
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user2 = _create_ext_perm_user(db_session, "user2")
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user2 = create_test_user(db_session, "user2")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
def mixed_group_sync_func(
|
||||
|
||||
@@ -9,7 +9,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import BuildSessionStatus
|
||||
from onyx.db.models import BuildSession
|
||||
from onyx.db.models import User
|
||||
@@ -53,7 +52,6 @@ def test_user(db_session: Session, tenant_context: None) -> User: # noqa: ARG00
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
account_type=AccountType.EXT_PERM_USER,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
"""
|
||||
Tests that account_type is correctly set when creating users through
|
||||
the internal DB functions: add_slack_user_if_not_exists and
|
||||
batch_add_ext_perm_user_if_not_exists.
|
||||
|
||||
These functions are called by background workers (Slack bot, permission sync)
|
||||
and are not exposed via API endpoints, so they must be tested directly.
|
||||
"""
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.users import add_slack_user_if_not_exists
|
||||
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
|
||||
|
||||
def test_slack_user_creation_sets_account_type_bot(db_session: Session) -> None:
|
||||
"""add_slack_user_if_not_exists sets account_type=BOT and role=SLACK_USER."""
|
||||
user = add_slack_user_if_not_exists(db_session, "slack_acct_type@test.com")
|
||||
|
||||
assert user.role == UserRole.SLACK_USER
|
||||
assert user.account_type == AccountType.BOT
|
||||
|
||||
|
||||
def test_ext_perm_user_creation_sets_account_type(db_session: Session) -> None:
|
||||
"""batch_add_ext_perm_user_if_not_exists sets account_type=EXT_PERM_USER."""
|
||||
users = batch_add_ext_perm_user_if_not_exists(
|
||||
db_session, ["extperm_acct_type@test.com"]
|
||||
)
|
||||
|
||||
assert len(users) == 1
|
||||
user = users[0]
|
||||
assert user.role == UserRole.EXT_PERM_USER
|
||||
assert user.account_type == AccountType.EXT_PERM_USER
|
||||
|
||||
|
||||
def test_ext_perm_to_slack_upgrade_updates_role_and_account_type(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""When an EXT_PERM_USER is upgraded to slack, both role and account_type update."""
|
||||
email = "ext_to_slack_acct_type@test.com"
|
||||
|
||||
# Create as ext_perm user first
|
||||
batch_add_ext_perm_user_if_not_exists(db_session, [email])
|
||||
|
||||
# Now "upgrade" via slack path
|
||||
user = add_slack_user_if_not_exists(db_session, email)
|
||||
|
||||
assert user.role == UserRole.SLACK_USER
|
||||
assert user.account_type == AccountType.BOT
|
||||
@@ -8,7 +8,6 @@ import pytest
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
@@ -47,7 +46,6 @@ def _create_admin(db_session: Session) -> User:
|
||||
is_superuser=True,
|
||||
is_verified=True,
|
||||
role=UserRole.ADMIN,
|
||||
account_type=AccountType.STANDARD,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
@@ -13,7 +13,6 @@ This test:
|
||||
All external HTTP calls are mocked, but Postgres and Redis are running.
|
||||
"""
|
||||
|
||||
import queue
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
@@ -21,7 +20,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.db.enums import MCPAuthenticationPerformer
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
from onyx.db.enums import MCPTransport
|
||||
@@ -138,7 +137,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
emitter=get_default_emitter(),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=search_tool_config,
|
||||
@@ -201,7 +200,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
emitter=get_default_emitter(),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -276,7 +275,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
emitter=get_default_emitter(),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -351,7 +350,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
emitter=get_default_emitter(),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -459,7 +458,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
emitter=get_default_emitter(),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
@@ -542,7 +541,7 @@ class TestMCPPassThroughOAuth:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
emitter=get_default_emitter(),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(),
|
||||
|
||||
@@ -8,7 +8,6 @@ Tests the priority logic for OAuth tokens when constructing custom tools:
|
||||
All external HTTP calls are mocked, but Postgres and Redis are running.
|
||||
"""
|
||||
|
||||
import queue
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
@@ -17,7 +16,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import OAuthConfig
|
||||
from onyx.db.models import Persona
|
||||
@@ -175,7 +174,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
emitter=get_default_emitter(),
|
||||
user=user,
|
||||
llm=llm,
|
||||
search_tool_config=search_tool_config,
|
||||
@@ -233,7 +232,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
emitter=get_default_emitter(),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -285,7 +284,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
emitter=get_default_emitter(),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -346,7 +345,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
emitter=get_default_emitter(),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -417,7 +416,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
emitter=get_default_emitter(),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -484,7 +483,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
emitter=get_default_emitter(),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
@@ -537,7 +536,7 @@ class TestOAuthToolIntegrationPriority:
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
emitter=Emitter(merged_queue=queue.Queue()),
|
||||
emitter=get_default_emitter(),
|
||||
user=user,
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user