Compare commits

...

3 Commits

Author SHA1 Message Date
Wenxi
ed40cbdd00 chore: hotfix/v2.0.0 beta.3 (#5715)
Co-authored-by: Chris Weaver <chris@onyx.app>
Co-authored-by: Evan Lohn <evan@danswer.ai>
Co-authored-by: edwin-onyx <edwin@onyx.app>
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
Co-authored-by: SubashMohan <subashmohan75@gmail.com>
Co-authored-by: trial-danswer <trial@danswer.ai>
Co-authored-by: Richard Guan <41275416+rguan72@users.noreply.github.com>
Co-authored-by: Justin Tahara <105671973+justin-tahara@users.noreply.github.com>
Co-authored-by: Nils <94993442+nsklei@users.noreply.github.com>
Co-authored-by: nsklei <nils.kleinrahm@pledoc.de>
Co-authored-by: Shahar Mazor <103638798+Django149@users.noreply.github.com>
Co-authored-by: Raunak Bhagat <r@rabh.io>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: Paulius Klyvis <grafke@users.noreply.github.com>
Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
Co-authored-by: Nikolas Garza <nikolas@Nikolass-MacBook-Pro.attlocal.net>
Co-authored-by: Edwin Luo <edwin@parafin.com>
2025-10-14 12:29:51 -07:00
Wenxi
b36910240d chore: Hotfix v2.0.0-beta.2 (#5658)
Co-authored-by: Chris Weaver <chris@onyx.app>
Co-authored-by: Evan Lohn <evan@danswer.ai>
Co-authored-by: edwin-onyx <edwin@onyx.app>
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
Co-authored-by: SubashMohan <subashmohan75@gmail.com>
Co-authored-by: trial-danswer <trial@danswer.ai>
Co-authored-by: Richard Guan <41275416+rguan72@users.noreply.github.com>
Co-authored-by: Justin Tahara <105671973+justin-tahara@users.noreply.github.com>
Co-authored-by: Nils <94993442+nsklei@users.noreply.github.com>
Co-authored-by: nsklei <nils.kleinrahm@pledoc.de>
Co-authored-by: Shahar Mazor <103638798+Django149@users.noreply.github.com>
Co-authored-by: Raunak Bhagat <r@rabh.io>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: Paulius Klyvis <grafke@users.noreply.github.com>
2025-10-07 18:30:48 -07:00
Wenxi
488b27ba04 chore: hotfix v2.0.0 beta.1 (#5616)
Co-authored-by: Chris Weaver <chris@onyx.app>
Co-authored-by: Evan Lohn <evan@danswer.ai>
Co-authored-by: edwin-onyx <edwin@onyx.app>
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
Co-authored-by: SubashMohan <subashmohan75@gmail.com>
Co-authored-by: Richard Guan <41275416+rguan72@users.noreply.github.com>
Co-authored-by: Justin Tahara <105671973+justin-tahara@users.noreply.github.com>
2025-10-07 17:08:17 -07:00
643 changed files with 26436 additions and 47842 deletions

View File

@@ -8,9 +8,9 @@ on:
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
# don't tag cloud images with "latest"
LATEST_TAG: ${{ contains(github.ref_name, 'latest') && !contains(github.ref_name, 'cloud') }}
# tag nightly builds with "edge"
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
jobs:
build-and-push:
@@ -33,7 +33,16 @@ jobs:
run: |
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Check if stable release version
id: check_version
run: |
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]] && [[ "${{ github.ref_name }}" != *"cloud"* ]]; then
echo "is_stable=true" >> $GITHUB_OUTPUT
else
echo "is_stable=false" >> $GITHUB_OUTPUT
fi
- name: Checkout code
uses: actions/checkout@v4
@@ -46,7 +55,8 @@ jobs:
latest=false
tags: |
type=raw,value=${{ github.ref_name }}
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -119,7 +129,8 @@ jobs:
latest=false
tags: |
type=raw,value=${{ github.ref_name }}
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
- name: Login to Docker Hub
uses: docker/login-action@v3

View File

@@ -11,8 +11,8 @@ env:
BUILDKIT_PROGRESS: plain
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
# don't tag cloud images with "latest"
LATEST_TAG: ${{ contains(github.ref_name, 'latest') && !contains(github.ref_name, 'cloud') }}
# tag nightly builds with "edge"
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
jobs:
@@ -145,6 +145,15 @@ jobs:
if: needs.check_model_server_changes.outputs.changed == 'true'
runs-on: ubuntu-latest
steps:
- name: Check if stable release version
id: check_version
run: |
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]] && [[ "${{ github.ref_name }}" != *"cloud"* ]]; then
echo "is_stable=true" >> $GITHUB_OUTPUT
else
echo "is_stable=false" >> $GITHUB_OUTPUT
fi
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
@@ -157,11 +166,16 @@ jobs:
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }} \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
if [[ "${{ env.LATEST_TAG }}" == "true" ]]; then
if [[ "${{ steps.check_version.outputs.is_stable }}" == "true" ]]; then
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:latest \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
fi
if [[ "${{ env.EDGE_TAG }}" == "true" ]]; then
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:edge \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
fi
- name: Run Trivy vulnerability scanner
uses: nick-fields/retry@v3

View File

@@ -7,7 +7,10 @@ on:
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
# tag nightly builds with "edge"
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
DEPLOYMENT: standalone
jobs:
@@ -45,6 +48,15 @@ jobs:
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Check if stable release version
id: check_version
run: |
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
echo "is_stable=true" >> $GITHUB_OUTPUT
else
echo "is_stable=false" >> $GITHUB_OUTPUT
fi
- name: Checkout
uses: actions/checkout@v4
@@ -57,7 +69,8 @@ jobs:
latest=false
tags: |
type=raw,value=${{ github.ref_name }}
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -126,7 +139,8 @@ jobs:
latest=false
tags: |
type=raw,value=${{ github.ref_name }}
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
- name: Login to Docker Hub
uses: docker/login-action@v3

View File

@@ -25,9 +25,11 @@ jobs:
- name: Add required Helm repositories
run: |
helm repo add bitnami https://charts.bitnami.com/bitnami
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
helm repo add onyx-vespa https://onyx-dot-app.github.io/vespa-helm-charts
helm repo add keda https://kedacore.github.io/charts
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo update
- name: Build chart dependencies

View File

@@ -20,6 +20,7 @@ env:
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
# LLMs
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}

View File

@@ -65,35 +65,45 @@ jobs:
if: steps.list-changed.outputs.changed == 'true'
run: |
echo "=== Adding Helm repositories ==="
helm repo add bitnami https://charts.bitnami.com/bitnami
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo update
- name: Pre-pull critical images
- name: Install Redis operator
if: steps.list-changed.outputs.changed == 'true'
shell: bash
run: |
echo "=== Installing redis-operator CRDs ==="
helm upgrade --install redis-operator ot-container-kit/redis-operator \
--namespace redis-operator --create-namespace --wait --timeout 300s
- name: Pre-pull required images
if: steps.list-changed.outputs.changed == 'true'
run: |
echo "=== Pre-pulling critical images to avoid timeout ==="
# Get kind cluster name
echo "=== Pre-pulling required images to avoid timeout ==="
KIND_CLUSTER=$(kubectl config current-context | sed 's/kind-//')
echo "Kind cluster: $KIND_CLUSTER"
# Pre-pull images that are likely to be used
echo "Pre-pulling PostgreSQL image..."
docker pull postgres:15-alpine || echo "Failed to pull postgres:15-alpine"
kind load docker-image postgres:15-alpine --name $KIND_CLUSTER || echo "Failed to load postgres image"
echo "Pre-pulling Redis image..."
docker pull redis:7-alpine || echo "Failed to pull redis:7-alpine"
kind load docker-image redis:7-alpine --name $KIND_CLUSTER || echo "Failed to load redis image"
echo "Pre-pulling Onyx images..."
docker pull docker.io/onyxdotapp/onyx-web-server:latest || echo "Failed to pull onyx web server"
docker pull docker.io/onyxdotapp/onyx-backend:latest || echo "Failed to pull onyx backend"
kind load docker-image docker.io/onyxdotapp/onyx-web-server:latest --name $KIND_CLUSTER || echo "Failed to load onyx web server"
kind load docker-image docker.io/onyxdotapp/onyx-backend:latest --name $KIND_CLUSTER || echo "Failed to load onyx backend"
IMAGES=(
"ghcr.io/cloudnative-pg/cloudnative-pg:1.27.0"
"quay.io/opstree/redis:v7.0.15"
"docker.io/onyxdotapp/onyx-web-server:latest"
)
for image in "${IMAGES[@]}"; do
echo "Pre-pulling $image"
if docker pull "$image"; then
kind load docker-image "$image" --name "$KIND_CLUSTER" || echo "Failed to load $image into kind"
else
echo "Failed to pull $image"
fi
done
echo "=== Images loaded into Kind cluster ==="
docker exec $KIND_CLUSTER-control-plane crictl images | grep -E "(postgres|redis|onyx)" || echo "Some images may still be loading..."
docker exec "$KIND_CLUSTER"-control-plane crictl images | grep -E "(cloudnative-pg|redis|onyx)" || echo "Some images may still be loading..."
- name: Validate chart dependencies
if: steps.list-changed.outputs.changed == 'true'
@@ -149,6 +159,7 @@ jobs:
# Run the actual installation with detailed logging
echo "=== Starting ct install ==="
set +e
ct install --all \
--helm-extra-set-args="\
--set=nginx.enabled=false \
@@ -156,8 +167,10 @@ jobs:
--set=vespa.enabled=false \
--set=slackbot.enabled=false \
--set=postgresql.enabled=true \
--set=postgresql.primary.persistence.enabled=false \
--set=postgresql.nameOverride=cloudnative-pg \
--set=postgresql.cluster.storage.storageClass=standard \
--set=redis.enabled=true \
--set=redis.storageSpec.volumeClaimTemplate.spec.storageClassName=standard \
--set=webserver.replicaCount=1 \
--set=api.replicaCount=0 \
--set=inferenceCapability.replicaCount=0 \
@@ -173,8 +186,16 @@ jobs:
--set=celery_worker_user_files_indexing.replicaCount=0" \
--helm-extra-args="--timeout 900s --debug" \
--debug --config ct.yaml
echo "=== Installation completed successfully ==="
CT_EXIT=$?
set -e
if [[ $CT_EXIT -ne 0 ]]; then
echo "ct install failed with exit code $CT_EXIT"
exit $CT_EXIT
else
echo "=== Installation completed successfully ==="
fi
kubectl get pods --all-namespaces
- name: Post-install verification
@@ -199,7 +220,7 @@ jobs:
echo "=== Recent logs for debugging ==="
kubectl logs --all-namespaces --tail=50 | grep -i "error\|timeout\|failed\|pull" || echo "No error logs found"
echo "=== Helm releases ==="
helm list --all-namespaces
# the following would install only changed charts, but we only have one chart so

View File

@@ -22,9 +22,11 @@ env:
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
@@ -131,6 +133,7 @@ jobs:
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}
push: true
outputs: type=registry
no-cache: true
build-model-server-image:
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
@@ -158,6 +161,7 @@ jobs:
push: true
outputs: type=registry
provenance: false
no-cache: true
build-integration-image:
needs: prepare-build
@@ -191,6 +195,7 @@ jobs:
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
push: true
outputs: type=registry
no-cache: true
integration-tests:
needs:
@@ -337,9 +342,11 @@ jobs:
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e CONFLUENCE_ACCESS_TOKEN_SCOPED=${CONFLUENCE_ACCESS_TOKEN_SCOPED} \
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
-e JIRA_API_TOKEN_SCOPED=${JIRA_API_TOKEN_SCOPED} \
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \

View File

@@ -19,9 +19,11 @@ env:
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
@@ -128,6 +130,7 @@ jobs:
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}
push: true
outputs: type=registry
no-cache: true
build-model-server-image:
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
@@ -155,6 +158,7 @@ jobs:
push: true
outputs: type=registry
provenance: false
no-cache: true
build-integration-image:
needs: prepare-build
@@ -188,6 +192,7 @@ jobs:
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
push: true
outputs: type=registry
no-cache: true
integration-tests-mit:
needs:
@@ -334,9 +339,11 @@ jobs:
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e CONFLUENCE_ACCESS_TOKEN_SCOPED=${CONFLUENCE_ACCESS_TOKEN_SCOPED} \
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
-e JIRA_API_TOKEN_SCOPED=${JIRA_API_TOKEN_SCOPED} \
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \

View File

@@ -20,11 +20,13 @@ env:
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
# Jira
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
# Gong
GONG_ACCESS_KEY: ${{ secrets.GONG_ACCESS_KEY }}
@@ -132,7 +134,24 @@ jobs:
playwright install chromium
playwright install-deps chromium
- name: Run Tests
- name: Detect Connector changes
id: changes
uses: dorny/paths-filter@v3
with:
filters: |
hubspot:
- 'backend/onyx/connectors/hubspot/**'
- 'backend/tests/daily/connectors/hubspot/**'
salesforce:
- 'backend/onyx/connectors/salesforce/**'
- 'backend/tests/daily/connectors/salesforce/**'
github:
- 'backend/onyx/connectors/github/**'
- 'backend/tests/daily/connectors/github/**'
file_processing:
- 'backend/onyx/file_processing/**'
- name: Run Tests (excluding HubSpot, Salesforce, and GitHub)
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: |
py.test \
@@ -142,7 +161,49 @@ jobs:
-o junit_family=xunit2 \
-xv \
--ff \
backend/tests/daily/connectors
backend/tests/daily/connectors \
--ignore backend/tests/daily/connectors/hubspot \
--ignore backend/tests/daily/connectors/salesforce \
--ignore backend/tests/daily/connectors/github
- name: Run HubSpot Connector Tests
if: ${{ github.event_name == 'schedule' || steps.changes.outputs.hubspot == 'true' || steps.changes.outputs.file_processing == 'true' }}
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: |
py.test \
-n 8 \
--dist loadfile \
--durations=8 \
-o junit_family=xunit2 \
-xv \
--ff \
backend/tests/daily/connectors/hubspot
- name: Run Salesforce Connector Tests
if: ${{ github.event_name == 'schedule' || steps.changes.outputs.salesforce == 'true' || steps.changes.outputs.file_processing == 'true' }}
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: |
py.test \
-n 8 \
--dist loadfile \
--durations=8 \
-o junit_family=xunit2 \
-xv \
--ff \
backend/tests/daily/connectors/salesforce
- name: Run GitHub Connector Tests
if: ${{ github.event_name == 'schedule' || steps.changes.outputs.github == 'true' || steps.changes.outputs.file_processing == 'true' }}
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: |
py.test \
-n 8 \
--dist loadfile \
--durations=8 \
-o junit_family=xunit2 \
-xv \
--ff \
backend/tests/daily/connectors/github
- name: Alert on Failure
if: failure() && github.event_name == 'schedule'

View File

@@ -34,8 +34,7 @@ repos:
hooks:
- id: prettier
types_or: [html, css, javascript, ts, tsx]
additional_dependencies:
- prettier
language_version: system
- repo: local
hooks:

View File

@@ -23,12 +23,10 @@
"Slack Bot",
"Celery primary",
"Celery light",
"Celery heavy",
"Celery background",
"Celery docfetching",
"Celery docprocessing",
"Celery beat",
"Celery monitoring",
"Celery user file processing"
"Celery beat"
],
"presentation": {
"group": "1"
@@ -42,16 +40,32 @@
}
},
{
"name": "Celery (all)",
"name": "Celery (lightweight mode)",
"configurations": [
"Celery primary",
"Celery light",
"Celery background",
"Celery docfetching",
"Celery docprocessing",
"Celery beat"
],
"presentation": {
"group": "1"
},
"stopAll": true
},
{
"name": "Celery (standard mode)",
"configurations": [
"Celery primary",
"Celery light",
"Celery heavy",
"Celery kg_processing",
"Celery monitoring",
"Celery user_file_processing",
"Celery docfetching",
"Celery docprocessing",
"Celery beat",
"Celery monitoring",
"Celery user file processing"
"Celery beat"
],
"presentation": {
"group": "1"
@@ -199,6 +213,35 @@
},
"consoleTitle": "Celery light Console"
},
{
"name": "Celery background",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.background",
"worker",
"--pool=threads",
"--concurrency=6",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=background@%n",
"-Q",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery background Console"
},
{
"name": "Celery heavy",
"type": "debugpy",
@@ -221,13 +264,100 @@
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync"
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery heavy Console"
},
{
"name": "Celery kg_processing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.kg_processing",
"worker",
"--pool=threads",
"--concurrency=2",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=kg_processing@%n",
"-Q",
"kg_processing"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery kg_processing Console"
},
{
"name": "Celery monitoring",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.monitoring",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=monitoring@%n",
"-Q",
"monitoring"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery monitoring Console"
},
{
"name": "Celery user_file_processing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.user_file_processing",
"worker",
"--pool=threads",
"--concurrency=2",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=user_file_processing@%n",
"-Q",
"user_file_processing,user_file_project_sync"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery user_file_processing Console"
},
{
"name": "Celery docfetching",
"type": "debugpy",
@@ -311,58 +441,6 @@
},
"consoleTitle": "Celery beat Console"
},
{
"name": "Celery monitoring",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {},
"args": [
"-A",
"onyx.background.celery.versioned_apps.monitoring",
"worker",
"--pool=solo",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=monitoring@%n",
"-Q",
"monitoring"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery monitoring Console"
},
{
"name": "Celery user file processing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"args": [
"-A",
"onyx.background.celery.versioned_apps.user_file_processing",
"worker",
"--loglevel=INFO",
"--hostname=user_file_processing@%n",
"--pool=threads",
"-Q",
"user_file_processing,user_file_project_sync"
],
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"presentation": {
"group": "2"
},
"consoleTitle": "Celery user file processing Console"
},
{
"name": "Pytest",
"consoleName": "Pytest",

View File

@@ -70,7 +70,12 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
- Single thread (monitoring doesn't need parallelism)
- Cloud-specific monitoring tasks
8. **Beat Worker** (`beat`)
8. **User File Processing Worker** (`user_file_processing`)
- Processes user-uploaded files
- Handles user file indexing and project synchronization
- Configurable concurrency
9. **Beat Worker** (`beat`)
- Celery's scheduler for periodic tasks
- Uses DynamicTenantScheduler for multi-tenant support
- Schedules tasks like:
@@ -82,6 +87,31 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
- Monitoring tasks (every 5 minutes)
- Cleanup tasks (hourly)
#### Worker Deployment Modes
Onyx supports two deployment modes for background workers, controlled by the `USE_LIGHTWEIGHT_BACKGROUND_WORKER` environment variable:
**Lightweight Mode** (default, `USE_LIGHTWEIGHT_BACKGROUND_WORKER=true`):
- Runs a single consolidated `background` worker that handles all background tasks:
- Pruning operations (from `heavy` worker)
- Knowledge graph processing (from `kg_processing` worker)
- Monitoring tasks (from `monitoring` worker)
- User file processing (from `user_file_processing` worker)
- Lower resource footprint (single worker process)
- Suitable for smaller deployments or development environments
- Default concurrency: 6 threads
**Standard Mode** (`USE_LIGHTWEIGHT_BACKGROUND_WORKER=false`):
- Runs separate specialized workers as documented above (heavy, kg_processing, monitoring, user_file_processing)
- Better isolation and scalability
- Can scale individual workers independently based on workload
- Suitable for production deployments with higher load
The deployment mode affects:
- **Backend**: Worker processes spawned by supervisord or dev scripts
- **Helm**: Which Kubernetes deployments are created
- **Dev Environment**: Which workers `dev_run_background_jobs.py` spawns
#### Key Features
- **Thread-based Workers**: All workers use thread pools (not processes) for stability

View File

@@ -70,7 +70,12 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
- Single thread (monitoring doesn't need parallelism)
- Cloud-specific monitoring tasks
8. **Beat Worker** (`beat`)
8. **User File Processing Worker** (`user_file_processing`)
- Processes user-uploaded files
- Handles user file indexing and project synchronization
- Configurable concurrency
9. **Beat Worker** (`beat`)
- Celery's scheduler for periodic tasks
- Uses DynamicTenantScheduler for multi-tenant support
- Schedules tasks like:
@@ -82,11 +87,36 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
- Monitoring tasks (every 5 minutes)
- Cleanup tasks (hourly)
#### Worker Deployment Modes
Onyx supports two deployment modes for background workers, controlled by the `USE_LIGHTWEIGHT_BACKGROUND_WORKER` environment variable:
**Lightweight Mode** (default, `USE_LIGHTWEIGHT_BACKGROUND_WORKER=true`):
- Runs a single consolidated `background` worker that handles all background tasks:
- Pruning operations (from `heavy` worker)
- Knowledge graph processing (from `kg_processing` worker)
- Monitoring tasks (from `monitoring` worker)
- User file processing (from `user_file_processing` worker)
- Lower resource footprint (single worker process)
- Suitable for smaller deployments or development environments
- Default concurrency: 6 threads
**Standard Mode** (`USE_LIGHTWEIGHT_BACKGROUND_WORKER=false`):
- Runs separate specialized workers as documented above (heavy, kg_processing, monitoring, user_file_processing)
- Better isolation and scalability
- Can scale individual workers independently based on workload
- Suitable for production deployments with higher load
The deployment mode affects:
- **Backend**: Worker processes spawned by supervisord or dev scripts
- **Helm**: Which Kubernetes deployments are created
- **Dev Environment**: Which workers `dev_run_background_jobs.py` spawns
#### Key Features
- **Thread-based Workers**: All workers use thread pools (not processes) for stability
- **Tenant Awareness**: Multi-tenant support with per-tenant task isolation. There is a
middleware layer that automatically finds the appropriate tenant ID when sending tasks
- **Tenant Awareness**: Multi-tenant support with per-tenant task isolation. There is a
middleware layer that automatically finds the appropriate tenant ID when sending tasks
via Celery Beat.
- **Task Prioritization**: High, Medium, Low priority queues
- **Monitoring**: Built-in heartbeat and liveness checking

View File

@@ -13,8 +13,7 @@ As an open source project in a rapidly changing space, we welcome all contributi
The [GitHub Issues](https://github.com/onyx-dot-app/onyx/issues) page is a great place to start for contribution ideas.
To ensure that your contribution is aligned with the project's direction, please reach out to any maintainer on the Onyx team
via [Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-34lu4m7xg-TsKGO6h8PDvR5W27zTdyhA) /
[Discord](https://discord.gg/TDJ59cGV2X) or [email](mailto:founders@onyx.app).
via [Discord](https://discord.gg/4NA5SbzrWb) or [email](mailto:hello@onyx.app).
Issues that have been explicitly approved by the maintainers (aligned with the direction of the project)
will be marked with the `approved by maintainers` label.
@@ -28,8 +27,7 @@ Your input is vital to making sure that Onyx moves in the right direction.
Before starting on implementation, please raise a GitHub issue.
Also, always feel free to message the founders (Chris Weaver / Yuhong Sun) on
[Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-34lu4m7xg-TsKGO6h8PDvR5W27zTdyhA) /
[Discord](https://discord.gg/TDJ59cGV2X) directly about anything at all.
[Discord](https://discord.gg/4NA5SbzrWb) directly about anything at all.
### Contributing Code
@@ -46,9 +44,7 @@ Our goal is to make contributing as easy as possible. If you run into any issues
That way we can help future contributors and users can avoid the same issue.
We also have support channels and generally interesting discussions on our
[Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA)
and
[Discord](https://discord.gg/TDJ59cGV2X).
[Discord](https://discord.gg/4NA5SbzrWb).
We would love to see you there!
@@ -105,6 +101,11 @@ pip install -r backend/requirements/ee.txt
pip install -r backend/requirements/model_server.txt
```
Fix vscode/cursor auto-imports:
```bash
pip install -e .
```
Install Playwright for Python (headless browser required by the Web Connector)
In the activated Python virtualenv, install Playwright for Python by running:
@@ -117,8 +118,15 @@ You may have to deactivate and reactivate your virtualenv for `playwright` to ap
#### Frontend: Node dependencies
Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend.
Once the above is done, navigate to `onyx/web` run:
Onyx uses Node v22.20.0. We highly recommend you use [Node Version Manager (nvm)](https://github.com/nvm-sh/nvm)
to manage your Node installations. Once installed, you can run
```bash
nvm install 22 && nvm use 22`
node -v # verify your active version
```
Navigate to `onyx/web` and run:
```bash
npm i
@@ -129,8 +137,6 @@ npm i
### Backend
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
First, install pre-commit (if you don't have it already) following the instructions
[here](https://pre-commit.com/#installation).
With the virtual environment active, install the pre-commit library with:
@@ -150,15 +156,17 @@ To run the mypy checks manually, run `python -m mypy .` from the `onyx/backend`
### Web
We use `prettier` for formatting. The desired version (2.8.8) will be installed via a `npm i` from the `onyx/web` directory.
We use `prettier` for formatting. The desired version will be installed via a `npm i` from the `onyx/web` directory.
To run the formatter, use `npx prettier --write .` from the `onyx/web` directory.
Please double check that prettier passes before creating a pull request.
Pre-commit will also run prettier automatically on files you've recently touched. If re-formatted, your commit will fail.
Re-stage your changes and commit again.
# Running the application for development
## Developing using VSCode Debugger (recommended)
We highly recommend using VSCode debugger for development.
**We highly recommend using VSCode debugger for development.**
See [CONTRIBUTING_VSCODE.md](./CONTRIBUTING_VSCODE.md) for more details.
Otherwise, you can follow the instructions below to run the application for development.

View File

@@ -21,6 +21,9 @@ Before starting, make sure the Docker Daemon is running.
5. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
6. Use the debug toolbar to step through code, inspect variables, etc.
Note: Clear and Restart External Volumes and Containers will reset your postgres and Vespa (relational-db and index).
Only run this if you are okay with wiping your data.
## Features
- Hot reload is enabled for the web server and API servers

View File

@@ -111,6 +111,8 @@ COPY ./static /app/static
# Escape hatch scripts
COPY ./scripts/debugging /app/scripts/debugging
COPY ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
COPY ./scripts/supervisord_entrypoint.sh /app/scripts/supervisord_entrypoint.sh
RUN chmod +x /app/scripts/supervisord_entrypoint.sh
# Put logo in assets
COPY ./assets /app/assets

View File

@@ -0,0 +1,28 @@
"""reset userfile document_id_migrated field
Revision ID: 40926a4dab77
Revises: 64bd5677aeb6
Create Date: 2025-10-06 16:10:32.898668
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "40926a4dab77"
down_revision = "64bd5677aeb6"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Set all existing records to not migrated
op.execute(
"UPDATE user_file SET document_id_migrated = FALSE "
"WHERE document_id_migrated IS DISTINCT FROM FALSE;"
)
def downgrade() -> None:
# No-op
pass

View File

@@ -0,0 +1,37 @@
"""Add image input support to model config
Revision ID: 64bd5677aeb6
Revises: b30353be4eec
Create Date: 2025-09-28 15:48:12.003612
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "64bd5677aeb6"
down_revision = "b30353be4eec"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"model_configuration",
sa.Column("supports_image_input", sa.Boolean(), nullable=True),
)
# Seems to be left over from when model visibility was introduced and a nullable field.
# Set any null is_visible values to False
connection = op.get_bind()
connection.execute(
sa.text(
"UPDATE model_configuration SET is_visible = false WHERE is_visible IS NULL"
)
)
def downgrade() -> None:
op.drop_column("model_configuration", "supports_image_input")

View File

@@ -0,0 +1,45 @@
"""mcp_tool_enabled
Revision ID: 96a5702df6aa
Revises: 40926a4dab77
Create Date: 2025-10-09 12:10:21.733097
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "96a5702df6aa"
down_revision = "40926a4dab77"
branch_labels = None
depends_on = None
DELETE_DISABLED_TOOLS_SQL = "DELETE FROM tool WHERE enabled = false"
def upgrade() -> None:
op.add_column(
"tool",
sa.Column(
"enabled",
sa.Boolean(),
nullable=False,
server_default=sa.true(),
),
)
op.create_index(
"ix_tool_mcp_server_enabled",
"tool",
["mcp_server_id", "enabled"],
)
# Remove the server default so application controls defaulting
op.alter_column("tool", "enabled", server_default=None)
def downgrade() -> None:
op.execute(DELETE_DISABLED_TOOLS_SQL)
op.drop_index("ix_tool_mcp_server_enabled", table_name="tool")
op.drop_column("tool", "enabled")

View File

@@ -1,8 +1,13 @@
import json
from datetime import datetime
from enum import Enum
from functools import lru_cache
from typing import Any
from typing import cast
import jwt
import requests
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
@@ -10,6 +15,7 @@ from fastapi import status
from jwt import decode as jwt_decode
from jwt import InvalidTokenError
from jwt import PyJWTError
from jwt.algorithms import RSAAlgorithm
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
@@ -30,43 +36,156 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_PUBLIC_KEY_FETCH_ATTEMPTS = 2
class PublicKeyFormat(Enum):
JWKS = "jwks"
PEM = "pem"
@lru_cache()
def get_public_key() -> str | None:
def _fetch_public_key_payload() -> tuple[str | dict[str, Any], PublicKeyFormat] | None:
"""Fetch and cache the raw JWT verification material."""
if JWT_PUBLIC_KEY_URL is None:
logger.error("JWT_PUBLIC_KEY_URL is not set")
return None
response = requests.get(JWT_PUBLIC_KEY_URL)
response.raise_for_status()
return response.text
try:
response = requests.get(JWT_PUBLIC_KEY_URL)
response.raise_for_status()
except requests.RequestException as exc:
logger.error(f"Failed to fetch JWT public key: {str(exc)}")
return None
content_type = response.headers.get("Content-Type", "").lower()
raw_body = response.text
body_lstripped = raw_body.lstrip()
if "application/json" in content_type or body_lstripped.startswith("{"):
try:
data = response.json()
except ValueError:
logger.error("JWT public key URL returned invalid JSON")
return None
if isinstance(data, dict) and "keys" in data:
return data, PublicKeyFormat.JWKS
logger.error(
"JWT public key URL returned JSON but no JWKS 'keys' field was found"
)
return None
body = raw_body.strip()
if not body:
logger.error("JWT public key URL returned an empty response")
return None
return body, PublicKeyFormat.PEM
def get_public_key(token: str) -> RSAPublicKey | str | None:
"""Return the concrete public key used to verify the provided JWT token."""
payload = _fetch_public_key_payload()
if payload is None:
logger.error("Failed to retrieve public key payload")
return None
key_material, key_format = payload
if key_format is PublicKeyFormat.JWKS:
jwks_data = cast(dict[str, Any], key_material)
return _resolve_public_key_from_jwks(token, jwks_data)
return cast(str, key_material)
def _resolve_public_key_from_jwks(
token: str, jwks_payload: dict[str, Any]
) -> RSAPublicKey | None:
try:
header = jwt.get_unverified_header(token)
except PyJWTError as e:
logger.error(f"Unable to parse JWT header: {str(e)}")
return None
keys = jwks_payload.get("keys", []) if isinstance(jwks_payload, dict) else []
if not keys:
logger.error("JWKS payload did not contain any keys")
return None
kid = header.get("kid")
thumbprint = header.get("x5t")
candidates = []
if kid:
candidates = [k for k in keys if k.get("kid") == kid]
if not candidates and thumbprint:
candidates = [k for k in keys if k.get("x5t") == thumbprint]
if not candidates and len(keys) == 1:
candidates = keys
if not candidates:
logger.warning(
"No matching JWK found for token header (kid=%s, x5t=%s)", kid, thumbprint
)
return None
if len(candidates) > 1:
logger.warning(
"Multiple JWKs matched token header kid=%s; selecting the first occurrence",
kid,
)
jwk = candidates[0]
try:
return cast(RSAPublicKey, RSAAlgorithm.from_jwk(json.dumps(jwk)))
except ValueError as e:
logger.error(f"Failed to construct RSA key from JWK: {str(e)}")
return None
async def verify_jwt_token(token: str, async_db_session: AsyncSession) -> User | None:
try:
public_key_pem = get_public_key()
if public_key_pem is None:
logger.error("Failed to retrieve public key")
for attempt in range(_PUBLIC_KEY_FETCH_ATTEMPTS):
public_key = get_public_key(token)
if public_key is None:
logger.error("Unable to resolve a public key for JWT verification")
if attempt < _PUBLIC_KEY_FETCH_ATTEMPTS - 1:
_fetch_public_key_payload.cache_clear()
continue
return None
try:
payload = jwt_decode(
token,
public_key,
algorithms=["RS256"],
options={"verify_aud": False},
)
except InvalidTokenError as e:
logger.error(f"Invalid JWT token: {str(e)}")
if attempt < _PUBLIC_KEY_FETCH_ATTEMPTS - 1:
_fetch_public_key_payload.cache_clear()
continue
return None
except PyJWTError as e:
logger.error(f"JWT decoding error: {str(e)}")
if attempt < _PUBLIC_KEY_FETCH_ATTEMPTS - 1:
_fetch_public_key_payload.cache_clear()
continue
return None
payload = jwt_decode(
token,
public_key_pem,
algorithms=["RS256"],
audience=None,
)
email = payload.get("email")
if email:
result = await async_db_session.execute(
select(User).where(func.lower(User.email) == func.lower(email))
)
return result.scalars().first()
except InvalidTokenError:
logger.error("Invalid JWT token")
get_public_key.cache_clear()
except PyJWTError as e:
logger.error(f"JWT decoding error: {str(e)}")
get_public_key.cache_clear()
logger.warning(
"JWT token decoded successfully but no email claim found; skipping auth"
)
break
return None

View File

@@ -0,0 +1,12 @@
from onyx.background.celery.apps.background import celery_app
celery_app.autodiscover_tasks(
[
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cleanup",
"ee.onyx.background.celery.tasks.tenant_provisioning",
"ee.onyx.background.celery.tasks.query_history",
]
)

View File

@@ -1,123 +1,4 @@
import csv
import io
from datetime import datetime
from celery import shared_task
from celery import Task
from ee.onyx.server.query_history.api import fetch_and_process_chat_session_history
from ee.onyx.server.query_history.api import ONYX_ANONYMIZED_EMAIL
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
from onyx.background.celery.apps.heavy import celery_app
from onyx.background.task_utils import construct_query_history_report_name
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import FileType
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import QueryHistoryType
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.tasks import delete_task_with_id
from onyx.db.tasks import mark_task_as_finished_with_id
from onyx.db.tasks import mark_task_as_started_with_id
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
logger = setup_logger()
@shared_task(
name=OnyxCeleryTask.EXPORT_QUERY_HISTORY_TASK,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
bind=True,
trail=False,
)
def export_query_history_task(
self: Task,
*,
start: datetime,
end: datetime,
start_time: datetime,
# Need to include the tenant_id since the TenantAwareTask needs this
tenant_id: str,
) -> None:
if not self.request.id:
raise RuntimeError("No task id defined for this task; cannot identify it")
task_id = self.request.id
stream = io.StringIO()
writer = csv.DictWriter(
stream,
fieldnames=list(QuestionAnswerPairSnapshot.model_fields.keys()),
)
writer.writeheader()
with get_session_with_current_tenant() as db_session:
try:
mark_task_as_started_with_id(
db_session=db_session,
task_id=task_id,
)
snapshot_generator = fetch_and_process_chat_session_history(
db_session=db_session,
start=start,
end=end,
)
for snapshot in snapshot_generator:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
snapshot.user_email = ONYX_ANONYMIZED_EMAIL
writer.writerows(
qa_pair.to_json()
for qa_pair in QuestionAnswerPairSnapshot.from_chat_session_snapshot(
snapshot
)
)
except Exception:
logger.exception(f"Failed to export query history with {task_id=}")
mark_task_as_finished_with_id(
db_session=db_session,
task_id=task_id,
success=False,
)
raise
report_name = construct_query_history_report_name(task_id)
with get_session_with_current_tenant() as db_session:
try:
stream.seek(0)
get_default_file_store().save_file(
content=stream,
display_name=report_name,
file_origin=FileOrigin.QUERY_HISTORY_CSV,
file_type=FileType.CSV,
file_metadata={
"start": start.isoformat(),
"end": end.isoformat(),
"start_time": start_time.isoformat(),
},
file_id=report_name,
)
delete_task_with_id(
db_session=db_session,
task_id=task_id,
)
except Exception:
logger.exception(
f"Failed to save query history export file; {report_name=}"
)
mark_task_as_finished_with_id(
db_session=db_session,
task_id=task_id,
success=False,
)
raise
celery_app.autodiscover_tasks(
@@ -125,5 +6,6 @@ celery_app.autodiscover_tasks(
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cleanup",
"ee.onyx.background.celery.tasks.query_history",
]
)

View File

@@ -0,0 +1,119 @@
import csv
import io
from datetime import datetime
from celery import shared_task
from celery import Task
from ee.onyx.server.query_history.api import fetch_and_process_chat_session_history
from ee.onyx.server.query_history.api import ONYX_ANONYMIZED_EMAIL
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
from onyx.background.task_utils import construct_query_history_report_name
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import FileType
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import QueryHistoryType
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.tasks import delete_task_with_id
from onyx.db.tasks import mark_task_as_finished_with_id
from onyx.db.tasks import mark_task_as_started_with_id
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
logger = setup_logger()
@shared_task(
name=OnyxCeleryTask.EXPORT_QUERY_HISTORY_TASK,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
bind=True,
trail=False,
)
def export_query_history_task(
self: Task,
*,
start: datetime,
end: datetime,
start_time: datetime,
# Need to include the tenant_id since the TenantAwareTask needs this
tenant_id: str,
) -> None:
if not self.request.id:
raise RuntimeError("No task id defined for this task; cannot identify it")
task_id = self.request.id
stream = io.StringIO()
writer = csv.DictWriter(
stream,
fieldnames=list(QuestionAnswerPairSnapshot.model_fields.keys()),
)
writer.writeheader()
with get_session_with_current_tenant() as db_session:
try:
mark_task_as_started_with_id(
db_session=db_session,
task_id=task_id,
)
snapshot_generator = fetch_and_process_chat_session_history(
db_session=db_session,
start=start,
end=end,
)
for snapshot in snapshot_generator:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
snapshot.user_email = ONYX_ANONYMIZED_EMAIL
writer.writerows(
qa_pair.to_json()
for qa_pair in QuestionAnswerPairSnapshot.from_chat_session_snapshot(
snapshot
)
)
except Exception:
logger.exception(f"Failed to export query history with {task_id=}")
mark_task_as_finished_with_id(
db_session=db_session,
task_id=task_id,
success=False,
)
raise
report_name = construct_query_history_report_name(task_id)
with get_session_with_current_tenant() as db_session:
try:
stream.seek(0)
get_default_file_store().save_file(
content=stream,
display_name=report_name,
file_origin=FileOrigin.QUERY_HISTORY_CSV,
file_type=FileType.CSV,
file_metadata={
"start": start.isoformat(),
"end": end.isoformat(),
"start_time": start_time.isoformat(),
},
file_id=report_name,
)
delete_task_with_id(
db_session=db_session,
task_id=task_id,
)
except Exception:
logger.exception(
f"Failed to save query history export file; {report_name=}"
)
mark_task_as_finished_with_id(
db_session=db_session,
task_id=task_id,
success=False,
)
raise

View File

@@ -124,9 +124,9 @@ def get_space_permission(
and not space_permissions.external_user_group_ids
):
logger.warning(
f"No permissions found for space '{space_key}'. This is very unlikely"
"to be correct and is more likely caused by an access token with"
"insufficient permissions. Make sure that the access token has Admin"
f"No permissions found for space '{space_key}'. This is very unlikely "
"to be correct and is more likely caused by an access token with "
"insufficient permissions. Make sure that the access token has Admin "
f"permissions for space '{space_key}'"
)

View File

@@ -26,7 +26,7 @@ def _get_slim_doc_generator(
else 0.0
)
return gmail_connector.retrieve_all_slim_documents(
return gmail_connector.retrieve_all_slim_docs_perm_sync(
start=start_time,
end=current_time.timestamp(),
callback=callback,

View File

@@ -34,7 +34,7 @@ def _get_slim_doc_generator(
else 0.0
)
return google_drive_connector.retrieve_all_slim_documents(
return google_drive_connector.retrieve_all_slim_docs_perm_sync(
start=start_time,
end=current_time.timestamp(),
callback=callback,

View File

@@ -59,7 +59,7 @@ def _build_holder_map(permissions: list[dict]) -> dict[str, list[Holder]]:
for raw_perm in permissions:
if not hasattr(raw_perm, "raw"):
logger.warn(f"Expected a 'raw' field, but none was found: {raw_perm=}")
logger.warning(f"Expected a 'raw' field, but none was found: {raw_perm=}")
continue
permission = Permission(**raw_perm.raw)
@@ -71,14 +71,14 @@ def _build_holder_map(permissions: list[dict]) -> dict[str, list[Holder]]:
# In order to associate this permission to some Atlassian entity, we need the "Holder".
# If this doesn't exist, then we cannot associate this permission to anyone; just skip.
if not permission.holder:
logger.warn(
logger.warning(
f"Expected to find a permission holder, but none was found: {permission=}"
)
continue
type = permission.holder.get("type")
if not type:
logger.warn(
logger.warning(
f"Expected to find the type of permission holder, but none was found: {permission=}"
)
continue

View File

@@ -105,7 +105,9 @@ def _get_slack_document_access(
channel_permissions: dict[str, ExternalAccess],
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
slim_doc_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
callback=callback
)
for doc_metadata_batch in slim_doc_generator:
for doc_metadata in doc_metadata_batch:

View File

@@ -4,7 +4,7 @@ from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFun
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -17,7 +17,7 @@ def generic_doc_sync(
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None,
doc_source: DocumentSource,
slim_connector: SlimConnector,
slim_connector: SlimConnectorWithPermSync,
label: str,
) -> Generator[DocExternalAccess, None, None]:
"""
@@ -40,7 +40,7 @@ def generic_doc_sync(
newly_fetched_doc_ids: set[str] = set()
logger.info(f"Fetching all slim documents from {doc_source}")
for doc_batch in slim_connector.retrieve_all_slim_documents(callback=callback):
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(callback=callback):
logger.info(f"Got {len(doc_batch)} slim documents from {doc_source}")
if callback:

View File

@@ -0,0 +1,15 @@
from ee.onyx.feature_flags.posthog_provider import PostHogFeatureFlagProvider
from onyx.feature_flags.interface import FeatureFlagProvider
def get_posthog_feature_flag_provider() -> FeatureFlagProvider:
"""
Get the PostHog feature flag provider instance.
This is the EE implementation that gets loaded by the versioned
implementation loader.
Returns:
PostHogFeatureFlagProvider: The PostHog-based feature flag provider
"""
return PostHogFeatureFlagProvider()

View File

@@ -0,0 +1,54 @@
from typing import Any
from uuid import UUID
from ee.onyx.utils.posthog_client import posthog
from onyx.feature_flags.interface import FeatureFlagProvider
from onyx.utils.logger import setup_logger
logger = setup_logger()
class PostHogFeatureFlagProvider(FeatureFlagProvider):
"""
PostHog-based feature flag provider.
Uses PostHog's feature flag API to determine if features are enabled
for specific users. Only active in multi-tenant mode.
"""
def feature_enabled(
self,
flag_key: str,
user_id: UUID,
user_properties: dict[str, Any] | None = None,
) -> bool:
"""
Check if a feature flag is enabled for a user via PostHog.
Args:
flag_key: The identifier for the feature flag to check
user_id: The unique identifier for the user
user_properties: Optional dictionary of user properties/attributes
that may influence flag evaluation
Returns:
True if the feature is enabled for the user, False otherwise.
"""
try:
posthog.set(
distinct_id=user_id,
properties=user_properties,
)
is_enabled = posthog.feature_enabled(
flag_key,
str(user_id),
person_properties=user_properties,
)
return bool(is_enabled) if is_enabled is not None else False
except Exception as e:
logger.error(
f"Error checking feature flag {flag_key} for user {user_id}: {e}"
)
return False

View File

@@ -1,45 +0,0 @@
import json
import os
from typing import cast
from typing import List
from cohere import Client
from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
Embedding = List[float]
def load_processed_docs(cohere_enabled: bool) -> list[dict]:
base_path = os.path.join(os.getcwd(), "onyx", "seeding")
if cohere_enabled and COHERE_DEFAULT_API_KEY:
initial_docs_path = os.path.join(base_path, "initial_docs_cohere.json")
processed_docs = json.load(open(initial_docs_path))
cohere_client = Client(api_key=COHERE_DEFAULT_API_KEY)
embed_model = "embed-english-v3.0"
for doc in processed_docs:
title_embed_response = cohere_client.embed(
texts=[doc["title"]],
model=embed_model,
input_type="search_document",
)
content_embed_response = cohere_client.embed(
texts=[doc["content"]],
model=embed_model,
input_type="search_document",
)
doc["title_embedding"] = cast(
List[Embedding], title_embed_response.embeddings
)[0]
doc["content_embedding"] = cast(
List[Embedding], content_embed_response.embeddings
)[0]
else:
initial_docs_path = os.path.join(base_path, "initial_docs.json")
processed_docs = json.load(open(initial_docs_path))
return processed_docs

View File

@@ -37,6 +37,19 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/auth/saml")
# Azure AD / Entra ID often returns the email attribute under different keys.
# Keep a list of common variations so we can fall back gracefully if the IdP
# does not send the plain "email" attribute name.
EMAIL_ATTRIBUTE_KEYS = {
"email",
"emailaddress",
"mail",
"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress",
"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/mail",
"http://schemas.microsoft.com/identity/claims/emailaddress",
}
EMAIL_ATTRIBUTE_KEYS_LOWER = {key.lower() for key in EMAIL_ATTRIBUTE_KEYS}
async def upsert_saml_user(email: str) -> User:
"""
@@ -204,16 +217,37 @@ async def _process_saml_callback(
detail=detail,
)
user_email = auth.get_attribute("email")
if not user_email:
detail = "SAML is not set up correctly, email attribute must be provided."
logger.error(detail)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=detail,
)
user_email: str | None = None
user_email = user_email[0]
# The OneLogin toolkit normalizes attribute keys, but still performs a
# case-sensitive lookup. Try the common keys first and then fall back to a
# case-insensitive scan of all returned attributes.
for attribute_key in EMAIL_ATTRIBUTE_KEYS:
attribute_values = auth.get_attribute(attribute_key)
if attribute_values:
user_email = attribute_values[0]
break
if not user_email:
# Fallback: perform a case-insensitive lookup across all attributes in
# case the IdP sent the email claim with a different capitalization.
attributes = auth.get_attributes()
for key, values in attributes.items():
if key.lower() in EMAIL_ATTRIBUTE_KEYS_LOWER:
if values:
user_email = values[0]
break
if not user_email:
detail = "SAML is not set up correctly, email attribute must be provided."
logger.error(detail)
logger.debug(
"Received SAML attributes without email: %s",
list(attributes.keys()),
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=detail,
)
user = await upsert_saml_user(email=user_email)

View File

@@ -37,9 +37,9 @@ from onyx.db.models import AvailableTenant
from onyx.db.models import IndexModelStatus
from onyx.db.models import SearchSettings
from onyx.db.models import UserTenantMapping
from onyx.llm.llm_provider_options import ANTHROPIC_MODEL_NAMES
from onyx.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
from onyx.llm.llm_provider_options import ANTHROPIC_VISIBLE_MODEL_NAMES
from onyx.llm.llm_provider_options import get_anthropic_model_names
from onyx.llm.llm_provider_options import OPEN_AI_MODEL_NAMES
from onyx.llm.llm_provider_options import OPEN_AI_VISIBLE_MODEL_NAMES
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
@@ -278,7 +278,7 @@ def configure_default_api_keys(db_session: Session) -> None:
is_visible=name in ANTHROPIC_VISIBLE_MODEL_NAMES,
max_input_tokens=None,
)
for name in ANTHROPIC_MODEL_NAMES
for name in get_anthropic_model_names()
],
api_key_changed=True,
)

View File

@@ -0,0 +1,22 @@
from typing import Any
from posthog import Posthog
from ee.onyx.configs.app_configs import POSTHOG_API_KEY
from ee.onyx.configs.app_configs import POSTHOG_HOST
from onyx.utils.logger import setup_logger
logger = setup_logger()
def posthog_on_error(error: Any, items: Any) -> None:
"""Log any PostHog delivery errors."""
logger.error(f"PostHog error: {error}, items: {items}")
posthog = Posthog(
project_api_key=POSTHOG_API_KEY,
host=POSTHOG_HOST,
debug=True,
on_error=posthog_on_error,
)

View File

@@ -1,27 +1,9 @@
from typing import Any
from posthog import Posthog
from ee.onyx.configs.app_configs import POSTHOG_API_KEY
from ee.onyx.configs.app_configs import POSTHOG_HOST
from ee.onyx.utils.posthog_client import posthog
from onyx.utils.logger import setup_logger
logger = setup_logger()
def posthog_on_error(error: Any, items: Any) -> None:
"""Log any PostHog delivery errors."""
logger.error(f"PostHog error: {error}, items: {items}")
posthog = Posthog(
project_api_key=POSTHOG_API_KEY,
host=POSTHOG_HOST,
debug=True,
on_error=posthog_on_error,
)
def event_telemetry(
distinct_id: str, event: str, properties: dict | None = None
) -> None:

View File

@@ -1,14 +1,12 @@
from typing import cast
from typing import Optional
from typing import TYPE_CHECKING
import numpy as np
import torch
import torch.nn.functional as F
from fastapi import APIRouter
from huggingface_hub import snapshot_download # type: ignore
from setfit import SetFitModel # type: ignore[import]
from transformers import AutoTokenizer # type: ignore
from transformers import BatchEncoding # type: ignore
from transformers import PreTrainedTokenizer # type: ignore
from model_server.constants import INFORMATION_CONTENT_MODEL_WARM_UP_STRING
from model_server.constants import MODEL_WARM_UP_STRING
@@ -37,23 +35,30 @@ from shared_configs.model_server_models import ContentClassificationPrediction
from shared_configs.model_server_models import IntentRequest
from shared_configs.model_server_models import IntentResponse
if TYPE_CHECKING:
from setfit import SetFitModel # type: ignore
from transformers import PreTrainedTokenizer, BatchEncoding # type: ignore
logger = setup_logger()
router = APIRouter(prefix="/custom")
_CONNECTOR_CLASSIFIER_TOKENIZER: PreTrainedTokenizer | None = None
_CONNECTOR_CLASSIFIER_TOKENIZER: Optional["PreTrainedTokenizer"] = None
_CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
_INTENT_TOKENIZER: PreTrainedTokenizer | None = None
_INTENT_TOKENIZER: Optional["PreTrainedTokenizer"] = None
_INTENT_MODEL: HybridClassifier | None = None
_INFORMATION_CONTENT_MODEL: SetFitModel | None = None
_INFORMATION_CONTENT_MODEL: Optional["SetFitModel"] = None
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
def get_connector_classifier_tokenizer() -> PreTrainedTokenizer:
def get_connector_classifier_tokenizer() -> "PreTrainedTokenizer":
global _CONNECTOR_CLASSIFIER_TOKENIZER
from transformers import AutoTokenizer, PreTrainedTokenizer
if _CONNECTOR_CLASSIFIER_TOKENIZER is None:
# The tokenizer details are not uploaded to the HF hub since it's just the
# unmodified distilbert tokenizer.
@@ -95,7 +100,9 @@ def get_local_connector_classifier(
return _CONNECTOR_CLASSIFIER_MODEL
def get_intent_model_tokenizer() -> PreTrainedTokenizer:
def get_intent_model_tokenizer() -> "PreTrainedTokenizer":
from transformers import AutoTokenizer, PreTrainedTokenizer
global _INTENT_TOKENIZER
if _INTENT_TOKENIZER is None:
# The tokenizer details are not uploaded to the HF hub since it's just the
@@ -141,7 +148,9 @@ def get_local_intent_model(
def get_local_information_content_model(
model_name_or_path: str = INFORMATION_CONTENT_MODEL_VERSION,
tag: str | None = INFORMATION_CONTENT_MODEL_TAG,
) -> SetFitModel:
) -> "SetFitModel":
from setfit import SetFitModel
global _INFORMATION_CONTENT_MODEL
if _INFORMATION_CONTENT_MODEL is None:
try:
@@ -179,7 +188,7 @@ def get_local_information_content_model(
def tokenize_connector_classification_query(
connectors: list[str],
query: str,
tokenizer: PreTrainedTokenizer,
tokenizer: "PreTrainedTokenizer",
connector_token_end_id: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
@@ -267,7 +276,7 @@ def warm_up_information_content_model() -> None:
@simple_log_function_time()
def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
def run_inference(tokens: "BatchEncoding") -> tuple[list[float], list[float]]:
intent_model = get_local_intent_model()
device = intent_model.device
@@ -401,7 +410,7 @@ def run_content_classification_inference(
def map_keywords(
input_ids: torch.Tensor, tokenizer: PreTrainedTokenizer, is_keyword: list[bool]
input_ids: torch.Tensor, tokenizer: "PreTrainedTokenizer", is_keyword: list[bool]
) -> list[str]:
tokens = tokenizer.convert_ids_to_tokens(input_ids) # type: ignore

View File

@@ -2,13 +2,11 @@ import asyncio
import time
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
from fastapi import APIRouter
from fastapi import HTTPException
from fastapi import Request
from litellm.exceptions import RateLimitError
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
from model_server.utils import simple_log_function_time
from onyx.utils.logger import setup_logger
@@ -20,6 +18,9 @@ from shared_configs.model_server_models import EmbedResponse
from shared_configs.model_server_models import RerankRequest
from shared_configs.model_server_models import RerankResponse
if TYPE_CHECKING:
from sentence_transformers import CrossEncoder, SentenceTransformer
logger = setup_logger()
router = APIRouter(prefix="/encoder")
@@ -88,8 +89,10 @@ def get_embedding_model(
def get_local_reranking_model(
model_name: str,
) -> CrossEncoder:
) -> "CrossEncoder":
global _RERANK_MODEL
from sentence_transformers import CrossEncoder # type: ignore
if _RERANK_MODEL is None:
logger.notice(f"Loading {model_name}")
model = CrossEncoder(model_name)
@@ -207,6 +210,8 @@ async def route_bi_encoder_embed(
async def process_embed_request(
embed_request: EmbedRequest, gpu_type: str = "UNKNOWN"
) -> EmbedResponse:
from litellm.exceptions import RateLimitError
# Only local models should use this endpoint - API providers should make direct API calls
if embed_request.provider_type is not None:
raise ValueError(

View File

@@ -30,6 +30,7 @@ from shared_configs.configs import MIN_THREADS_ML_MODELS
from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import SENTRY_DSN
from shared_configs.configs import SKIP_WARM_UP
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
@@ -91,16 +92,17 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
logger.notice(f"Torch Threads: {torch.get_num_threads()}")
if not INDEXING_ONLY:
logger.notice(
"The intent model should run on the model server. The information content model should not run here."
)
warm_up_intent_model()
if not SKIP_WARM_UP:
if not INDEXING_ONLY:
logger.notice("Warming up intent model for inference model server")
warm_up_intent_model()
else:
logger.notice(
"Warming up content information model for indexing model server"
)
warm_up_information_content_model()
else:
logger.notice(
"The content information model should run on the indexing model server. The intent model should not run here."
)
warm_up_information_content_model()
logger.notice("Skipping model warmup due to SKIP_WARM_UP=true")
yield

View File

@@ -1,16 +1,20 @@
import json
import os
from typing import cast
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
from transformers import DistilBertConfig # type: ignore
from transformers import DistilBertModel # type: ignore
from transformers import DistilBertTokenizer # type: ignore
if TYPE_CHECKING:
from transformers import DistilBertConfig # type: ignore
class HybridClassifier(nn.Module):
def __init__(self) -> None:
from transformers import DistilBertConfig, DistilBertModel
super().__init__()
config = DistilBertConfig()
self.distilbert = DistilBertModel(config)
@@ -74,7 +78,9 @@ class HybridClassifier(nn.Module):
class ConnectorClassifier(nn.Module):
def __init__(self, config: DistilBertConfig) -> None:
def __init__(self, config: "DistilBertConfig") -> None:
from transformers import DistilBertTokenizer, DistilBertModel
super().__init__()
self.config = config
@@ -115,6 +121,8 @@ class ConnectorClassifier(nn.Module):
@classmethod
def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier":
from transformers import DistilBertConfig
config = cast(
DistilBertConfig,
DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json")),

View File

@@ -3,6 +3,7 @@ from datetime import datetime
from typing import Any
from typing import cast
from braintrust import traced
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
@@ -22,6 +23,9 @@ from onyx.agents.agent_search.dr.models import DecisionResponse
from onyx.agents.agent_search.dr.models import DRPromptPurpose
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
from onyx.agents.agent_search.dr.models import OrchestratorTool
from onyx.agents.agent_search.dr.process_llm_stream import (
BasicSearchProcessedStreamResults,
)
from onyx.agents.agent_search.dr.process_llm_stream import process_llm_stream
from onyx.agents.agent_search.dr.states import MainState
from onyx.agents.agent_search.dr.states import OrchestrationSetup
@@ -70,6 +74,7 @@ from onyx.prompts.dr_prompts import DEFAULT_DR_SYSTEM_PROMPT
from onyx.prompts.dr_prompts import REPEAT_PROMPT
from onyx.prompts.dr_prompts import TOOL_DESCRIPTION
from onyx.prompts.prompt_template import PromptTemplate
from onyx.prompts.prompt_utils import handle_company_awareness
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import SectionEnd
@@ -116,7 +121,9 @@ def _get_available_tools(
else:
include_kg = False
tool_dict: dict[int, Tool] = {tool.id: tool for tool in get_tools(db_session)}
tool_dict: dict[int, Tool] = {
tool.id: tool for tool in get_tools(db_session, only_enabled=True)
}
for tool in graph_config.tooling.tools:
@@ -484,6 +491,7 @@ def clarifier(
+ PROJECT_INSTRUCTIONS_SEPARATOR
+ graph_config.inputs.project_instructions
)
assistant_system_prompt = handle_company_awareness(assistant_system_prompt)
chat_history_string = (
get_chat_history_string(
@@ -666,28 +674,30 @@ def clarifier(
system_prompt_to_use = assistant_system_prompt
user_prompt_to_use = decision_prompt + assistant_task_prompt
stream = graph_config.tooling.primary_llm.stream(
prompt=create_question_prompt(
cast(str, system_prompt_to_use),
cast(str, user_prompt_to_use),
uploaded_image_context=uploaded_image_context,
),
tools=([_ARTIFICIAL_ALL_ENCOMPASSING_TOOL]),
tool_choice=(None),
structured_response_format=graph_config.inputs.structured_response_format,
)
full_response = process_llm_stream(
messages=stream,
should_stream_answer=True,
writer=writer,
ind=0,
final_search_results=context_llm_docs,
displayed_search_results=context_llm_docs,
generate_final_answer=True,
chat_message_id=str(graph_config.persistence.chat_session_id),
)
@traced(name="clarifier stream and process", type="llm")
def stream_and_process() -> BasicSearchProcessedStreamResults:
stream = graph_config.tooling.primary_llm.stream(
prompt=create_question_prompt(
cast(str, system_prompt_to_use),
cast(str, user_prompt_to_use),
uploaded_image_context=uploaded_image_context,
),
tools=([_ARTIFICIAL_ALL_ENCOMPASSING_TOOL]),
tool_choice=(None),
structured_response_format=graph_config.inputs.structured_response_format,
)
return process_llm_stream(
messages=stream,
should_stream_answer=True,
writer=writer,
ind=0,
final_search_results=context_llm_docs,
displayed_search_results=context_llm_docs,
generate_final_answer=True,
chat_message_id=str(graph_config.persistence.chat_session_id),
)
full_response = stream_and_process()
if len(full_response.ai_message_chunk.tool_calls) == 0:
if isinstance(full_response.full_answer, str):

View File

@@ -1,3 +1,4 @@
import json
from datetime import datetime
from typing import cast
@@ -28,6 +29,7 @@ from onyx.tools.tool_implementations.images.image_generation_tool import (
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations.images.image_generation_tool import ImageShape
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -62,6 +64,29 @@ def image_generation(
image_tool_info = state.available_tools[state.tools_used[-1]]
image_tool = cast(ImageGenerationTool, image_tool_info.tool_object)
image_prompt = branch_query
requested_shape: ImageShape | None = None
try:
parsed_query = json.loads(branch_query)
except json.JSONDecodeError:
parsed_query = None
if isinstance(parsed_query, dict):
prompt_from_llm = parsed_query.get("prompt")
if isinstance(prompt_from_llm, str) and prompt_from_llm.strip():
image_prompt = prompt_from_llm.strip()
raw_shape = parsed_query.get("shape")
if isinstance(raw_shape, str):
try:
requested_shape = ImageShape(raw_shape)
except ValueError:
logger.warning(
"Received unsupported image shape '%s' from LLM. Falling back to square.",
raw_shape,
)
logger.debug(
f"Image generation start for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
)
@@ -69,7 +94,15 @@ def image_generation(
# Generate images using the image generation tool
image_generation_responses: list[ImageGenerationResponse] = []
for tool_response in image_tool.run(prompt=branch_query):
if requested_shape is not None:
tool_iterator = image_tool.run(
prompt=image_prompt,
shape=requested_shape.value,
)
else:
tool_iterator = image_tool.run(prompt=image_prompt)
for tool_response in tool_iterator:
if tool_response.id == IMAGE_GENERATION_HEARTBEAT_ID:
# Stream heartbeat to frontend
write_custom_event(
@@ -95,6 +128,7 @@ def image_generation(
file_id=file_id,
url=build_frontend_file_url(file_id),
revised_prompt=img.revised_prompt,
shape=(requested_shape or ImageShape.SQUARE).value,
)
for file_id, img in zip(file_ids, image_generation_responses)
]
@@ -107,15 +141,29 @@ def image_generation(
if final_generated_images:
image_descriptions = []
for i, img in enumerate(final_generated_images, 1):
image_descriptions.append(f"Image {i}: {img.revised_prompt}")
if img.shape and img.shape != ImageShape.SQUARE.value:
image_descriptions.append(
f"Image {i}: {img.revised_prompt} (shape: {img.shape})"
)
else:
image_descriptions.append(f"Image {i}: {img.revised_prompt}")
answer_string = (
f"Generated {len(final_generated_images)} image(s) based on the request: {branch_query}\n\n"
f"Generated {len(final_generated_images)} image(s) based on the request: {image_prompt}\n\n"
+ "\n".join(image_descriptions)
)
reasoning = f"Used image generation tool to create {len(final_generated_images)} image(s) based on the user's request."
if requested_shape:
reasoning = (
"Used image generation tool to create "
f"{len(final_generated_images)} image(s) in {requested_shape.value} orientation."
)
else:
reasoning = (
"Used image generation tool to create "
f"{len(final_generated_images)} image(s) based on the user's request."
)
else:
answer_string = f"Failed to generate images for request: {branch_query}"
answer_string = f"Failed to generate images for request: {image_prompt}"
reasoning = "Image generation tool did not return any results."
return BranchUpdate(

View File

@@ -5,6 +5,7 @@ class GeneratedImage(BaseModel):
file_id: str
url: str
revised_prompt: str
shape: str | None = None
# Needed for PydanticType

View File

@@ -2,30 +2,28 @@ from exa_py import Exa
from exa_py.api import HighlightsContentsOptions
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetContent,
WebContent,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetSearchProvider,
WebSearchProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetSearchResult,
WebSearchResult,
)
from onyx.configs.chat_configs import EXA_API_KEY
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.utils.retry_wrapper import retry_builder
# TODO Dependency inject for testing
class ExaClient(InternetSearchProvider):
class ExaClient(WebSearchProvider):
def __init__(self, api_key: str | None = EXA_API_KEY) -> None:
self.exa = Exa(api_key=api_key)
@retry_builder(tries=3, delay=1, backoff=2)
def search(self, query: str) -> list[InternetSearchResult]:
def search(self, query: str) -> list[WebSearchResult]:
response = self.exa.search_and_contents(
query,
type="fast",
livecrawl="never",
type="auto",
highlights=HighlightsContentsOptions(
num_sentences=2,
highlights_per_url=1,
@@ -34,7 +32,7 @@ class ExaClient(InternetSearchProvider):
)
return [
InternetSearchResult(
WebSearchResult(
title=result.title or "",
link=result.url,
snippet=result.highlights[0] if result.highlights else "",
@@ -49,7 +47,7 @@ class ExaClient(InternetSearchProvider):
]
@retry_builder(tries=3, delay=1, backoff=2)
def contents(self, urls: list[str]) -> list[InternetContent]:
def contents(self, urls: list[str]) -> list[WebContent]:
response = self.exa.get_contents(
urls=urls,
text=True,
@@ -57,7 +55,7 @@ class ExaClient(InternetSearchProvider):
)
return [
InternetContent(
WebContent(
title=result.title or "",
link=result.url,
full_content=result.text or "",

View File

@@ -0,0 +1,147 @@
import json
from concurrent.futures import ThreadPoolExecutor
import requests
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebContent,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
WebSearchResult,
)
from onyx.configs.chat_configs import SERPER_API_KEY
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.utils.retry_wrapper import retry_builder
SERPER_SEARCH_URL = "https://google.serper.dev/search"
SERPER_CONTENTS_URL = "https://scrape.serper.dev"
class SerperClient(WebSearchProvider):
def __init__(self, api_key: str | None = SERPER_API_KEY) -> None:
self.headers = {
"X-API-KEY": api_key,
"Content-Type": "application/json",
}
@retry_builder(tries=3, delay=1, backoff=2)
def search(self, query: str) -> list[WebSearchResult]:
payload = {
"q": query,
}
response = requests.post(
SERPER_SEARCH_URL,
headers=self.headers,
data=json.dumps(payload),
)
response.raise_for_status()
results = response.json()
organic_results = results["organic"]
return [
WebSearchResult(
title=result["title"],
link=result["link"],
snippet=result["snippet"],
author=None,
published_date=None,
)
for result in organic_results
]
def contents(self, urls: list[str]) -> list[WebContent]:
if not urls:
return []
# Serper can responds with 500s regularly. We want to retry,
# but in the event of failure, return an unsuccesful scrape.
def safe_get_webpage_content(url: str) -> WebContent:
try:
return self._get_webpage_content(url)
except Exception:
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
with ThreadPoolExecutor(max_workers=min(8, len(urls))) as e:
return list(e.map(safe_get_webpage_content, urls))
@retry_builder(tries=3, delay=1, backoff=2)
def _get_webpage_content(self, url: str) -> WebContent:
payload = {
"url": url,
}
response = requests.post(
SERPER_CONTENTS_URL,
headers=self.headers,
data=json.dumps(payload),
)
# 400 returned when serper cannot scrape
if response.status_code == 400:
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
response.raise_for_status()
response_json = response.json()
# Response only guarantees text
text = response_json["text"]
# metadata & jsonld is not guaranteed to be present
metadata = response_json.get("metadata", {})
jsonld = response_json.get("jsonld", {})
title = extract_title_from_metadata(metadata)
# Serper does not provide a reliable mechanism to extract the url
response_url = url
published_date_str = extract_published_date_from_jsonld(jsonld)
published_date = None
if published_date_str:
try:
published_date = time_str_to_utc(published_date_str)
except Exception:
published_date = None
return WebContent(
title=title or "",
link=response_url,
full_content=text or "",
published_date=published_date,
)
def extract_title_from_metadata(metadata: dict[str, str]) -> str | None:
keys = ["title", "og:title"]
return extract_value_from_dict(metadata, keys)
def extract_published_date_from_jsonld(jsonld: dict[str, str]) -> str | None:
keys = ["dateModified"]
return extract_value_from_dict(jsonld, keys)
def extract_value_from_dict(data: dict[str, str], keys: list[str]) -> str | None:
for key in keys:
if key in data:
return data[key]
return None

View File

@@ -7,7 +7,7 @@ from langsmith import traceable
from onyx.agents.agent_search.dr.models import WebSearchAnswer
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetSearchResult,
WebSearchResult,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
get_default_provider,
@@ -75,15 +75,15 @@ def web_search(
raise ValueError("No internet search provider found")
@traceable(name="Search Provider API Call")
def _search(search_query: str) -> list[InternetSearchResult]:
search_results: list[InternetSearchResult] = []
def _search(search_query: str) -> list[WebSearchResult]:
search_results: list[WebSearchResult] = []
try:
search_results = provider.search(search_query)
except Exception as e:
logger.error(f"Error performing search: {e}")
return search_results
search_results: list[InternetSearchResult] = _search(search_query)
search_results: list[WebSearchResult] = _search(search_query)
search_results_text = "\n\n".join(
[
f"{i}. {result.title}\n URL: {result.link}\n"

View File

@@ -4,7 +4,7 @@ from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetSearchResult,
WebSearchResult,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
InternetSearchInput,
@@ -23,7 +23,7 @@ def dedup_urls(
writer: StreamWriter = lambda _: None,
) -> InternetSearchInput:
branch_questions_to_urls: dict[str, list[str]] = defaultdict(list)
unique_results_by_link: dict[str, InternetSearchResult] = {}
unique_results_by_link: dict[str, WebSearchResult] = {}
for query, result in state.results_to_open:
branch_questions_to_urls[query].append(result.link)
if result.link not in unique_results_by_link:

View File

@@ -13,7 +13,7 @@ class ProviderType(Enum):
EXA = "exa"
class InternetSearchResult(BaseModel):
class WebSearchResult(BaseModel):
title: str
link: str
author: str | None = None
@@ -21,18 +21,19 @@ class InternetSearchResult(BaseModel):
snippet: str | None = None
class InternetContent(BaseModel):
class WebContent(BaseModel):
title: str
link: str
full_content: str
published_date: datetime | None = None
scrape_successful: bool = True
class InternetSearchProvider(ABC):
class WebSearchProvider(ABC):
@abstractmethod
def search(self, query: str) -> list[InternetSearchResult]:
def search(self, query: str) -> list[WebSearchResult]:
pass
@abstractmethod
def contents(self, urls: list[str]) -> list[InternetContent]:
def contents(self, urls: list[str]) -> list[WebContent]:
pass

View File

@@ -1,13 +1,19 @@
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.exa_client import (
ExaClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.serper_client import (
SerperClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetSearchProvider,
WebSearchProvider,
)
from onyx.configs.chat_configs import EXA_API_KEY
from onyx.configs.chat_configs import SERPER_API_KEY
def get_default_provider() -> InternetSearchProvider | None:
def get_default_provider() -> WebSearchProvider | None:
if EXA_API_KEY:
return ExaClient()
if SERPER_API_KEY:
return SerperClient()
return None

View File

@@ -4,13 +4,13 @@ from typing import Annotated
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetSearchResult,
WebSearchResult,
)
from onyx.context.search.models import InferenceSection
class InternetSearchInput(SubAgentInput):
results_to_open: Annotated[list[tuple[str, InternetSearchResult]], add] = []
results_to_open: Annotated[list[tuple[str, WebSearchResult]], add] = []
parallelization_nr: int = 0
branch_question: Annotated[str, lambda x, y: y] = ""
branch_questions_to_urls: Annotated[dict[str, list[str]], lambda x, y: y] = {}
@@ -18,7 +18,7 @@ class InternetSearchInput(SubAgentInput):
class InternetSearchUpdate(LoggerUpdate):
results_to_open: Annotated[list[tuple[str, InternetSearchResult]], add] = []
results_to_open: Annotated[list[tuple[str, WebSearchResult]], add] = []
class FetchInput(SubAgentInput):

View File

@@ -1,8 +1,8 @@
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetContent,
WebContent,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetSearchResult,
WebSearchResult,
)
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceChunk
@@ -17,7 +17,7 @@ def truncate_search_result_content(content: str, max_chars: int = 10000) -> str:
def dummy_inference_section_from_internet_content(
result: InternetContent,
result: WebContent,
) -> InferenceSection:
truncated_content = truncate_search_result_content(result.full_content)
return InferenceSection(
@@ -34,7 +34,7 @@ def dummy_inference_section_from_internet_content(
boost=1,
recency_bias=1.0,
score=1.0,
hidden=False,
hidden=(not result.scrape_successful),
metadata={},
match_highlights=[],
doc_summary=truncated_content,
@@ -48,7 +48,7 @@ def dummy_inference_section_from_internet_content(
def dummy_inference_section_from_internet_search_result(
result: InternetSearchResult,
result: WebSearchResult,
) -> InferenceSection:
return InferenceSection(
center_chunk=InferenceChunk(

View File

@@ -1,7 +1,6 @@
from uuid import UUID
from pydantic import BaseModel
from pydantic import model_validator
from sqlalchemy.orm import Session
from onyx.agents.agent_search.dr.enums import ResearchType
@@ -88,11 +87,5 @@ class GraphConfig(BaseModel):
# Only needed for agentic search
persistence: GraphPersistence
@model_validator(mode="after")
def validate_search_tool(self) -> "GraphConfig":
if self.behavior.use_agentic_search and self.tooling.search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
return self
class Config:
arbitrary_types_allowed = True

View File

@@ -1,71 +0,0 @@
from typing import cast
from langchain_core.messages import AIMessageChunk
from langchain_core.messages.tool import ToolCall
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.orchestration.states import ToolCallOutput
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.tools.message import build_tool_message
from onyx.tools.message import ToolCallSummary
from onyx.tools.tool_runner import ToolRunner
from onyx.utils.logger import setup_logger
logger = setup_logger()
class ToolCallException(Exception):
"""Exception raised for errors during tool calls."""
def call_tool(
state: ToolChoiceUpdate,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ToolCallUpdate:
"""Calls the tool specified in the state and updates the state with the result"""
cast(GraphConfig, config["metadata"]["config"])
tool_choice = state.tool_choice
if tool_choice is None:
raise ValueError("Cannot invoke tool call node without a tool choice")
tool = tool_choice.tool
tool_args = tool_choice.tool_args
tool_id = tool_choice.id
tool_runner = ToolRunner(
tool, tool_args, override_kwargs=tool_choice.search_tool_override_kwargs
)
tool_kickoff = tool_runner.kickoff()
try:
tool_responses = []
for response in tool_runner.tool_responses():
tool_responses.append(response)
tool_final_result = tool_runner.tool_final_result()
except Exception as e:
raise ToolCallException(
f"Error during tool call for {tool.display_name}: {e}"
) from e
tool_call = ToolCall(name=tool.name, args=tool_args, id=tool_id)
tool_call_summary = ToolCallSummary(
tool_call_request=AIMessageChunk(content="", tool_calls=[tool_call]),
tool_call_result=build_tool_message(
tool_call, tool_runner.tool_message_content()
),
)
tool_call_output = ToolCallOutput(
tool_call_summary=tool_call_summary,
tool_call_kickoff=tool_kickoff,
tool_call_responses=tool_responses,
tool_call_final_result=tool_final_result,
)
return ToolCallUpdate(tool_call_output=tool_call_output)

View File

@@ -1,354 +0,0 @@
from typing import cast
from uuid import uuid4
from langchain_core.messages import AIMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import ToolCall
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.process_llm_stream import process_llm_stream
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.orchestration.states import ToolChoice
from onyx.agents.agent_search.orchestration.states import ToolChoiceState
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.chat.tool_handling.tool_response_handler import (
get_tool_call_for_non_tool_calling_llm_impl,
)
from onyx.configs.chat_configs import USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH
from onyx.context.search.preprocessing.preprocessing import query_analysis
from onyx.context.search.retrieval.search_runner import get_query_embedding
from onyx.llm.factory import get_default_llms
from onyx.prompts.chat_prompts import QUERY_KEYWORD_EXPANSION_WITH_HISTORY_PROMPT
from onyx.prompts.chat_prompts import QUERY_KEYWORD_EXPANSION_WITHOUT_HISTORY_PROMPT
from onyx.prompts.chat_prompts import QUERY_SEMANTIC_EXPANSION_WITH_HISTORY_PROMPT
from onyx.prompts.chat_prompts import QUERY_SEMANTIC_EXPANSION_WITHOUT_HISTORY_PROMPT
from onyx.tools.models import QueryExpansions
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import TimeoutThread
from onyx.utils.threadpool_concurrency import wait_on_background
from onyx.utils.timing import log_function_time
from shared_configs.model_server_models import Embedding
logger = setup_logger()
def _create_history_str(prompt_builder: AnswerPromptBuilder) -> str:
# TODO: Add trimming logic
history_segments = []
for msg in prompt_builder.message_history:
if isinstance(msg, HumanMessage):
role = "User"
elif isinstance(msg, AIMessage):
role = "Assistant"
else:
continue
history_segments.append(f"{role}:\n {msg.content}\n\n")
return "\n".join(history_segments)
def _expand_query(
query: str,
expansion_type: QueryExpansionType,
prompt_builder: AnswerPromptBuilder,
) -> str:
history_str = _create_history_str(prompt_builder)
if history_str:
if expansion_type == QueryExpansionType.KEYWORD:
base_prompt = QUERY_KEYWORD_EXPANSION_WITH_HISTORY_PROMPT
else:
base_prompt = QUERY_SEMANTIC_EXPANSION_WITH_HISTORY_PROMPT
expansion_prompt = base_prompt.format(question=query, history=history_str)
else:
if expansion_type == QueryExpansionType.KEYWORD:
base_prompt = QUERY_KEYWORD_EXPANSION_WITHOUT_HISTORY_PROMPT
else:
base_prompt = QUERY_SEMANTIC_EXPANSION_WITHOUT_HISTORY_PROMPT
expansion_prompt = base_prompt.format(question=query)
msg = HumanMessage(content=expansion_prompt)
primary_llm, _ = get_default_llms()
response = primary_llm.invoke([msg])
rephrased_query: str = cast(str, response.content)
return rephrased_query
def _expand_query_non_tool_calling_llm(
expanded_keyword_thread: TimeoutThread[str],
expanded_semantic_thread: TimeoutThread[str],
) -> QueryExpansions | None:
keyword_expansion: str | None = wait_on_background(expanded_keyword_thread)
semantic_expansion: str | None = wait_on_background(expanded_semantic_thread)
if keyword_expansion is None or semantic_expansion is None:
return None
return QueryExpansions(
keywords_expansions=[keyword_expansion],
semantic_expansions=[semantic_expansion],
)
# TODO: break this out into an implementation function
# and a function that handles extracting the necessary fields
# from the state and config
# TODO: fan-out to multiple tool call nodes? Make this configurable?
@log_function_time(print_only=True)
def choose_tool(
state: ToolChoiceState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ToolChoiceUpdate:
"""
This node is responsible for calling the LLM to choose a tool. If no tool is chosen,
The node MAY emit an answer, depending on whether state["should_stream_answer"] is set.
"""
should_stream_answer = state.should_stream_answer
agent_config = cast(GraphConfig, config["metadata"]["config"])
force_use_tool = agent_config.tooling.force_use_tool
embedding_thread: TimeoutThread[Embedding] | None = None
keyword_thread: TimeoutThread[tuple[bool, list[str]]] | None = None
expanded_keyword_thread: TimeoutThread[str] | None = None
expanded_semantic_thread: TimeoutThread[str] | None = None
# If we have override_kwargs, add them to the tool_args
override_kwargs: SearchToolOverrideKwargs = (
force_use_tool.override_kwargs or SearchToolOverrideKwargs()
)
override_kwargs.original_query = agent_config.inputs.prompt_builder.raw_user_query
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
llm = agent_config.tooling.primary_llm
skip_gen_ai_answer_generation = agent_config.behavior.skip_gen_ai_answer_generation
if (
not agent_config.behavior.use_agentic_search
and agent_config.tooling.search_tool is not None
and (
not force_use_tool.force_use or force_use_tool.tool_name == SearchTool._NAME
)
):
# Run in a background thread to avoid blocking the main thread
embedding_thread = run_in_background(
get_query_embedding,
agent_config.inputs.prompt_builder.raw_user_query,
agent_config.persistence.db_session,
)
keyword_thread = run_in_background(
query_analysis,
agent_config.inputs.prompt_builder.raw_user_query,
)
if USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH:
expanded_keyword_thread = run_in_background(
_expand_query,
agent_config.inputs.prompt_builder.raw_user_query,
QueryExpansionType.KEYWORD,
prompt_builder,
)
expanded_semantic_thread = run_in_background(
_expand_query,
agent_config.inputs.prompt_builder.raw_user_query,
QueryExpansionType.SEMANTIC,
prompt_builder,
)
structured_response_format = agent_config.inputs.structured_response_format
tools = [
tool for tool in (agent_config.tooling.tools or []) if tool.name in state.tools
]
tool, tool_args = None, None
if force_use_tool.force_use and force_use_tool.args is not None:
tool_name, tool_args = (
force_use_tool.tool_name,
force_use_tool.args,
)
tool = get_tool_by_name(tools, tool_name)
# special pre-logic for non-tool calling LLM case
elif not using_tool_calling_llm and tools:
chosen_tool_and_args = get_tool_call_for_non_tool_calling_llm_impl(
force_use_tool=force_use_tool,
tools=tools,
prompt_builder=prompt_builder,
llm=llm,
)
if chosen_tool_and_args:
tool, tool_args = chosen_tool_and_args
# If we have a tool and tool args, we are ready to request a tool call.
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
if tool and tool_args:
if embedding_thread and tool.name == SearchTool._NAME:
# Wait for the embedding thread to finish
embedding = wait_on_background(embedding_thread)
override_kwargs.precomputed_query_embedding = embedding
if keyword_thread and tool.name == SearchTool._NAME:
is_keyword, keywords = wait_on_background(keyword_thread)
override_kwargs.precomputed_is_keyword = is_keyword
override_kwargs.precomputed_keywords = keywords
# dual keyword expansion needs to be added here for non-tool calling LLM case
if (
USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH
and expanded_keyword_thread
and expanded_semantic_thread
and tool.name == SearchTool._NAME
):
override_kwargs.expanded_queries = _expand_query_non_tool_calling_llm(
expanded_keyword_thread=expanded_keyword_thread,
expanded_semantic_thread=expanded_semantic_thread,
)
if (
USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH
and tool.name == SearchTool._NAME
and override_kwargs.expanded_queries
):
if (
override_kwargs.expanded_queries.keywords_expansions is None
or override_kwargs.expanded_queries.semantic_expansions is None
):
raise ValueError("No expanded keyword or semantic threads found.")
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=tool,
tool_args=tool_args,
id=str(uuid4()),
search_tool_override_kwargs=override_kwargs,
),
)
# if we're skipping gen ai answer generation, we should only
# continue if we're forcing a tool call (which will be emitted by
# the tool calling llm in the stream() below)
if skip_gen_ai_answer_generation and not force_use_tool.force_use:
return ToolChoiceUpdate(
tool_choice=None,
)
built_prompt = (
prompt_builder.build()
if isinstance(prompt_builder, AnswerPromptBuilder)
else prompt_builder.built_prompt
)
# At this point, we are either using a tool calling LLM or we are skipping the tool call.
# DEBUG: good breakpoint
stream = llm.stream(
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
prompt=built_prompt,
tools=(
[tool.tool_definition() for tool in tools] or None
if using_tool_calling_llm
else None
),
tool_choice=(
"required"
if tools and force_use_tool.force_use and using_tool_calling_llm
else None
),
structured_response_format=structured_response_format,
)
tool_message = process_llm_stream(
stream,
should_stream_answer
and not agent_config.behavior.skip_gen_ai_answer_generation,
writer,
ind=0,
).ai_message_chunk
if tool_message is None:
raise ValueError("No tool message emitted by LLM")
# If no tool calls are emitted by the LLM, we should not choose a tool
if len(tool_message.tool_calls) == 0:
logger.debug("No tool calls emitted by LLM")
return ToolChoiceUpdate(
tool_choice=None,
)
# TODO: here we could handle parallel tool calls. Right now
# we just pick the first one that matches.
selected_tool: Tool | None = None
selected_tool_call_request: ToolCall | None = None
for tool_call_request in tool_message.tool_calls:
known_tools_by_name = [
tool for tool in tools if tool.name == tool_call_request["name"]
]
if known_tools_by_name:
selected_tool = known_tools_by_name[0]
selected_tool_call_request = tool_call_request
break
logger.error(
"Tool call requested with unknown name field. \n"
f"tools: {tools}"
f"tool_call_request: {tool_call_request}"
)
if not selected_tool or not selected_tool_call_request:
raise ValueError(
f"Tool call attempted with tool {selected_tool}, request {selected_tool_call_request}"
)
logger.debug(f"Selected tool: {selected_tool.name}")
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
if embedding_thread and selected_tool.name == SearchTool._NAME:
# Wait for the embedding thread to finish
embedding = wait_on_background(embedding_thread)
override_kwargs.precomputed_query_embedding = embedding
if keyword_thread and selected_tool.name == SearchTool._NAME:
is_keyword, keywords = wait_on_background(keyword_thread)
override_kwargs.precomputed_is_keyword = is_keyword
override_kwargs.precomputed_keywords = keywords
if (
selected_tool.name == SearchTool._NAME
and expanded_keyword_thread
and expanded_semantic_thread
):
override_kwargs.expanded_queries = _expand_query_non_tool_calling_llm(
expanded_keyword_thread=expanded_keyword_thread,
expanded_semantic_thread=expanded_semantic_thread,
)
if (
USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH
and selected_tool.name == SearchTool._NAME
and override_kwargs.expanded_queries
):
# TODO: this is a hack to handle the case where the expanded queries are not found.
# We should refactor this to be more robust.
if (
override_kwargs.expanded_queries.keywords_expansions is None
or override_kwargs.expanded_queries.semantic_expansions is None
):
raise ValueError("No expanded keyword or semantic threads found.")
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=selected_tool,
tool_args=selected_tool_call_request["args"],
id=selected_tool_call_request["id"],
search_tool_override_kwargs=override_kwargs,
),
)

View File

@@ -1,17 +0,0 @@
from typing import Any
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
def prepare_tool_input(state: Any, config: RunnableConfig) -> ToolChoiceInput:
agent_config = cast(GraphConfig, config["metadata"]["config"])
return ToolChoiceInput(
# NOTE: this node is used at the top level of the agent, so we always stream
should_stream_answer=True,
prompt_snapshot=None, # uses default prompt builder
tools=[tool.name for tool in (agent_config.tooling.tools or [])],
)

View File

@@ -5,11 +5,10 @@ from typing import Literal
from typing import Type
from typing import TypeVar
from braintrust import traced
from langchain.schema.language_model import LanguageModelInput
from langchain_core.messages import HumanMessage
from langgraph.types import StreamWriter
from litellm import get_supported_openai_params
from litellm import supports_response_schema
from pydantic import BaseModel
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
@@ -29,6 +28,7 @@ SchemaType = TypeVar("SchemaType", bound=BaseModel)
JSON_PATTERN = re.compile(r"```(?:json)?\s*(\{.*?\})\s*```", re.DOTALL)
@traced(name="stream llm", type="llm")
def stream_llm_answer(
llm: LLM,
prompt: LanguageModelInput,
@@ -147,6 +147,7 @@ def invoke_llm_json(
Invoke an LLM, forcing it to respond in a specified JSON format if possible,
and return an object of that schema.
"""
from litellm.utils import get_supported_openai_params, supports_response_schema
# check if the model supports response_format: json_schema
supports_json = "response_format" in (

View File

@@ -29,7 +29,7 @@ from onyx.configs.app_configs import SMTP_USER
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME
from onyx.configs.constants import ONYX_SLACK_URL
from onyx.configs.constants import ONYX_DISCORD_URL
from onyx.db.models import User
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.logger import setup_logger
@@ -145,7 +145,7 @@ HTML_EMAIL_TEMPLATE = """\
<tr>
<td class="footer">
© {year} {application_name}. All rights reserved.
{slack_fragment}
{community_link_fragment}
</td>
</tr>
</table>
@@ -161,9 +161,9 @@ def build_html_email(
cta_text: str | None = None,
cta_link: str | None = None,
) -> str:
slack_fragment = ""
community_link_fragment = ""
if application_name == ONYX_DEFAULT_APPLICATION_NAME:
slack_fragment = f'<br>Have questions? Join our Slack community <a href="{ONYX_SLACK_URL}">here</a>.'
community_link_fragment = f'<br>Have questions? Join our Discord community <a href="{ONYX_DISCORD_URL}">here</a>.'
if cta_text and cta_link:
cta_block = f'<a class="cta-button" href="{cta_link}">{cta_text}</a>'
@@ -175,7 +175,7 @@ def build_html_email(
heading=heading,
message=message,
cta_block=cta_block,
slack_fragment=slack_fragment,
community_link_fragment=community_link_fragment,
year=datetime.now().year,
)

View File

@@ -1040,7 +1040,10 @@ async def optional_user(
# check if an API key is present
if user is None:
hashed_api_key = get_hashed_api_key_from_request(request)
try:
hashed_api_key = get_hashed_api_key_from_request(request)
except ValueError:
hashed_api_key = None
if hashed_api_key:
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)

View File

@@ -0,0 +1,111 @@
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_process_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.background")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
@celeryd_init.connect
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received for consolidated background worker.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME)
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
app_base.on_secondary_worker_init(sender, **kwargs)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_ready(sender, **kwargs)
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@worker_process_init.connect
def init_worker(**kwargs: Any) -> None:
SqlEngine.reset_engine()
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
base_bootsteps = app_base.get_bootsteps()
for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.kg_processing",
"onyx.background.celery.tasks.monitoring",
"onyx.background.celery.tasks.user_file_processing",
]
)

View File

@@ -98,5 +98,8 @@ for bootstep in base_bootsteps:
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.docfetching",
# Ensure the user files indexing worker registers the doc_id migration task
# TODO(subash): remove this once the doc_id migration is complete
"onyx.background.celery.tasks.user_file_processing",
]
)

View File

@@ -324,5 +324,6 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.llm_model_update",
"onyx.background.celery.tasks.kg_processing",
"onyx.background.celery.tasks.user_file_processing",
]
)

View File

@@ -19,7 +19,9 @@ from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -30,7 +32,7 @@ PRUNING_CHECKPOINTED_BATCH_SIZE = 32
def document_batch_to_ids(
doc_batch: Iterator[list[Document]],
doc_batch: Iterator[list[Document]] | Iterator[list[SlimDocument]],
) -> Generator[set[str], None, None]:
for doc_list in doc_batch:
yield {doc.id for doc in doc_list}
@@ -41,20 +43,24 @@ def extract_ids_from_runnable_connector(
callback: IndexingHeartbeatInterface | None = None,
) -> set[str]:
"""
If the SlimConnector hasnt been implemented for the given connector, just pull
If the given connector is neither a SlimConnector nor a SlimConnectorWithPermSync, just pull
all docs using the load_from_state and grab out the IDs.
Optionally, a callback can be passed to handle the length of each document batch.
"""
all_connector_doc_ids: set[str] = set()
if isinstance(runnable_connector, SlimConnector):
for metadata_batch in runnable_connector.retrieve_all_slim_documents():
all_connector_doc_ids.update({doc.id for doc in metadata_batch})
doc_batch_id_generator = None
if isinstance(runnable_connector, LoadConnector):
if isinstance(runnable_connector, SlimConnector):
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.retrieve_all_slim_docs()
)
elif isinstance(runnable_connector, SlimConnectorWithPermSync):
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.retrieve_all_slim_docs_perm_sync()
)
# If the connector isn't slim, fall back to running it normally to get ids
elif isinstance(runnable_connector, LoadConnector):
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.load_from_state()
)
@@ -78,13 +84,14 @@ def extract_ids_from_runnable_connector(
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
# this function is called per batch for rate limiting
def doc_batch_processing_func(doc_batch_ids: set[str]) -> set[str]:
return doc_batch_ids
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
doc_batch_processing_func = rate_limit_builder(
doc_batch_processing_func = (
rate_limit_builder(
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
)(lambda x: x)
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
else lambda x: x
)
for doc_batch_ids in doc_batch_id_generator:
if callback:
if callback.should_stop():

View File

@@ -0,0 +1,21 @@
import onyx.background.celery.configs.base as shared_config
from onyx.configs.app_configs import CELERY_WORKER_BACKGROUND_CONCURRENCY
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
worker_concurrency = CELERY_WORKER_BACKGROUND_CONCURRENCY
worker_pool = "threads"
worker_prefetch_multiplier = 1

View File

@@ -1,4 +1,5 @@
import onyx.background.celery.configs.base as shared_config
from onyx.configs.app_configs import CELERY_WORKER_HEAVY_CONCURRENCY
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
@@ -15,6 +16,6 @@ result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
worker_concurrency = 4
worker_concurrency = CELERY_WORKER_HEAVY_CONCURRENCY
worker_pool = "threads"
worker_prefetch_multiplier = 1

View File

@@ -1,4 +1,5 @@
import onyx.background.celery.configs.base as shared_config
from onyx.configs.app_configs import CELERY_WORKER_MONITORING_CONCURRENCY
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
@@ -16,6 +17,6 @@ task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
# Monitoring worker specific settings
worker_concurrency = 1 # Single worker is sufficient for monitoring
worker_concurrency = CELERY_WORKER_MONITORING_CONCURRENCY
worker_pool = "threads"
worker_prefetch_multiplier = 1

View File

@@ -33,17 +33,25 @@ beat_task_templates: list[dict] = [
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.USER_FILE_PROCESSING,
},
},
{
"name": "check-for-user-file-project-sync",
"task": OnyxCeleryTask.CHECK_FOR_USER_FILE_PROJECT_SYNC,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "user-file-docid-migration",
"task": OnyxCeleryTask.USER_FILE_DOCID_MIGRATION,
"schedule": timedelta(minutes=1),
"schedule": timedelta(minutes=10),
"options": {
"priority": OnyxCeleryPriority.LOW,
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.USER_FILE_PROCESSING,
"queue": OnyxCeleryQueues.USER_FILES_INDEXING,
},
},
{
@@ -85,9 +93,9 @@ beat_task_templates: list[dict] = [
{
"name": "check-for-index-attempt-cleanup",
"task": OnyxCeleryTask.CHECK_FOR_INDEX_ATTEMPT_CLEANUP,
"schedule": timedelta(hours=1),
"schedule": timedelta(minutes=30),
"options": {
"priority": OnyxCeleryPriority.LOW,
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},

View File

@@ -89,6 +89,7 @@ from onyx.indexing.adapters.document_indexing_adapter import (
DocumentIndexingBatchAdapter,
)
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
@@ -1270,8 +1271,6 @@ def _docprocessing_task(
tenant_id: str,
batch_num: int,
) -> None:
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
start_time = time.monotonic()
if tenant_id:

View File

@@ -9,17 +9,19 @@ import sqlalchemy as sa
from celery import shared_task
from celery import Task
from redis.lock import Lock as RedisLock
from retry import retry
from sqlalchemy import select
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import OnyxCeleryPriority
@@ -36,13 +38,12 @@ from onyx.db.models import UserFile
from onyx.db.search_settings import get_active_search_settings
from onyx.db.search_settings import get_active_search_settings_list
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaDocumentFields
from onyx.document_index.interfaces import VespaDocumentUserFields
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
from onyx.document_index.vespa.shared_utils.utils import (
replace_invalid_doc_id_characters,
)
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.document_index.vespa_constants import USER_PROJECT
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.file_store import S3BackedFileStore
from onyx.httpx.httpx_pool import HttpxPool
@@ -134,7 +135,8 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
redis_client = get_redis_client(tenant_id=tenant_id)
file_lock: RedisLock = redis_client.lock(
_user_file_lock_key(user_file_id), timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT
_user_file_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
)
if not file_lock.acquire(blocking=False):
@@ -306,8 +308,10 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
user_file_ids = (
db_session.execute(
select(UserFile.id).where(
UserFile.needs_project_sync.is_(True)
and UserFile.status == UserFileStatus.COMPLETED
sa.and_(
UserFile.needs_project_sync.is_(True),
UserFile.status == UserFileStatus.COMPLETED,
)
)
)
.scalars()
@@ -348,7 +352,7 @@ def process_single_user_file_project_sync(
redis_client = get_redis_client(tenant_id=tenant_id)
file_lock: RedisLock = redis_client.lock(
_user_file_project_sync_lock_key(user_file_id),
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
)
if not file_lock.acquire(blocking=False):
@@ -359,6 +363,15 @@ def process_single_user_file_project_sync(
try:
with get_session_with_current_tenant() as db_session:
# 20 is the documented default for httpx max_keepalive_connections
if MANAGED_VESPA:
httpx_init_vespa_pool(
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
)
else:
httpx_init_vespa_pool(20)
active_search_settings = get_active_search_settings(db_session)
doc_index = get_default_document_index(
search_settings=active_search_settings.primary,
@@ -416,6 +429,7 @@ def _normalize_legacy_user_file_doc_id(old_id: str) -> str:
return old_id
@retry(tries=3, delay=1, backoff=2, jitter=(0.0, 1.0))
def _visit_chunks(
*,
http_client: httpx.Client,
@@ -423,10 +437,13 @@ def _visit_chunks(
selection: str,
continuation: str | None = None,
) -> tuple[list[dict[str, Any]], str | None]:
task_logger.info(
f"Visiting chunks for index={index_name} with selection={selection}"
)
base_url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
params: dict[str, str] = {
"selection": selection,
"wantedDocumentCount": "1000",
"wantedDocumentCount": "100", # Use smaller batch size to avoid timeouts
}
if continuation:
params["continuation"] = continuation
@@ -436,127 +453,155 @@ def _visit_chunks(
return payload.get("documents", []), payload.get("continuation")
def _update_document_id_in_vespa(
*,
index_name: str,
old_doc_id: str,
new_doc_id: str,
user_project_ids: list[int] | None = None,
) -> None:
clean_new_doc_id = replace_invalid_doc_id_characters(new_doc_id)
normalized_old = _normalize_legacy_user_file_doc_id(old_doc_id)
clean_old_doc_id = replace_invalid_doc_id_characters(normalized_old)
def update_legacy_plaintext_file_records() -> None:
"""Migrate legacy plaintext cache objects from int-based keys to UUID-based
keys. Copies each S3 object to its expected UUID key and updates DB.
selection = f"{index_name}.document_id=='{clean_old_doc_id}'"
task_logger.debug(f"Vespa selection: {selection}")
Examples:
- Old key: bucket/schema/plaintext_<int>
- New key: bucket/schema/plaintext_<uuid>
"""
with get_vespa_http_client() as http_client:
continuation: str | None = None
while True:
docs, continuation = _visit_chunks(
http_client=http_client,
index_name=index_name,
selection=selection,
continuation=continuation,
task_logger.info("update_legacy_plaintext_file_records - Starting")
with get_session_with_current_tenant() as db_session:
store = get_default_file_store()
if not isinstance(store, S3BackedFileStore):
task_logger.info(
"update_legacy_plaintext_file_records - Skipping non-S3 store"
)
if not docs:
break
for doc in docs:
vespa_full_id = doc.get("id")
if not vespa_full_id:
return
s3_client = store._get_s3_client()
bucket_name = store._get_bucket_name()
# Select PLAINTEXT_CACHE records whose object_key ends with 'plaintext_' + non-hyphen chars
# Example: 'some/path/plaintext_abc123' matches; '.../plaintext_foo-bar' does not
plaintext_records: Sequence[FileRecord] = (
db_session.execute(
sa.select(FileRecord).where(
FileRecord.file_origin == FileOrigin.PLAINTEXT_CACHE,
FileRecord.object_key.op("~")(r"plaintext_[^-]+$"),
)
)
.scalars()
.all()
)
task_logger.info(
f"update_legacy_plaintext_file_records - Found {len(plaintext_records)} plaintext records to update"
)
normalized = 0
for fr in plaintext_records:
try:
expected_key = store._get_s3_key(fr.file_id)
if fr.object_key == expected_key:
continue
vespa_doc_uuid = vespa_full_id.split("::")[-1]
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_doc_uuid}"
update_request: dict[str, Any] = {
"fields": {"document_id": {"assign": clean_new_doc_id}}
}
if user_project_ids is not None:
update_request["fields"][USER_PROJECT] = {
"assign": user_project_ids
}
r = http_client.put(vespa_url, json=update_request)
r.raise_for_status()
if not continuation:
break
if fr.bucket_name is None:
task_logger.warning(f"id={fr.file_id} - Bucket name is None")
continue
if fr.object_key is None:
task_logger.warning(f"id={fr.file_id} - Object key is None")
continue
# Copy old object to new key
copy_source = f"{fr.bucket_name}/{fr.object_key}"
s3_client.copy_object(
CopySource=copy_source,
Bucket=bucket_name,
Key=expected_key,
MetadataDirective="COPY",
)
# Delete old object (best-effort)
try:
s3_client.delete_object(Bucket=fr.bucket_name, Key=fr.object_key)
except Exception:
pass
# Update DB record with new key
fr.object_key = expected_key
db_session.add(fr)
normalized += 1
except Exception as e:
task_logger.warning(f"id={fr.file_id} - {e.__class__.__name__}")
if normalized:
db_session.commit()
task_logger.info(
f"user_file_docid_migration_task normalized {normalized} plaintext objects"
)
@shared_task(
name=OnyxCeleryTask.USER_FILE_DOCID_MIGRATION,
ignore_result=True,
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
time_limit=LIGHT_TIME_LIMIT,
bind=True,
)
def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
"""Per-tenant job to update Vespa and search_doc document_id values for user files.
- For each user_file with a legacy document_id, set Vespa `document_id` to the UUID `user_file.id`.
- Update `search_doc.document_id` to the same UUID string.
"""
task_logger.info(
f"user_file_docid_migration_task - Starting for tenant={tenant_id}"
)
redis_client = get_redis_client(tenant_id=tenant_id)
lock: RedisLock = redis_client.lock(
OnyxRedisLocks.USER_FILE_DOCID_MIGRATION_LOCK,
timeout=CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT,
)
if not lock.acquire(blocking=False):
task_logger.info(
f"user_file_docid_migration_task - Lock held, skipping tenant={tenant_id}"
)
return False
updated_count = 0
try:
update_legacy_plaintext_file_records()
# Track lock renewal
last_lock_time = time.monotonic()
with get_session_with_current_tenant() as db_session:
# 20 is the documented default for httpx max_keepalive_connections
if MANAGED_VESPA:
httpx_init_vespa_pool(
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
)
else:
httpx_init_vespa_pool(20)
active_settings = get_active_search_settings(db_session)
document_index = get_default_document_index(
active_settings.primary,
active_settings.secondary,
search_settings=active_settings.primary,
secondary_search_settings=active_settings.secondary,
httpx_client=HttpxPool.get("vespa"),
)
if hasattr(document_index, "index_name"):
index_name = document_index.index_name
else:
index_name = "danswer_index"
# Fetch mappings of legacy -> new ids
rows = db_session.execute(
sa.select(
UserFile.document_id.label("document_id"),
UserFile.id.label("id"),
).where(
UserFile.document_id.is_not(None),
UserFile.document_id_migrated.is_(False),
)
).all()
retry_index = RetryDocumentIndex(document_index)
# dedupe by old document_id
seen: set[str] = set()
for row in rows:
old_doc_id = str(row.document_id)
new_uuid = str(row.id)
if not old_doc_id or not new_uuid or old_doc_id in seen:
continue
seen.add(old_doc_id)
# collect user project ids for a combined Vespa update
user_project_ids: list[int] | None = None
try:
uf = db_session.get(UserFile, UUID(new_uuid))
if uf is not None:
user_project_ids = [project.id for project in uf.projects]
except Exception as e:
task_logger.warning(
f"Tenant={tenant_id} failed fetching projects for doc_id={new_uuid} - {e.__class__.__name__}"
)
try:
_update_document_id_in_vespa(
index_name=index_name,
old_doc_id=old_doc_id,
new_doc_id=new_uuid,
user_project_ids=user_project_ids,
)
except Exception as e:
task_logger.warning(
f"Tenant={tenant_id} failed Vespa update for doc_id={new_uuid} - {e.__class__.__name__}"
)
# Update search_doc records to refer to the UUID string
# we are not using document_id_migrated = false because if the migration already completed,
# it will not run again and we will not update the search_doc records because of the issue currently fixed
# Select user files with a legacy doc id that have not been migrated
user_files = (
db_session.execute(
sa.select(UserFile).where(UserFile.document_id.is_not(None))
sa.select(UserFile).where(
sa.and_(
UserFile.document_id.is_not(None),
UserFile.document_id_migrated.is_(False),
)
)
)
.scalars()
.all()
)
task_logger.info(
f"user_file_docid_migration_task - Found {len(user_files)} user files to migrate"
)
# Query all SearchDocs that need updating
search_docs = (
db_session.execute(
@@ -567,9 +612,9 @@ def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
.scalars()
.all()
)
task_logger.info(f"Found {len(user_files)} user files to update")
task_logger.info(f"Found {len(search_docs)} search docs to update")
task_logger.info(
f"user_file_docid_migration_task - Found {len(search_docs)} search docs to update"
)
# Build a map of normalized doc IDs to SearchDocs
search_doc_map: dict[str, list[SearchDoc]] = {}
@@ -579,102 +624,139 @@ def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
search_doc_map[doc_id] = []
search_doc_map[doc_id].append(sd)
# Process each UserFile and update matching SearchDocs
updated_count = 0
for uf in user_files:
doc_id = uf.document_id
if doc_id.startswith("USER_FILE_CONNECTOR__"):
doc_id = "FILE_CONNECTOR__" + doc_id[len("USER_FILE_CONNECTOR__") :]
task_logger.debug(
f"user_file_docid_migration_task - Built search doc map with {len(search_doc_map)} entries"
)
if doc_id in search_doc_map:
# Update the SearchDoc to use the UserFile's UUID
for search_doc in search_doc_map[doc_id]:
search_doc.document_id = str(uf.id)
db_session.add(search_doc)
ids_preview = list(search_doc_map.keys())[:5]
task_logger.debug(
f"user_file_docid_migration_task - First few search_doc_map ids: {ids_preview if ids_preview else 'No ids found'}"
)
task_logger.debug(
f"user_file_docid_migration_task - search_doc_map total items: "
f"{sum(len(docs) for docs in search_doc_map.values())}"
)
for user_file in user_files:
# Periodically renew the Redis lock to prevent expiry mid-run
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT / 4
):
renewed = False
try:
# extend lock ttl to full timeout window
lock.extend(CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT)
renewed = True
except Exception:
# if extend fails, best-effort reacquire as a fallback
try:
lock.reacquire()
renewed = True
except Exception:
renewed = False
last_lock_time = current_time
if not renewed or not lock.owned():
task_logger.error(
"user_file_docid_migration_task - Lost lock ownership or failed to renew; aborting for safety"
)
return False
# Mark UserFile as migrated
uf.document_id_migrated = True
db_session.add(uf)
try:
clean_old_doc_id = replace_invalid_doc_id_characters(
user_file.document_id
)
normalized_doc_id = _normalize_legacy_user_file_doc_id(
clean_old_doc_id
)
user_project_ids = [project.id for project in user_file.projects]
task_logger.info(
f"user_file_docid_migration_task - Migrating user file {user_file.id} with doc_id {normalized_doc_id}"
)
index_name = active_settings.primary.index_name
# First find the chunks count using direct Vespa query
selection = f"{index_name}.document_id=='{normalized_doc_id}'"
# Count all chunks for this document
chunk_count = 0
continuation = None
while True:
docs, continuation = _visit_chunks(
http_client=HttpxPool.get("vespa"),
index_name=index_name,
selection=selection,
continuation=continuation,
)
if not docs:
break
chunk_count += len(docs)
if not continuation:
break
task_logger.info(
f"Found {chunk_count} chunks for document {normalized_doc_id}"
)
# Now update Vespa chunks with the found chunk count using retry_index
updated_chunks = retry_index.update_single(
doc_id=str(normalized_doc_id),
tenant_id=tenant_id,
chunk_count=chunk_count,
fields=VespaDocumentFields(document_id=str(user_file.id)),
user_fields=VespaDocumentUserFields(
user_projects=user_project_ids
),
)
user_file.chunk_count = updated_chunks
# Update the SearchDocs
actual_doc_id = str(user_file.document_id)
normalized_actual_doc_id = _normalize_legacy_user_file_doc_id(
actual_doc_id
)
if (
normalized_doc_id in search_doc_map
or normalized_actual_doc_id in search_doc_map
):
to_update = (
search_doc_map[normalized_doc_id]
if normalized_doc_id in search_doc_map
else search_doc_map[normalized_actual_doc_id]
)
task_logger.debug(
f"user_file_docid_migration_task - Updating {len(to_update)} search docs for user file {user_file.id}"
)
for search_doc in to_update:
search_doc.document_id = str(user_file.id)
db_session.add(search_doc)
user_file.document_id_migrated = True
db_session.add(user_file)
db_session.commit()
updated_count += 1
except Exception as per_file_exc:
# Rollback the current transaction and continue with the next file
db_session.rollback()
task_logger.exception(
f"user_file_docid_migration_task - Error migrating user file {user_file.id} - "
f"{per_file_exc.__class__.__name__}"
)
task_logger.info(
f"Updated {updated_count} SearchDoc records with new UUIDs"
f"user_file_docid_migration_task - Updated {updated_count} user files"
)
db_session.commit()
# Normalize plaintext FileRecord blobs: ensure S3 object key aligns with current file_id
try:
store = get_default_file_store()
# Only supported for S3-backed stores where we can manipulate object keys
if isinstance(store, S3BackedFileStore):
s3_client = store._get_s3_client()
bucket_name = store._get_bucket_name()
plaintext_records: Sequence[FileRecord] = (
db_session.execute(
sa.select(FileRecord).where(
FileRecord.file_origin == FileOrigin.PLAINTEXT_CACHE,
FileRecord.file_id.like("plaintext_%"),
)
)
.scalars()
.all()
)
normalized = 0
for fr in plaintext_records:
try:
expected_key = store._get_s3_key(fr.file_id)
if fr.object_key == expected_key:
continue
# Copy old object to new key
copy_source = f"{fr.bucket_name}/{fr.object_key}"
s3_client.copy_object(
CopySource=copy_source,
Bucket=bucket_name,
Key=expected_key,
MetadataDirective="COPY",
)
# Delete old object (best-effort)
try:
s3_client.delete_object(
Bucket=fr.bucket_name, Key=fr.object_key
)
except Exception:
pass
# Update DB record with new key
fr.object_key = expected_key
db_session.add(fr)
normalized += 1
except Exception as e:
task_logger.warning(
f"Tenant={tenant_id} failed plaintext object normalize for "
f"id={fr.file_id} - {e.__class__.__name__}"
)
if normalized:
db_session.commit()
task_logger.info(
f"user_file_docid_migration_task normalized {normalized} plaintext objects for tenant={tenant_id}"
)
else:
task_logger.info(
"user_file_docid_migration_task skipping plaintext object normalization (non-S3 store)"
)
except Exception:
task_logger.exception(
f"user_file_docid_migration_task - Error during plaintext normalization for tenant={tenant_id}"
)
task_logger.info(
f"user_file_docid_migration_task completed for tenant={tenant_id} (rows={len(rows)})"
f"user_file_docid_migration_task - Completed for tenant={tenant_id} (updated={updated_count})"
)
return True
except Exception:
except Exception as e:
task_logger.exception(
f"user_file_docid_migration_task - Error during execution for tenant={tenant_id}"
f"user_file_docid_migration_task - Error during execution for tenant={tenant_id} "
f"(updated={updated_count}) exception={e.__class__.__name__}"
)
return False
finally:
if lock.owned():
lock.release()

View File

@@ -0,0 +1,10 @@
from celery import Celery
from onyx.utils.variable_functionality import fetch_versioned_implementation
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
app: Celery = fetch_versioned_implementation(
"onyx.background.celery.apps.background",
"celery_app",
)

View File

@@ -5,6 +5,7 @@ from sqlalchemy.orm import Session
from onyx.configs.constants import NUM_DAYS_TO_KEEP_INDEX_ATTEMPTS
from onyx.db.engine.time_utils import get_db_current_time
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
def get_old_index_attempts(
@@ -21,6 +22,10 @@ def get_old_index_attempts(
def cleanup_index_attempts(db_session: Session, index_attempt_ids: list[int]) -> None:
"""Clean up multiple index attempts"""
db_session.query(IndexAttemptError).filter(
IndexAttemptError.index_attempt_id.in_(index_attempt_ids)
).delete(synchronize_session=False)
db_session.query(IndexAttempt).filter(
IndexAttempt.id.in_(index_attempt_ids)
).delete(synchronize_session=False)

View File

@@ -28,6 +28,7 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorStopSignal
@@ -101,7 +102,6 @@ def _get_connector_runner(
are the complete list of existing documents of the connector. If the task
of type LOAD_STATE, the list will be considered complete and otherwise incomplete.
"""
from onyx.connectors.factory import instantiate_connector
task = attempt.connector_credential_pair.connector.input_type

View File

@@ -2,14 +2,15 @@ import re
import time
import traceback
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from typing import cast
from typing import Protocol
from uuid import UUID
from redis.client import Redis
from sqlalchemy.orm import Session
from onyx.agents.agent_search.orchestration.nodes.call_tool import ToolCallException
from onyx.chat.answer import Answer
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.chat_utils import create_temporary_persona
@@ -26,12 +27,12 @@ from onyx.chat.models import PromptConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import StreamingError
from onyx.chat.models import UserKnowledgeFilePacket
from onyx.chat.packet_proccessing.process_streamed_packets import (
process_streamed_packets,
)
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
from onyx.chat.turn import fast_chat_turn
from onyx.chat.turn.infra.emitter import get_default_emitter
from onyx.chat.turn.models import ChatTurnDependencies
from onyx.chat.user_files.parse_user_files import parse_user_files
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
@@ -89,6 +90,7 @@ from onyx.server.query_and_chat.streaming_models import MessageDelta
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.utils import get_json_line
from onyx.tools.adapter_v1_to_v2 import tools_to_function_tools
from onyx.tools.force import ForceUseTool
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool import Tool
@@ -113,6 +115,10 @@ logger = setup_logger()
ERROR_TYPE_CANCELLED = "cancelled"
class ToolCallException(Exception):
"""Exception raised for errors during tool calls."""
class PartialResponse(Protocol):
def __call__(
self,
@@ -352,14 +358,12 @@ def stream_chat_message_objects(
long_term_logger = LongTermLogger(
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
)
persona = _get_persona_for_chat_session(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
default_persona=chat_session.persona,
)
# TODO: remove once we have an endpoint for this stuff
process_kg_commands(new_msg_req.message, persona.name, tenant_id, db_session)
@@ -731,12 +735,19 @@ def stream_chat_message_objects(
)
prompt_builder = AnswerPromptBuilder(
# TODO: for backwards compatibility, we are using the V1
# user_message=default_build_user_message_v2(
# user_query=final_msg.message,
# prompt_config=prompt_config,
# files=latest_query_files,
# ),
user_message=default_build_user_message(
user_query=final_msg.message,
prompt_config=prompt_config,
files=latest_query_files,
single_message_history=single_message_history,
),
# TODO: for backwards compatibility, we are using the V1
# system_message=default_build_system_message_v2(prompt_config, llm.config),
system_message=default_build_system_message(prompt_config, llm.config),
message_history=message_history,
llm_config=llm.config,
@@ -780,10 +791,20 @@ def stream_chat_message_objects(
project_instructions=project_instructions,
)
# Process streamed packets using the new packet processing module
yield from process_streamed_packets(
from onyx.chat.packet_proccessing import process_streamed_packets
yield from process_streamed_packets.process_streamed_packets(
answer_processed_output=answer.processed_streamed_output,
)
# TODO: For backwards compatible PR, switch back to the original call
# yield from _fast_message_stream(
# answer,
# tools,
# db_session,
# get_redis_client(),
# str(chat_session_id),
# str(reserved_message_id),
# )
except ValueError as e:
logger.exception("Failed to process chat message.")
@@ -820,6 +841,59 @@ def stream_chat_message_objects(
return
# TODO: Refactor this to live somewhere else
def _fast_message_stream(
answer: Answer,
tools: list[Tool],
db_session: Session,
redis_client: Redis,
chat_session_id: UUID,
reserved_message_id: int,
) -> Generator[Packet, None, None]:
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
OktaProfileTool,
)
from onyx.llm.litellm_singleton import LitellmModel
image_generation_tool_instance = None
okta_profile_tool_instance = None
for tool in tools:
if isinstance(tool, ImageGenerationTool):
image_generation_tool_instance = tool
elif isinstance(tool, OktaProfileTool):
okta_profile_tool_instance = tool
converted_message_history = [
PreviousMessage.from_langchain_msg(message, 0).to_agent_sdk_msg()
for message in answer.graph_inputs.prompt_builder.build()
]
emitter = get_default_emitter()
return fast_chat_turn.fast_chat_turn(
messages=converted_message_history,
# TODO: Maybe we can use some DI framework here?
dependencies=ChatTurnDependencies(
llm_model=LitellmModel(
model=answer.graph_tooling.primary_llm.config.model_name,
base_url=answer.graph_tooling.primary_llm.config.api_base,
api_key=answer.graph_tooling.primary_llm.config.api_key,
),
llm=answer.graph_tooling.primary_llm,
tools=tools_to_function_tools(tools),
search_pipeline=answer.graph_tooling.search_tool,
image_generation_tool=image_generation_tool_instance,
okta_profile_tool=okta_profile_tool_instance,
db_session=db_session,
redis_client=redis_client,
emitter=emitter,
),
chat_session_id=chat_session_id,
message_id=reserved_message_id,
research_type=answer.graph_config.behavior.research_type,
)
@log_generator_function_time()
def stream_chat_message(
new_msg_req: CreateChatMessageRequest,

View File

@@ -21,8 +21,10 @@ from onyx.llm.utils import model_supports_image_input
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
from onyx.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
from onyx.prompts.prompt_utils import drop_messages_history_overflow
from onyx.prompts.prompt_utils import handle_company_awareness
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
from onyx.tools.force import ForceUseTool
from onyx.tools.models import ToolCallFinalResult
@@ -31,6 +33,33 @@ from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
def default_build_system_message_v2(
prompt_config: PromptConfig,
llm_config: LLMConfig,
) -> SystemMessage | None:
system_prompt = prompt_config.system_prompt.strip()
system_prompt += REQUIRE_CITATION_STATEMENT
# See https://simonwillison.net/tags/markdown/ for context on this temporary fix
# for o-series markdown generation
if (
llm_config.model_provider == OPENAI_PROVIDER_NAME
and llm_config.model_name.startswith("o")
):
system_prompt = CODE_BLOCK_MARKDOWN + system_prompt
tag_handled_prompt = handle_onyx_date_awareness(
system_prompt,
prompt_config,
add_additional_info_if_no_tag=prompt_config.datetime_aware,
)
if not tag_handled_prompt:
return None
tag_handled_prompt = handle_company_awareness(tag_handled_prompt)
return SystemMessage(content=tag_handled_prompt)
def default_build_system_message(
prompt_config: PromptConfig,
llm_config: LLMConfig,
@@ -52,9 +81,29 @@ def default_build_system_message(
if not tag_handled_prompt:
return None
tag_handled_prompt = handle_company_awareness(tag_handled_prompt)
return SystemMessage(content=tag_handled_prompt)
def default_build_user_message_v2(
user_query: str,
prompt_config: PromptConfig,
files: list[InMemoryChatFile] = [],
) -> HumanMessage:
user_prompt = user_query
user_prompt = user_prompt.strip()
tag_handled_prompt = handle_onyx_date_awareness(user_prompt, prompt_config)
user_msg = HumanMessage(
content=(
build_content_with_imgs(tag_handled_prompt, files)
if files
else tag_handled_prompt
)
)
return user_msg
def default_build_user_message(
user_query: str,
prompt_config: PromptConfig,

View File

@@ -0,0 +1,56 @@
from uuid import UUID
from redis.client import Redis
from shared_configs.contextvars import get_current_tenant_id
# Redis key prefixes for chat session stop signals
PREFIX = "chatsessionstop"
FENCE_PREFIX = f"{PREFIX}_fence"
def set_fence(chat_session_id: UUID, redis_client: Redis, value: bool) -> None:
"""
Set or clear the stop signal fence for a chat session.
Args:
chat_session_id: The UUID of the chat session
redis_client: Redis client to use
value: True to set the fence (stop signal), False to clear it
"""
tenant_id = get_current_tenant_id()
fence_key = f"{FENCE_PREFIX}_{tenant_id}_{chat_session_id}"
if not value:
redis_client.delete(fence_key)
return
redis_client.set(fence_key, 0)
def is_connected(chat_session_id: UUID, redis_client: Redis) -> bool:
"""
Check if the chat session should continue (not stopped).
Args:
chat_session_id: The UUID of the chat session to check
redis_client: Redis client to use for checking the stop signal
Returns:
True if the session should continue, False if it should stop
"""
tenant_id = get_current_tenant_id()
fence_key = f"{FENCE_PREFIX}_{tenant_id}_{chat_session_id}"
return not bool(redis_client.exists(fence_key))
def reset_cancel_status(chat_session_id: UUID, redis_client: Redis) -> None:
"""
Clear the stop signal for a chat session.
Args:
chat_session_id: The UUID of the chat session
redis_client: Redis client to use
"""
tenant_id = get_current_tenant_id()
fence_key = f"{FENCE_PREFIX}_{tenant_id}_{chat_session_id}"
redis_client.delete(fence_key)

View File

@@ -0,0 +1 @@
# Turn module for chat functionality

View File

@@ -0,0 +1,257 @@
from typing import cast
from uuid import UUID
from agents import Agent
from agents import ModelSettings
from agents import RawResponsesStreamEvent
from agents import StopAtTools
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import AggregatedDRContext
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.chat.chat_utils import llm_doc_from_inference_section
from onyx.chat.stop_signal_checker import is_connected
from onyx.chat.stop_signal_checker import reset_cancel_status
from onyx.chat.stream_processing.citation_processing import CitationProcessor
from onyx.chat.turn.infra.chat_turn_event_stream import unified_event_stream
from onyx.chat.turn.infra.session_sink import extract_final_answer_from_packets
from onyx.chat.turn.infra.session_sink import save_iteration
from onyx.chat.turn.infra.sync_agent_stream_adapter import SyncAgentStream
from onyx.chat.turn.models import ChatTurnContext
from onyx.chat.turn.models import ChatTurnDependencies
from onyx.context.search.models import InferenceSection
from onyx.server.query_and_chat.streaming_models import CitationDelta
from onyx.server.query_and_chat.streaming_models import CitationStart
from onyx.server.query_and_chat.streaming_models import MessageDelta
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PacketObj
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.tools.tool_implementations_v2.image_generation import image_generation_tool
def _fast_chat_turn_core(
messages: list[dict],
dependencies: ChatTurnDependencies,
chat_session_id: UUID,
message_id: int,
research_type: ResearchType,
# Dependency injectable arguments for testing
starter_global_iteration_responses: list[IterationAnswer] | None = None,
starter_cited_documents: list[InferenceSection] | None = None,
) -> None:
"""Core fast chat turn logic that allows overriding global_iteration_responses for testing.
Args:
messages: List of chat messages
dependencies: Chat turn dependencies
chat_session_id: Chat session ID
message_id: Message ID
research_type: Research type
global_iteration_responses: Optional list of iteration answers to inject for testing
cited_documents: Optional list of cited documents to inject for testing
"""
reset_cancel_status(
chat_session_id,
dependencies.redis_client,
)
ctx = ChatTurnContext(
run_dependencies=dependencies,
aggregated_context=AggregatedDRContext(
context="context",
cited_documents=starter_cited_documents or [],
is_internet_marker_dict={},
global_iteration_responses=starter_global_iteration_responses or [],
),
iteration_instructions=[],
chat_session_id=chat_session_id,
message_id=message_id,
research_type=research_type,
)
agent = Agent(
name="Assistant",
model=dependencies.llm_model,
tools=cast(list, dependencies.tools), # type: ignore[arg-type]
model_settings=ModelSettings(
temperature=dependencies.llm.config.temperature,
include_usage=True,
),
tool_use_behavior=StopAtTools(stop_at_tool_names=[image_generation_tool.name]),
)
# By default, the agent can only take 10 turns. For our use case, it should be higher.
max_turns = 25
agent_stream: SyncAgentStream = SyncAgentStream(
agent=agent,
input=messages,
context=ctx,
max_turns=max_turns,
)
for ev in agent_stream:
connected = is_connected(
chat_session_id,
dependencies.redis_client,
)
if not connected:
_emit_clean_up_packets(dependencies, ctx)
agent_stream.cancel()
break
obj = _default_packet_translation(ev, ctx)
if obj:
dependencies.emitter.emit(Packet(ind=ctx.current_run_step, obj=obj))
final_answer = extract_final_answer_from_packets(
dependencies.emitter.packet_history
)
all_cited_documents = []
if ctx.aggregated_context.global_iteration_responses:
context_docs = _gather_context_docs_from_iteration_answers(
ctx.aggregated_context.global_iteration_responses
)
all_cited_documents = context_docs
if context_docs and final_answer:
_process_citations_for_final_answer(
final_answer=final_answer,
context_docs=context_docs,
dependencies=dependencies,
ctx=ctx,
)
save_iteration(
db_session=dependencies.db_session,
message_id=message_id,
chat_session_id=chat_session_id,
research_type=research_type,
ctx=ctx,
final_answer=final_answer,
all_cited_documents=all_cited_documents,
)
dependencies.emitter.emit(
Packet(ind=ctx.current_run_step, obj=OverallStop(type="stop"))
)
@unified_event_stream
def fast_chat_turn(
messages: list[dict],
dependencies: ChatTurnDependencies,
chat_session_id: UUID,
message_id: int,
research_type: ResearchType,
) -> None:
"""Main fast chat turn function that calls the core logic with default parameters."""
_fast_chat_turn_core(
messages,
dependencies,
chat_session_id,
message_id,
research_type,
starter_global_iteration_responses=None,
)
# TODO: Maybe in general there's a cleaner way to handle cancellation in the middle of a tool call?
def _emit_clean_up_packets(
dependencies: ChatTurnDependencies, ctx: ChatTurnContext
) -> None:
if not (
dependencies.emitter.packet_history
and dependencies.emitter.packet_history[-1].obj.type == "message_delta"
):
dependencies.emitter.emit(
Packet(
ind=ctx.current_run_step,
obj=MessageStart(
type="message_start", content="Cancelled", final_documents=None
),
)
)
dependencies.emitter.emit(
Packet(ind=ctx.current_run_step, obj=SectionEnd(type="section_end"))
)
def _gather_context_docs_from_iteration_answers(
iteration_answers: list[IterationAnswer],
) -> list[InferenceSection]:
"""Gather cited documents from iteration answers for citation processing."""
context_docs: list[InferenceSection] = []
for iteration_answer in iteration_answers:
# Extract cited documents from this iteration
for inference_section in iteration_answer.cited_documents.values():
# Avoid duplicates by checking document_id
if not any(
doc.center_chunk.document_id
== inference_section.center_chunk.document_id
for doc in context_docs
):
context_docs.append(inference_section)
return context_docs
def _process_citations_for_final_answer(
final_answer: str,
context_docs: list[InferenceSection],
dependencies: ChatTurnDependencies,
ctx: ChatTurnContext,
) -> None:
index = ctx.current_run_step + 1
"""Process citations in the final answer and emit citation events."""
from onyx.chat.stream_processing.utils import DocumentIdOrderMapping
# Convert InferenceSection objects to LlmDoc objects for citation processing
llm_docs = [llm_doc_from_inference_section(section) for section in context_docs]
# Create document ID to rank mappings (simple 1-based indexing)
final_doc_id_to_rank_map = DocumentIdOrderMapping(
order_mapping={doc.document_id: i + 1 for i, doc in enumerate(llm_docs)}
)
display_doc_id_to_rank_map = final_doc_id_to_rank_map # Same mapping for display
# Initialize citation processor
citation_processor = CitationProcessor(
context_docs=llm_docs,
final_doc_id_to_rank_map=final_doc_id_to_rank_map,
display_doc_id_to_rank_map=display_doc_id_to_rank_map,
)
# Process the final answer through citation processor
collected_citations: list = []
for response_part in citation_processor.process_token(final_answer):
if hasattr(response_part, "citation_num"): # It's a CitationInfo
collected_citations.append(response_part)
# Emit citation events if we found any citations
if collected_citations:
dependencies.emitter.emit(Packet(ind=index, obj=CitationStart()))
dependencies.emitter.emit(
Packet(
ind=index,
obj=CitationDelta(citations=collected_citations), # type: ignore[arg-type]
)
)
dependencies.emitter.emit(Packet(ind=index, obj=SectionEnd(type="section_end")))
ctx.current_run_step = index
def _default_packet_translation(ev: object, ctx: ChatTurnContext) -> PacketObj | None:
if isinstance(ev, RawResponsesStreamEvent):
# TODO: might need some variation here for different types of models
# OpenAI packet translator
obj: PacketObj | None = None
if ev.data.type == "response.content_part.added":
retrieved_search_docs = convert_inference_sections_to_search_docs(
ctx.aggregated_context.cited_documents
)
obj = MessageStart(
type="message_start", content="", final_documents=retrieved_search_docs
)
elif ev.data.type == "response.output_text.delta":
obj = MessageDelta(type="message_delta", content=ev.data.delta)
elif ev.data.type == "response.content_part.done":
obj = SectionEnd(type="section_end")
return obj
return None

View File

@@ -0,0 +1 @@
# Infrastructure module for chat turn orchestration

View File

@@ -0,0 +1,57 @@
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from typing import Dict
from typing import List
from onyx.chat.turn.models import ChatTurnDependencies
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PacketException
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import wait_on_background
def unified_event_stream(
turn_func: Callable[..., None],
) -> Callable[..., Generator[Packet, None]]:
"""
Decorator that wraps a turn_func to provide event streaming capabilities.
Usage:
@unified_event_stream
def my_turn_func(messages, dependencies, *args, **kwargs):
# Your turn logic here
pass
Then call it like:
generator = my_turn_func(messages, dependencies, *args, **kwargs)
"""
def wrapper(
messages: List[Dict[str, Any]],
dependencies: ChatTurnDependencies,
*args: Any,
**kwargs: Any
) -> Generator[Packet, None]:
def run_with_exception_capture() -> None:
try:
turn_func(messages, dependencies, *args, **kwargs)
except Exception as e:
dependencies.emitter.emit(
Packet(ind=0, obj=PacketException(type="error", exception=e))
)
thread = run_in_background(run_with_exception_capture)
while True:
pkt: Packet = dependencies.emitter.bus.get()
if pkt.obj == OverallStop(type="stop"):
yield pkt
break
elif isinstance(pkt.obj, PacketException):
raise pkt.obj.exception
else:
yield pkt
wait_on_background(thread)
return wrapper

View File

@@ -0,0 +1,21 @@
from queue import Queue
from onyx.server.query_and_chat.streaming_models import Packet
class Emitter:
"""Use this inside tools to emit arbitrary UI progress."""
def __init__(self, bus: Queue):
self.bus = bus
self.packet_history: list[Packet] = []
def emit(self, packet: Packet) -> None:
self.bus.put(packet)
self.packet_history.append(packet)
def get_default_emitter() -> Emitter:
bus: Queue[Packet] = Queue()
emitter = Emitter(bus)
return emitter

View File

@@ -0,0 +1,168 @@
# TODO: Figure out a way to persist information is robust to cancellation,
# modular so easily testable in unit tests and evals [likely injecting some higher
# level session manager and span sink], potentially has some robustness off the critical path,
# and promotes clean separation of concerns.
import re
from uuid import UUID
from sqlalchemy.orm import Session
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
GeneratedImageFullResult,
)
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.chat.turn.models import ChatTurnContext
from onyx.context.search.models import InferenceSection
from onyx.db.chat import create_search_doc_from_inference_section
from onyx.db.chat import update_db_session_with_messages
from onyx.db.models import ChatMessage__SearchDoc
from onyx.db.models import ResearchAgentIteration
from onyx.db.models import ResearchAgentIterationSubStep
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.query_and_chat.streaming_models import MessageDelta
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import Packet
def save_iteration(
db_session: Session,
message_id: int,
chat_session_id: UUID,
research_type: ResearchType,
ctx: ChatTurnContext,
final_answer: str,
all_cited_documents: list[InferenceSection],
) -> None:
# first, insert the search_docs
is_internet_marker_dict: dict[str, bool] = {}
search_docs = [
create_search_doc_from_inference_section(
inference_section=inference_section,
is_internet=is_internet_marker_dict.get(
inference_section.center_chunk.document_id, False
), # TODO: revisit
db_session=db_session,
commit=False,
)
for inference_section in all_cited_documents
]
# then, map_search_docs to message
_insert_chat_message_search_doc_pair(
message_id, [search_doc.id for search_doc in search_docs], db_session
)
# lastly, insert the citations
citation_dict: dict[int, int] = {}
cited_doc_nrs = _extract_citation_numbers(final_answer)
if search_docs:
for cited_doc_nr in cited_doc_nrs:
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
llm_tokenizer = get_tokenizer(
model_name=ctx.run_dependencies.llm.config.model_name,
provider_type=ctx.run_dependencies.llm.config.model_provider,
)
num_tokens = len(llm_tokenizer.encode(final_answer or ""))
# Update the chat message and its parent message in database
update_db_session_with_messages(
db_session=db_session,
chat_message_id=message_id,
chat_session_id=chat_session_id,
is_agentic=research_type == ResearchType.DEEP,
message=final_answer,
citations=citation_dict,
research_type=research_type,
research_plan={},
final_documents=search_docs,
update_parent_message=True,
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
token_count=num_tokens,
)
# TODO: I don't think this is the ideal schema for all use cases
# find a better schema to store tool and reasoning calls
for iteration_preparation in ctx.iteration_instructions:
research_agent_iteration_step = ResearchAgentIteration(
primary_question_id=message_id,
reasoning=iteration_preparation.reasoning,
purpose=iteration_preparation.purpose,
iteration_nr=iteration_preparation.iteration_nr,
)
db_session.add(research_agent_iteration_step)
for iteration_answer in ctx.aggregated_context.global_iteration_responses:
retrieved_search_docs = convert_inference_sections_to_search_docs(
list(iteration_answer.cited_documents.values())
)
# Convert SavedSearchDoc objects to JSON-serializable format
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
primary_question_id=message_id,
iteration_nr=iteration_answer.iteration_nr,
iteration_sub_step_nr=iteration_answer.parallelization_nr,
sub_step_instructions=iteration_answer.question,
sub_step_tool_id=iteration_answer.tool_id,
sub_answer=iteration_answer.answer,
reasoning=iteration_answer.reasoning,
claims=iteration_answer.claims,
cited_doc_results=serialized_search_docs,
generated_images=(
GeneratedImageFullResult(images=iteration_answer.generated_images)
if iteration_answer.generated_images
else None
),
additional_data=iteration_answer.additional_data,
)
db_session.add(research_agent_iteration_sub_step)
db_session.commit()
def _insert_chat_message_search_doc_pair(
message_id: int, search_doc_ids: list[int], db_session: Session
) -> None:
"""
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
Args:
message_id: The ID of the chat message
search_doc_id: The ID of the search document
db_session: The database session
"""
for search_doc_id in search_doc_ids:
chat_message_search_doc = ChatMessage__SearchDoc(
chat_message_id=message_id, search_doc_id=search_doc_id
)
db_session.add(chat_message_search_doc)
def _extract_citation_numbers(text: str) -> list[int]:
"""
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
Returns a list of all unique citation numbers found.
"""
# Pattern to match [[number]] or [[number1, number2, ...]]
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
matches = re.findall(pattern, text)
cited_numbers = []
for match in matches:
# Split by comma and extract all numbers
numbers = [int(num.strip()) for num in match.split(",")]
cited_numbers.extend(numbers)
return list(set(cited_numbers)) # Return unique numbers
def extract_final_answer_from_packets(packet_history: list[Packet]) -> str:
"""Extract the final answer by concatenating all MessageDelta content."""
final_answer = ""
for packet in packet_history:
if isinstance(packet.obj, MessageDelta) or isinstance(packet.obj, MessageStart):
final_answer += packet.obj.content
return final_answer

View File

@@ -0,0 +1,177 @@
import asyncio
import queue
import threading
from collections.abc import Iterator
from typing import Generic
from typing import Optional
from typing import TypeVar
from agents import Agent
from agents import RunResultStreaming
from agents.run import Runner
from onyx.chat.turn.models import ChatTurnContext
from onyx.utils.threadpool_concurrency import run_in_background
T = TypeVar("T")
class SyncAgentStream(Generic[T]):
"""
Convert an async streamed run into a sync iterator with cooperative cancellation.
Runs the Agent in a background thread.
Usage:
adapter = SyncStreamAdapter(
agent=agent,
input=input,
context=context,
max_turns=100,
queue_maxsize=0, # optional backpressure
)
for ev in adapter: # sync iteration
...
# or cancel from elsewhere:
adapter.cancel()
"""
_SENTINEL = object()
def __init__(
self,
*,
agent: Agent,
input: list[dict],
context: ChatTurnContext,
max_turns: int = 100,
queue_maxsize: int = 0,
) -> None:
self._agent = agent
self._input = input
self._context = context
self._max_turns = max_turns
self._q: "queue.Queue[object]" = queue.Queue(maxsize=queue_maxsize)
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._thread: Optional[threading.Thread] = None
self._streamed: RunResultStreaming | None = None
self._exc: Optional[BaseException] = None
self._cancel_requested = threading.Event()
self._started = threading.Event()
self._done = threading.Event()
self._start_thread()
# ---------- public sync API ----------
def __iter__(self) -> Iterator[T]:
try:
while True:
item = self._q.get()
if item is self._SENTINEL:
# If the consumer thread raised, surface it now
if self._exc is not None:
raise self._exc
# Normal completion
return
yield item # type: ignore[misc,return-value]
finally:
# Ensure we fully clean up whether we exited due to exception,
# StopIteration, or external cancel.
self.close()
def cancel(self) -> bool:
"""
Cooperatively cancel the underlying streamed run and shut down.
Safe to call multiple times and from any thread.
"""
self._cancel_requested.set()
loop = self._loop
streamed = self._streamed
if loop is not None and streamed is not None and not self._done.is_set():
loop.call_soon_threadsafe(streamed.cancel)
return True
return False
def close(self, *, wait: bool = True) -> None:
"""Idempotent shutdown."""
self.cancel()
# ask the loop to stop if it's still running
loop = self._loop
if loop is not None and loop.is_running():
try:
loop.call_soon_threadsafe(loop.stop)
except Exception:
pass
# join the thread
if wait and self._thread is not None and self._thread.is_alive():
self._thread.join(timeout=5.0)
# ---------- internals ----------
def _start_thread(self) -> None:
t = run_in_background(self._thread_main)
self._thread = t
# Optionally wait until the loop/worker is started so .cancel() is safe soon after init
self._started.wait(timeout=1.0)
def _thread_main(self) -> None:
loop = asyncio.new_event_loop()
self._loop = loop
asyncio.set_event_loop(loop)
async def worker() -> None:
try:
# Start the streamed run inside the loop thread
self._streamed = Runner.run_streamed(
self._agent,
self._input, # type: ignore[arg-type]
context=self._context,
max_turns=self._max_turns,
)
# If cancel was requested before we created _streamed, honor it now
if self._cancel_requested.is_set():
await self._streamed.cancel() # type: ignore[func-returns-value]
# Consume async events and forward into the thread-safe queue
async for ev in self._streamed.stream_events():
# Early exit if a late cancel arrives
if self._cancel_requested.is_set():
# Try to cancel gracefully; don't break until cancel takes effect
try:
await self._streamed.cancel() # type: ignore[func-returns-value]
except Exception:
pass
break
# This put() may block if queue_maxsize > 0 (backpressure)
self._q.put(ev)
except BaseException as e:
# Save exception to surface on the sync iterator side
self._exc = e
finally:
# Signal end-of-stream
self._q.put(self._SENTINEL)
self._done.set()
# Mark started and run the worker to completion
self._started.set()
try:
loop.run_until_complete(worker())
finally:
try:
# Drain pending tasks/callbacks safely
pending = asyncio.all_tasks(loop=loop)
for task in pending:
task.cancel()
if pending:
loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
except Exception:
pass
finally:
loop.close()
self._loop = None
self._streamed = None

View File

@@ -0,0 +1,50 @@
import dataclasses
from dataclasses import dataclass
from uuid import UUID
from agents import FunctionTool
from agents import Model
from redis.client import Redis
from sqlalchemy.orm import Session
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import AggregatedDRContext
from onyx.agents.agent_search.dr.models import IterationInstructions
from onyx.chat.turn.infra.emitter import Emitter
from onyx.llm.interfaces import LLM
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
OktaProfileTool,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
@dataclass
class ChatTurnDependencies:
llm_model: Model
llm: LLM
db_session: Session
tools: list[FunctionTool]
redis_client: Redis
emitter: Emitter
search_pipeline: SearchTool | None = None
image_generation_tool: ImageGenerationTool | None = None
okta_profile_tool: OktaProfileTool | None = None
@dataclass
class ChatTurnContext:
"""Context class to hold search tool and other dependencies"""
chat_session_id: UUID
message_id: int
research_type: ResearchType
run_dependencies: ChatTurnDependencies
aggregated_context: AggregatedDRContext
current_run_step: int = 0
iteration_instructions: list[IterationInstructions] = dataclasses.field(
default_factory=list
)
web_fetch_results: list[dict] = dataclasses.field(default_factory=list)

View File

@@ -24,8 +24,6 @@ APP_PORT = 8080
# prefix from requests directed towards the API server. In these cases, set this to `/api`
APP_API_PREFIX = os.environ.get("API_PREFIX", "")
SKIP_WARM_UP = os.environ.get("SKIP_WARM_UP", "").lower() == "true"
#####
# User Facing Features Configs
#####
@@ -351,10 +349,6 @@ except ValueError:
CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT
)
CELERY_WORKER_KG_PROCESSING_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_KG_PROCESSING_CONCURRENCY") or 4
)
CELERY_WORKER_PRIMARY_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_PRIMARY_CONCURRENCY") or 4
)
@@ -362,18 +356,28 @@ CELERY_WORKER_PRIMARY_CONCURRENCY = int(
CELERY_WORKER_PRIMARY_POOL_OVERFLOW = int(
os.environ.get("CELERY_WORKER_PRIMARY_POOL_OVERFLOW") or 4
)
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY_DEFAULT = 4
try:
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY = int(
os.environ.get(
"CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY",
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY_DEFAULT,
)
)
except ValueError:
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY = (
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY_DEFAULT
)
# Consolidated background worker (merges heavy, kg_processing, monitoring, user_file_processing)
CELERY_WORKER_BACKGROUND_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_BACKGROUND_CONCURRENCY") or 6
)
# Individual worker concurrency settings (used when USE_LIGHTWEIGHT_BACKGROUND_WORKER is False or on Kuberenetes deployments)
CELERY_WORKER_HEAVY_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_HEAVY_CONCURRENCY") or 4
)
CELERY_WORKER_KG_PROCESSING_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_KG_PROCESSING_CONCURRENCY") or 2
)
CELERY_WORKER_MONITORING_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_MONITORING_CONCURRENCY") or 1
)
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY") or 2
)
# The maximum number of tasks that can be queued up to sync to Vespa in a single pass
VESPA_SYNC_MAX_TASKS = 8192
@@ -511,6 +515,10 @@ SHAREPOINT_CONNECTOR_SIZE_THRESHOLD = int(
os.environ.get("SHAREPOINT_CONNECTOR_SIZE_THRESHOLD", 20 * 1024 * 1024)
)
BLOB_STORAGE_SIZE_THRESHOLD = int(
os.environ.get("BLOB_STORAGE_SIZE_THRESHOLD", 20 * 1024 * 1024)
)
JIRA_CONNECTOR_LABELS_TO_SKIP = [
ignored_tag
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
@@ -756,7 +764,7 @@ MAX_FEDERATED_CHUNKS = int(
# NOTE: this should only be enabled if you have purchased an enterprise license.
# if you're interested in an enterprise license, please reach out to us at
# founders@onyx.app OR message Chris Weaver or Yuhong Sun in the Onyx
# Slack community (https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ)
# Discord community https://discord.gg/4NA5SbzrWb
ENTERPRISE_EDITION_ENABLED = (
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
)

View File

@@ -90,6 +90,7 @@ HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "").lower() == "true"
# Internet Search
EXA_API_KEY = os.environ.get("EXA_API_KEY") or None
SERPER_API_KEY = os.environ.get("SERPER_API_KEY") or None
NUM_INTERNET_SEARCH_RESULTS = int(os.environ.get("NUM_INTERNET_SEARCH_RESULTS") or 10)
NUM_INTERNET_SEARCH_CHUNKS = int(os.environ.get("NUM_INTERNET_SEARCH_CHUNKS") or 50)

View File

@@ -6,7 +6,7 @@ from enum import Enum
ONYX_DEFAULT_APPLICATION_NAME = "Onyx"
ONYX_SLACK_URL = "https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA"
ONYX_DISCORD_URL = "https://discord.gg/4NA5SbzrWb"
SLACK_USER_TOKEN_PREFIX = "xoxp-"
SLACK_BOT_TOKEN_PREFIX = "xoxb-"
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
@@ -72,12 +72,13 @@ POSTGRES_CELERY_APP_NAME = "celery"
POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME = "celery_worker_docprocessing"
POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME = "celery_worker_docfetching"
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME = "celery_worker_background"
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME = "celery_worker_kg_processing"
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
POSTGRES_CELERY_WORKER_USER_FILE_PROCESSING_APP_NAME = (
"celery_worker_user_file_processing"
)
@@ -108,7 +109,6 @@ KV_CUSTOMER_UUID_KEY = "customer_uuid"
KV_INSTANCE_DOMAIN_KEY = "instance_domain"
KV_ENTERPRISE_SETTINGS_KEY = "onyx_enterprise_settings"
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
KV_DOCUMENTS_SEEDED_KEY = "documents_seeded"
KV_KG_CONFIG_KEY = "kg_config"
# NOTE: we use this timeout / 4 in various places to refresh a lock
@@ -146,6 +146,13 @@ CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT = 3600 # 1 hour (in seconds)
CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
# Doc ID migration can be long-running; use a longer TTL and renew periodically
CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT = 10 * 60 # 10 minutes (in seconds)
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
TMP_DRALPHA_PERSONA_NAME = "KG Beta"
@@ -405,6 +412,7 @@ class OnyxRedisLocks:
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
USER_FILE_DOCID_MIGRATION_LOCK = "da_lock:user_file_docid_migration"
class OnyxRedisSignals:

View File

@@ -61,21 +61,16 @@ _BASE_EMBEDDING_MODELS = [
dim=1536,
index_name="danswer_chunk_text_embedding_3_small",
),
_BaseEmbeddingModel(
name="google/gemini-embedding-001",
dim=3072,
index_name="danswer_chunk_google_gemini_embedding_001",
),
_BaseEmbeddingModel(
name="google/text-embedding-005",
dim=768,
index_name="danswer_chunk_google_text_embedding_005",
),
_BaseEmbeddingModel(
name="google/textembedding-gecko@003",
dim=768,
index_name="danswer_chunk_google_textembedding_gecko_003",
),
_BaseEmbeddingModel(
name="google/textembedding-gecko@003",
dim=768,
index_name="danswer_chunk_textembedding_gecko_003",
),
_BaseEmbeddingModel(
name="voyage/voyage-large-2-instruct",
dim=1024,

View File

@@ -41,7 +41,7 @@ All new connectors should have tests added to the `backend/tests/daily/connector
#### Implementing the new Connector
The connector must subclass one or more of LoadConnector, PollConnector, SlimConnector, or EventConnector.
The connector must subclass one or more of LoadConnector, PollConnector, CheckpointedConnector, or CheckpointedConnectorWithPermSync
The `__init__` should take arguments for configuring what documents the connector will and where it finds those
documents. For example, if you have a wiki site, it may include the configuration for the team, topic, folder, etc. of

View File

@@ -25,7 +25,7 @@ from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
@@ -56,7 +56,7 @@ class BitbucketConnectorCheckpoint(ConnectorCheckpoint):
class BitbucketConnector(
CheckpointedConnector[BitbucketConnectorCheckpoint],
SlimConnector,
SlimConnectorWithPermSync,
):
"""Connector for indexing Bitbucket Cloud pull requests.
@@ -266,7 +266,7 @@ class BitbucketConnector(
"""Validate and deserialize a checkpoint instance from JSON."""
return BitbucketConnectorCheckpoint.model_validate_json(checkpoint_json)
def retrieve_all_slim_documents(
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,

View File

@@ -1,8 +1,10 @@
import os
import time
from collections.abc import Mapping
from datetime import datetime
from datetime import timezone
from io import BytesIO
from numbers import Integral
from typing import Any
from typing import Optional
@@ -15,6 +17,7 @@ from botocore.exceptions import PartialCredentialsError
from botocore.session import get_session
from mypy_boto3_s3 import S3Client # type: ignore
from onyx.configs.app_configs import BLOB_STORAGE_SIZE_THRESHOLD
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import BlobType
from onyx.configs.constants import DocumentSource
@@ -44,6 +47,10 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
DOWNLOAD_CHUNK_SIZE = 1024 * 1024
SIZE_THRESHOLD_BUFFER = 64
class BlobStorageConnector(LoadConnector, PollConnector):
def __init__(
self,
@@ -58,6 +65,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
self.batch_size = batch_size
self.s3_client: Optional[S3Client] = None
self._allow_images: bool | None = None
self.size_threshold: int | None = BLOB_STORAGE_SIZE_THRESHOLD
def set_allow_images(self, allow_images: bool) -> None:
"""Set whether to process images in this connector."""
@@ -195,11 +203,43 @@ class BlobStorageConnector(LoadConnector, PollConnector):
return None
def _download_object(self, key: str) -> bytes:
def _download_object(self, key: str) -> bytes | None:
if self.s3_client is None:
raise ConnectorMissingCredentialError("Blob storage")
object = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
return object["Body"].read()
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
body = response["Body"]
try:
if self.size_threshold is None:
return body.read()
return self._read_stream_with_limit(body, key)
finally:
body.close()
def _read_stream_with_limit(self, body: Any, key: str) -> bytes | None:
if self.size_threshold is None:
return body.read()
bytes_read = 0
chunks: list[bytes] = []
chunk_size = min(
DOWNLOAD_CHUNK_SIZE, self.size_threshold + SIZE_THRESHOLD_BUFFER
)
for chunk in body.iter_chunks(chunk_size=chunk_size):
if not chunk:
continue
chunks.append(chunk)
bytes_read += len(chunk)
if bytes_read > self.size_threshold + SIZE_THRESHOLD_BUFFER:
logger.warning(
f"{key} exceeds size threshold of {self.size_threshold}. Skipping."
)
return None
return b"".join(chunks)
# NOTE: Left in as may be useful for one-off access to documents and sharing across orgs.
# def _get_presigned_url(self, key: str) -> str:
@@ -236,6 +276,51 @@ class BlobStorageConnector(LoadConnector, PollConnector):
else:
raise ValueError(f"Unsupported bucket type: {self.bucket_type}")
@staticmethod
def _extract_size_bytes(obj: Mapping[str, Any]) -> int | None:
"""Return the first numeric size field found on the object metadata."""
candidate_keys = (
"Size",
"size",
"ContentLength",
"content_length",
"Content-Length",
"contentLength",
"bytes",
"Bytes",
)
def _normalize(value: Any) -> int | None:
if value is None or isinstance(value, bool):
return None
if isinstance(value, Integral):
return int(value)
try:
numeric = float(value)
except (TypeError, ValueError):
return None
if numeric >= 0 and numeric.is_integer():
return int(numeric)
return None
for key in candidate_keys:
if key in obj:
normalized = _normalize(obj.get(key))
if normalized is not None:
return normalized
for key, value in obj.items():
if not isinstance(key, str):
continue
lowered_key = key.lower()
if "size" in lowered_key or "length" in lowered_key:
normalized = _normalize(value)
if normalized is not None:
return normalized
return None
def _yield_blob_objects(
self,
start: datetime,
@@ -266,6 +351,18 @@ class BlobStorageConnector(LoadConnector, PollConnector):
key = obj["Key"]
link = self._get_blob_link(key)
size_bytes = self._extract_size_bytes(obj)
if (
self.size_threshold is not None
and isinstance(size_bytes, int)
and self.size_threshold is not None
and size_bytes > self.size_threshold
):
logger.warning(
f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping."
)
continue
# Handle image files
if is_accepted_file_ext(file_ext, OnyxExtensionType.Multimedia):
if not self._allow_images:
@@ -277,6 +374,8 @@ class BlobStorageConnector(LoadConnector, PollConnector):
# Process the image file
try:
downloaded_file = self._download_object(key)
if downloaded_file is None:
continue
# TODO: Refactor to avoid direct DB access in connector
# This will require broader refactoring across the codebase
@@ -309,6 +408,8 @@ class BlobStorageConnector(LoadConnector, PollConnector):
# Handle text and document files
try:
downloaded_file = self._download_object(key)
if downloaded_file is None:
continue
extraction_result = extract_text_and_images(
BytesIO(downloaded_file), file_name=file_name
)

View File

@@ -5,6 +5,7 @@ from datetime import timezone
from typing import Any
from urllib.parse import quote
from atlassian.errors import ApiError # type: ignore
from requests.exceptions import HTTPError
from typing_extensions import override
@@ -21,7 +22,6 @@ from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.confluence.utils import build_confluence_document_id
from onyx.connectors.confluence.utils import convert_attachment_to_content
from onyx.connectors.confluence.utils import datetime_from_string
from onyx.connectors.confluence.utils import process_attachment
from onyx.connectors.confluence.utils import update_param_in_path
from onyx.connectors.confluence.utils import validate_attachment_filetype
from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider
@@ -41,6 +41,7 @@ from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
@@ -91,6 +92,7 @@ class ConfluenceCheckpoint(ConnectorCheckpoint):
class ConfluenceConnector(
CheckpointedConnector[ConfluenceCheckpoint],
SlimConnector,
SlimConnectorWithPermSync,
CredentialsConnector,
):
def __init__(
@@ -108,6 +110,7 @@ class ConfluenceConnector(
# pages.
labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP,
timezone_offset: float = CONFLUENCE_TIMEZONE_OFFSET,
scoped_token: bool = False,
) -> None:
self.wiki_base = wiki_base
self.is_cloud = is_cloud
@@ -118,6 +121,7 @@ class ConfluenceConnector(
self.batch_size = batch_size
self.labels_to_skip = labels_to_skip
self.timezone_offset = timezone_offset
self.scoped_token = scoped_token
self._confluence_client: OnyxConfluence | None = None
self._low_timeout_confluence_client: OnyxConfluence | None = None
self._fetched_titles: set[str] = set()
@@ -195,6 +199,7 @@ class ConfluenceConnector(
is_cloud=self.is_cloud,
url=self.wiki_base,
credentials_provider=credentials_provider,
scoped_token=self.scoped_token,
)
confluence_client._probe_connection(**self.probe_kwargs)
confluence_client._initialize_connection(**self.final_kwargs)
@@ -207,6 +212,7 @@ class ConfluenceConnector(
url=self.wiki_base,
credentials_provider=credentials_provider,
timeout=3,
scoped_token=self.scoped_token,
)
low_timeout_confluence_client._probe_connection(**self.probe_kwargs)
low_timeout_confluence_client._initialize_connection(**self.final_kwargs)
@@ -243,9 +249,26 @@ class ConfluenceConnector(
page_query += " order by lastmodified asc"
return page_query
def _construct_attachment_query(self, confluence_page_id: str) -> str:
def _construct_attachment_query(
self,
confluence_page_id: str,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> str:
attachment_query = f"type=attachment and container='{confluence_page_id}'"
attachment_query += self.cql_label_filter
# Add time filters to avoid reprocessing unchanged attachments during refresh
if start:
formatted_start_time = datetime.fromtimestamp(
start, tz=self.timezone
).strftime("%Y-%m-%d %H:%M")
attachment_query += f" and lastmodified >= '{formatted_start_time}'"
if end:
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
"%Y-%m-%d %H:%M"
)
attachment_query += f" and lastmodified <= '{formatted_end_time}'"
attachment_query += " order by lastmodified asc"
return attachment_query
def _get_comment_string_for_page_id(self, page_id: str) -> str:
@@ -299,41 +322,8 @@ class ConfluenceConnector(
sections.append(
TextSection(text=comment_text, link=f"{page_url}#comments")
)
# Process attachments
if "children" in page and "attachment" in page["children"]:
attachments = self.confluence_client.get_attachments_for_page(
page_id, expand="metadata"
)
for attachment in attachments.get("results", []):
# Process each attachment
result = process_attachment(
self.confluence_client,
attachment,
page_id,
self.allow_images,
)
if result and result.text:
# Create a section for the attachment text
attachment_section = TextSection(
text=result.text,
link=f"{page_url}#attachment-{attachment['id']}",
)
sections.append(attachment_section)
elif result and result.file_name:
# Create an ImageSection for image attachments
image_section = ImageSection(
link=f"{page_url}#attachment-{attachment['id']}",
image_file_id=result.file_name,
)
sections.append(image_section)
else:
logger.warning(
f"Error processing attachment '{attachment.get('title')}':",
f"{result.error if result else 'Unknown error'}",
)
# Note: attachments are no longer merged into the page document.
# They are indexed as separate documents downstream.
# Extract metadata
metadata = {}
@@ -382,9 +372,20 @@ class ConfluenceConnector(
)
def _fetch_page_attachments(
self, page: dict[str, Any], doc: Document
) -> Document | ConnectorFailure:
attachment_query = self._construct_attachment_query(page["id"])
self,
page: dict[str, Any],
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> tuple[list[Document], list[ConnectorFailure]]:
"""
Inline attachments are added directly to the document as text or image sections by
this function. The returned documents/connectorfailures are for non-inline attachments
and those at the end of the page.
"""
attachment_query = self._construct_attachment_query(page["id"], start, end)
attachment_failures: list[ConnectorFailure] = []
attachment_docs: list[Document] = []
page_url = ""
for attachment in self.confluence_client.paginated_cql_retrieval(
cql=attachment_query,
@@ -413,11 +414,17 @@ class ConfluenceConnector(
logger.info(
f"Processing attachment: {attachment['title']} attached to page {page['title']}"
)
# Attempt to get textual content or image summarization:
object_url = build_confluence_document_id(
self.wiki_base, attachment["_links"]["webui"], self.is_cloud
)
# Attachment document id: use the download URL for stable identity
try:
object_url = build_confluence_document_id(
self.wiki_base, attachment["_links"]["download"], self.is_cloud
)
except Exception as e:
logger.warning(
f"Invalid attachment url for id {attachment['id']}, skipping"
)
logger.debug(f"Error building attachment url: {e}")
continue
try:
response = convert_attachment_to_content(
confluence_client=self.confluence_client,
@@ -430,38 +437,76 @@ class ConfluenceConnector(
content_text, file_storage_name = response
sections: list[TextSection | ImageSection] = []
if content_text:
doc.sections.append(
TextSection(
text=content_text,
link=object_url,
)
)
sections.append(TextSection(text=content_text, link=object_url))
elif file_storage_name:
doc.sections.append(
ImageSection(
link=object_url,
image_file_id=file_storage_name,
)
sections.append(
ImageSection(link=object_url, image_file_id=file_storage_name)
)
# Build attachment-specific metadata
attachment_metadata: dict[str, str | list[str]] = {}
if "space" in attachment:
attachment_metadata["space"] = attachment["space"].get("name", "")
labels: list[str] = []
if "metadata" in attachment and "labels" in attachment["metadata"]:
for label in attachment["metadata"]["labels"].get("results", []):
labels.append(label.get("name", ""))
if labels:
attachment_metadata["labels"] = labels
page_url = page_url or build_confluence_document_id(
self.wiki_base, page["_links"]["webui"], self.is_cloud
)
attachment_metadata["parent_page_id"] = page_url
attachment_id = build_confluence_document_id(
self.wiki_base, attachment["_links"]["webui"], self.is_cloud
)
primary_owners: list[BasicExpertInfo] | None = None
if "version" in attachment and "by" in attachment["version"]:
author = attachment["version"]["by"]
display_name = author.get("displayName", "Unknown")
email = author.get("email", "unknown@domain.invalid")
primary_owners = [
BasicExpertInfo(display_name=display_name, email=email)
]
attachment_doc = Document(
id=attachment_id,
sections=sections,
source=DocumentSource.CONFLUENCE,
semantic_identifier=attachment.get("title", object_url),
metadata=attachment_metadata,
doc_updated_at=(
datetime_from_string(attachment["version"]["when"])
if attachment.get("version")
and attachment["version"].get("when")
else None
),
primary_owners=primary_owners,
)
attachment_docs.append(attachment_doc)
except Exception as e:
logger.error(
f"Failed to extract/summarize attachment {attachment['title']}",
exc_info=e,
)
if is_atlassian_date_error(
e
): # propagate error to be caught and retried
if is_atlassian_date_error(e):
# propagate error to be caught and retried
raise
return ConnectorFailure(
failed_document=DocumentFailure(
document_id=doc.id,
document_link=object_url,
),
failure_message=f"Failed to extract/summarize attachment {attachment['title']} for doc {doc.id}",
exception=e,
attachment_failures.append(
ConnectorFailure(
failed_document=DocumentFailure(
document_id=object_url,
document_link=object_url,
),
failure_message=f"Failed to extract/summarize attachment {attachment['title']} for doc {object_url}",
exception=e,
)
)
return doc
return attachment_docs, attachment_failures
def _fetch_document_batches(
self,
@@ -500,12 +545,18 @@ class ConfluenceConnector(
if isinstance(doc_or_failure, ConnectorFailure):
yield doc_or_failure
continue
# Now get attachments for that page:
doc_or_failure = self._fetch_page_attachments(page, doc_or_failure)
# yield completed document (or failure)
yield doc_or_failure
# Now get attachments for that page:
attachment_docs, attachment_failures = self._fetch_page_attachments(
page, start, end
)
# yield attached docs and failures
yield from attachment_docs
yield from attachment_failures
# Create checkpoint once a full page of results is returned
if checkpoint.next_page_url and checkpoint.next_page_url != page_query_url:
return checkpoint
@@ -558,7 +609,21 @@ class ConfluenceConnector(
def validate_checkpoint_json(self, checkpoint_json: str) -> ConfluenceCheckpoint:
return ConfluenceCheckpoint.model_validate_json(checkpoint_json)
def retrieve_all_slim_documents(
@override
def retrieve_all_slim_docs(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
return self._retrieve_all_slim_docs(
start=start,
end=end,
callback=callback,
include_permissions=False,
)
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
@@ -568,12 +633,28 @@ class ConfluenceConnector(
Return 'slim' docs (IDs + minimal permission data).
Does not fetch actual text. Used primarily for incremental permission sync.
"""
return self._retrieve_all_slim_docs(
start=start,
end=end,
callback=callback,
include_permissions=True,
)
def _retrieve_all_slim_docs(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
include_permissions: bool = True,
) -> GenerateSlimDocumentOutput:
doc_metadata_list: list[SlimDocument] = []
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
space_level_access_info = get_all_space_permissions(
self.confluence_client, self.is_cloud
)
space_level_access_info: dict[str, ExternalAccess] = {}
if include_permissions:
space_level_access_info = get_all_space_permissions(
self.confluence_client, self.is_cloud
)
def get_external_access(
doc_id: str, restrictions: dict[str, Any], ancestors: list[dict[str, Any]]
@@ -600,8 +681,10 @@ class ConfluenceConnector(
doc_metadata_list.append(
SlimDocument(
id=page_id,
external_access=get_external_access(
page_id, page_restrictions, page_ancestors
external_access=(
get_external_access(page_id, page_restrictions, page_ancestors)
if include_permissions
else None
),
)
)
@@ -636,8 +719,12 @@ class ConfluenceConnector(
doc_metadata_list.append(
SlimDocument(
id=attachment_id,
external_access=get_external_access(
attachment_id, attachment_restrictions, []
external_access=(
get_external_access(
attachment_id, attachment_restrictions, []
)
if include_permissions
else None
),
)
)
@@ -648,10 +735,10 @@ class ConfluenceConnector(
if callback and callback.should_stop():
raise RuntimeError(
"retrieve_all_slim_documents: Stop signal detected"
"retrieve_all_slim_docs_perm_sync: Stop signal detected"
)
if callback:
callback.progress("retrieve_all_slim_documents", 1)
callback.progress("retrieve_all_slim_docs_perm_sync", 1)
yield doc_metadata_list
@@ -676,6 +763,14 @@ class ConfluenceConnector(
f"Unexpected error while validating Confluence settings: {e}"
)
if self.space:
try:
self.low_timeout_confluence_client.get_space(self.space)
except ApiError as e:
raise ConnectorValidationError(
"Invalid Confluence space key provided"
) from e
if not spaces or not spaces.get("results"):
raise ConnectorValidationError(
"No Confluence spaces found. Either your credentials lack permissions, or "
@@ -724,7 +819,7 @@ if __name__ == "__main__":
end = datetime.now().timestamp()
# Fetch all `SlimDocuments`.
for slim_doc in confluence_connector.retrieve_all_slim_documents():
for slim_doc in confluence_connector.retrieve_all_slim_docs_perm_sync():
print(slim_doc)
# Fetch all `Documents`.

View File

@@ -41,6 +41,7 @@ from onyx.connectors.confluence.utils import _handle_http_error
from onyx.connectors.confluence.utils import confluence_refresh_tokens
from onyx.connectors.confluence.utils import get_start_param_from_url
from onyx.connectors.confluence.utils import update_param_in_path
from onyx.connectors.cross_connector_utils.miscellaneous_utils import scoped_url
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.file_processing.html_utils import format_document_soup
from onyx.redis.redis_pool import get_redis_client
@@ -87,16 +88,20 @@ class OnyxConfluence:
url: str,
credentials_provider: CredentialsProviderInterface,
timeout: int | None = None,
scoped_token: bool = False,
# should generally not be passed in, but making it overridable for
# easier testing
confluence_user_profiles_override: list[dict[str, str]] | None = (
CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE
),
) -> None:
self.base_url = url #'/'.join(url.rstrip("/").split("/")[:-1])
url = scoped_url(url, "confluence") if scoped_token else url
self._is_cloud = is_cloud
self._url = url.rstrip("/")
self._credentials_provider = credentials_provider
self.scoped_token = scoped_token
self.redis_client: Redis | None = None
self.static_credentials: dict[str, Any] | None = None
if self._credentials_provider.is_dynamic():
@@ -218,6 +223,34 @@ class OnyxConfluence:
with self._credentials_provider:
credentials, _ = self._renew_credentials()
if self.scoped_token:
# v2 endpoint doesn't always work with scoped tokens, use v1
token = credentials["confluence_access_token"]
probe_url = f"{self.base_url}/rest/api/space?limit=1"
import requests
logger.info(f"First and Last 5 of token: {token[:5]}...{token[-5:]}")
try:
r = requests.get(
probe_url,
headers={"Authorization": f"Bearer {token}"},
timeout=10,
)
r.raise_for_status()
except HTTPError as e:
if e.response.status_code == 403:
logger.warning(
"scoped token authenticated but not valid for probe endpoint (spaces)"
)
else:
if "WWW-Authenticate" in e.response.headers:
logger.warning(
f"WWW-Authenticate: {e.response.headers['WWW-Authenticate']}"
)
logger.warning(f"Full error: {e.response.text}")
raise e
return
# probe connection with direct client, no retries
if "confluence_refresh_token" in credentials:
@@ -236,6 +269,7 @@ class OnyxConfluence:
logger.info("Probing Confluence with Personal Access Token.")
url = self._url
if self._is_cloud:
logger.info("running with cloud client")
confluence_client_with_minimal_retries = Confluence(
url=url,
username=credentials["confluence_username"],
@@ -304,7 +338,9 @@ class OnyxConfluence:
url = f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}"
confluence = Confluence(url=url, oauth2=oauth2_dict, **kwargs)
else:
logger.info("Connecting to Confluence with Personal Access Token.")
logger.info(
f"Connecting to Confluence with Personal Access Token as user: {credentials['confluence_username']}"
)
if self._is_cloud:
confluence = Confluence(
url=self._url,
@@ -930,6 +966,13 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND
def sanitize_attachment_title(title: str) -> str:
"""
Sanitize the attachment title to be a valid HTML attribute.
"""
return title.replace("<", "_").replace(">", "_").replace(" ", "_").replace(":", "_")
def extract_text_from_confluence_html(
confluence_client: OnyxConfluence,
confluence_object: dict[str, Any],
@@ -1032,6 +1075,16 @@ def extract_text_from_confluence_html(
except Exception as e:
logger.warning(f"Error processing ac:link-body: {e}")
for html_attachment in soup.findAll("ri:attachment"):
# This extracts the text from inline attachments in the page so they can be
# represented in the document text as plain text
try:
html_attachment.replaceWith(
f"<attachment>{sanitize_attachment_title(html_attachment.attrs['ri:filename'])}</attachment>"
) # to be replaced later
except Exception as e:
logger.warning(f"Error processing ac:attachment: {e}")
return format_document_soup(soup)

View File

@@ -5,7 +5,10 @@ from datetime import datetime
from datetime import timezone
from typing import Any
from typing import TypeVar
from urllib.parse import urljoin
from urllib.parse import urlparse
import requests
from dateutil.parser import parse
from onyx.configs.app_configs import CONNECTOR_LOCALHOST_OVERRIDE
@@ -148,3 +151,17 @@ def get_oauth_callback_uri(base_domain: str, connector_id: str) -> str:
def is_atlassian_date_error(e: Exception) -> bool:
return "field 'updated' is invalid" in str(e)
def get_cloudId(base_url: str) -> str:
tenant_info_url = urljoin(base_url, "/_edge/tenant_info")
response = requests.get(tenant_info_url, timeout=10)
response.raise_for_status()
return response.json()["cloudId"]
def scoped_url(url: str, product: str) -> str:
parsed = urlparse(url)
base_url = parsed.scheme + "://" + parsed.netloc
cloud_id = get_cloudId(base_url)
return f"https://api.atlassian.com/ex/{product}/{cloud_id}{parsed.path}"

View File

@@ -1,3 +1,4 @@
import importlib
from typing import Any
from typing import Type
@@ -6,60 +7,16 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.constants import DocumentSource
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.connectors.airtable.airtable_connector import AirtableConnector
from onyx.connectors.asana.connector import AsanaConnector
from onyx.connectors.axero.connector import AxeroConnector
from onyx.connectors.bitbucket.connector import BitbucketConnector
from onyx.connectors.blob.connector import BlobStorageConnector
from onyx.connectors.bookstack.connector import BookstackConnector
from onyx.connectors.clickup.connector import ClickupConnector
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.connectors.discord.connector import DiscordConnector
from onyx.connectors.discourse.connector import DiscourseConnector
from onyx.connectors.document360.connector import Document360Connector
from onyx.connectors.dropbox.connector import DropboxConnector
from onyx.connectors.egnyte.connector import EgnyteConnector
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.file.connector import LocalFileConnector
from onyx.connectors.fireflies.connector import FirefliesConnector
from onyx.connectors.freshdesk.connector import FreshdeskConnector
from onyx.connectors.gitbook.connector import GitbookConnector
from onyx.connectors.github.connector import GithubConnector
from onyx.connectors.gitlab.connector import GitlabConnector
from onyx.connectors.gmail.connector import GmailConnector
from onyx.connectors.gong.connector import GongConnector
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.google_site.connector import GoogleSitesConnector
from onyx.connectors.guru.connector import GuruConnector
from onyx.connectors.highspot.connector import HighspotConnector
from onyx.connectors.hubspot.connector import HubSpotConnector
from onyx.connectors.imap.connector import ImapConnector
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import CredentialsConnector
from onyx.connectors.interfaces import EventConnector
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.jira.connector import JiraConnector
from onyx.connectors.linear.connector import LinearConnector
from onyx.connectors.loopio.connector import LoopioConnector
from onyx.connectors.mediawiki.wiki import MediaWikiConnector
from onyx.connectors.mock_connector.connector import MockConnector
from onyx.connectors.models import InputType
from onyx.connectors.notion.connector import NotionConnector
from onyx.connectors.outline.connector import OutlineConnector
from onyx.connectors.productboard.connector import ProductboardConnector
from onyx.connectors.salesforce.connector import SalesforceConnector
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.connectors.slab.connector import SlabConnector
from onyx.connectors.slack.connector import SlackConnector
from onyx.connectors.teams.connector import TeamsConnector
from onyx.connectors.web.connector import WebConnector
from onyx.connectors.wikipedia.connector import WikipediaConnector
from onyx.connectors.xenforo.connector import XenforoConnector
from onyx.connectors.zendesk.connector import ZendeskConnector
from onyx.connectors.zulip.connector import ZulipConnector
from onyx.connectors.registry import CONNECTOR_CLASS_MAP
from onyx.db.connector import fetch_connector_by_id
from onyx.db.credentials import backend_update_credential_json
from onyx.db.credentials import fetch_credential_by_id
@@ -72,101 +29,75 @@ class ConnectorMissingException(Exception):
pass
# Cache for already imported connector classes
_connector_cache: dict[DocumentSource, Type[BaseConnector]] = {}
def _load_connector_class(source: DocumentSource) -> Type[BaseConnector]:
"""Dynamically load and cache a connector class."""
if source in _connector_cache:
return _connector_cache[source]
if source not in CONNECTOR_CLASS_MAP:
raise ConnectorMissingException(f"Connector not found for source={source}")
mapping = CONNECTOR_CLASS_MAP[source]
try:
module = importlib.import_module(mapping.module_path)
connector_class = getattr(module, mapping.class_name)
_connector_cache[source] = connector_class
return connector_class
except (ImportError, AttributeError) as e:
raise ConnectorMissingException(
f"Failed to import {mapping.class_name} from {mapping.module_path}: {e}"
)
def _validate_connector_supports_input_type(
connector: Type[BaseConnector],
input_type: InputType | None,
source: DocumentSource,
) -> None:
"""Validate that a connector supports the requested input type."""
if input_type is None:
return
# Check each input type requirement separately for clarity
load_state_unsupported = input_type == InputType.LOAD_STATE and not issubclass(
connector, LoadConnector
)
poll_unsupported = (
input_type == InputType.POLL
# Either poll or checkpoint works for this, in the future
# all connectors should be checkpoint connectors
and (
not issubclass(connector, PollConnector)
and not issubclass(connector, CheckpointedConnector)
)
)
event_unsupported = input_type == InputType.EVENT and not issubclass(
connector, EventConnector
)
if any([load_state_unsupported, poll_unsupported, event_unsupported]):
raise ConnectorMissingException(
f"Connector for source={source} does not accept input_type={input_type}"
)
def identify_connector_class(
source: DocumentSource,
input_type: InputType | None = None,
) -> Type[BaseConnector]:
connector_map = {
DocumentSource.WEB: WebConnector,
DocumentSource.FILE: LocalFileConnector,
DocumentSource.SLACK: {
InputType.POLL: SlackConnector,
InputType.SLIM_RETRIEVAL: SlackConnector,
},
DocumentSource.GITHUB: GithubConnector,
DocumentSource.GMAIL: GmailConnector,
DocumentSource.GITLAB: GitlabConnector,
DocumentSource.GITBOOK: GitbookConnector,
DocumentSource.GOOGLE_DRIVE: GoogleDriveConnector,
DocumentSource.BOOKSTACK: BookstackConnector,
DocumentSource.OUTLINE: OutlineConnector,
DocumentSource.CONFLUENCE: ConfluenceConnector,
DocumentSource.JIRA: JiraConnector,
DocumentSource.PRODUCTBOARD: ProductboardConnector,
DocumentSource.SLAB: SlabConnector,
DocumentSource.NOTION: NotionConnector,
DocumentSource.ZULIP: ZulipConnector,
DocumentSource.GURU: GuruConnector,
DocumentSource.LINEAR: LinearConnector,
DocumentSource.HUBSPOT: HubSpotConnector,
DocumentSource.DOCUMENT360: Document360Connector,
DocumentSource.GONG: GongConnector,
DocumentSource.GOOGLE_SITES: GoogleSitesConnector,
DocumentSource.ZENDESK: ZendeskConnector,
DocumentSource.LOOPIO: LoopioConnector,
DocumentSource.DROPBOX: DropboxConnector,
DocumentSource.SHAREPOINT: SharepointConnector,
DocumentSource.TEAMS: TeamsConnector,
DocumentSource.SALESFORCE: SalesforceConnector,
DocumentSource.DISCOURSE: DiscourseConnector,
DocumentSource.AXERO: AxeroConnector,
DocumentSource.CLICKUP: ClickupConnector,
DocumentSource.MEDIAWIKI: MediaWikiConnector,
DocumentSource.WIKIPEDIA: WikipediaConnector,
DocumentSource.ASANA: AsanaConnector,
DocumentSource.S3: BlobStorageConnector,
DocumentSource.R2: BlobStorageConnector,
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
DocumentSource.OCI_STORAGE: BlobStorageConnector,
DocumentSource.XENFORO: XenforoConnector,
DocumentSource.DISCORD: DiscordConnector,
DocumentSource.FRESHDESK: FreshdeskConnector,
DocumentSource.FIREFLIES: FirefliesConnector,
DocumentSource.EGNYTE: EgnyteConnector,
DocumentSource.AIRTABLE: AirtableConnector,
DocumentSource.HIGHSPOT: HighspotConnector,
DocumentSource.IMAP: ImapConnector,
DocumentSource.BITBUCKET: BitbucketConnector,
# just for integration tests
DocumentSource.MOCK_CONNECTOR: MockConnector,
}
connector_by_source = connector_map.get(source, {})
# Load the connector class using lazy loading
connector = _load_connector_class(source)
if isinstance(connector_by_source, dict):
if input_type is None:
# If not specified, default to most exhaustive update
connector = connector_by_source.get(InputType.LOAD_STATE)
else:
connector = connector_by_source.get(input_type)
else:
connector = connector_by_source
if connector is None:
raise ConnectorMissingException(f"Connector not found for source={source}")
# Validate connector supports the requested input_type
_validate_connector_supports_input_type(connector, input_type, source)
if any(
[
(
input_type == InputType.LOAD_STATE
and not issubclass(connector, LoadConnector)
),
(
input_type == InputType.POLL
# either poll or checkpoint works for this, in the future
# all connectors should be checkpoint connectors
and (
not issubclass(connector, PollConnector)
and not issubclass(connector, CheckpointedConnector)
)
),
(
input_type == InputType.EVENT
and not issubclass(connector, EventConnector)
),
]
):
raise ConnectorMissingException(
f"Connector for source={source} does not accept input_type={input_type}"
)
return connector

View File

@@ -219,12 +219,19 @@ def _get_batch_rate_limited(
def _get_userinfo(user: NamedUser) -> dict[str, str]:
def _safe_get(attr_name: str) -> str | None:
try:
return cast(str | None, getattr(user, attr_name))
except GithubException:
logger.debug(f"Error getting {attr_name} for user")
return None
return {
k: v
for k, v in {
"login": user.login,
"name": user.name,
"email": user.email,
"login": _safe_get("login"),
"name": _safe_get("name"),
"email": _safe_get("email"),
}.items()
if v is not None
}

View File

@@ -28,7 +28,7 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
@@ -232,7 +232,7 @@ def thread_to_document(
)
class GmailConnector(LoadConnector, PollConnector, SlimConnector):
class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
@@ -397,10 +397,10 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
if callback:
if callback.should_stop():
raise RuntimeError(
"retrieve_all_slim_documents: Stop signal detected"
"retrieve_all_slim_docs_perm_sync: Stop signal detected"
)
callback.progress("retrieve_all_slim_documents", 1)
callback.progress("retrieve_all_slim_docs_perm_sync", 1)
except HttpError as e:
if _is_mail_service_disabled_error(e):
logger.warning(
@@ -431,7 +431,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
def retrieve_all_slim_documents(
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,

View File

@@ -64,7 +64,7 @@ from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
@@ -153,7 +153,7 @@ class DriveIdStatus(Enum):
class GoogleDriveConnector(
SlimConnector, CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint]
SlimConnectorWithPermSync, CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint]
):
def __init__(
self,
@@ -1296,7 +1296,7 @@ class GoogleDriveConnector(
callback.progress("_extract_slim_docs_from_google_drive", 1)
yield slim_batch
def retrieve_all_slim_documents(
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,

View File

@@ -18,7 +18,7 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
@@ -38,7 +38,7 @@ class HighspotSpot(BaseModel):
name: str
class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
class HighspotConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
"""
Connector for loading data from Highspot.
@@ -362,7 +362,7 @@ class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
description = item_details.get("description", "")
return title, description
def retrieve_all_slim_documents(
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,

View File

@@ -1,14 +1,18 @@
import re
from collections.abc import Callable
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from typing import TypeVar
import requests
from hubspot import HubSpot # type: ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.hubspot.rate_limit import HubSpotRateLimiter
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
@@ -25,6 +29,10 @@ HUBSPOT_API_URL = "https://api.hubapi.com/integrations/v1/me"
# Available HubSpot object types
AVAILABLE_OBJECT_TYPES = {"tickets", "companies", "deals", "contacts"}
HUBSPOT_PAGE_SIZE = 100
T = TypeVar("T")
logger = setup_logger()
@@ -38,6 +46,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
self.batch_size = batch_size
self._access_token = access_token
self._portal_id: str | None = None
self._rate_limiter = HubSpotRateLimiter()
# Set object types to fetch, default to all available types
if object_types is None:
@@ -77,6 +86,37 @@ class HubSpotConnector(LoadConnector, PollConnector):
"""Set the portal ID."""
self._portal_id = value
def _call_hubspot(self, func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
return self._rate_limiter.call(func, *args, **kwargs)
def _paginated_results(
self,
fetch_page: Callable[..., Any],
**kwargs: Any,
) -> Generator[Any, None, None]:
base_kwargs = dict(kwargs)
base_kwargs.setdefault("limit", HUBSPOT_PAGE_SIZE)
after: str | None = None
while True:
page_kwargs = base_kwargs.copy()
if after is not None:
page_kwargs["after"] = after
page = self._call_hubspot(fetch_page, **page_kwargs)
results = getattr(page, "results", [])
for result in results:
yield result
paging = getattr(page, "paging", None)
next_page = getattr(paging, "next", None) if paging else None
if next_page is None:
break
after = getattr(next_page, "after", None)
if after is None:
break
def _clean_html_content(self, html_content: str) -> str:
"""Clean HTML content and extract raw text"""
if not html_content:
@@ -150,78 +190,82 @@ class HubSpotConnector(LoadConnector, PollConnector):
) -> list[dict[str, Any]]:
"""Get associated objects for a given object"""
try:
associations = api_client.crm.associations.v4.basic_api.get_page(
associations_iter = self._paginated_results(
api_client.crm.associations.v4.basic_api.get_page,
object_type=from_object_type,
object_id=object_id,
to_object_type=to_object_type,
)
associated_objects = []
if associations.results:
object_ids = [assoc.to_object_id for assoc in associations.results]
object_ids = [assoc.to_object_id for assoc in associations_iter]
# Batch get the associated objects
if to_object_type == "contacts":
for obj_id in object_ids:
try:
obj = api_client.crm.contacts.basic_api.get_by_id(
contact_id=obj_id,
properties=[
"firstname",
"lastname",
"email",
"company",
"jobtitle",
],
)
associated_objects.append(obj.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch contact {obj_id}: {e}")
associated_objects: list[dict[str, Any]] = []
elif to_object_type == "companies":
for obj_id in object_ids:
try:
obj = api_client.crm.companies.basic_api.get_by_id(
company_id=obj_id,
properties=[
"name",
"domain",
"industry",
"city",
"state",
],
)
associated_objects.append(obj.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch company {obj_id}: {e}")
if to_object_type == "contacts":
for obj_id in object_ids:
try:
obj = self._call_hubspot(
api_client.crm.contacts.basic_api.get_by_id,
contact_id=obj_id,
properties=[
"firstname",
"lastname",
"email",
"company",
"jobtitle",
],
)
associated_objects.append(obj.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch contact {obj_id}: {e}")
elif to_object_type == "deals":
for obj_id in object_ids:
try:
obj = api_client.crm.deals.basic_api.get_by_id(
deal_id=obj_id,
properties=[
"dealname",
"amount",
"dealstage",
"closedate",
"pipeline",
],
)
associated_objects.append(obj.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch deal {obj_id}: {e}")
elif to_object_type == "companies":
for obj_id in object_ids:
try:
obj = self._call_hubspot(
api_client.crm.companies.basic_api.get_by_id,
company_id=obj_id,
properties=[
"name",
"domain",
"industry",
"city",
"state",
],
)
associated_objects.append(obj.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch company {obj_id}: {e}")
elif to_object_type == "tickets":
for obj_id in object_ids:
try:
obj = api_client.crm.tickets.basic_api.get_by_id(
ticket_id=obj_id,
properties=["subject", "content", "hs_ticket_priority"],
)
associated_objects.append(obj.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch ticket {obj_id}: {e}")
elif to_object_type == "deals":
for obj_id in object_ids:
try:
obj = self._call_hubspot(
api_client.crm.deals.basic_api.get_by_id,
deal_id=obj_id,
properties=[
"dealname",
"amount",
"dealstage",
"closedate",
"pipeline",
],
)
associated_objects.append(obj.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch deal {obj_id}: {e}")
elif to_object_type == "tickets":
for obj_id in object_ids:
try:
obj = self._call_hubspot(
api_client.crm.tickets.basic_api.get_by_id,
ticket_id=obj_id,
properties=["subject", "content", "hs_ticket_priority"],
)
associated_objects.append(obj.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch ticket {obj_id}: {e}")
return associated_objects
@@ -239,33 +283,33 @@ class HubSpotConnector(LoadConnector, PollConnector):
) -> list[dict[str, Any]]:
"""Get notes associated with a given object"""
try:
# Get associations to notes (engagement type)
associations = api_client.crm.associations.v4.basic_api.get_page(
associations_iter = self._paginated_results(
api_client.crm.associations.v4.basic_api.get_page,
object_type=object_type,
object_id=object_id,
to_object_type="notes",
)
associated_notes = []
if associations.results:
note_ids = [assoc.to_object_id for assoc in associations.results]
note_ids = [assoc.to_object_id for assoc in associations_iter]
# Batch get the associated notes
for note_id in note_ids:
try:
# Notes are engagements in HubSpot, use the engagements API
note = api_client.crm.objects.notes.basic_api.get_by_id(
note_id=note_id,
properties=[
"hs_note_body",
"hs_timestamp",
"hs_created_by",
"hubspot_owner_id",
],
)
associated_notes.append(note.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch note {note_id}: {e}")
associated_notes = []
for note_id in note_ids:
try:
# Notes are engagements in HubSpot, use the engagements API
note = self._call_hubspot(
api_client.crm.objects.notes.basic_api.get_by_id,
note_id=note_id,
properties=[
"hs_note_body",
"hs_timestamp",
"hs_created_by",
"hubspot_owner_id",
],
)
associated_notes.append(note.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch note {note_id}: {e}")
return associated_notes
@@ -358,7 +402,9 @@ class HubSpotConnector(LoadConnector, PollConnector):
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
api_client = HubSpot(access_token=self.access_token)
all_tickets = api_client.crm.tickets.get_all(
tickets_iter = self._paginated_results(
api_client.crm.tickets.basic_api.get_page,
properties=[
"subject",
"content",
@@ -371,7 +417,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
doc_batch: list[Document] = []
for ticket in all_tickets:
for ticket in tickets_iter:
updated_at = ticket.updated_at.replace(tzinfo=None)
if start is not None and updated_at < start.replace(tzinfo=None):
continue
@@ -459,7 +505,9 @@ class HubSpotConnector(LoadConnector, PollConnector):
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
api_client = HubSpot(access_token=self.access_token)
all_companies = api_client.crm.companies.get_all(
companies_iter = self._paginated_results(
api_client.crm.companies.basic_api.get_page,
properties=[
"name",
"domain",
@@ -475,7 +523,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
doc_batch: list[Document] = []
for company in all_companies:
for company in companies_iter:
updated_at = company.updated_at.replace(tzinfo=None)
if start is not None and updated_at < start.replace(tzinfo=None):
continue
@@ -582,7 +630,9 @@ class HubSpotConnector(LoadConnector, PollConnector):
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
api_client = HubSpot(access_token=self.access_token)
all_deals = api_client.crm.deals.get_all(
deals_iter = self._paginated_results(
api_client.crm.deals.basic_api.get_page,
properties=[
"dealname",
"amount",
@@ -598,7 +648,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
doc_batch: list[Document] = []
for deal in all_deals:
for deal in deals_iter:
updated_at = deal.updated_at.replace(tzinfo=None)
if start is not None and updated_at < start.replace(tzinfo=None):
continue
@@ -703,7 +753,9 @@ class HubSpotConnector(LoadConnector, PollConnector):
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
api_client = HubSpot(access_token=self.access_token)
all_contacts = api_client.crm.contacts.get_all(
contacts_iter = self._paginated_results(
api_client.crm.contacts.basic_api.get_page,
properties=[
"firstname",
"lastname",
@@ -721,7 +773,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
doc_batch: list[Document] = []
for contact in all_contacts:
for contact in contacts_iter:
updated_at = contact.updated_at.replace(tzinfo=None)
if start is not None and updated_at < start.replace(tzinfo=None):
continue

Some files were not shown because too many files have changed in this diff Show More