mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-02 06:05:46 +00:00
Compare commits
3 Commits
content-re
...
v2.0.0-bet
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ed40cbdd00 | ||
|
|
b36910240d | ||
|
|
488b27ba04 |
@@ -8,9 +8,9 @@ on:
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
|
||||
|
||||
# don't tag cloud images with "latest"
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') && !contains(github.ref_name, 'cloud') }}
|
||||
|
||||
# tag nightly builds with "edge"
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
@@ -33,7 +33,16 @@ jobs:
|
||||
run: |
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
|
||||
- name: Check if stable release version
|
||||
id: check_version
|
||||
run: |
|
||||
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]] && [[ "${{ github.ref_name }}" != *"cloud"* ]]; then
|
||||
echo "is_stable=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "is_stable=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -46,7 +55,8 @@ jobs:
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -119,7 +129,8 @@ jobs:
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
|
||||
@@ -11,8 +11,8 @@ env:
|
||||
BUILDKIT_PROGRESS: plain
|
||||
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
|
||||
|
||||
# don't tag cloud images with "latest"
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') && !contains(github.ref_name, 'cloud') }}
|
||||
# tag nightly builds with "edge"
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
|
||||
|
||||
jobs:
|
||||
|
||||
@@ -145,6 +145,15 @@ jobs:
|
||||
if: needs.check_model_server_changes.outputs.changed == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check if stable release version
|
||||
id: check_version
|
||||
run: |
|
||||
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]] && [[ "${{ github.ref_name }}" != *"cloud"* ]]; then
|
||||
echo "is_stable=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "is_stable=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
@@ -157,11 +166,16 @@ jobs:
|
||||
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }} \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
|
||||
if [[ "${{ env.LATEST_TAG }}" == "true" ]]; then
|
||||
if [[ "${{ steps.check_version.outputs.is_stable }}" == "true" ]]; then
|
||||
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:latest \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
|
||||
fi
|
||||
if [[ "${{ env.EDGE_TAG }}" == "true" ]]; then
|
||||
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:edge \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
|
||||
fi
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@v3
|
||||
|
||||
@@ -7,7 +7,10 @@ on:
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
# tag nightly builds with "edge"
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
|
||||
|
||||
DEPLOYMENT: standalone
|
||||
|
||||
jobs:
|
||||
@@ -45,6 +48,15 @@ jobs:
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Check if stable release version
|
||||
id: check_version
|
||||
run: |
|
||||
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||
echo "is_stable=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "is_stable=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -57,7 +69,8 @@ jobs:
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -126,7 +139,8 @@ jobs:
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
|
||||
6
.github/workflows/helm-chart-releases.yml
vendored
6
.github/workflows/helm-chart-releases.yml
vendored
@@ -25,9 +25,11 @@ jobs:
|
||||
|
||||
- name: Add required Helm repositories
|
||||
run: |
|
||||
helm repo add bitnami https://charts.bitnami.com/bitnami
|
||||
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
|
||||
helm repo add onyx-vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo add keda https://kedacore.github.io/charts
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
helm repo update
|
||||
|
||||
- name: Build chart dependencies
|
||||
|
||||
@@ -20,6 +20,7 @@ env:
|
||||
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
|
||||
|
||||
# LLMs
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
71
.github/workflows/pr-helm-chart-testing.yml
vendored
71
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -65,35 +65,45 @@ jobs:
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Adding Helm repositories ==="
|
||||
helm repo add bitnami https://charts.bitnami.com/bitnami
|
||||
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
|
||||
helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
helm repo update
|
||||
|
||||
- name: Pre-pull critical images
|
||||
- name: Install Redis operator
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
echo "=== Installing redis-operator CRDs ==="
|
||||
helm upgrade --install redis-operator ot-container-kit/redis-operator \
|
||||
--namespace redis-operator --create-namespace --wait --timeout 300s
|
||||
|
||||
- name: Pre-pull required images
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Pre-pulling critical images to avoid timeout ==="
|
||||
# Get kind cluster name
|
||||
echo "=== Pre-pulling required images to avoid timeout ==="
|
||||
KIND_CLUSTER=$(kubectl config current-context | sed 's/kind-//')
|
||||
echo "Kind cluster: $KIND_CLUSTER"
|
||||
|
||||
# Pre-pull images that are likely to be used
|
||||
echo "Pre-pulling PostgreSQL image..."
|
||||
docker pull postgres:15-alpine || echo "Failed to pull postgres:15-alpine"
|
||||
kind load docker-image postgres:15-alpine --name $KIND_CLUSTER || echo "Failed to load postgres image"
|
||||
|
||||
echo "Pre-pulling Redis image..."
|
||||
docker pull redis:7-alpine || echo "Failed to pull redis:7-alpine"
|
||||
kind load docker-image redis:7-alpine --name $KIND_CLUSTER || echo "Failed to load redis image"
|
||||
|
||||
echo "Pre-pulling Onyx images..."
|
||||
docker pull docker.io/onyxdotapp/onyx-web-server:latest || echo "Failed to pull onyx web server"
|
||||
docker pull docker.io/onyxdotapp/onyx-backend:latest || echo "Failed to pull onyx backend"
|
||||
kind load docker-image docker.io/onyxdotapp/onyx-web-server:latest --name $KIND_CLUSTER || echo "Failed to load onyx web server"
|
||||
kind load docker-image docker.io/onyxdotapp/onyx-backend:latest --name $KIND_CLUSTER || echo "Failed to load onyx backend"
|
||||
|
||||
|
||||
IMAGES=(
|
||||
"ghcr.io/cloudnative-pg/cloudnative-pg:1.27.0"
|
||||
"quay.io/opstree/redis:v7.0.15"
|
||||
"docker.io/onyxdotapp/onyx-web-server:latest"
|
||||
)
|
||||
|
||||
for image in "${IMAGES[@]}"; do
|
||||
echo "Pre-pulling $image"
|
||||
if docker pull "$image"; then
|
||||
kind load docker-image "$image" --name "$KIND_CLUSTER" || echo "Failed to load $image into kind"
|
||||
else
|
||||
echo "Failed to pull $image"
|
||||
fi
|
||||
done
|
||||
|
||||
echo "=== Images loaded into Kind cluster ==="
|
||||
docker exec $KIND_CLUSTER-control-plane crictl images | grep -E "(postgres|redis|onyx)" || echo "Some images may still be loading..."
|
||||
docker exec "$KIND_CLUSTER"-control-plane crictl images | grep -E "(cloudnative-pg|redis|onyx)" || echo "Some images may still be loading..."
|
||||
|
||||
- name: Validate chart dependencies
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
@@ -149,6 +159,7 @@ jobs:
|
||||
|
||||
# Run the actual installation with detailed logging
|
||||
echo "=== Starting ct install ==="
|
||||
set +e
|
||||
ct install --all \
|
||||
--helm-extra-set-args="\
|
||||
--set=nginx.enabled=false \
|
||||
@@ -156,8 +167,10 @@ jobs:
|
||||
--set=vespa.enabled=false \
|
||||
--set=slackbot.enabled=false \
|
||||
--set=postgresql.enabled=true \
|
||||
--set=postgresql.primary.persistence.enabled=false \
|
||||
--set=postgresql.nameOverride=cloudnative-pg \
|
||||
--set=postgresql.cluster.storage.storageClass=standard \
|
||||
--set=redis.enabled=true \
|
||||
--set=redis.storageSpec.volumeClaimTemplate.spec.storageClassName=standard \
|
||||
--set=webserver.replicaCount=1 \
|
||||
--set=api.replicaCount=0 \
|
||||
--set=inferenceCapability.replicaCount=0 \
|
||||
@@ -173,8 +186,16 @@ jobs:
|
||||
--set=celery_worker_user_files_indexing.replicaCount=0" \
|
||||
--helm-extra-args="--timeout 900s --debug" \
|
||||
--debug --config ct.yaml
|
||||
|
||||
echo "=== Installation completed successfully ==="
|
||||
CT_EXIT=$?
|
||||
set -e
|
||||
|
||||
if [[ $CT_EXIT -ne 0 ]]; then
|
||||
echo "ct install failed with exit code $CT_EXIT"
|
||||
exit $CT_EXIT
|
||||
else
|
||||
echo "=== Installation completed successfully ==="
|
||||
fi
|
||||
|
||||
kubectl get pods --all-namespaces
|
||||
|
||||
- name: Post-install verification
|
||||
@@ -199,7 +220,7 @@ jobs:
|
||||
|
||||
echo "=== Recent logs for debugging ==="
|
||||
kubectl logs --all-namespaces --tail=50 | grep -i "error\|timeout\|failed\|pull" || echo "No error logs found"
|
||||
|
||||
|
||||
echo "=== Helm releases ==="
|
||||
helm list --all-namespaces
|
||||
# the following would install only changed charts, but we only have one chart so
|
||||
|
||||
7
.github/workflows/pr-integration-tests.yml
vendored
7
.github/workflows/pr-integration-tests.yml
vendored
@@ -22,9 +22,11 @@ env:
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
|
||||
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
|
||||
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
@@ -131,6 +133,7 @@ jobs:
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
no-cache: true
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
@@ -158,6 +161,7 @@ jobs:
|
||||
push: true
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
no-cache: true
|
||||
|
||||
build-integration-image:
|
||||
needs: prepare-build
|
||||
@@ -191,6 +195,7 @@ jobs:
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
no-cache: true
|
||||
|
||||
integration-tests:
|
||||
needs:
|
||||
@@ -337,9 +342,11 @@ jobs:
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN_SCOPED=${CONFLUENCE_ACCESS_TOKEN_SCOPED} \
|
||||
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
|
||||
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
|
||||
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
|
||||
-e JIRA_API_TOKEN_SCOPED=${JIRA_API_TOKEN_SCOPED} \
|
||||
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
|
||||
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
|
||||
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
|
||||
|
||||
@@ -19,9 +19,11 @@ env:
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
|
||||
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
|
||||
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
@@ -128,6 +130,7 @@ jobs:
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-backend:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
no-cache: true
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: blacksmith-16vcpu-ubuntu-2404-arm
|
||||
@@ -155,6 +158,7 @@ jobs:
|
||||
push: true
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
no-cache: true
|
||||
|
||||
build-integration-image:
|
||||
needs: prepare-build
|
||||
@@ -188,6 +192,7 @@ jobs:
|
||||
tags: ${{ env.PRIVATE_REGISTRY }}/integration-test-onyx-integration:test-${{ github.run_id }}
|
||||
push: true
|
||||
outputs: type=registry
|
||||
no-cache: true
|
||||
|
||||
integration-tests-mit:
|
||||
needs:
|
||||
@@ -334,9 +339,11 @@ jobs:
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN_SCOPED=${CONFLUENCE_ACCESS_TOKEN_SCOPED} \
|
||||
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
|
||||
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
|
||||
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
|
||||
-e JIRA_API_TOKEN_SCOPED=${JIRA_API_TOKEN_SCOPED} \
|
||||
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
|
||||
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
|
||||
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
|
||||
|
||||
65
.github/workflows/pr-python-connector-tests.yml
vendored
65
.github/workflows/pr-python-connector-tests.yml
vendored
@@ -20,11 +20,13 @@ env:
|
||||
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
|
||||
|
||||
# Jira
|
||||
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
|
||||
|
||||
# Gong
|
||||
GONG_ACCESS_KEY: ${{ secrets.GONG_ACCESS_KEY }}
|
||||
@@ -132,7 +134,24 @@ jobs:
|
||||
playwright install chromium
|
||||
playwright install-deps chromium
|
||||
|
||||
- name: Run Tests
|
||||
- name: Detect Connector changes
|
||||
id: changes
|
||||
uses: dorny/paths-filter@v3
|
||||
with:
|
||||
filters: |
|
||||
hubspot:
|
||||
- 'backend/onyx/connectors/hubspot/**'
|
||||
- 'backend/tests/daily/connectors/hubspot/**'
|
||||
salesforce:
|
||||
- 'backend/onyx/connectors/salesforce/**'
|
||||
- 'backend/tests/daily/connectors/salesforce/**'
|
||||
github:
|
||||
- 'backend/onyx/connectors/github/**'
|
||||
- 'backend/tests/daily/connectors/github/**'
|
||||
file_processing:
|
||||
- 'backend/onyx/file_processing/**'
|
||||
|
||||
- name: Run Tests (excluding HubSpot, Salesforce, and GitHub)
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: |
|
||||
py.test \
|
||||
@@ -142,7 +161,49 @@ jobs:
|
||||
-o junit_family=xunit2 \
|
||||
-xv \
|
||||
--ff \
|
||||
backend/tests/daily/connectors
|
||||
backend/tests/daily/connectors \
|
||||
--ignore backend/tests/daily/connectors/hubspot \
|
||||
--ignore backend/tests/daily/connectors/salesforce \
|
||||
--ignore backend/tests/daily/connectors/github
|
||||
|
||||
- name: Run HubSpot Connector Tests
|
||||
if: ${{ github.event_name == 'schedule' || steps.changes.outputs.hubspot == 'true' || steps.changes.outputs.file_processing == 'true' }}
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: |
|
||||
py.test \
|
||||
-n 8 \
|
||||
--dist loadfile \
|
||||
--durations=8 \
|
||||
-o junit_family=xunit2 \
|
||||
-xv \
|
||||
--ff \
|
||||
backend/tests/daily/connectors/hubspot
|
||||
|
||||
- name: Run Salesforce Connector Tests
|
||||
if: ${{ github.event_name == 'schedule' || steps.changes.outputs.salesforce == 'true' || steps.changes.outputs.file_processing == 'true' }}
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: |
|
||||
py.test \
|
||||
-n 8 \
|
||||
--dist loadfile \
|
||||
--durations=8 \
|
||||
-o junit_family=xunit2 \
|
||||
-xv \
|
||||
--ff \
|
||||
backend/tests/daily/connectors/salesforce
|
||||
|
||||
- name: Run GitHub Connector Tests
|
||||
if: ${{ github.event_name == 'schedule' || steps.changes.outputs.github == 'true' || steps.changes.outputs.file_processing == 'true' }}
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: |
|
||||
py.test \
|
||||
-n 8 \
|
||||
--dist loadfile \
|
||||
--durations=8 \
|
||||
-o junit_family=xunit2 \
|
||||
-xv \
|
||||
--ff \
|
||||
backend/tests/daily/connectors/github
|
||||
|
||||
- name: Alert on Failure
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
|
||||
@@ -34,8 +34,7 @@ repos:
|
||||
hooks:
|
||||
- id: prettier
|
||||
types_or: [html, css, javascript, ts, tsx]
|
||||
additional_dependencies:
|
||||
- prettier
|
||||
language_version: system
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
|
||||
200
.vscode/launch.template.jsonc
vendored
200
.vscode/launch.template.jsonc
vendored
@@ -23,12 +23,10 @@
|
||||
"Slack Bot",
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery background",
|
||||
"Celery docfetching",
|
||||
"Celery docprocessing",
|
||||
"Celery beat",
|
||||
"Celery monitoring",
|
||||
"Celery user file processing"
|
||||
"Celery beat"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
@@ -42,16 +40,32 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Celery (all)",
|
||||
"name": "Celery (lightweight mode)",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery background",
|
||||
"Celery docfetching",
|
||||
"Celery docprocessing",
|
||||
"Celery beat"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
},
|
||||
"stopAll": true
|
||||
},
|
||||
{
|
||||
"name": "Celery (standard mode)",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery kg_processing",
|
||||
"Celery monitoring",
|
||||
"Celery user_file_processing",
|
||||
"Celery docfetching",
|
||||
"Celery docprocessing",
|
||||
"Celery beat",
|
||||
"Celery monitoring",
|
||||
"Celery user file processing"
|
||||
"Celery beat"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
@@ -199,6 +213,35 @@
|
||||
},
|
||||
"consoleTitle": "Celery light Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery background",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.background",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=6",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=background@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery background Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery heavy",
|
||||
"type": "debugpy",
|
||||
@@ -221,13 +264,100 @@
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync"
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery heavy Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery kg_processing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.kg_processing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=2",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=kg_processing@%n",
|
||||
"-Q",
|
||||
"kg_processing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery kg_processing Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery monitoring",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery monitoring Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery user_file_processing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.user_file_processing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=2",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=user_file_processing@%n",
|
||||
"-Q",
|
||||
"user_file_processing,user_file_project_sync"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery user_file_processing Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery docfetching",
|
||||
"type": "debugpy",
|
||||
@@ -311,58 +441,6 @@
|
||||
},
|
||||
"consoleTitle": "Celery beat Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery monitoring",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--pool=solo",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery monitoring Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery user file processing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.user_file_processing",
|
||||
"worker",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=user_file_processing@%n",
|
||||
"--pool=threads",
|
||||
"-Q",
|
||||
"user_file_processing,user_file_project_sync"
|
||||
],
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery user file processing Console"
|
||||
},
|
||||
{
|
||||
"name": "Pytest",
|
||||
"consoleName": "Pytest",
|
||||
|
||||
32
AGENTS.md
32
AGENTS.md
@@ -70,7 +70,12 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
|
||||
- Single thread (monitoring doesn't need parallelism)
|
||||
- Cloud-specific monitoring tasks
|
||||
|
||||
8. **Beat Worker** (`beat`)
|
||||
8. **User File Processing Worker** (`user_file_processing`)
|
||||
- Processes user-uploaded files
|
||||
- Handles user file indexing and project synchronization
|
||||
- Configurable concurrency
|
||||
|
||||
9. **Beat Worker** (`beat`)
|
||||
- Celery's scheduler for periodic tasks
|
||||
- Uses DynamicTenantScheduler for multi-tenant support
|
||||
- Schedules tasks like:
|
||||
@@ -82,6 +87,31 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
|
||||
- Monitoring tasks (every 5 minutes)
|
||||
- Cleanup tasks (hourly)
|
||||
|
||||
#### Worker Deployment Modes
|
||||
|
||||
Onyx supports two deployment modes for background workers, controlled by the `USE_LIGHTWEIGHT_BACKGROUND_WORKER` environment variable:
|
||||
|
||||
**Lightweight Mode** (default, `USE_LIGHTWEIGHT_BACKGROUND_WORKER=true`):
|
||||
- Runs a single consolidated `background` worker that handles all background tasks:
|
||||
- Pruning operations (from `heavy` worker)
|
||||
- Knowledge graph processing (from `kg_processing` worker)
|
||||
- Monitoring tasks (from `monitoring` worker)
|
||||
- User file processing (from `user_file_processing` worker)
|
||||
- Lower resource footprint (single worker process)
|
||||
- Suitable for smaller deployments or development environments
|
||||
- Default concurrency: 6 threads
|
||||
|
||||
**Standard Mode** (`USE_LIGHTWEIGHT_BACKGROUND_WORKER=false`):
|
||||
- Runs separate specialized workers as documented above (heavy, kg_processing, monitoring, user_file_processing)
|
||||
- Better isolation and scalability
|
||||
- Can scale individual workers independently based on workload
|
||||
- Suitable for production deployments with higher load
|
||||
|
||||
The deployment mode affects:
|
||||
- **Backend**: Worker processes spawned by supervisord or dev scripts
|
||||
- **Helm**: Which Kubernetes deployments are created
|
||||
- **Dev Environment**: Which workers `dev_run_background_jobs.py` spawns
|
||||
|
||||
#### Key Features
|
||||
|
||||
- **Thread-based Workers**: All workers use thread pools (not processes) for stability
|
||||
|
||||
36
CLAUDE.md
36
CLAUDE.md
@@ -70,7 +70,12 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
|
||||
- Single thread (monitoring doesn't need parallelism)
|
||||
- Cloud-specific monitoring tasks
|
||||
|
||||
8. **Beat Worker** (`beat`)
|
||||
8. **User File Processing Worker** (`user_file_processing`)
|
||||
- Processes user-uploaded files
|
||||
- Handles user file indexing and project synchronization
|
||||
- Configurable concurrency
|
||||
|
||||
9. **Beat Worker** (`beat`)
|
||||
- Celery's scheduler for periodic tasks
|
||||
- Uses DynamicTenantScheduler for multi-tenant support
|
||||
- Schedules tasks like:
|
||||
@@ -82,11 +87,36 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
|
||||
- Monitoring tasks (every 5 minutes)
|
||||
- Cleanup tasks (hourly)
|
||||
|
||||
#### Worker Deployment Modes
|
||||
|
||||
Onyx supports two deployment modes for background workers, controlled by the `USE_LIGHTWEIGHT_BACKGROUND_WORKER` environment variable:
|
||||
|
||||
**Lightweight Mode** (default, `USE_LIGHTWEIGHT_BACKGROUND_WORKER=true`):
|
||||
- Runs a single consolidated `background` worker that handles all background tasks:
|
||||
- Pruning operations (from `heavy` worker)
|
||||
- Knowledge graph processing (from `kg_processing` worker)
|
||||
- Monitoring tasks (from `monitoring` worker)
|
||||
- User file processing (from `user_file_processing` worker)
|
||||
- Lower resource footprint (single worker process)
|
||||
- Suitable for smaller deployments or development environments
|
||||
- Default concurrency: 6 threads
|
||||
|
||||
**Standard Mode** (`USE_LIGHTWEIGHT_BACKGROUND_WORKER=false`):
|
||||
- Runs separate specialized workers as documented above (heavy, kg_processing, monitoring, user_file_processing)
|
||||
- Better isolation and scalability
|
||||
- Can scale individual workers independently based on workload
|
||||
- Suitable for production deployments with higher load
|
||||
|
||||
The deployment mode affects:
|
||||
- **Backend**: Worker processes spawned by supervisord or dev scripts
|
||||
- **Helm**: Which Kubernetes deployments are created
|
||||
- **Dev Environment**: Which workers `dev_run_background_jobs.py` spawns
|
||||
|
||||
#### Key Features
|
||||
|
||||
- **Thread-based Workers**: All workers use thread pools (not processes) for stability
|
||||
- **Tenant Awareness**: Multi-tenant support with per-tenant task isolation. There is a
|
||||
middleware layer that automatically finds the appropriate tenant ID when sending tasks
|
||||
- **Tenant Awareness**: Multi-tenant support with per-tenant task isolation. There is a
|
||||
middleware layer that automatically finds the appropriate tenant ID when sending tasks
|
||||
via Celery Beat.
|
||||
- **Task Prioritization**: High, Medium, Low priority queues
|
||||
- **Monitoring**: Built-in heartbeat and liveness checking
|
||||
|
||||
@@ -13,8 +13,7 @@ As an open source project in a rapidly changing space, we welcome all contributi
|
||||
The [GitHub Issues](https://github.com/onyx-dot-app/onyx/issues) page is a great place to start for contribution ideas.
|
||||
|
||||
To ensure that your contribution is aligned with the project's direction, please reach out to any maintainer on the Onyx team
|
||||
via [Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-34lu4m7xg-TsKGO6h8PDvR5W27zTdyhA) /
|
||||
[Discord](https://discord.gg/TDJ59cGV2X) or [email](mailto:founders@onyx.app).
|
||||
via [Discord](https://discord.gg/4NA5SbzrWb) or [email](mailto:hello@onyx.app).
|
||||
|
||||
Issues that have been explicitly approved by the maintainers (aligned with the direction of the project)
|
||||
will be marked with the `approved by maintainers` label.
|
||||
@@ -28,8 +27,7 @@ Your input is vital to making sure that Onyx moves in the right direction.
|
||||
Before starting on implementation, please raise a GitHub issue.
|
||||
|
||||
Also, always feel free to message the founders (Chris Weaver / Yuhong Sun) on
|
||||
[Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-34lu4m7xg-TsKGO6h8PDvR5W27zTdyhA) /
|
||||
[Discord](https://discord.gg/TDJ59cGV2X) directly about anything at all.
|
||||
[Discord](https://discord.gg/4NA5SbzrWb) directly about anything at all.
|
||||
|
||||
### Contributing Code
|
||||
|
||||
@@ -46,9 +44,7 @@ Our goal is to make contributing as easy as possible. If you run into any issues
|
||||
That way we can help future contributors and users can avoid the same issue.
|
||||
|
||||
We also have support channels and generally interesting discussions on our
|
||||
[Slack](https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA)
|
||||
and
|
||||
[Discord](https://discord.gg/TDJ59cGV2X).
|
||||
[Discord](https://discord.gg/4NA5SbzrWb).
|
||||
|
||||
We would love to see you there!
|
||||
|
||||
@@ -105,6 +101,11 @@ pip install -r backend/requirements/ee.txt
|
||||
pip install -r backend/requirements/model_server.txt
|
||||
```
|
||||
|
||||
Fix vscode/cursor auto-imports:
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Install Playwright for Python (headless browser required by the Web Connector)
|
||||
|
||||
In the activated Python virtualenv, install Playwright for Python by running:
|
||||
@@ -117,8 +118,15 @@ You may have to deactivate and reactivate your virtualenv for `playwright` to ap
|
||||
|
||||
#### Frontend: Node dependencies
|
||||
|
||||
Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend.
|
||||
Once the above is done, navigate to `onyx/web` run:
|
||||
Onyx uses Node v22.20.0. We highly recommend you use [Node Version Manager (nvm)](https://github.com/nvm-sh/nvm)
|
||||
to manage your Node installations. Once installed, you can run
|
||||
|
||||
```bash
|
||||
nvm install 22 && nvm use 22`
|
||||
node -v # verify your active version
|
||||
```
|
||||
|
||||
Navigate to `onyx/web` and run:
|
||||
|
||||
```bash
|
||||
npm i
|
||||
@@ -129,8 +137,6 @@ npm i
|
||||
### Backend
|
||||
|
||||
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
|
||||
First, install pre-commit (if you don't have it already) following the instructions
|
||||
[here](https://pre-commit.com/#installation).
|
||||
|
||||
With the virtual environment active, install the pre-commit library with:
|
||||
|
||||
@@ -150,15 +156,17 @@ To run the mypy checks manually, run `python -m mypy .` from the `onyx/backend`
|
||||
|
||||
### Web
|
||||
|
||||
We use `prettier` for formatting. The desired version (2.8.8) will be installed via a `npm i` from the `onyx/web` directory.
|
||||
We use `prettier` for formatting. The desired version will be installed via a `npm i` from the `onyx/web` directory.
|
||||
To run the formatter, use `npx prettier --write .` from the `onyx/web` directory.
|
||||
Please double check that prettier passes before creating a pull request.
|
||||
|
||||
Pre-commit will also run prettier automatically on files you've recently touched. If re-formatted, your commit will fail.
|
||||
Re-stage your changes and commit again.
|
||||
|
||||
# Running the application for development
|
||||
|
||||
## Developing using VSCode Debugger (recommended)
|
||||
|
||||
We highly recommend using VSCode debugger for development.
|
||||
**We highly recommend using VSCode debugger for development.**
|
||||
See [CONTRIBUTING_VSCODE.md](./CONTRIBUTING_VSCODE.md) for more details.
|
||||
|
||||
Otherwise, you can follow the instructions below to run the application for development.
|
||||
|
||||
@@ -21,6 +21,9 @@ Before starting, make sure the Docker Daemon is running.
|
||||
5. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
|
||||
6. Use the debug toolbar to step through code, inspect variables, etc.
|
||||
|
||||
Note: Clear and Restart External Volumes and Containers will reset your postgres and Vespa (relational-db and index).
|
||||
Only run this if you are okay with wiping your data.
|
||||
|
||||
## Features
|
||||
|
||||
- Hot reload is enabled for the web server and API servers
|
||||
|
||||
@@ -111,6 +111,8 @@ COPY ./static /app/static
|
||||
# Escape hatch scripts
|
||||
COPY ./scripts/debugging /app/scripts/debugging
|
||||
COPY ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
|
||||
COPY ./scripts/supervisord_entrypoint.sh /app/scripts/supervisord_entrypoint.sh
|
||||
RUN chmod +x /app/scripts/supervisord_entrypoint.sh
|
||||
|
||||
# Put logo in assets
|
||||
COPY ./assets /app/assets
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
"""reset userfile document_id_migrated field
|
||||
|
||||
Revision ID: 40926a4dab77
|
||||
Revises: 64bd5677aeb6
|
||||
Create Date: 2025-10-06 16:10:32.898668
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "40926a4dab77"
|
||||
down_revision = "64bd5677aeb6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Set all existing records to not migrated
|
||||
op.execute(
|
||||
"UPDATE user_file SET document_id_migrated = FALSE "
|
||||
"WHERE document_id_migrated IS DISTINCT FROM FALSE;"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# No-op
|
||||
pass
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Add image input support to model config
|
||||
|
||||
Revision ID: 64bd5677aeb6
|
||||
Revises: b30353be4eec
|
||||
Create Date: 2025-09-28 15:48:12.003612
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "64bd5677aeb6"
|
||||
down_revision = "b30353be4eec"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"model_configuration",
|
||||
sa.Column("supports_image_input", sa.Boolean(), nullable=True),
|
||||
)
|
||||
|
||||
# Seems to be left over from when model visibility was introduced and a nullable field.
|
||||
# Set any null is_visible values to False
|
||||
connection = op.get_bind()
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"UPDATE model_configuration SET is_visible = false WHERE is_visible IS NULL"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("model_configuration", "supports_image_input")
|
||||
45
backend/alembic/versions/96a5702df6aa_mcp_tool_enabled.py
Normal file
45
backend/alembic/versions/96a5702df6aa_mcp_tool_enabled.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""mcp_tool_enabled
|
||||
|
||||
Revision ID: 96a5702df6aa
|
||||
Revises: 40926a4dab77
|
||||
Create Date: 2025-10-09 12:10:21.733097
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "96a5702df6aa"
|
||||
down_revision = "40926a4dab77"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
DELETE_DISABLED_TOOLS_SQL = "DELETE FROM tool WHERE enabled = false"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"tool",
|
||||
sa.Column(
|
||||
"enabled",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.true(),
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_tool_mcp_server_enabled",
|
||||
"tool",
|
||||
["mcp_server_id", "enabled"],
|
||||
)
|
||||
# Remove the server default so application controls defaulting
|
||||
op.alter_column("tool", "enabled", server_default=None)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(DELETE_DISABLED_TOOLS_SQL)
|
||||
op.drop_index("ix_tool_mcp_server_enabled", table_name="tool")
|
||||
op.drop_column("tool", "enabled")
|
||||
@@ -1,8 +1,13 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import jwt
|
||||
import requests
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
@@ -10,6 +15,7 @@ from fastapi import status
|
||||
from jwt import decode as jwt_decode
|
||||
from jwt import InvalidTokenError
|
||||
from jwt import PyJWTError
|
||||
from jwt.algorithms import RSAAlgorithm
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -30,43 +36,156 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_PUBLIC_KEY_FETCH_ATTEMPTS = 2
|
||||
|
||||
|
||||
class PublicKeyFormat(Enum):
|
||||
JWKS = "jwks"
|
||||
PEM = "pem"
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_public_key() -> str | None:
|
||||
def _fetch_public_key_payload() -> tuple[str | dict[str, Any], PublicKeyFormat] | None:
|
||||
"""Fetch and cache the raw JWT verification material."""
|
||||
if JWT_PUBLIC_KEY_URL is None:
|
||||
logger.error("JWT_PUBLIC_KEY_URL is not set")
|
||||
return None
|
||||
|
||||
response = requests.get(JWT_PUBLIC_KEY_URL)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
try:
|
||||
response = requests.get(JWT_PUBLIC_KEY_URL)
|
||||
response.raise_for_status()
|
||||
except requests.RequestException as exc:
|
||||
logger.error(f"Failed to fetch JWT public key: {str(exc)}")
|
||||
return None
|
||||
content_type = response.headers.get("Content-Type", "").lower()
|
||||
raw_body = response.text
|
||||
body_lstripped = raw_body.lstrip()
|
||||
|
||||
if "application/json" in content_type or body_lstripped.startswith("{"):
|
||||
try:
|
||||
data = response.json()
|
||||
except ValueError:
|
||||
logger.error("JWT public key URL returned invalid JSON")
|
||||
return None
|
||||
|
||||
if isinstance(data, dict) and "keys" in data:
|
||||
return data, PublicKeyFormat.JWKS
|
||||
|
||||
logger.error(
|
||||
"JWT public key URL returned JSON but no JWKS 'keys' field was found"
|
||||
)
|
||||
return None
|
||||
|
||||
body = raw_body.strip()
|
||||
if not body:
|
||||
logger.error("JWT public key URL returned an empty response")
|
||||
return None
|
||||
|
||||
return body, PublicKeyFormat.PEM
|
||||
|
||||
|
||||
def get_public_key(token: str) -> RSAPublicKey | str | None:
|
||||
"""Return the concrete public key used to verify the provided JWT token."""
|
||||
payload = _fetch_public_key_payload()
|
||||
if payload is None:
|
||||
logger.error("Failed to retrieve public key payload")
|
||||
return None
|
||||
|
||||
key_material, key_format = payload
|
||||
|
||||
if key_format is PublicKeyFormat.JWKS:
|
||||
jwks_data = cast(dict[str, Any], key_material)
|
||||
return _resolve_public_key_from_jwks(token, jwks_data)
|
||||
|
||||
return cast(str, key_material)
|
||||
|
||||
|
||||
def _resolve_public_key_from_jwks(
|
||||
token: str, jwks_payload: dict[str, Any]
|
||||
) -> RSAPublicKey | None:
|
||||
try:
|
||||
header = jwt.get_unverified_header(token)
|
||||
except PyJWTError as e:
|
||||
logger.error(f"Unable to parse JWT header: {str(e)}")
|
||||
return None
|
||||
|
||||
keys = jwks_payload.get("keys", []) if isinstance(jwks_payload, dict) else []
|
||||
if not keys:
|
||||
logger.error("JWKS payload did not contain any keys")
|
||||
return None
|
||||
|
||||
kid = header.get("kid")
|
||||
thumbprint = header.get("x5t")
|
||||
|
||||
candidates = []
|
||||
if kid:
|
||||
candidates = [k for k in keys if k.get("kid") == kid]
|
||||
if not candidates and thumbprint:
|
||||
candidates = [k for k in keys if k.get("x5t") == thumbprint]
|
||||
if not candidates and len(keys) == 1:
|
||||
candidates = keys
|
||||
|
||||
if not candidates:
|
||||
logger.warning(
|
||||
"No matching JWK found for token header (kid=%s, x5t=%s)", kid, thumbprint
|
||||
)
|
||||
return None
|
||||
|
||||
if len(candidates) > 1:
|
||||
logger.warning(
|
||||
"Multiple JWKs matched token header kid=%s; selecting the first occurrence",
|
||||
kid,
|
||||
)
|
||||
|
||||
jwk = candidates[0]
|
||||
try:
|
||||
return cast(RSAPublicKey, RSAAlgorithm.from_jwk(json.dumps(jwk)))
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to construct RSA key from JWK: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def verify_jwt_token(token: str, async_db_session: AsyncSession) -> User | None:
|
||||
try:
|
||||
public_key_pem = get_public_key()
|
||||
if public_key_pem is None:
|
||||
logger.error("Failed to retrieve public key")
|
||||
for attempt in range(_PUBLIC_KEY_FETCH_ATTEMPTS):
|
||||
public_key = get_public_key(token)
|
||||
if public_key is None:
|
||||
logger.error("Unable to resolve a public key for JWT verification")
|
||||
if attempt < _PUBLIC_KEY_FETCH_ATTEMPTS - 1:
|
||||
_fetch_public_key_payload.cache_clear()
|
||||
continue
|
||||
return None
|
||||
|
||||
try:
|
||||
payload = jwt_decode(
|
||||
token,
|
||||
public_key,
|
||||
algorithms=["RS256"],
|
||||
options={"verify_aud": False},
|
||||
)
|
||||
except InvalidTokenError as e:
|
||||
logger.error(f"Invalid JWT token: {str(e)}")
|
||||
if attempt < _PUBLIC_KEY_FETCH_ATTEMPTS - 1:
|
||||
_fetch_public_key_payload.cache_clear()
|
||||
continue
|
||||
return None
|
||||
except PyJWTError as e:
|
||||
logger.error(f"JWT decoding error: {str(e)}")
|
||||
if attempt < _PUBLIC_KEY_FETCH_ATTEMPTS - 1:
|
||||
_fetch_public_key_payload.cache_clear()
|
||||
continue
|
||||
return None
|
||||
|
||||
payload = jwt_decode(
|
||||
token,
|
||||
public_key_pem,
|
||||
algorithms=["RS256"],
|
||||
audience=None,
|
||||
)
|
||||
email = payload.get("email")
|
||||
if email:
|
||||
result = await async_db_session.execute(
|
||||
select(User).where(func.lower(User.email) == func.lower(email))
|
||||
)
|
||||
return result.scalars().first()
|
||||
except InvalidTokenError:
|
||||
logger.error("Invalid JWT token")
|
||||
get_public_key.cache_clear()
|
||||
except PyJWTError as e:
|
||||
logger.error(f"JWT decoding error: {str(e)}")
|
||||
get_public_key.cache_clear()
|
||||
logger.warning(
|
||||
"JWT token decoded successfully but no email claim found; skipping auth"
|
||||
)
|
||||
break
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
||||
12
backend/ee/onyx/background/celery/apps/background.py
Normal file
12
backend/ee/onyx/background/celery/apps/background.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from onyx.background.celery.apps.background import celery_app
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cleanup",
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning",
|
||||
"ee.onyx.background.celery.tasks.query_history",
|
||||
]
|
||||
)
|
||||
@@ -1,123 +1,4 @@
|
||||
import csv
|
||||
import io
|
||||
from datetime import datetime
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from ee.onyx.server.query_history.api import fetch_and_process_chat_session_history
|
||||
from ee.onyx.server.query_history.api import ONYX_ANONYMIZED_EMAIL
|
||||
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
|
||||
from onyx.background.celery.apps.heavy import celery_app
|
||||
from onyx.background.task_utils import construct_query_history_report_name
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import FileType
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.tasks import delete_task_with_id
|
||||
from onyx.db.tasks import mark_task_as_finished_with_id
|
||||
from onyx.db.tasks import mark_task_as_started_with_id
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.EXPORT_QUERY_HISTORY_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
trail=False,
|
||||
)
|
||||
def export_query_history_task(
|
||||
self: Task,
|
||||
*,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
start_time: datetime,
|
||||
# Need to include the tenant_id since the TenantAwareTask needs this
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
if not self.request.id:
|
||||
raise RuntimeError("No task id defined for this task; cannot identify it")
|
||||
|
||||
task_id = self.request.id
|
||||
stream = io.StringIO()
|
||||
writer = csv.DictWriter(
|
||||
stream,
|
||||
fieldnames=list(QuestionAnswerPairSnapshot.model_fields.keys()),
|
||||
)
|
||||
writer.writeheader()
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
mark_task_as_started_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
snapshot_generator = fetch_and_process_chat_session_history(
|
||||
db_session=db_session,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
for snapshot in snapshot_generator:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
snapshot.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
|
||||
writer.writerows(
|
||||
qa_pair.to_json()
|
||||
for qa_pair in QuestionAnswerPairSnapshot.from_chat_session_snapshot(
|
||||
snapshot
|
||||
)
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"Failed to export query history with {task_id=}")
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=False,
|
||||
)
|
||||
raise
|
||||
|
||||
report_name = construct_query_history_report_name(task_id)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
stream.seek(0)
|
||||
get_default_file_store().save_file(
|
||||
content=stream,
|
||||
display_name=report_name,
|
||||
file_origin=FileOrigin.QUERY_HISTORY_CSV,
|
||||
file_type=FileType.CSV,
|
||||
file_metadata={
|
||||
"start": start.isoformat(),
|
||||
"end": end.isoformat(),
|
||||
"start_time": start_time.isoformat(),
|
||||
},
|
||||
file_id=report_name,
|
||||
)
|
||||
|
||||
delete_task_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to save query history export file; {report_name=}"
|
||||
)
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=False,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
@@ -125,5 +6,6 @@ celery_app.autodiscover_tasks(
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cleanup",
|
||||
"ee.onyx.background.celery.tasks.query_history",
|
||||
]
|
||||
)
|
||||
|
||||
119
backend/ee/onyx/background/celery/tasks/query_history/tasks.py
Normal file
119
backend/ee/onyx/background/celery/tasks/query_history/tasks.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import csv
|
||||
import io
|
||||
from datetime import datetime
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from ee.onyx.server.query_history.api import fetch_and_process_chat_session_history
|
||||
from ee.onyx.server.query_history.api import ONYX_ANONYMIZED_EMAIL
|
||||
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
|
||||
from onyx.background.task_utils import construct_query_history_report_name
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import FileType
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.tasks import delete_task_with_id
|
||||
from onyx.db.tasks import mark_task_as_finished_with_id
|
||||
from onyx.db.tasks import mark_task_as_started_with_id
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.EXPORT_QUERY_HISTORY_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
trail=False,
|
||||
)
|
||||
def export_query_history_task(
|
||||
self: Task,
|
||||
*,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
start_time: datetime,
|
||||
# Need to include the tenant_id since the TenantAwareTask needs this
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
if not self.request.id:
|
||||
raise RuntimeError("No task id defined for this task; cannot identify it")
|
||||
|
||||
task_id = self.request.id
|
||||
stream = io.StringIO()
|
||||
writer = csv.DictWriter(
|
||||
stream,
|
||||
fieldnames=list(QuestionAnswerPairSnapshot.model_fields.keys()),
|
||||
)
|
||||
writer.writeheader()
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
mark_task_as_started_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
snapshot_generator = fetch_and_process_chat_session_history(
|
||||
db_session=db_session,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
for snapshot in snapshot_generator:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
snapshot.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
|
||||
writer.writerows(
|
||||
qa_pair.to_json()
|
||||
for qa_pair in QuestionAnswerPairSnapshot.from_chat_session_snapshot(
|
||||
snapshot
|
||||
)
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"Failed to export query history with {task_id=}")
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=False,
|
||||
)
|
||||
raise
|
||||
|
||||
report_name = construct_query_history_report_name(task_id)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
stream.seek(0)
|
||||
get_default_file_store().save_file(
|
||||
content=stream,
|
||||
display_name=report_name,
|
||||
file_origin=FileOrigin.QUERY_HISTORY_CSV,
|
||||
file_type=FileType.CSV,
|
||||
file_metadata={
|
||||
"start": start.isoformat(),
|
||||
"end": end.isoformat(),
|
||||
"start_time": start_time.isoformat(),
|
||||
},
|
||||
file_id=report_name,
|
||||
)
|
||||
|
||||
delete_task_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to save query history export file; {report_name=}"
|
||||
)
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=False,
|
||||
)
|
||||
raise
|
||||
@@ -124,9 +124,9 @@ def get_space_permission(
|
||||
and not space_permissions.external_user_group_ids
|
||||
):
|
||||
logger.warning(
|
||||
f"No permissions found for space '{space_key}'. This is very unlikely"
|
||||
"to be correct and is more likely caused by an access token with"
|
||||
"insufficient permissions. Make sure that the access token has Admin"
|
||||
f"No permissions found for space '{space_key}'. This is very unlikely "
|
||||
"to be correct and is more likely caused by an access token with "
|
||||
"insufficient permissions. Make sure that the access token has Admin "
|
||||
f"permissions for space '{space_key}'"
|
||||
)
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ def _get_slim_doc_generator(
|
||||
else 0.0
|
||||
)
|
||||
|
||||
return gmail_connector.retrieve_all_slim_documents(
|
||||
return gmail_connector.retrieve_all_slim_docs_perm_sync(
|
||||
start=start_time,
|
||||
end=current_time.timestamp(),
|
||||
callback=callback,
|
||||
|
||||
@@ -34,7 +34,7 @@ def _get_slim_doc_generator(
|
||||
else 0.0
|
||||
)
|
||||
|
||||
return google_drive_connector.retrieve_all_slim_documents(
|
||||
return google_drive_connector.retrieve_all_slim_docs_perm_sync(
|
||||
start=start_time,
|
||||
end=current_time.timestamp(),
|
||||
callback=callback,
|
||||
|
||||
@@ -59,7 +59,7 @@ def _build_holder_map(permissions: list[dict]) -> dict[str, list[Holder]]:
|
||||
|
||||
for raw_perm in permissions:
|
||||
if not hasattr(raw_perm, "raw"):
|
||||
logger.warn(f"Expected a 'raw' field, but none was found: {raw_perm=}")
|
||||
logger.warning(f"Expected a 'raw' field, but none was found: {raw_perm=}")
|
||||
continue
|
||||
|
||||
permission = Permission(**raw_perm.raw)
|
||||
@@ -71,14 +71,14 @@ def _build_holder_map(permissions: list[dict]) -> dict[str, list[Holder]]:
|
||||
# In order to associate this permission to some Atlassian entity, we need the "Holder".
|
||||
# If this doesn't exist, then we cannot associate this permission to anyone; just skip.
|
||||
if not permission.holder:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Expected to find a permission holder, but none was found: {permission=}"
|
||||
)
|
||||
continue
|
||||
|
||||
type = permission.holder.get("type")
|
||||
if not type:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Expected to find the type of permission holder, but none was found: {permission=}"
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -105,7 +105,9 @@ def _get_slack_document_access(
|
||||
channel_permissions: dict[str, ExternalAccess],
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
|
||||
callback=callback
|
||||
)
|
||||
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
for doc_metadata in doc_metadata_batch:
|
||||
|
||||
@@ -4,7 +4,7 @@ from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFun
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -17,7 +17,7 @@ def generic_doc_sync(
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
doc_source: DocumentSource,
|
||||
slim_connector: SlimConnector,
|
||||
slim_connector: SlimConnectorWithPermSync,
|
||||
label: str,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
@@ -40,7 +40,7 @@ def generic_doc_sync(
|
||||
newly_fetched_doc_ids: set[str] = set()
|
||||
|
||||
logger.info(f"Fetching all slim documents from {doc_source}")
|
||||
for doc_batch in slim_connector.retrieve_all_slim_documents(callback=callback):
|
||||
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(callback=callback):
|
||||
logger.info(f"Got {len(doc_batch)} slim documents from {doc_source}")
|
||||
|
||||
if callback:
|
||||
|
||||
15
backend/ee/onyx/feature_flags/factory.py
Normal file
15
backend/ee/onyx/feature_flags/factory.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from ee.onyx.feature_flags.posthog_provider import PostHogFeatureFlagProvider
|
||||
from onyx.feature_flags.interface import FeatureFlagProvider
|
||||
|
||||
|
||||
def get_posthog_feature_flag_provider() -> FeatureFlagProvider:
|
||||
"""
|
||||
Get the PostHog feature flag provider instance.
|
||||
|
||||
This is the EE implementation that gets loaded by the versioned
|
||||
implementation loader.
|
||||
|
||||
Returns:
|
||||
PostHogFeatureFlagProvider: The PostHog-based feature flag provider
|
||||
"""
|
||||
return PostHogFeatureFlagProvider()
|
||||
54
backend/ee/onyx/feature_flags/posthog_provider.py
Normal file
54
backend/ee/onyx/feature_flags/posthog_provider.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from ee.onyx.utils.posthog_client import posthog
|
||||
from onyx.feature_flags.interface import FeatureFlagProvider
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class PostHogFeatureFlagProvider(FeatureFlagProvider):
|
||||
"""
|
||||
PostHog-based feature flag provider.
|
||||
|
||||
Uses PostHog's feature flag API to determine if features are enabled
|
||||
for specific users. Only active in multi-tenant mode.
|
||||
"""
|
||||
|
||||
def feature_enabled(
|
||||
self,
|
||||
flag_key: str,
|
||||
user_id: UUID,
|
||||
user_properties: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a feature flag is enabled for a user via PostHog.
|
||||
|
||||
Args:
|
||||
flag_key: The identifier for the feature flag to check
|
||||
user_id: The unique identifier for the user
|
||||
user_properties: Optional dictionary of user properties/attributes
|
||||
that may influence flag evaluation
|
||||
|
||||
Returns:
|
||||
True if the feature is enabled for the user, False otherwise.
|
||||
"""
|
||||
try:
|
||||
posthog.set(
|
||||
distinct_id=user_id,
|
||||
properties=user_properties,
|
||||
)
|
||||
is_enabled = posthog.feature_enabled(
|
||||
flag_key,
|
||||
str(user_id),
|
||||
person_properties=user_properties,
|
||||
)
|
||||
|
||||
return bool(is_enabled) if is_enabled is not None else False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error checking feature flag {flag_key} for user {user_id}: {e}"
|
||||
)
|
||||
return False
|
||||
@@ -1,45 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
from typing import cast
|
||||
from typing import List
|
||||
|
||||
from cohere import Client
|
||||
|
||||
from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
|
||||
|
||||
Embedding = List[float]
|
||||
|
||||
|
||||
def load_processed_docs(cohere_enabled: bool) -> list[dict]:
|
||||
base_path = os.path.join(os.getcwd(), "onyx", "seeding")
|
||||
|
||||
if cohere_enabled and COHERE_DEFAULT_API_KEY:
|
||||
initial_docs_path = os.path.join(base_path, "initial_docs_cohere.json")
|
||||
processed_docs = json.load(open(initial_docs_path))
|
||||
|
||||
cohere_client = Client(api_key=COHERE_DEFAULT_API_KEY)
|
||||
embed_model = "embed-english-v3.0"
|
||||
|
||||
for doc in processed_docs:
|
||||
title_embed_response = cohere_client.embed(
|
||||
texts=[doc["title"]],
|
||||
model=embed_model,
|
||||
input_type="search_document",
|
||||
)
|
||||
content_embed_response = cohere_client.embed(
|
||||
texts=[doc["content"]],
|
||||
model=embed_model,
|
||||
input_type="search_document",
|
||||
)
|
||||
|
||||
doc["title_embedding"] = cast(
|
||||
List[Embedding], title_embed_response.embeddings
|
||||
)[0]
|
||||
doc["content_embedding"] = cast(
|
||||
List[Embedding], content_embed_response.embeddings
|
||||
)[0]
|
||||
else:
|
||||
initial_docs_path = os.path.join(base_path, "initial_docs.json")
|
||||
processed_docs = json.load(open(initial_docs_path))
|
||||
|
||||
return processed_docs
|
||||
@@ -37,6 +37,19 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
router = APIRouter(prefix="/auth/saml")
|
||||
|
||||
# Azure AD / Entra ID often returns the email attribute under different keys.
|
||||
# Keep a list of common variations so we can fall back gracefully if the IdP
|
||||
# does not send the plain "email" attribute name.
|
||||
EMAIL_ATTRIBUTE_KEYS = {
|
||||
"email",
|
||||
"emailaddress",
|
||||
"mail",
|
||||
"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress",
|
||||
"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/mail",
|
||||
"http://schemas.microsoft.com/identity/claims/emailaddress",
|
||||
}
|
||||
EMAIL_ATTRIBUTE_KEYS_LOWER = {key.lower() for key in EMAIL_ATTRIBUTE_KEYS}
|
||||
|
||||
|
||||
async def upsert_saml_user(email: str) -> User:
|
||||
"""
|
||||
@@ -204,16 +217,37 @@ async def _process_saml_callback(
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
user_email = auth.get_attribute("email")
|
||||
if not user_email:
|
||||
detail = "SAML is not set up correctly, email attribute must be provided."
|
||||
logger.error(detail)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=detail,
|
||||
)
|
||||
user_email: str | None = None
|
||||
|
||||
user_email = user_email[0]
|
||||
# The OneLogin toolkit normalizes attribute keys, but still performs a
|
||||
# case-sensitive lookup. Try the common keys first and then fall back to a
|
||||
# case-insensitive scan of all returned attributes.
|
||||
for attribute_key in EMAIL_ATTRIBUTE_KEYS:
|
||||
attribute_values = auth.get_attribute(attribute_key)
|
||||
if attribute_values:
|
||||
user_email = attribute_values[0]
|
||||
break
|
||||
|
||||
if not user_email:
|
||||
# Fallback: perform a case-insensitive lookup across all attributes in
|
||||
# case the IdP sent the email claim with a different capitalization.
|
||||
attributes = auth.get_attributes()
|
||||
for key, values in attributes.items():
|
||||
if key.lower() in EMAIL_ATTRIBUTE_KEYS_LOWER:
|
||||
if values:
|
||||
user_email = values[0]
|
||||
break
|
||||
if not user_email:
|
||||
detail = "SAML is not set up correctly, email attribute must be provided."
|
||||
logger.error(detail)
|
||||
logger.debug(
|
||||
"Received SAML attributes without email: %s",
|
||||
list(attributes.keys()),
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
user = await upsert_saml_user(email=user_email)
|
||||
|
||||
|
||||
@@ -37,9 +37,9 @@ from onyx.db.models import AvailableTenant
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import UserTenantMapping
|
||||
from onyx.llm.llm_provider_options import ANTHROPIC_MODEL_NAMES
|
||||
from onyx.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
|
||||
from onyx.llm.llm_provider_options import ANTHROPIC_VISIBLE_MODEL_NAMES
|
||||
from onyx.llm.llm_provider_options import get_anthropic_model_names
|
||||
from onyx.llm.llm_provider_options import OPEN_AI_MODEL_NAMES
|
||||
from onyx.llm.llm_provider_options import OPEN_AI_VISIBLE_MODEL_NAMES
|
||||
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
|
||||
@@ -278,7 +278,7 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
is_visible=name in ANTHROPIC_VISIBLE_MODEL_NAMES,
|
||||
max_input_tokens=None,
|
||||
)
|
||||
for name in ANTHROPIC_MODEL_NAMES
|
||||
for name in get_anthropic_model_names()
|
||||
],
|
||||
api_key_changed=True,
|
||||
)
|
||||
|
||||
22
backend/ee/onyx/utils/posthog_client.py
Normal file
22
backend/ee/onyx/utils/posthog_client.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Any
|
||||
|
||||
from posthog import Posthog
|
||||
|
||||
from ee.onyx.configs.app_configs import POSTHOG_API_KEY
|
||||
from ee.onyx.configs.app_configs import POSTHOG_HOST
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def posthog_on_error(error: Any, items: Any) -> None:
|
||||
"""Log any PostHog delivery errors."""
|
||||
logger.error(f"PostHog error: {error}, items: {items}")
|
||||
|
||||
|
||||
posthog = Posthog(
|
||||
project_api_key=POSTHOG_API_KEY,
|
||||
host=POSTHOG_HOST,
|
||||
debug=True,
|
||||
on_error=posthog_on_error,
|
||||
)
|
||||
@@ -1,27 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
from posthog import Posthog
|
||||
|
||||
from ee.onyx.configs.app_configs import POSTHOG_API_KEY
|
||||
from ee.onyx.configs.app_configs import POSTHOG_HOST
|
||||
from ee.onyx.utils.posthog_client import posthog
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def posthog_on_error(error: Any, items: Any) -> None:
|
||||
"""Log any PostHog delivery errors."""
|
||||
logger.error(f"PostHog error: {error}, items: {items}")
|
||||
|
||||
|
||||
posthog = Posthog(
|
||||
project_api_key=POSTHOG_API_KEY,
|
||||
host=POSTHOG_HOST,
|
||||
debug=True,
|
||||
on_error=posthog_on_error,
|
||||
)
|
||||
|
||||
|
||||
def event_telemetry(
|
||||
distinct_id: str, event: str, properties: dict | None = None
|
||||
) -> None:
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fastapi import APIRouter
|
||||
from huggingface_hub import snapshot_download # type: ignore
|
||||
from setfit import SetFitModel # type: ignore[import]
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
from transformers import BatchEncoding # type: ignore
|
||||
from transformers import PreTrainedTokenizer # type: ignore
|
||||
|
||||
from model_server.constants import INFORMATION_CONTENT_MODEL_WARM_UP_STRING
|
||||
from model_server.constants import MODEL_WARM_UP_STRING
|
||||
@@ -37,23 +35,30 @@ from shared_configs.model_server_models import ContentClassificationPrediction
|
||||
from shared_configs.model_server_models import IntentRequest
|
||||
from shared_configs.model_server_models import IntentResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from setfit import SetFitModel # type: ignore
|
||||
from transformers import PreTrainedTokenizer, BatchEncoding # type: ignore
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/custom")
|
||||
|
||||
_CONNECTOR_CLASSIFIER_TOKENIZER: PreTrainedTokenizer | None = None
|
||||
_CONNECTOR_CLASSIFIER_TOKENIZER: Optional["PreTrainedTokenizer"] = None
|
||||
_CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
|
||||
|
||||
_INTENT_TOKENIZER: PreTrainedTokenizer | None = None
|
||||
_INTENT_TOKENIZER: Optional["PreTrainedTokenizer"] = None
|
||||
_INTENT_MODEL: HybridClassifier | None = None
|
||||
|
||||
_INFORMATION_CONTENT_MODEL: SetFitModel | None = None
|
||||
_INFORMATION_CONTENT_MODEL: Optional["SetFitModel"] = None
|
||||
|
||||
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
|
||||
|
||||
|
||||
def get_connector_classifier_tokenizer() -> PreTrainedTokenizer:
|
||||
def get_connector_classifier_tokenizer() -> "PreTrainedTokenizer":
|
||||
global _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
if _CONNECTOR_CLASSIFIER_TOKENIZER is None:
|
||||
# The tokenizer details are not uploaded to the HF hub since it's just the
|
||||
# unmodified distilbert tokenizer.
|
||||
@@ -95,7 +100,9 @@ def get_local_connector_classifier(
|
||||
return _CONNECTOR_CLASSIFIER_MODEL
|
||||
|
||||
|
||||
def get_intent_model_tokenizer() -> PreTrainedTokenizer:
|
||||
def get_intent_model_tokenizer() -> "PreTrainedTokenizer":
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
global _INTENT_TOKENIZER
|
||||
if _INTENT_TOKENIZER is None:
|
||||
# The tokenizer details are not uploaded to the HF hub since it's just the
|
||||
@@ -141,7 +148,9 @@ def get_local_intent_model(
|
||||
def get_local_information_content_model(
|
||||
model_name_or_path: str = INFORMATION_CONTENT_MODEL_VERSION,
|
||||
tag: str | None = INFORMATION_CONTENT_MODEL_TAG,
|
||||
) -> SetFitModel:
|
||||
) -> "SetFitModel":
|
||||
from setfit import SetFitModel
|
||||
|
||||
global _INFORMATION_CONTENT_MODEL
|
||||
if _INFORMATION_CONTENT_MODEL is None:
|
||||
try:
|
||||
@@ -179,7 +188,7 @@ def get_local_information_content_model(
|
||||
def tokenize_connector_classification_query(
|
||||
connectors: list[str],
|
||||
query: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
connector_token_end_id: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@@ -267,7 +276,7 @@ def warm_up_information_content_model() -> None:
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
|
||||
def run_inference(tokens: "BatchEncoding") -> tuple[list[float], list[float]]:
|
||||
intent_model = get_local_intent_model()
|
||||
device = intent_model.device
|
||||
|
||||
@@ -401,7 +410,7 @@ def run_content_classification_inference(
|
||||
|
||||
|
||||
def map_keywords(
|
||||
input_ids: torch.Tensor, tokenizer: PreTrainedTokenizer, is_keyword: list[bool]
|
||||
input_ids: torch.Tensor, tokenizer: "PreTrainedTokenizer", is_keyword: list[bool]
|
||||
) -> list[str]:
|
||||
tokens = tokenizer.convert_ids_to_tokens(input_ids) # type: ignore
|
||||
|
||||
|
||||
@@ -2,13 +2,11 @@ import asyncio
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from litellm.exceptions import RateLimitError
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
|
||||
from model_server.utils import simple_log_function_time
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -20,6 +18,9 @@ from shared_configs.model_server_models import EmbedResponse
|
||||
from shared_configs.model_server_models import RerankRequest
|
||||
from shared_configs.model_server_models import RerankResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sentence_transformers import CrossEncoder, SentenceTransformer
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/encoder")
|
||||
@@ -88,8 +89,10 @@ def get_embedding_model(
|
||||
|
||||
def get_local_reranking_model(
|
||||
model_name: str,
|
||||
) -> CrossEncoder:
|
||||
) -> "CrossEncoder":
|
||||
global _RERANK_MODEL
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
|
||||
if _RERANK_MODEL is None:
|
||||
logger.notice(f"Loading {model_name}")
|
||||
model = CrossEncoder(model_name)
|
||||
@@ -207,6 +210,8 @@ async def route_bi_encoder_embed(
|
||||
async def process_embed_request(
|
||||
embed_request: EmbedRequest, gpu_type: str = "UNKNOWN"
|
||||
) -> EmbedResponse:
|
||||
from litellm.exceptions import RateLimitError
|
||||
|
||||
# Only local models should use this endpoint - API providers should make direct API calls
|
||||
if embed_request.provider_type is not None:
|
||||
raise ValueError(
|
||||
|
||||
@@ -30,6 +30,7 @@ from shared_configs.configs import MIN_THREADS_ML_MODELS
|
||||
from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
from shared_configs.configs import SKIP_WARM_UP
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
@@ -91,16 +92,17 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
|
||||
logger.notice(f"Torch Threads: {torch.get_num_threads()}")
|
||||
|
||||
if not INDEXING_ONLY:
|
||||
logger.notice(
|
||||
"The intent model should run on the model server. The information content model should not run here."
|
||||
)
|
||||
warm_up_intent_model()
|
||||
if not SKIP_WARM_UP:
|
||||
if not INDEXING_ONLY:
|
||||
logger.notice("Warming up intent model for inference model server")
|
||||
warm_up_intent_model()
|
||||
else:
|
||||
logger.notice(
|
||||
"Warming up content information model for indexing model server"
|
||||
)
|
||||
warm_up_information_content_model()
|
||||
else:
|
||||
logger.notice(
|
||||
"The content information model should run on the indexing model server. The intent model should not run here."
|
||||
)
|
||||
warm_up_information_content_model()
|
||||
logger.notice("Skipping model warmup due to SKIP_WARM_UP=true")
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
import json
|
||||
import os
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import DistilBertConfig # type: ignore
|
||||
from transformers import DistilBertModel # type: ignore
|
||||
from transformers import DistilBertTokenizer # type: ignore
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import DistilBertConfig # type: ignore
|
||||
|
||||
|
||||
class HybridClassifier(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
from transformers import DistilBertConfig, DistilBertModel
|
||||
|
||||
super().__init__()
|
||||
config = DistilBertConfig()
|
||||
self.distilbert = DistilBertModel(config)
|
||||
@@ -74,7 +78,9 @@ class HybridClassifier(nn.Module):
|
||||
|
||||
|
||||
class ConnectorClassifier(nn.Module):
|
||||
def __init__(self, config: DistilBertConfig) -> None:
|
||||
def __init__(self, config: "DistilBertConfig") -> None:
|
||||
from transformers import DistilBertTokenizer, DistilBertModel
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
@@ -115,6 +121,8 @@ class ConnectorClassifier(nn.Module):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier":
|
||||
from transformers import DistilBertConfig
|
||||
|
||||
config = cast(
|
||||
DistilBertConfig,
|
||||
DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json")),
|
||||
|
||||
@@ -3,6 +3,7 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from braintrust import traced
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
@@ -22,6 +23,9 @@ from onyx.agents.agent_search.dr.models import DecisionResponse
|
||||
from onyx.agents.agent_search.dr.models import DRPromptPurpose
|
||||
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
|
||||
from onyx.agents.agent_search.dr.models import OrchestratorTool
|
||||
from onyx.agents.agent_search.dr.process_llm_stream import (
|
||||
BasicSearchProcessedStreamResults,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.process_llm_stream import process_llm_stream
|
||||
from onyx.agents.agent_search.dr.states import MainState
|
||||
from onyx.agents.agent_search.dr.states import OrchestrationSetup
|
||||
@@ -70,6 +74,7 @@ from onyx.prompts.dr_prompts import DEFAULT_DR_SYSTEM_PROMPT
|
||||
from onyx.prompts.dr_prompts import REPEAT_PROMPT
|
||||
from onyx.prompts.dr_prompts import TOOL_DESCRIPTION
|
||||
from onyx.prompts.prompt_template import PromptTemplate
|
||||
from onyx.prompts.prompt_utils import handle_company_awareness
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
@@ -116,7 +121,9 @@ def _get_available_tools(
|
||||
else:
|
||||
include_kg = False
|
||||
|
||||
tool_dict: dict[int, Tool] = {tool.id: tool for tool in get_tools(db_session)}
|
||||
tool_dict: dict[int, Tool] = {
|
||||
tool.id: tool for tool in get_tools(db_session, only_enabled=True)
|
||||
}
|
||||
|
||||
for tool in graph_config.tooling.tools:
|
||||
|
||||
@@ -484,6 +491,7 @@ def clarifier(
|
||||
+ PROJECT_INSTRUCTIONS_SEPARATOR
|
||||
+ graph_config.inputs.project_instructions
|
||||
)
|
||||
assistant_system_prompt = handle_company_awareness(assistant_system_prompt)
|
||||
|
||||
chat_history_string = (
|
||||
get_chat_history_string(
|
||||
@@ -666,28 +674,30 @@ def clarifier(
|
||||
system_prompt_to_use = assistant_system_prompt
|
||||
user_prompt_to_use = decision_prompt + assistant_task_prompt
|
||||
|
||||
stream = graph_config.tooling.primary_llm.stream(
|
||||
prompt=create_question_prompt(
|
||||
cast(str, system_prompt_to_use),
|
||||
cast(str, user_prompt_to_use),
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
),
|
||||
tools=([_ARTIFICIAL_ALL_ENCOMPASSING_TOOL]),
|
||||
tool_choice=(None),
|
||||
structured_response_format=graph_config.inputs.structured_response_format,
|
||||
)
|
||||
|
||||
full_response = process_llm_stream(
|
||||
messages=stream,
|
||||
should_stream_answer=True,
|
||||
writer=writer,
|
||||
ind=0,
|
||||
final_search_results=context_llm_docs,
|
||||
displayed_search_results=context_llm_docs,
|
||||
generate_final_answer=True,
|
||||
chat_message_id=str(graph_config.persistence.chat_session_id),
|
||||
)
|
||||
@traced(name="clarifier stream and process", type="llm")
|
||||
def stream_and_process() -> BasicSearchProcessedStreamResults:
|
||||
stream = graph_config.tooling.primary_llm.stream(
|
||||
prompt=create_question_prompt(
|
||||
cast(str, system_prompt_to_use),
|
||||
cast(str, user_prompt_to_use),
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
),
|
||||
tools=([_ARTIFICIAL_ALL_ENCOMPASSING_TOOL]),
|
||||
tool_choice=(None),
|
||||
structured_response_format=graph_config.inputs.structured_response_format,
|
||||
)
|
||||
return process_llm_stream(
|
||||
messages=stream,
|
||||
should_stream_answer=True,
|
||||
writer=writer,
|
||||
ind=0,
|
||||
final_search_results=context_llm_docs,
|
||||
displayed_search_results=context_llm_docs,
|
||||
generate_final_answer=True,
|
||||
chat_message_id=str(graph_config.persistence.chat_session_id),
|
||||
)
|
||||
|
||||
full_response = stream_and_process()
|
||||
if len(full_response.ai_message_chunk.tool_calls) == 0:
|
||||
|
||||
if isinstance(full_response.full_answer, str):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
@@ -28,6 +29,7 @@ from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import ImageShape
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -62,6 +64,29 @@ def image_generation(
|
||||
image_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
image_tool = cast(ImageGenerationTool, image_tool_info.tool_object)
|
||||
|
||||
image_prompt = branch_query
|
||||
requested_shape: ImageShape | None = None
|
||||
|
||||
try:
|
||||
parsed_query = json.loads(branch_query)
|
||||
except json.JSONDecodeError:
|
||||
parsed_query = None
|
||||
|
||||
if isinstance(parsed_query, dict):
|
||||
prompt_from_llm = parsed_query.get("prompt")
|
||||
if isinstance(prompt_from_llm, str) and prompt_from_llm.strip():
|
||||
image_prompt = prompt_from_llm.strip()
|
||||
|
||||
raw_shape = parsed_query.get("shape")
|
||||
if isinstance(raw_shape, str):
|
||||
try:
|
||||
requested_shape = ImageShape(raw_shape)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"Received unsupported image shape '%s' from LLM. Falling back to square.",
|
||||
raw_shape,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Image generation start for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
@@ -69,7 +94,15 @@ def image_generation(
|
||||
# Generate images using the image generation tool
|
||||
image_generation_responses: list[ImageGenerationResponse] = []
|
||||
|
||||
for tool_response in image_tool.run(prompt=branch_query):
|
||||
if requested_shape is not None:
|
||||
tool_iterator = image_tool.run(
|
||||
prompt=image_prompt,
|
||||
shape=requested_shape.value,
|
||||
)
|
||||
else:
|
||||
tool_iterator = image_tool.run(prompt=image_prompt)
|
||||
|
||||
for tool_response in tool_iterator:
|
||||
if tool_response.id == IMAGE_GENERATION_HEARTBEAT_ID:
|
||||
# Stream heartbeat to frontend
|
||||
write_custom_event(
|
||||
@@ -95,6 +128,7 @@ def image_generation(
|
||||
file_id=file_id,
|
||||
url=build_frontend_file_url(file_id),
|
||||
revised_prompt=img.revised_prompt,
|
||||
shape=(requested_shape or ImageShape.SQUARE).value,
|
||||
)
|
||||
for file_id, img in zip(file_ids, image_generation_responses)
|
||||
]
|
||||
@@ -107,15 +141,29 @@ def image_generation(
|
||||
if final_generated_images:
|
||||
image_descriptions = []
|
||||
for i, img in enumerate(final_generated_images, 1):
|
||||
image_descriptions.append(f"Image {i}: {img.revised_prompt}")
|
||||
if img.shape and img.shape != ImageShape.SQUARE.value:
|
||||
image_descriptions.append(
|
||||
f"Image {i}: {img.revised_prompt} (shape: {img.shape})"
|
||||
)
|
||||
else:
|
||||
image_descriptions.append(f"Image {i}: {img.revised_prompt}")
|
||||
|
||||
answer_string = (
|
||||
f"Generated {len(final_generated_images)} image(s) based on the request: {branch_query}\n\n"
|
||||
f"Generated {len(final_generated_images)} image(s) based on the request: {image_prompt}\n\n"
|
||||
+ "\n".join(image_descriptions)
|
||||
)
|
||||
reasoning = f"Used image generation tool to create {len(final_generated_images)} image(s) based on the user's request."
|
||||
if requested_shape:
|
||||
reasoning = (
|
||||
"Used image generation tool to create "
|
||||
f"{len(final_generated_images)} image(s) in {requested_shape.value} orientation."
|
||||
)
|
||||
else:
|
||||
reasoning = (
|
||||
"Used image generation tool to create "
|
||||
f"{len(final_generated_images)} image(s) based on the user's request."
|
||||
)
|
||||
else:
|
||||
answer_string = f"Failed to generate images for request: {branch_query}"
|
||||
answer_string = f"Failed to generate images for request: {image_prompt}"
|
||||
reasoning = "Image generation tool did not return any results."
|
||||
|
||||
return BranchUpdate(
|
||||
|
||||
@@ -5,6 +5,7 @@ class GeneratedImage(BaseModel):
|
||||
file_id: str
|
||||
url: str
|
||||
revised_prompt: str
|
||||
shape: str | None = None
|
||||
|
||||
|
||||
# Needed for PydanticType
|
||||
|
||||
@@ -2,30 +2,28 @@ from exa_py import Exa
|
||||
from exa_py.api import HighlightsContentsOptions
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetContent,
|
||||
WebContent,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchProvider,
|
||||
WebSearchProvider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchResult,
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.configs.chat_configs import EXA_API_KEY
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
|
||||
# TODO Dependency inject for testing
|
||||
class ExaClient(InternetSearchProvider):
|
||||
class ExaClient(WebSearchProvider):
|
||||
def __init__(self, api_key: str | None = EXA_API_KEY) -> None:
|
||||
self.exa = Exa(api_key=api_key)
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def search(self, query: str) -> list[InternetSearchResult]:
|
||||
def search(self, query: str) -> list[WebSearchResult]:
|
||||
response = self.exa.search_and_contents(
|
||||
query,
|
||||
type="fast",
|
||||
livecrawl="never",
|
||||
type="auto",
|
||||
highlights=HighlightsContentsOptions(
|
||||
num_sentences=2,
|
||||
highlights_per_url=1,
|
||||
@@ -34,7 +32,7 @@ class ExaClient(InternetSearchProvider):
|
||||
)
|
||||
|
||||
return [
|
||||
InternetSearchResult(
|
||||
WebSearchResult(
|
||||
title=result.title or "",
|
||||
link=result.url,
|
||||
snippet=result.highlights[0] if result.highlights else "",
|
||||
@@ -49,7 +47,7 @@ class ExaClient(InternetSearchProvider):
|
||||
]
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def contents(self, urls: list[str]) -> list[InternetContent]:
|
||||
def contents(self, urls: list[str]) -> list[WebContent]:
|
||||
response = self.exa.get_contents(
|
||||
urls=urls,
|
||||
text=True,
|
||||
@@ -57,7 +55,7 @@ class ExaClient(InternetSearchProvider):
|
||||
)
|
||||
|
||||
return [
|
||||
InternetContent(
|
||||
WebContent(
|
||||
title=result.title or "",
|
||||
link=result.url,
|
||||
full_content=result.text or "",
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebContent,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebSearchProvider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.configs.chat_configs import SERPER_API_KEY
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
SERPER_SEARCH_URL = "https://google.serper.dev/search"
|
||||
SERPER_CONTENTS_URL = "https://scrape.serper.dev"
|
||||
|
||||
|
||||
class SerperClient(WebSearchProvider):
|
||||
def __init__(self, api_key: str | None = SERPER_API_KEY) -> None:
|
||||
self.headers = {
|
||||
"X-API-KEY": api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def search(self, query: str) -> list[WebSearchResult]:
|
||||
payload = {
|
||||
"q": query,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
SERPER_SEARCH_URL,
|
||||
headers=self.headers,
|
||||
data=json.dumps(payload),
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
results = response.json()
|
||||
organic_results = results["organic"]
|
||||
|
||||
return [
|
||||
WebSearchResult(
|
||||
title=result["title"],
|
||||
link=result["link"],
|
||||
snippet=result["snippet"],
|
||||
author=None,
|
||||
published_date=None,
|
||||
)
|
||||
for result in organic_results
|
||||
]
|
||||
|
||||
def contents(self, urls: list[str]) -> list[WebContent]:
|
||||
if not urls:
|
||||
return []
|
||||
|
||||
# Serper can responds with 500s regularly. We want to retry,
|
||||
# but in the event of failure, return an unsuccesful scrape.
|
||||
def safe_get_webpage_content(url: str) -> WebContent:
|
||||
try:
|
||||
return self._get_webpage_content(url)
|
||||
except Exception:
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=min(8, len(urls))) as e:
|
||||
return list(e.map(safe_get_webpage_content, urls))
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def _get_webpage_content(self, url: str) -> WebContent:
|
||||
payload = {
|
||||
"url": url,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
SERPER_CONTENTS_URL,
|
||||
headers=self.headers,
|
||||
data=json.dumps(payload),
|
||||
)
|
||||
|
||||
# 400 returned when serper cannot scrape
|
||||
if response.status_code == 400:
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
# Response only guarantees text
|
||||
text = response_json["text"]
|
||||
|
||||
# metadata & jsonld is not guaranteed to be present
|
||||
metadata = response_json.get("metadata", {})
|
||||
jsonld = response_json.get("jsonld", {})
|
||||
|
||||
title = extract_title_from_metadata(metadata)
|
||||
|
||||
# Serper does not provide a reliable mechanism to extract the url
|
||||
response_url = url
|
||||
published_date_str = extract_published_date_from_jsonld(jsonld)
|
||||
published_date = None
|
||||
|
||||
if published_date_str:
|
||||
try:
|
||||
published_date = time_str_to_utc(published_date_str)
|
||||
except Exception:
|
||||
published_date = None
|
||||
|
||||
return WebContent(
|
||||
title=title or "",
|
||||
link=response_url,
|
||||
full_content=text or "",
|
||||
published_date=published_date,
|
||||
)
|
||||
|
||||
|
||||
def extract_title_from_metadata(metadata: dict[str, str]) -> str | None:
|
||||
keys = ["title", "og:title"]
|
||||
return extract_value_from_dict(metadata, keys)
|
||||
|
||||
|
||||
def extract_published_date_from_jsonld(jsonld: dict[str, str]) -> str | None:
|
||||
keys = ["dateModified"]
|
||||
return extract_value_from_dict(jsonld, keys)
|
||||
|
||||
|
||||
def extract_value_from_dict(data: dict[str, str], keys: list[str]) -> str | None:
|
||||
for key in keys:
|
||||
if key in data:
|
||||
return data[key]
|
||||
return None
|
||||
@@ -7,7 +7,7 @@ from langsmith import traceable
|
||||
|
||||
from onyx.agents.agent_search.dr.models import WebSearchAnswer
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchResult,
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
|
||||
get_default_provider,
|
||||
@@ -75,15 +75,15 @@ def web_search(
|
||||
raise ValueError("No internet search provider found")
|
||||
|
||||
@traceable(name="Search Provider API Call")
|
||||
def _search(search_query: str) -> list[InternetSearchResult]:
|
||||
search_results: list[InternetSearchResult] = []
|
||||
def _search(search_query: str) -> list[WebSearchResult]:
|
||||
search_results: list[WebSearchResult] = []
|
||||
try:
|
||||
search_results = provider.search(search_query)
|
||||
except Exception as e:
|
||||
logger.error(f"Error performing search: {e}")
|
||||
return search_results
|
||||
|
||||
search_results: list[InternetSearchResult] = _search(search_query)
|
||||
search_results: list[WebSearchResult] = _search(search_query)
|
||||
search_results_text = "\n\n".join(
|
||||
[
|
||||
f"{i}. {result.title}\n URL: {result.link}\n"
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchResult,
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
|
||||
InternetSearchInput,
|
||||
@@ -23,7 +23,7 @@ def dedup_urls(
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> InternetSearchInput:
|
||||
branch_questions_to_urls: dict[str, list[str]] = defaultdict(list)
|
||||
unique_results_by_link: dict[str, InternetSearchResult] = {}
|
||||
unique_results_by_link: dict[str, WebSearchResult] = {}
|
||||
for query, result in state.results_to_open:
|
||||
branch_questions_to_urls[query].append(result.link)
|
||||
if result.link not in unique_results_by_link:
|
||||
|
||||
@@ -13,7 +13,7 @@ class ProviderType(Enum):
|
||||
EXA = "exa"
|
||||
|
||||
|
||||
class InternetSearchResult(BaseModel):
|
||||
class WebSearchResult(BaseModel):
|
||||
title: str
|
||||
link: str
|
||||
author: str | None = None
|
||||
@@ -21,18 +21,19 @@ class InternetSearchResult(BaseModel):
|
||||
snippet: str | None = None
|
||||
|
||||
|
||||
class InternetContent(BaseModel):
|
||||
class WebContent(BaseModel):
|
||||
title: str
|
||||
link: str
|
||||
full_content: str
|
||||
published_date: datetime | None = None
|
||||
scrape_successful: bool = True
|
||||
|
||||
|
||||
class InternetSearchProvider(ABC):
|
||||
class WebSearchProvider(ABC):
|
||||
@abstractmethod
|
||||
def search(self, query: str) -> list[InternetSearchResult]:
|
||||
def search(self, query: str) -> list[WebSearchResult]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def contents(self, urls: list[str]) -> list[InternetContent]:
|
||||
def contents(self, urls: list[str]) -> list[WebContent]:
|
||||
pass
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.exa_client import (
|
||||
ExaClient,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.serper_client import (
|
||||
SerperClient,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchProvider,
|
||||
WebSearchProvider,
|
||||
)
|
||||
from onyx.configs.chat_configs import EXA_API_KEY
|
||||
from onyx.configs.chat_configs import SERPER_API_KEY
|
||||
|
||||
|
||||
def get_default_provider() -> InternetSearchProvider | None:
|
||||
def get_default_provider() -> WebSearchProvider | None:
|
||||
if EXA_API_KEY:
|
||||
return ExaClient()
|
||||
if SERPER_API_KEY:
|
||||
return SerperClient()
|
||||
return None
|
||||
|
||||
@@ -4,13 +4,13 @@ from typing import Annotated
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchResult,
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class InternetSearchInput(SubAgentInput):
|
||||
results_to_open: Annotated[list[tuple[str, InternetSearchResult]], add] = []
|
||||
results_to_open: Annotated[list[tuple[str, WebSearchResult]], add] = []
|
||||
parallelization_nr: int = 0
|
||||
branch_question: Annotated[str, lambda x, y: y] = ""
|
||||
branch_questions_to_urls: Annotated[dict[str, list[str]], lambda x, y: y] = {}
|
||||
@@ -18,7 +18,7 @@ class InternetSearchInput(SubAgentInput):
|
||||
|
||||
|
||||
class InternetSearchUpdate(LoggerUpdate):
|
||||
results_to_open: Annotated[list[tuple[str, InternetSearchResult]], add] = []
|
||||
results_to_open: Annotated[list[tuple[str, WebSearchResult]], add] = []
|
||||
|
||||
|
||||
class FetchInput(SubAgentInput):
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetContent,
|
||||
WebContent,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchResult,
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
@@ -17,7 +17,7 @@ def truncate_search_result_content(content: str, max_chars: int = 10000) -> str:
|
||||
|
||||
|
||||
def dummy_inference_section_from_internet_content(
|
||||
result: InternetContent,
|
||||
result: WebContent,
|
||||
) -> InferenceSection:
|
||||
truncated_content = truncate_search_result_content(result.full_content)
|
||||
return InferenceSection(
|
||||
@@ -34,7 +34,7 @@ def dummy_inference_section_from_internet_content(
|
||||
boost=1,
|
||||
recency_bias=1.0,
|
||||
score=1.0,
|
||||
hidden=False,
|
||||
hidden=(not result.scrape_successful),
|
||||
metadata={},
|
||||
match_highlights=[],
|
||||
doc_summary=truncated_content,
|
||||
@@ -48,7 +48,7 @@ def dummy_inference_section_from_internet_content(
|
||||
|
||||
|
||||
def dummy_inference_section_from_internet_search_result(
|
||||
result: InternetSearchResult,
|
||||
result: WebSearchResult,
|
||||
) -> InferenceSection:
|
||||
return InferenceSection(
|
||||
center_chunk=InferenceChunk(
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
@@ -88,11 +87,5 @@ class GraphConfig(BaseModel):
|
||||
# Only needed for agentic search
|
||||
persistence: GraphPersistence
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_search_tool(self) -> "GraphConfig":
|
||||
if self.behavior.use_agentic_search and self.tooling.search_tool is None:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
return self
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages.tool import ToolCall
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallOutput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.tools.message import build_tool_message
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.tool_runner import ToolRunner
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ToolCallException(Exception):
|
||||
"""Exception raised for errors during tool calls."""
|
||||
|
||||
|
||||
def call_tool(
|
||||
state: ToolChoiceUpdate,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> ToolCallUpdate:
|
||||
"""Calls the tool specified in the state and updates the state with the result"""
|
||||
|
||||
cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
tool_choice = state.tool_choice
|
||||
if tool_choice is None:
|
||||
raise ValueError("Cannot invoke tool call node without a tool choice")
|
||||
|
||||
tool = tool_choice.tool
|
||||
tool_args = tool_choice.tool_args
|
||||
tool_id = tool_choice.id
|
||||
tool_runner = ToolRunner(
|
||||
tool, tool_args, override_kwargs=tool_choice.search_tool_override_kwargs
|
||||
)
|
||||
tool_kickoff = tool_runner.kickoff()
|
||||
|
||||
try:
|
||||
tool_responses = []
|
||||
for response in tool_runner.tool_responses():
|
||||
tool_responses.append(response)
|
||||
|
||||
tool_final_result = tool_runner.tool_final_result()
|
||||
except Exception as e:
|
||||
raise ToolCallException(
|
||||
f"Error during tool call for {tool.display_name}: {e}"
|
||||
) from e
|
||||
|
||||
tool_call = ToolCall(name=tool.name, args=tool_args, id=tool_id)
|
||||
tool_call_summary = ToolCallSummary(
|
||||
tool_call_request=AIMessageChunk(content="", tool_calls=[tool_call]),
|
||||
tool_call_result=build_tool_message(
|
||||
tool_call, tool_runner.tool_message_content()
|
||||
),
|
||||
)
|
||||
|
||||
tool_call_output = ToolCallOutput(
|
||||
tool_call_summary=tool_call_summary,
|
||||
tool_call_kickoff=tool_kickoff,
|
||||
tool_call_responses=tool_responses,
|
||||
tool_call_final_result=tool_final_result,
|
||||
)
|
||||
return ToolCallUpdate(tool_call_output=tool_call_output)
|
||||
@@ -1,354 +0,0 @@
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import ToolCall
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.process_llm_stream import process_llm_stream
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoice
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceState
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
||||
from onyx.chat.tool_handling.tool_response_handler import (
|
||||
get_tool_call_for_non_tool_calling_llm_impl,
|
||||
)
|
||||
from onyx.configs.chat_configs import USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH
|
||||
from onyx.context.search.preprocessing.preprocessing import query_analysis
|
||||
from onyx.context.search.retrieval.search_runner import get_query_embedding
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.prompts.chat_prompts import QUERY_KEYWORD_EXPANSION_WITH_HISTORY_PROMPT
|
||||
from onyx.prompts.chat_prompts import QUERY_KEYWORD_EXPANSION_WITHOUT_HISTORY_PROMPT
|
||||
from onyx.prompts.chat_prompts import QUERY_SEMANTIC_EXPANSION_WITH_HISTORY_PROMPT
|
||||
from onyx.prompts.chat_prompts import QUERY_SEMANTIC_EXPANSION_WITHOUT_HISTORY_PROMPT
|
||||
from onyx.tools.models import QueryExpansions
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import TimeoutThread
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _create_history_str(prompt_builder: AnswerPromptBuilder) -> str:
|
||||
# TODO: Add trimming logic
|
||||
history_segments = []
|
||||
for msg in prompt_builder.message_history:
|
||||
if isinstance(msg, HumanMessage):
|
||||
role = "User"
|
||||
elif isinstance(msg, AIMessage):
|
||||
role = "Assistant"
|
||||
else:
|
||||
continue
|
||||
history_segments.append(f"{role}:\n {msg.content}\n\n")
|
||||
return "\n".join(history_segments)
|
||||
|
||||
|
||||
def _expand_query(
|
||||
query: str,
|
||||
expansion_type: QueryExpansionType,
|
||||
prompt_builder: AnswerPromptBuilder,
|
||||
) -> str:
|
||||
|
||||
history_str = _create_history_str(prompt_builder)
|
||||
|
||||
if history_str:
|
||||
if expansion_type == QueryExpansionType.KEYWORD:
|
||||
base_prompt = QUERY_KEYWORD_EXPANSION_WITH_HISTORY_PROMPT
|
||||
else:
|
||||
base_prompt = QUERY_SEMANTIC_EXPANSION_WITH_HISTORY_PROMPT
|
||||
expansion_prompt = base_prompt.format(question=query, history=history_str)
|
||||
else:
|
||||
if expansion_type == QueryExpansionType.KEYWORD:
|
||||
base_prompt = QUERY_KEYWORD_EXPANSION_WITHOUT_HISTORY_PROMPT
|
||||
else:
|
||||
base_prompt = QUERY_SEMANTIC_EXPANSION_WITHOUT_HISTORY_PROMPT
|
||||
expansion_prompt = base_prompt.format(question=query)
|
||||
|
||||
msg = HumanMessage(content=expansion_prompt)
|
||||
primary_llm, _ = get_default_llms()
|
||||
response = primary_llm.invoke([msg])
|
||||
rephrased_query: str = cast(str, response.content)
|
||||
|
||||
return rephrased_query
|
||||
|
||||
|
||||
def _expand_query_non_tool_calling_llm(
|
||||
expanded_keyword_thread: TimeoutThread[str],
|
||||
expanded_semantic_thread: TimeoutThread[str],
|
||||
) -> QueryExpansions | None:
|
||||
keyword_expansion: str | None = wait_on_background(expanded_keyword_thread)
|
||||
semantic_expansion: str | None = wait_on_background(expanded_semantic_thread)
|
||||
|
||||
if keyword_expansion is None or semantic_expansion is None:
|
||||
return None
|
||||
|
||||
return QueryExpansions(
|
||||
keywords_expansions=[keyword_expansion],
|
||||
semantic_expansions=[semantic_expansion],
|
||||
)
|
||||
|
||||
|
||||
# TODO: break this out into an implementation function
|
||||
# and a function that handles extracting the necessary fields
|
||||
# from the state and config
|
||||
# TODO: fan-out to multiple tool call nodes? Make this configurable?
|
||||
@log_function_time(print_only=True)
|
||||
def choose_tool(
|
||||
state: ToolChoiceState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> ToolChoiceUpdate:
|
||||
"""
|
||||
This node is responsible for calling the LLM to choose a tool. If no tool is chosen,
|
||||
The node MAY emit an answer, depending on whether state["should_stream_answer"] is set.
|
||||
"""
|
||||
should_stream_answer = state.should_stream_answer
|
||||
|
||||
agent_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
force_use_tool = agent_config.tooling.force_use_tool
|
||||
|
||||
embedding_thread: TimeoutThread[Embedding] | None = None
|
||||
keyword_thread: TimeoutThread[tuple[bool, list[str]]] | None = None
|
||||
expanded_keyword_thread: TimeoutThread[str] | None = None
|
||||
expanded_semantic_thread: TimeoutThread[str] | None = None
|
||||
# If we have override_kwargs, add them to the tool_args
|
||||
override_kwargs: SearchToolOverrideKwargs = (
|
||||
force_use_tool.override_kwargs or SearchToolOverrideKwargs()
|
||||
)
|
||||
override_kwargs.original_query = agent_config.inputs.prompt_builder.raw_user_query
|
||||
|
||||
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
|
||||
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
|
||||
|
||||
llm = agent_config.tooling.primary_llm
|
||||
skip_gen_ai_answer_generation = agent_config.behavior.skip_gen_ai_answer_generation
|
||||
|
||||
if (
|
||||
not agent_config.behavior.use_agentic_search
|
||||
and agent_config.tooling.search_tool is not None
|
||||
and (
|
||||
not force_use_tool.force_use or force_use_tool.tool_name == SearchTool._NAME
|
||||
)
|
||||
):
|
||||
# Run in a background thread to avoid blocking the main thread
|
||||
embedding_thread = run_in_background(
|
||||
get_query_embedding,
|
||||
agent_config.inputs.prompt_builder.raw_user_query,
|
||||
agent_config.persistence.db_session,
|
||||
)
|
||||
keyword_thread = run_in_background(
|
||||
query_analysis,
|
||||
agent_config.inputs.prompt_builder.raw_user_query,
|
||||
)
|
||||
|
||||
if USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH:
|
||||
|
||||
expanded_keyword_thread = run_in_background(
|
||||
_expand_query,
|
||||
agent_config.inputs.prompt_builder.raw_user_query,
|
||||
QueryExpansionType.KEYWORD,
|
||||
prompt_builder,
|
||||
)
|
||||
expanded_semantic_thread = run_in_background(
|
||||
_expand_query,
|
||||
agent_config.inputs.prompt_builder.raw_user_query,
|
||||
QueryExpansionType.SEMANTIC,
|
||||
prompt_builder,
|
||||
)
|
||||
|
||||
structured_response_format = agent_config.inputs.structured_response_format
|
||||
tools = [
|
||||
tool for tool in (agent_config.tooling.tools or []) if tool.name in state.tools
|
||||
]
|
||||
|
||||
tool, tool_args = None, None
|
||||
if force_use_tool.force_use and force_use_tool.args is not None:
|
||||
tool_name, tool_args = (
|
||||
force_use_tool.tool_name,
|
||||
force_use_tool.args,
|
||||
)
|
||||
tool = get_tool_by_name(tools, tool_name)
|
||||
|
||||
# special pre-logic for non-tool calling LLM case
|
||||
elif not using_tool_calling_llm and tools:
|
||||
chosen_tool_and_args = get_tool_call_for_non_tool_calling_llm_impl(
|
||||
force_use_tool=force_use_tool,
|
||||
tools=tools,
|
||||
prompt_builder=prompt_builder,
|
||||
llm=llm,
|
||||
)
|
||||
if chosen_tool_and_args:
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
|
||||
# If we have a tool and tool args, we are ready to request a tool call.
|
||||
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
|
||||
if tool and tool_args:
|
||||
if embedding_thread and tool.name == SearchTool._NAME:
|
||||
# Wait for the embedding thread to finish
|
||||
embedding = wait_on_background(embedding_thread)
|
||||
override_kwargs.precomputed_query_embedding = embedding
|
||||
if keyword_thread and tool.name == SearchTool._NAME:
|
||||
is_keyword, keywords = wait_on_background(keyword_thread)
|
||||
override_kwargs.precomputed_is_keyword = is_keyword
|
||||
override_kwargs.precomputed_keywords = keywords
|
||||
# dual keyword expansion needs to be added here for non-tool calling LLM case
|
||||
if (
|
||||
USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH
|
||||
and expanded_keyword_thread
|
||||
and expanded_semantic_thread
|
||||
and tool.name == SearchTool._NAME
|
||||
):
|
||||
override_kwargs.expanded_queries = _expand_query_non_tool_calling_llm(
|
||||
expanded_keyword_thread=expanded_keyword_thread,
|
||||
expanded_semantic_thread=expanded_semantic_thread,
|
||||
)
|
||||
if (
|
||||
USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH
|
||||
and tool.name == SearchTool._NAME
|
||||
and override_kwargs.expanded_queries
|
||||
):
|
||||
if (
|
||||
override_kwargs.expanded_queries.keywords_expansions is None
|
||||
or override_kwargs.expanded_queries.semantic_expansions is None
|
||||
):
|
||||
raise ValueError("No expanded keyword or semantic threads found.")
|
||||
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=tool,
|
||||
tool_args=tool_args,
|
||||
id=str(uuid4()),
|
||||
search_tool_override_kwargs=override_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
# if we're skipping gen ai answer generation, we should only
|
||||
# continue if we're forcing a tool call (which will be emitted by
|
||||
# the tool calling llm in the stream() below)
|
||||
if skip_gen_ai_answer_generation and not force_use_tool.force_use:
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
built_prompt = (
|
||||
prompt_builder.build()
|
||||
if isinstance(prompt_builder, AnswerPromptBuilder)
|
||||
else prompt_builder.built_prompt
|
||||
)
|
||||
# At this point, we are either using a tool calling LLM or we are skipping the tool call.
|
||||
# DEBUG: good breakpoint
|
||||
stream = llm.stream(
|
||||
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
|
||||
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
|
||||
prompt=built_prompt,
|
||||
tools=(
|
||||
[tool.tool_definition() for tool in tools] or None
|
||||
if using_tool_calling_llm
|
||||
else None
|
||||
),
|
||||
tool_choice=(
|
||||
"required"
|
||||
if tools and force_use_tool.force_use and using_tool_calling_llm
|
||||
else None
|
||||
),
|
||||
structured_response_format=structured_response_format,
|
||||
)
|
||||
|
||||
tool_message = process_llm_stream(
|
||||
stream,
|
||||
should_stream_answer
|
||||
and not agent_config.behavior.skip_gen_ai_answer_generation,
|
||||
writer,
|
||||
ind=0,
|
||||
).ai_message_chunk
|
||||
|
||||
if tool_message is None:
|
||||
raise ValueError("No tool message emitted by LLM")
|
||||
|
||||
# If no tool calls are emitted by the LLM, we should not choose a tool
|
||||
if len(tool_message.tool_calls) == 0:
|
||||
logger.debug("No tool calls emitted by LLM")
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
# TODO: here we could handle parallel tool calls. Right now
|
||||
# we just pick the first one that matches.
|
||||
selected_tool: Tool | None = None
|
||||
selected_tool_call_request: ToolCall | None = None
|
||||
for tool_call_request in tool_message.tool_calls:
|
||||
known_tools_by_name = [
|
||||
tool for tool in tools if tool.name == tool_call_request["name"]
|
||||
]
|
||||
|
||||
if known_tools_by_name:
|
||||
selected_tool = known_tools_by_name[0]
|
||||
selected_tool_call_request = tool_call_request
|
||||
break
|
||||
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"tools: {tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
|
||||
if not selected_tool or not selected_tool_call_request:
|
||||
raise ValueError(
|
||||
f"Tool call attempted with tool {selected_tool}, request {selected_tool_call_request}"
|
||||
)
|
||||
|
||||
logger.debug(f"Selected tool: {selected_tool.name}")
|
||||
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
|
||||
|
||||
if embedding_thread and selected_tool.name == SearchTool._NAME:
|
||||
# Wait for the embedding thread to finish
|
||||
embedding = wait_on_background(embedding_thread)
|
||||
override_kwargs.precomputed_query_embedding = embedding
|
||||
if keyword_thread and selected_tool.name == SearchTool._NAME:
|
||||
is_keyword, keywords = wait_on_background(keyword_thread)
|
||||
override_kwargs.precomputed_is_keyword = is_keyword
|
||||
override_kwargs.precomputed_keywords = keywords
|
||||
|
||||
if (
|
||||
selected_tool.name == SearchTool._NAME
|
||||
and expanded_keyword_thread
|
||||
and expanded_semantic_thread
|
||||
):
|
||||
|
||||
override_kwargs.expanded_queries = _expand_query_non_tool_calling_llm(
|
||||
expanded_keyword_thread=expanded_keyword_thread,
|
||||
expanded_semantic_thread=expanded_semantic_thread,
|
||||
)
|
||||
if (
|
||||
USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH
|
||||
and selected_tool.name == SearchTool._NAME
|
||||
and override_kwargs.expanded_queries
|
||||
):
|
||||
# TODO: this is a hack to handle the case where the expanded queries are not found.
|
||||
# We should refactor this to be more robust.
|
||||
if (
|
||||
override_kwargs.expanded_queries.keywords_expansions is None
|
||||
or override_kwargs.expanded_queries.semantic_expansions is None
|
||||
):
|
||||
raise ValueError("No expanded keyword or semantic threads found.")
|
||||
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=selected_tool,
|
||||
tool_args=selected_tool_call_request["args"],
|
||||
id=selected_tool_call_request["id"],
|
||||
search_tool_override_kwargs=override_kwargs,
|
||||
),
|
||||
)
|
||||
@@ -1,17 +0,0 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
|
||||
|
||||
def prepare_tool_input(state: Any, config: RunnableConfig) -> ToolChoiceInput:
|
||||
agent_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
return ToolChoiceInput(
|
||||
# NOTE: this node is used at the top level of the agent, so we always stream
|
||||
should_stream_answer=True,
|
||||
prompt_snapshot=None, # uses default prompt builder
|
||||
tools=[tool.name for tool in (agent_config.tooling.tools or [])],
|
||||
)
|
||||
@@ -5,11 +5,10 @@ from typing import Literal
|
||||
from typing import Type
|
||||
from typing import TypeVar
|
||||
|
||||
from braintrust import traced
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.types import StreamWriter
|
||||
from litellm import get_supported_openai_params
|
||||
from litellm import supports_response_schema
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
@@ -29,6 +28,7 @@ SchemaType = TypeVar("SchemaType", bound=BaseModel)
|
||||
JSON_PATTERN = re.compile(r"```(?:json)?\s*(\{.*?\})\s*```", re.DOTALL)
|
||||
|
||||
|
||||
@traced(name="stream llm", type="llm")
|
||||
def stream_llm_answer(
|
||||
llm: LLM,
|
||||
prompt: LanguageModelInput,
|
||||
@@ -147,6 +147,7 @@ def invoke_llm_json(
|
||||
Invoke an LLM, forcing it to respond in a specified JSON format if possible,
|
||||
and return an object of that schema.
|
||||
"""
|
||||
from litellm.utils import get_supported_openai_params, supports_response_schema
|
||||
|
||||
# check if the model supports response_format: json_schema
|
||||
supports_json = "response_format" in (
|
||||
|
||||
@@ -29,7 +29,7 @@ from onyx.configs.app_configs import SMTP_USER
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME
|
||||
from onyx.configs.constants import ONYX_SLACK_URL
|
||||
from onyx.configs.constants import ONYX_DISCORD_URL
|
||||
from onyx.db.models import User
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -145,7 +145,7 @@ HTML_EMAIL_TEMPLATE = """\
|
||||
<tr>
|
||||
<td class="footer">
|
||||
© {year} {application_name}. All rights reserved.
|
||||
{slack_fragment}
|
||||
{community_link_fragment}
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
@@ -161,9 +161,9 @@ def build_html_email(
|
||||
cta_text: str | None = None,
|
||||
cta_link: str | None = None,
|
||||
) -> str:
|
||||
slack_fragment = ""
|
||||
community_link_fragment = ""
|
||||
if application_name == ONYX_DEFAULT_APPLICATION_NAME:
|
||||
slack_fragment = f'<br>Have questions? Join our Slack community <a href="{ONYX_SLACK_URL}">here</a>.'
|
||||
community_link_fragment = f'<br>Have questions? Join our Discord community <a href="{ONYX_DISCORD_URL}">here</a>.'
|
||||
|
||||
if cta_text and cta_link:
|
||||
cta_block = f'<a class="cta-button" href="{cta_link}">{cta_text}</a>'
|
||||
@@ -175,7 +175,7 @@ def build_html_email(
|
||||
heading=heading,
|
||||
message=message,
|
||||
cta_block=cta_block,
|
||||
slack_fragment=slack_fragment,
|
||||
community_link_fragment=community_link_fragment,
|
||||
year=datetime.now().year,
|
||||
)
|
||||
|
||||
|
||||
@@ -1040,7 +1040,10 @@ async def optional_user(
|
||||
|
||||
# check if an API key is present
|
||||
if user is None:
|
||||
hashed_api_key = get_hashed_api_key_from_request(request)
|
||||
try:
|
||||
hashed_api_key = get_hashed_api_key_from_request(request)
|
||||
except ValueError:
|
||||
hashed_api_key = None
|
||||
if hashed_api_key:
|
||||
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)
|
||||
|
||||
|
||||
111
backend/onyx/background/celery/apps/background.py
Normal file
111
backend/onyx/background/celery/apps/background.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_process_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.background")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received for consolidated background worker.")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME)
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_process_init.connect
|
||||
def init_worker(**kwargs: Any) -> None:
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
base_bootsteps = app_base.get_bootsteps()
|
||||
for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.kg_processing",
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
]
|
||||
)
|
||||
@@ -98,5 +98,8 @@ for bootstep in base_bootsteps:
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
# Ensure the user files indexing worker registers the doc_id migration task
|
||||
# TODO(subash): remove this once the doc_id migration is complete
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -324,5 +324,6 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
"onyx.background.celery.tasks.kg_processing",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -19,7 +19,9 @@ from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -30,7 +32,7 @@ PRUNING_CHECKPOINTED_BATCH_SIZE = 32
|
||||
|
||||
|
||||
def document_batch_to_ids(
|
||||
doc_batch: Iterator[list[Document]],
|
||||
doc_batch: Iterator[list[Document]] | Iterator[list[SlimDocument]],
|
||||
) -> Generator[set[str], None, None]:
|
||||
for doc_list in doc_batch:
|
||||
yield {doc.id for doc in doc_list}
|
||||
@@ -41,20 +43,24 @@ def extract_ids_from_runnable_connector(
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> set[str]:
|
||||
"""
|
||||
If the SlimConnector hasnt been implemented for the given connector, just pull
|
||||
If the given connector is neither a SlimConnector nor a SlimConnectorWithPermSync, just pull
|
||||
all docs using the load_from_state and grab out the IDs.
|
||||
|
||||
Optionally, a callback can be passed to handle the length of each document batch.
|
||||
"""
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
|
||||
if isinstance(runnable_connector, SlimConnector):
|
||||
for metadata_batch in runnable_connector.retrieve_all_slim_documents():
|
||||
all_connector_doc_ids.update({doc.id for doc in metadata_batch})
|
||||
|
||||
doc_batch_id_generator = None
|
||||
|
||||
if isinstance(runnable_connector, LoadConnector):
|
||||
if isinstance(runnable_connector, SlimConnector):
|
||||
doc_batch_id_generator = document_batch_to_ids(
|
||||
runnable_connector.retrieve_all_slim_docs()
|
||||
)
|
||||
elif isinstance(runnable_connector, SlimConnectorWithPermSync):
|
||||
doc_batch_id_generator = document_batch_to_ids(
|
||||
runnable_connector.retrieve_all_slim_docs_perm_sync()
|
||||
)
|
||||
# If the connector isn't slim, fall back to running it normally to get ids
|
||||
elif isinstance(runnable_connector, LoadConnector):
|
||||
doc_batch_id_generator = document_batch_to_ids(
|
||||
runnable_connector.load_from_state()
|
||||
)
|
||||
@@ -78,13 +84,14 @@ def extract_ids_from_runnable_connector(
|
||||
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
|
||||
|
||||
# this function is called per batch for rate limiting
|
||||
def doc_batch_processing_func(doc_batch_ids: set[str]) -> set[str]:
|
||||
return doc_batch_ids
|
||||
|
||||
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
|
||||
doc_batch_processing_func = rate_limit_builder(
|
||||
doc_batch_processing_func = (
|
||||
rate_limit_builder(
|
||||
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
|
||||
)(lambda x: x)
|
||||
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
else lambda x: x
|
||||
)
|
||||
|
||||
for doc_batch_ids in doc_batch_id_generator:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
|
||||
21
backend/onyx/background/celery/configs/background.py
Normal file
21
backend/onyx/background/celery/configs/background.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
from onyx.configs.app_configs import CELERY_WORKER_BACKGROUND_CONCURRENCY
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
worker_concurrency = CELERY_WORKER_BACKGROUND_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
@@ -1,4 +1,5 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
from onyx.configs.app_configs import CELERY_WORKER_HEAVY_CONCURRENCY
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
@@ -15,6 +16,6 @@ result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
worker_concurrency = 4
|
||||
worker_concurrency = CELERY_WORKER_HEAVY_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
from onyx.configs.app_configs import CELERY_WORKER_MONITORING_CONCURRENCY
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
@@ -16,6 +17,6 @@ task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
# Monitoring worker specific settings
|
||||
worker_concurrency = 1 # Single worker is sufficient for monitoring
|
||||
worker_concurrency = CELERY_WORKER_MONITORING_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
|
||||
@@ -33,17 +33,25 @@ beat_task_templates: list[dict] = [
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-user-file-project-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_USER_FILE_PROJECT_SYNC,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "user-file-docid-migration",
|
||||
"task": OnyxCeleryTask.USER_FILE_DOCID_MIGRATION,
|
||||
"schedule": timedelta(minutes=1),
|
||||
"schedule": timedelta(minutes=10),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
"queue": OnyxCeleryQueues.USER_FILES_INDEXING,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -85,9 +93,9 @@ beat_task_templates: list[dict] = [
|
||||
{
|
||||
"name": "check-for-index-attempt-cleanup",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_INDEX_ATTEMPT_CLEANUP,
|
||||
"schedule": timedelta(hours=1),
|
||||
"schedule": timedelta(minutes=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -89,6 +89,7 @@ from onyx.indexing.adapters.document_indexing_adapter import (
|
||||
DocumentIndexingBatchAdapter,
|
||||
)
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
@@ -1270,8 +1271,6 @@ def _docprocessing_task(
|
||||
tenant_id: str,
|
||||
batch_num: int,
|
||||
) -> None:
|
||||
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
|
||||
|
||||
start_time = time.monotonic()
|
||||
|
||||
if tenant_id:
|
||||
|
||||
@@ -9,17 +9,19 @@ import sqlalchemy as sa
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from redis.lock import Lock as RedisLock
|
||||
from retry import retry
|
||||
from sqlalchemy import select
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
@@ -36,13 +38,12 @@ from onyx.db.models import UserFile
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa.shared_utils.utils import (
|
||||
replace_invalid_doc_id_characters,
|
||||
)
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.document_index.vespa_constants import USER_PROJECT
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.file_store import S3BackedFileStore
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
@@ -134,7 +135,8 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id), timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
if not file_lock.acquire(blocking=False):
|
||||
@@ -306,8 +308,10 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
|
||||
user_file_ids = (
|
||||
db_session.execute(
|
||||
select(UserFile.id).where(
|
||||
UserFile.needs_project_sync.is_(True)
|
||||
and UserFile.status == UserFileStatus.COMPLETED
|
||||
sa.and_(
|
||||
UserFile.needs_project_sync.is_(True),
|
||||
UserFile.status == UserFileStatus.COMPLETED,
|
||||
)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
@@ -348,7 +352,7 @@ def process_single_user_file_project_sync(
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_project_sync_lock_key(user_file_id),
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
if not file_lock.acquire(blocking=False):
|
||||
@@ -359,6 +363,15 @@ def process_single_user_file_project_sync(
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
|
||||
# 20 is the documented default for httpx max_keepalive_connections
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
search_settings=active_search_settings.primary,
|
||||
@@ -416,6 +429,7 @@ def _normalize_legacy_user_file_doc_id(old_id: str) -> str:
|
||||
return old_id
|
||||
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2, jitter=(0.0, 1.0))
|
||||
def _visit_chunks(
|
||||
*,
|
||||
http_client: httpx.Client,
|
||||
@@ -423,10 +437,13 @@ def _visit_chunks(
|
||||
selection: str,
|
||||
continuation: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
task_logger.info(
|
||||
f"Visiting chunks for index={index_name} with selection={selection}"
|
||||
)
|
||||
base_url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
|
||||
params: dict[str, str] = {
|
||||
"selection": selection,
|
||||
"wantedDocumentCount": "1000",
|
||||
"wantedDocumentCount": "100", # Use smaller batch size to avoid timeouts
|
||||
}
|
||||
if continuation:
|
||||
params["continuation"] = continuation
|
||||
@@ -436,127 +453,155 @@ def _visit_chunks(
|
||||
return payload.get("documents", []), payload.get("continuation")
|
||||
|
||||
|
||||
def _update_document_id_in_vespa(
|
||||
*,
|
||||
index_name: str,
|
||||
old_doc_id: str,
|
||||
new_doc_id: str,
|
||||
user_project_ids: list[int] | None = None,
|
||||
) -> None:
|
||||
clean_new_doc_id = replace_invalid_doc_id_characters(new_doc_id)
|
||||
normalized_old = _normalize_legacy_user_file_doc_id(old_doc_id)
|
||||
clean_old_doc_id = replace_invalid_doc_id_characters(normalized_old)
|
||||
def update_legacy_plaintext_file_records() -> None:
|
||||
"""Migrate legacy plaintext cache objects from int-based keys to UUID-based
|
||||
keys. Copies each S3 object to its expected UUID key and updates DB.
|
||||
|
||||
selection = f"{index_name}.document_id=='{clean_old_doc_id}'"
|
||||
task_logger.debug(f"Vespa selection: {selection}")
|
||||
Examples:
|
||||
- Old key: bucket/schema/plaintext_<int>
|
||||
- New key: bucket/schema/plaintext_<uuid>
|
||||
"""
|
||||
|
||||
with get_vespa_http_client() as http_client:
|
||||
continuation: str | None = None
|
||||
while True:
|
||||
docs, continuation = _visit_chunks(
|
||||
http_client=http_client,
|
||||
index_name=index_name,
|
||||
selection=selection,
|
||||
continuation=continuation,
|
||||
task_logger.info("update_legacy_plaintext_file_records - Starting")
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
store = get_default_file_store()
|
||||
|
||||
if not isinstance(store, S3BackedFileStore):
|
||||
task_logger.info(
|
||||
"update_legacy_plaintext_file_records - Skipping non-S3 store"
|
||||
)
|
||||
if not docs:
|
||||
break
|
||||
for doc in docs:
|
||||
vespa_full_id = doc.get("id")
|
||||
if not vespa_full_id:
|
||||
return
|
||||
|
||||
s3_client = store._get_s3_client()
|
||||
bucket_name = store._get_bucket_name()
|
||||
|
||||
# Select PLAINTEXT_CACHE records whose object_key ends with 'plaintext_' + non-hyphen chars
|
||||
# Example: 'some/path/plaintext_abc123' matches; '.../plaintext_foo-bar' does not
|
||||
plaintext_records: Sequence[FileRecord] = (
|
||||
db_session.execute(
|
||||
sa.select(FileRecord).where(
|
||||
FileRecord.file_origin == FileOrigin.PLAINTEXT_CACHE,
|
||||
FileRecord.object_key.op("~")(r"plaintext_[^-]+$"),
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"update_legacy_plaintext_file_records - Found {len(plaintext_records)} plaintext records to update"
|
||||
)
|
||||
|
||||
normalized = 0
|
||||
for fr in plaintext_records:
|
||||
try:
|
||||
expected_key = store._get_s3_key(fr.file_id)
|
||||
if fr.object_key == expected_key:
|
||||
continue
|
||||
vespa_doc_uuid = vespa_full_id.split("::")[-1]
|
||||
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_doc_uuid}"
|
||||
update_request: dict[str, Any] = {
|
||||
"fields": {"document_id": {"assign": clean_new_doc_id}}
|
||||
}
|
||||
if user_project_ids is not None:
|
||||
update_request["fields"][USER_PROJECT] = {
|
||||
"assign": user_project_ids
|
||||
}
|
||||
r = http_client.put(vespa_url, json=update_request)
|
||||
r.raise_for_status()
|
||||
if not continuation:
|
||||
break
|
||||
|
||||
if fr.bucket_name is None:
|
||||
task_logger.warning(f"id={fr.file_id} - Bucket name is None")
|
||||
continue
|
||||
|
||||
if fr.object_key is None:
|
||||
task_logger.warning(f"id={fr.file_id} - Object key is None")
|
||||
continue
|
||||
|
||||
# Copy old object to new key
|
||||
copy_source = f"{fr.bucket_name}/{fr.object_key}"
|
||||
s3_client.copy_object(
|
||||
CopySource=copy_source,
|
||||
Bucket=bucket_name,
|
||||
Key=expected_key,
|
||||
MetadataDirective="COPY",
|
||||
)
|
||||
|
||||
# Delete old object (best-effort)
|
||||
try:
|
||||
s3_client.delete_object(Bucket=fr.bucket_name, Key=fr.object_key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Update DB record with new key
|
||||
fr.object_key = expected_key
|
||||
db_session.add(fr)
|
||||
normalized += 1
|
||||
except Exception as e:
|
||||
task_logger.warning(f"id={fr.file_id} - {e.__class__.__name__}")
|
||||
|
||||
if normalized:
|
||||
db_session.commit()
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task normalized {normalized} plaintext objects"
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.USER_FILE_DOCID_MIGRATION,
|
||||
ignore_result=True,
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
bind=True,
|
||||
)
|
||||
def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
|
||||
"""Per-tenant job to update Vespa and search_doc document_id values for user files.
|
||||
|
||||
- For each user_file with a legacy document_id, set Vespa `document_id` to the UUID `user_file.id`.
|
||||
- Update `search_doc.document_id` to the same UUID string.
|
||||
"""
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Starting for tenant={tenant_id}"
|
||||
)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
lock: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.USER_FILE_DOCID_MIGRATION_LOCK,
|
||||
timeout=CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
if not lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Lock held, skipping tenant={tenant_id}"
|
||||
)
|
||||
return False
|
||||
|
||||
updated_count = 0
|
||||
try:
|
||||
update_legacy_plaintext_file_records()
|
||||
# Track lock renewal
|
||||
last_lock_time = time.monotonic()
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
|
||||
# 20 is the documented default for httpx max_keepalive_connections
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
active_settings = get_active_search_settings(db_session)
|
||||
document_index = get_default_document_index(
|
||||
active_settings.primary,
|
||||
active_settings.secondary,
|
||||
search_settings=active_settings.primary,
|
||||
secondary_search_settings=active_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
if hasattr(document_index, "index_name"):
|
||||
index_name = document_index.index_name
|
||||
else:
|
||||
index_name = "danswer_index"
|
||||
|
||||
# Fetch mappings of legacy -> new ids
|
||||
rows = db_session.execute(
|
||||
sa.select(
|
||||
UserFile.document_id.label("document_id"),
|
||||
UserFile.id.label("id"),
|
||||
).where(
|
||||
UserFile.document_id.is_not(None),
|
||||
UserFile.document_id_migrated.is_(False),
|
||||
)
|
||||
).all()
|
||||
retry_index = RetryDocumentIndex(document_index)
|
||||
|
||||
# dedupe by old document_id
|
||||
seen: set[str] = set()
|
||||
for row in rows:
|
||||
old_doc_id = str(row.document_id)
|
||||
new_uuid = str(row.id)
|
||||
if not old_doc_id or not new_uuid or old_doc_id in seen:
|
||||
continue
|
||||
seen.add(old_doc_id)
|
||||
# collect user project ids for a combined Vespa update
|
||||
user_project_ids: list[int] | None = None
|
||||
try:
|
||||
uf = db_session.get(UserFile, UUID(new_uuid))
|
||||
if uf is not None:
|
||||
user_project_ids = [project.id for project in uf.projects]
|
||||
except Exception as e:
|
||||
task_logger.warning(
|
||||
f"Tenant={tenant_id} failed fetching projects for doc_id={new_uuid} - {e.__class__.__name__}"
|
||||
)
|
||||
try:
|
||||
_update_document_id_in_vespa(
|
||||
index_name=index_name,
|
||||
old_doc_id=old_doc_id,
|
||||
new_doc_id=new_uuid,
|
||||
user_project_ids=user_project_ids,
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.warning(
|
||||
f"Tenant={tenant_id} failed Vespa update for doc_id={new_uuid} - {e.__class__.__name__}"
|
||||
)
|
||||
# Update search_doc records to refer to the UUID string
|
||||
# we are not using document_id_migrated = false because if the migration already completed,
|
||||
# it will not run again and we will not update the search_doc records because of the issue currently fixed
|
||||
# Select user files with a legacy doc id that have not been migrated
|
||||
user_files = (
|
||||
db_session.execute(
|
||||
sa.select(UserFile).where(UserFile.document_id.is_not(None))
|
||||
sa.select(UserFile).where(
|
||||
sa.and_(
|
||||
UserFile.document_id.is_not(None),
|
||||
UserFile.document_id_migrated.is_(False),
|
||||
)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Found {len(user_files)} user files to migrate"
|
||||
)
|
||||
|
||||
# Query all SearchDocs that need updating
|
||||
search_docs = (
|
||||
db_session.execute(
|
||||
@@ -567,9 +612,9 @@ def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
task_logger.info(f"Found {len(user_files)} user files to update")
|
||||
task_logger.info(f"Found {len(search_docs)} search docs to update")
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Found {len(search_docs)} search docs to update"
|
||||
)
|
||||
|
||||
# Build a map of normalized doc IDs to SearchDocs
|
||||
search_doc_map: dict[str, list[SearchDoc]] = {}
|
||||
@@ -579,102 +624,139 @@ def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
|
||||
search_doc_map[doc_id] = []
|
||||
search_doc_map[doc_id].append(sd)
|
||||
|
||||
# Process each UserFile and update matching SearchDocs
|
||||
updated_count = 0
|
||||
for uf in user_files:
|
||||
doc_id = uf.document_id
|
||||
if doc_id.startswith("USER_FILE_CONNECTOR__"):
|
||||
doc_id = "FILE_CONNECTOR__" + doc_id[len("USER_FILE_CONNECTOR__") :]
|
||||
task_logger.debug(
|
||||
f"user_file_docid_migration_task - Built search doc map with {len(search_doc_map)} entries"
|
||||
)
|
||||
|
||||
if doc_id in search_doc_map:
|
||||
# Update the SearchDoc to use the UserFile's UUID
|
||||
for search_doc in search_doc_map[doc_id]:
|
||||
search_doc.document_id = str(uf.id)
|
||||
db_session.add(search_doc)
|
||||
ids_preview = list(search_doc_map.keys())[:5]
|
||||
task_logger.debug(
|
||||
f"user_file_docid_migration_task - First few search_doc_map ids: {ids_preview if ids_preview else 'No ids found'}"
|
||||
)
|
||||
task_logger.debug(
|
||||
f"user_file_docid_migration_task - search_doc_map total items: "
|
||||
f"{sum(len(docs) for docs in search_doc_map.values())}"
|
||||
)
|
||||
for user_file in user_files:
|
||||
# Periodically renew the Redis lock to prevent expiry mid-run
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT / 4
|
||||
):
|
||||
renewed = False
|
||||
try:
|
||||
# extend lock ttl to full timeout window
|
||||
lock.extend(CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT)
|
||||
renewed = True
|
||||
except Exception:
|
||||
# if extend fails, best-effort reacquire as a fallback
|
||||
try:
|
||||
lock.reacquire()
|
||||
renewed = True
|
||||
except Exception:
|
||||
renewed = False
|
||||
last_lock_time = current_time
|
||||
if not renewed or not lock.owned():
|
||||
task_logger.error(
|
||||
"user_file_docid_migration_task - Lost lock ownership or failed to renew; aborting for safety"
|
||||
)
|
||||
return False
|
||||
|
||||
# Mark UserFile as migrated
|
||||
uf.document_id_migrated = True
|
||||
db_session.add(uf)
|
||||
try:
|
||||
clean_old_doc_id = replace_invalid_doc_id_characters(
|
||||
user_file.document_id
|
||||
)
|
||||
normalized_doc_id = _normalize_legacy_user_file_doc_id(
|
||||
clean_old_doc_id
|
||||
)
|
||||
user_project_ids = [project.id for project in user_file.projects]
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Migrating user file {user_file.id} with doc_id {normalized_doc_id}"
|
||||
)
|
||||
|
||||
index_name = active_settings.primary.index_name
|
||||
|
||||
# First find the chunks count using direct Vespa query
|
||||
selection = f"{index_name}.document_id=='{normalized_doc_id}'"
|
||||
|
||||
# Count all chunks for this document
|
||||
chunk_count = 0
|
||||
continuation = None
|
||||
while True:
|
||||
docs, continuation = _visit_chunks(
|
||||
http_client=HttpxPool.get("vespa"),
|
||||
index_name=index_name,
|
||||
selection=selection,
|
||||
continuation=continuation,
|
||||
)
|
||||
if not docs:
|
||||
break
|
||||
chunk_count += len(docs)
|
||||
if not continuation:
|
||||
break
|
||||
|
||||
task_logger.info(
|
||||
f"Found {chunk_count} chunks for document {normalized_doc_id}"
|
||||
)
|
||||
|
||||
# Now update Vespa chunks with the found chunk count using retry_index
|
||||
updated_chunks = retry_index.update_single(
|
||||
doc_id=str(normalized_doc_id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
fields=VespaDocumentFields(document_id=str(user_file.id)),
|
||||
user_fields=VespaDocumentUserFields(
|
||||
user_projects=user_project_ids
|
||||
),
|
||||
)
|
||||
user_file.chunk_count = updated_chunks
|
||||
|
||||
# Update the SearchDocs
|
||||
actual_doc_id = str(user_file.document_id)
|
||||
normalized_actual_doc_id = _normalize_legacy_user_file_doc_id(
|
||||
actual_doc_id
|
||||
)
|
||||
if (
|
||||
normalized_doc_id in search_doc_map
|
||||
or normalized_actual_doc_id in search_doc_map
|
||||
):
|
||||
to_update = (
|
||||
search_doc_map[normalized_doc_id]
|
||||
if normalized_doc_id in search_doc_map
|
||||
else search_doc_map[normalized_actual_doc_id]
|
||||
)
|
||||
task_logger.debug(
|
||||
f"user_file_docid_migration_task - Updating {len(to_update)} search docs for user file {user_file.id}"
|
||||
)
|
||||
for search_doc in to_update:
|
||||
search_doc.document_id = str(user_file.id)
|
||||
db_session.add(search_doc)
|
||||
|
||||
user_file.document_id_migrated = True
|
||||
db_session.add(user_file)
|
||||
db_session.commit()
|
||||
updated_count += 1
|
||||
except Exception as per_file_exc:
|
||||
# Rollback the current transaction and continue with the next file
|
||||
db_session.rollback()
|
||||
task_logger.exception(
|
||||
f"user_file_docid_migration_task - Error migrating user file {user_file.id} - "
|
||||
f"{per_file_exc.__class__.__name__}"
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Updated {updated_count} SearchDoc records with new UUIDs"
|
||||
f"user_file_docid_migration_task - Updated {updated_count} user files"
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
# Normalize plaintext FileRecord blobs: ensure S3 object key aligns with current file_id
|
||||
try:
|
||||
store = get_default_file_store()
|
||||
# Only supported for S3-backed stores where we can manipulate object keys
|
||||
if isinstance(store, S3BackedFileStore):
|
||||
s3_client = store._get_s3_client()
|
||||
bucket_name = store._get_bucket_name()
|
||||
|
||||
plaintext_records: Sequence[FileRecord] = (
|
||||
db_session.execute(
|
||||
sa.select(FileRecord).where(
|
||||
FileRecord.file_origin == FileOrigin.PLAINTEXT_CACHE,
|
||||
FileRecord.file_id.like("plaintext_%"),
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
normalized = 0
|
||||
for fr in plaintext_records:
|
||||
try:
|
||||
expected_key = store._get_s3_key(fr.file_id)
|
||||
if fr.object_key == expected_key:
|
||||
continue
|
||||
|
||||
# Copy old object to new key
|
||||
copy_source = f"{fr.bucket_name}/{fr.object_key}"
|
||||
s3_client.copy_object(
|
||||
CopySource=copy_source,
|
||||
Bucket=bucket_name,
|
||||
Key=expected_key,
|
||||
MetadataDirective="COPY",
|
||||
)
|
||||
|
||||
# Delete old object (best-effort)
|
||||
try:
|
||||
s3_client.delete_object(
|
||||
Bucket=fr.bucket_name, Key=fr.object_key
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Update DB record with new key
|
||||
fr.object_key = expected_key
|
||||
db_session.add(fr)
|
||||
normalized += 1
|
||||
except Exception as e:
|
||||
task_logger.warning(
|
||||
f"Tenant={tenant_id} failed plaintext object normalize for "
|
||||
f"id={fr.file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
|
||||
if normalized:
|
||||
db_session.commit()
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task normalized {normalized} plaintext objects for tenant={tenant_id}"
|
||||
)
|
||||
else:
|
||||
task_logger.info(
|
||||
"user_file_docid_migration_task skipping plaintext object normalization (non-S3 store)"
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"user_file_docid_migration_task - Error during plaintext normalization for tenant={tenant_id}"
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task completed for tenant={tenant_id} (rows={len(rows)})"
|
||||
f"user_file_docid_migration_task - Completed for tenant={tenant_id} (updated={updated_count})"
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"user_file_docid_migration_task - Error during execution for tenant={tenant_id}"
|
||||
f"user_file_docid_migration_task - Error during execution for tenant={tenant_id} "
|
||||
f"(updated={updated_count}) exception={e.__class__.__name__}"
|
||||
)
|
||||
return False
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
10
backend/onyx/background/celery/versioned_apps/background.py
Normal file
10
backend/onyx/background/celery/versioned_apps/background.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from celery import Celery
|
||||
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
app: Celery = fetch_versioned_implementation(
|
||||
"onyx.background.celery.apps.background",
|
||||
"celery_app",
|
||||
)
|
||||
@@ -5,6 +5,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.constants import NUM_DAYS_TO_KEEP_INDEX_ATTEMPTS
|
||||
from onyx.db.engine.time_utils import get_db_current_time
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexAttemptError
|
||||
|
||||
|
||||
def get_old_index_attempts(
|
||||
@@ -21,6 +22,10 @@ def get_old_index_attempts(
|
||||
|
||||
def cleanup_index_attempts(db_session: Session, index_attempt_ids: list[int]) -> None:
|
||||
"""Clean up multiple index attempts"""
|
||||
db_session.query(IndexAttemptError).filter(
|
||||
IndexAttemptError.index_attempt_id.in_(index_attempt_ids)
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
db_session.query(IndexAttempt).filter(
|
||||
IndexAttempt.id.in_(index_attempt_ids)
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
@@ -28,6 +28,7 @@ from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.connectors.connector_runner import ConnectorRunner
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorStopSignal
|
||||
@@ -101,7 +102,6 @@ def _get_connector_runner(
|
||||
are the complete list of existing documents of the connector. If the task
|
||||
of type LOAD_STATE, the list will be considered complete and otherwise incomplete.
|
||||
"""
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
|
||||
task = attempt.connector_credential_pair.connector.input_type
|
||||
|
||||
|
||||
@@ -2,14 +2,15 @@ import re
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
from typing import Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from redis.client import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.orchestration.nodes.call_tool import ToolCallException
|
||||
from onyx.chat.answer import Answer
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.chat_utils import create_temporary_persona
|
||||
@@ -26,12 +27,12 @@ from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import UserKnowledgeFilePacket
|
||||
from onyx.chat.packet_proccessing.process_streamed_packets import (
|
||||
process_streamed_packets,
|
||||
)
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
from onyx.chat.turn import fast_chat_turn
|
||||
from onyx.chat.turn.infra.emitter import get_default_emitter
|
||||
from onyx.chat.turn.models import ChatTurnDependencies
|
||||
from onyx.chat.user_files.parse_user_files import parse_user_files
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
@@ -89,6 +90,7 @@ from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.utils import get_json_line
|
||||
from onyx.tools.adapter_v1_to_v2 import tools_to_function_tools
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool import Tool
|
||||
@@ -113,6 +115,10 @@ logger = setup_logger()
|
||||
ERROR_TYPE_CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class ToolCallException(Exception):
|
||||
"""Exception raised for errors during tool calls."""
|
||||
|
||||
|
||||
class PartialResponse(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
@@ -352,14 +358,12 @@ def stream_chat_message_objects(
|
||||
long_term_logger = LongTermLogger(
|
||||
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
|
||||
)
|
||||
|
||||
persona = _get_persona_for_chat_session(
|
||||
new_msg_req=new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
default_persona=chat_session.persona,
|
||||
)
|
||||
|
||||
# TODO: remove once we have an endpoint for this stuff
|
||||
process_kg_commands(new_msg_req.message, persona.name, tenant_id, db_session)
|
||||
|
||||
@@ -731,12 +735,19 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
prompt_builder = AnswerPromptBuilder(
|
||||
# TODO: for backwards compatibility, we are using the V1
|
||||
# user_message=default_build_user_message_v2(
|
||||
# user_query=final_msg.message,
|
||||
# prompt_config=prompt_config,
|
||||
# files=latest_query_files,
|
||||
# ),
|
||||
user_message=default_build_user_message(
|
||||
user_query=final_msg.message,
|
||||
prompt_config=prompt_config,
|
||||
files=latest_query_files,
|
||||
single_message_history=single_message_history,
|
||||
),
|
||||
# TODO: for backwards compatibility, we are using the V1
|
||||
# system_message=default_build_system_message_v2(prompt_config, llm.config),
|
||||
system_message=default_build_system_message(prompt_config, llm.config),
|
||||
message_history=message_history,
|
||||
llm_config=llm.config,
|
||||
@@ -780,10 +791,20 @@ def stream_chat_message_objects(
|
||||
project_instructions=project_instructions,
|
||||
)
|
||||
|
||||
# Process streamed packets using the new packet processing module
|
||||
yield from process_streamed_packets(
|
||||
from onyx.chat.packet_proccessing import process_streamed_packets
|
||||
|
||||
yield from process_streamed_packets.process_streamed_packets(
|
||||
answer_processed_output=answer.processed_streamed_output,
|
||||
)
|
||||
# TODO: For backwards compatible PR, switch back to the original call
|
||||
# yield from _fast_message_stream(
|
||||
# answer,
|
||||
# tools,
|
||||
# db_session,
|
||||
# get_redis_client(),
|
||||
# str(chat_session_id),
|
||||
# str(reserved_message_id),
|
||||
# )
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to process chat message.")
|
||||
@@ -820,6 +841,59 @@ def stream_chat_message_objects(
|
||||
return
|
||||
|
||||
|
||||
# TODO: Refactor this to live somewhere else
|
||||
def _fast_message_stream(
|
||||
answer: Answer,
|
||||
tools: list[Tool],
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
chat_session_id: UUID,
|
||||
reserved_message_id: int,
|
||||
) -> Generator[Packet, None, None]:
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
|
||||
OktaProfileTool,
|
||||
)
|
||||
from onyx.llm.litellm_singleton import LitellmModel
|
||||
|
||||
image_generation_tool_instance = None
|
||||
okta_profile_tool_instance = None
|
||||
for tool in tools:
|
||||
if isinstance(tool, ImageGenerationTool):
|
||||
image_generation_tool_instance = tool
|
||||
elif isinstance(tool, OktaProfileTool):
|
||||
okta_profile_tool_instance = tool
|
||||
converted_message_history = [
|
||||
PreviousMessage.from_langchain_msg(message, 0).to_agent_sdk_msg()
|
||||
for message in answer.graph_inputs.prompt_builder.build()
|
||||
]
|
||||
emitter = get_default_emitter()
|
||||
return fast_chat_turn.fast_chat_turn(
|
||||
messages=converted_message_history,
|
||||
# TODO: Maybe we can use some DI framework here?
|
||||
dependencies=ChatTurnDependencies(
|
||||
llm_model=LitellmModel(
|
||||
model=answer.graph_tooling.primary_llm.config.model_name,
|
||||
base_url=answer.graph_tooling.primary_llm.config.api_base,
|
||||
api_key=answer.graph_tooling.primary_llm.config.api_key,
|
||||
),
|
||||
llm=answer.graph_tooling.primary_llm,
|
||||
tools=tools_to_function_tools(tools),
|
||||
search_pipeline=answer.graph_tooling.search_tool,
|
||||
image_generation_tool=image_generation_tool_instance,
|
||||
okta_profile_tool=okta_profile_tool_instance,
|
||||
db_session=db_session,
|
||||
redis_client=redis_client,
|
||||
emitter=emitter,
|
||||
),
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=reserved_message_id,
|
||||
research_type=answer.graph_config.behavior.research_type,
|
||||
)
|
||||
|
||||
|
||||
@log_generator_function_time()
|
||||
def stream_chat_message(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
|
||||
@@ -21,8 +21,10 @@ from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
|
||||
from onyx.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
|
||||
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from onyx.prompts.prompt_utils import drop_messages_history_overflow
|
||||
from onyx.prompts.prompt_utils import handle_company_awareness
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
@@ -31,6 +33,33 @@ from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
|
||||
|
||||
def default_build_system_message_v2(
|
||||
prompt_config: PromptConfig,
|
||||
llm_config: LLMConfig,
|
||||
) -> SystemMessage | None:
|
||||
system_prompt = prompt_config.system_prompt.strip()
|
||||
system_prompt += REQUIRE_CITATION_STATEMENT
|
||||
# See https://simonwillison.net/tags/markdown/ for context on this temporary fix
|
||||
# for o-series markdown generation
|
||||
if (
|
||||
llm_config.model_provider == OPENAI_PROVIDER_NAME
|
||||
and llm_config.model_name.startswith("o")
|
||||
):
|
||||
system_prompt = CODE_BLOCK_MARKDOWN + system_prompt
|
||||
tag_handled_prompt = handle_onyx_date_awareness(
|
||||
system_prompt,
|
||||
prompt_config,
|
||||
add_additional_info_if_no_tag=prompt_config.datetime_aware,
|
||||
)
|
||||
|
||||
if not tag_handled_prompt:
|
||||
return None
|
||||
|
||||
tag_handled_prompt = handle_company_awareness(tag_handled_prompt)
|
||||
|
||||
return SystemMessage(content=tag_handled_prompt)
|
||||
|
||||
|
||||
def default_build_system_message(
|
||||
prompt_config: PromptConfig,
|
||||
llm_config: LLMConfig,
|
||||
@@ -52,9 +81,29 @@ def default_build_system_message(
|
||||
if not tag_handled_prompt:
|
||||
return None
|
||||
|
||||
tag_handled_prompt = handle_company_awareness(tag_handled_prompt)
|
||||
|
||||
return SystemMessage(content=tag_handled_prompt)
|
||||
|
||||
|
||||
def default_build_user_message_v2(
|
||||
user_query: str,
|
||||
prompt_config: PromptConfig,
|
||||
files: list[InMemoryChatFile] = [],
|
||||
) -> HumanMessage:
|
||||
user_prompt = user_query
|
||||
user_prompt = user_prompt.strip()
|
||||
tag_handled_prompt = handle_onyx_date_awareness(user_prompt, prompt_config)
|
||||
user_msg = HumanMessage(
|
||||
content=(
|
||||
build_content_with_imgs(tag_handled_prompt, files)
|
||||
if files
|
||||
else tag_handled_prompt
|
||||
)
|
||||
)
|
||||
return user_msg
|
||||
|
||||
|
||||
def default_build_user_message(
|
||||
user_query: str,
|
||||
prompt_config: PromptConfig,
|
||||
|
||||
56
backend/onyx/chat/stop_signal_checker.py
Normal file
56
backend/onyx/chat/stop_signal_checker.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from uuid import UUID
|
||||
|
||||
from redis.client import Redis
|
||||
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
# Redis key prefixes for chat session stop signals
|
||||
PREFIX = "chatsessionstop"
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
|
||||
|
||||
def set_fence(chat_session_id: UUID, redis_client: Redis, value: bool) -> None:
|
||||
"""
|
||||
Set or clear the stop signal fence for a chat session.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
redis_client: Redis client to use
|
||||
value: True to set the fence (stop signal), False to clear it
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
fence_key = f"{FENCE_PREFIX}_{tenant_id}_{chat_session_id}"
|
||||
if not value:
|
||||
redis_client.delete(fence_key)
|
||||
return
|
||||
|
||||
redis_client.set(fence_key, 0)
|
||||
|
||||
|
||||
def is_connected(chat_session_id: UUID, redis_client: Redis) -> bool:
|
||||
"""
|
||||
Check if the chat session should continue (not stopped).
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session to check
|
||||
redis_client: Redis client to use for checking the stop signal
|
||||
|
||||
Returns:
|
||||
True if the session should continue, False if it should stop
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
fence_key = f"{FENCE_PREFIX}_{tenant_id}_{chat_session_id}"
|
||||
return not bool(redis_client.exists(fence_key))
|
||||
|
||||
|
||||
def reset_cancel_status(chat_session_id: UUID, redis_client: Redis) -> None:
|
||||
"""
|
||||
Clear the stop signal for a chat session.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
redis_client: Redis client to use
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
fence_key = f"{FENCE_PREFIX}_{tenant_id}_{chat_session_id}"
|
||||
redis_client.delete(fence_key)
|
||||
1
backend/onyx/chat/turn/__init__.py
Normal file
1
backend/onyx/chat/turn/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Turn module for chat functionality
|
||||
257
backend/onyx/chat/turn/fast_chat_turn.py
Normal file
257
backend/onyx/chat/turn/fast_chat_turn.py
Normal file
@@ -0,0 +1,257 @@
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from agents import Agent
|
||||
from agents import ModelSettings
|
||||
from agents import RawResponsesStreamEvent
|
||||
from agents import StopAtTools
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import AggregatedDRContext
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
from onyx.chat.chat_utils import llm_doc_from_inference_section
|
||||
from onyx.chat.stop_signal_checker import is_connected
|
||||
from onyx.chat.stop_signal_checker import reset_cancel_status
|
||||
from onyx.chat.stream_processing.citation_processing import CitationProcessor
|
||||
from onyx.chat.turn.infra.chat_turn_event_stream import unified_event_stream
|
||||
from onyx.chat.turn.infra.session_sink import extract_final_answer_from_packets
|
||||
from onyx.chat.turn.infra.session_sink import save_iteration
|
||||
from onyx.chat.turn.infra.sync_agent_stream_adapter import SyncAgentStream
|
||||
from onyx.chat.turn.models import ChatTurnContext
|
||||
from onyx.chat.turn.models import ChatTurnDependencies
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.server.query_and_chat.streaming_models import CitationDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CitationStart
|
||||
from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PacketObj
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.tools.tool_implementations_v2.image_generation import image_generation_tool
|
||||
|
||||
|
||||
def _fast_chat_turn_core(
|
||||
messages: list[dict],
|
||||
dependencies: ChatTurnDependencies,
|
||||
chat_session_id: UUID,
|
||||
message_id: int,
|
||||
research_type: ResearchType,
|
||||
# Dependency injectable arguments for testing
|
||||
starter_global_iteration_responses: list[IterationAnswer] | None = None,
|
||||
starter_cited_documents: list[InferenceSection] | None = None,
|
||||
) -> None:
|
||||
"""Core fast chat turn logic that allows overriding global_iteration_responses for testing.
|
||||
|
||||
Args:
|
||||
messages: List of chat messages
|
||||
dependencies: Chat turn dependencies
|
||||
chat_session_id: Chat session ID
|
||||
message_id: Message ID
|
||||
research_type: Research type
|
||||
global_iteration_responses: Optional list of iteration answers to inject for testing
|
||||
cited_documents: Optional list of cited documents to inject for testing
|
||||
"""
|
||||
reset_cancel_status(
|
||||
chat_session_id,
|
||||
dependencies.redis_client,
|
||||
)
|
||||
ctx = ChatTurnContext(
|
||||
run_dependencies=dependencies,
|
||||
aggregated_context=AggregatedDRContext(
|
||||
context="context",
|
||||
cited_documents=starter_cited_documents or [],
|
||||
is_internet_marker_dict={},
|
||||
global_iteration_responses=starter_global_iteration_responses or [],
|
||||
),
|
||||
iteration_instructions=[],
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=message_id,
|
||||
research_type=research_type,
|
||||
)
|
||||
agent = Agent(
|
||||
name="Assistant",
|
||||
model=dependencies.llm_model,
|
||||
tools=cast(list, dependencies.tools), # type: ignore[arg-type]
|
||||
model_settings=ModelSettings(
|
||||
temperature=dependencies.llm.config.temperature,
|
||||
include_usage=True,
|
||||
),
|
||||
tool_use_behavior=StopAtTools(stop_at_tool_names=[image_generation_tool.name]),
|
||||
)
|
||||
# By default, the agent can only take 10 turns. For our use case, it should be higher.
|
||||
max_turns = 25
|
||||
agent_stream: SyncAgentStream = SyncAgentStream(
|
||||
agent=agent,
|
||||
input=messages,
|
||||
context=ctx,
|
||||
max_turns=max_turns,
|
||||
)
|
||||
for ev in agent_stream:
|
||||
connected = is_connected(
|
||||
chat_session_id,
|
||||
dependencies.redis_client,
|
||||
)
|
||||
if not connected:
|
||||
_emit_clean_up_packets(dependencies, ctx)
|
||||
agent_stream.cancel()
|
||||
break
|
||||
obj = _default_packet_translation(ev, ctx)
|
||||
if obj:
|
||||
dependencies.emitter.emit(Packet(ind=ctx.current_run_step, obj=obj))
|
||||
final_answer = extract_final_answer_from_packets(
|
||||
dependencies.emitter.packet_history
|
||||
)
|
||||
|
||||
all_cited_documents = []
|
||||
if ctx.aggregated_context.global_iteration_responses:
|
||||
context_docs = _gather_context_docs_from_iteration_answers(
|
||||
ctx.aggregated_context.global_iteration_responses
|
||||
)
|
||||
all_cited_documents = context_docs
|
||||
if context_docs and final_answer:
|
||||
_process_citations_for_final_answer(
|
||||
final_answer=final_answer,
|
||||
context_docs=context_docs,
|
||||
dependencies=dependencies,
|
||||
ctx=ctx,
|
||||
)
|
||||
|
||||
save_iteration(
|
||||
db_session=dependencies.db_session,
|
||||
message_id=message_id,
|
||||
chat_session_id=chat_session_id,
|
||||
research_type=research_type,
|
||||
ctx=ctx,
|
||||
final_answer=final_answer,
|
||||
all_cited_documents=all_cited_documents,
|
||||
)
|
||||
dependencies.emitter.emit(
|
||||
Packet(ind=ctx.current_run_step, obj=OverallStop(type="stop"))
|
||||
)
|
||||
|
||||
|
||||
@unified_event_stream
|
||||
def fast_chat_turn(
|
||||
messages: list[dict],
|
||||
dependencies: ChatTurnDependencies,
|
||||
chat_session_id: UUID,
|
||||
message_id: int,
|
||||
research_type: ResearchType,
|
||||
) -> None:
|
||||
"""Main fast chat turn function that calls the core logic with default parameters."""
|
||||
_fast_chat_turn_core(
|
||||
messages,
|
||||
dependencies,
|
||||
chat_session_id,
|
||||
message_id,
|
||||
research_type,
|
||||
starter_global_iteration_responses=None,
|
||||
)
|
||||
|
||||
|
||||
# TODO: Maybe in general there's a cleaner way to handle cancellation in the middle of a tool call?
|
||||
def _emit_clean_up_packets(
|
||||
dependencies: ChatTurnDependencies, ctx: ChatTurnContext
|
||||
) -> None:
|
||||
if not (
|
||||
dependencies.emitter.packet_history
|
||||
and dependencies.emitter.packet_history[-1].obj.type == "message_delta"
|
||||
):
|
||||
dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=ctx.current_run_step,
|
||||
obj=MessageStart(
|
||||
type="message_start", content="Cancelled", final_documents=None
|
||||
),
|
||||
)
|
||||
)
|
||||
dependencies.emitter.emit(
|
||||
Packet(ind=ctx.current_run_step, obj=SectionEnd(type="section_end"))
|
||||
)
|
||||
|
||||
|
||||
def _gather_context_docs_from_iteration_answers(
|
||||
iteration_answers: list[IterationAnswer],
|
||||
) -> list[InferenceSection]:
|
||||
"""Gather cited documents from iteration answers for citation processing."""
|
||||
context_docs: list[InferenceSection] = []
|
||||
|
||||
for iteration_answer in iteration_answers:
|
||||
# Extract cited documents from this iteration
|
||||
for inference_section in iteration_answer.cited_documents.values():
|
||||
# Avoid duplicates by checking document_id
|
||||
if not any(
|
||||
doc.center_chunk.document_id
|
||||
== inference_section.center_chunk.document_id
|
||||
for doc in context_docs
|
||||
):
|
||||
context_docs.append(inference_section)
|
||||
|
||||
return context_docs
|
||||
|
||||
|
||||
def _process_citations_for_final_answer(
|
||||
final_answer: str,
|
||||
context_docs: list[InferenceSection],
|
||||
dependencies: ChatTurnDependencies,
|
||||
ctx: ChatTurnContext,
|
||||
) -> None:
|
||||
index = ctx.current_run_step + 1
|
||||
"""Process citations in the final answer and emit citation events."""
|
||||
from onyx.chat.stream_processing.utils import DocumentIdOrderMapping
|
||||
|
||||
# Convert InferenceSection objects to LlmDoc objects for citation processing
|
||||
llm_docs = [llm_doc_from_inference_section(section) for section in context_docs]
|
||||
|
||||
# Create document ID to rank mappings (simple 1-based indexing)
|
||||
final_doc_id_to_rank_map = DocumentIdOrderMapping(
|
||||
order_mapping={doc.document_id: i + 1 for i, doc in enumerate(llm_docs)}
|
||||
)
|
||||
display_doc_id_to_rank_map = final_doc_id_to_rank_map # Same mapping for display
|
||||
|
||||
# Initialize citation processor
|
||||
citation_processor = CitationProcessor(
|
||||
context_docs=llm_docs,
|
||||
final_doc_id_to_rank_map=final_doc_id_to_rank_map,
|
||||
display_doc_id_to_rank_map=display_doc_id_to_rank_map,
|
||||
)
|
||||
|
||||
# Process the final answer through citation processor
|
||||
collected_citations: list = []
|
||||
for response_part in citation_processor.process_token(final_answer):
|
||||
if hasattr(response_part, "citation_num"): # It's a CitationInfo
|
||||
collected_citations.append(response_part)
|
||||
|
||||
# Emit citation events if we found any citations
|
||||
if collected_citations:
|
||||
dependencies.emitter.emit(Packet(ind=index, obj=CitationStart()))
|
||||
dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=CitationDelta(citations=collected_citations), # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
dependencies.emitter.emit(Packet(ind=index, obj=SectionEnd(type="section_end")))
|
||||
ctx.current_run_step = index
|
||||
|
||||
|
||||
def _default_packet_translation(ev: object, ctx: ChatTurnContext) -> PacketObj | None:
|
||||
if isinstance(ev, RawResponsesStreamEvent):
|
||||
# TODO: might need some variation here for different types of models
|
||||
# OpenAI packet translator
|
||||
obj: PacketObj | None = None
|
||||
if ev.data.type == "response.content_part.added":
|
||||
retrieved_search_docs = convert_inference_sections_to_search_docs(
|
||||
ctx.aggregated_context.cited_documents
|
||||
)
|
||||
obj = MessageStart(
|
||||
type="message_start", content="", final_documents=retrieved_search_docs
|
||||
)
|
||||
elif ev.data.type == "response.output_text.delta":
|
||||
obj = MessageDelta(type="message_delta", content=ev.data.delta)
|
||||
elif ev.data.type == "response.content_part.done":
|
||||
obj = SectionEnd(type="section_end")
|
||||
return obj
|
||||
return None
|
||||
1
backend/onyx/chat/turn/infra/__init__.py
Normal file
1
backend/onyx/chat/turn/infra/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Infrastructure module for chat turn orchestration
|
||||
57
backend/onyx/chat/turn/infra/chat_turn_event_stream.py
Normal file
57
backend/onyx/chat/turn/infra/chat_turn_event_stream.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
from onyx.chat.turn.models import ChatTurnDependencies
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PacketException
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
|
||||
|
||||
def unified_event_stream(
|
||||
turn_func: Callable[..., None],
|
||||
) -> Callable[..., Generator[Packet, None]]:
|
||||
"""
|
||||
Decorator that wraps a turn_func to provide event streaming capabilities.
|
||||
|
||||
Usage:
|
||||
@unified_event_stream
|
||||
def my_turn_func(messages, dependencies, *args, **kwargs):
|
||||
# Your turn logic here
|
||||
pass
|
||||
|
||||
Then call it like:
|
||||
generator = my_turn_func(messages, dependencies, *args, **kwargs)
|
||||
"""
|
||||
|
||||
def wrapper(
|
||||
messages: List[Dict[str, Any]],
|
||||
dependencies: ChatTurnDependencies,
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
) -> Generator[Packet, None]:
|
||||
def run_with_exception_capture() -> None:
|
||||
try:
|
||||
turn_func(messages, dependencies, *args, **kwargs)
|
||||
except Exception as e:
|
||||
dependencies.emitter.emit(
|
||||
Packet(ind=0, obj=PacketException(type="error", exception=e))
|
||||
)
|
||||
|
||||
thread = run_in_background(run_with_exception_capture)
|
||||
while True:
|
||||
pkt: Packet = dependencies.emitter.bus.get()
|
||||
if pkt.obj == OverallStop(type="stop"):
|
||||
yield pkt
|
||||
break
|
||||
elif isinstance(pkt.obj, PacketException):
|
||||
raise pkt.obj.exception
|
||||
else:
|
||||
yield pkt
|
||||
wait_on_background(thread)
|
||||
|
||||
return wrapper
|
||||
21
backend/onyx/chat/turn/infra/emitter.py
Normal file
21
backend/onyx/chat/turn/infra/emitter.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from queue import Queue
|
||||
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
|
||||
|
||||
class Emitter:
|
||||
"""Use this inside tools to emit arbitrary UI progress."""
|
||||
|
||||
def __init__(self, bus: Queue):
|
||||
self.bus = bus
|
||||
self.packet_history: list[Packet] = []
|
||||
|
||||
def emit(self, packet: Packet) -> None:
|
||||
self.bus.put(packet)
|
||||
self.packet_history.append(packet)
|
||||
|
||||
|
||||
def get_default_emitter() -> Emitter:
|
||||
bus: Queue[Packet] = Queue()
|
||||
emitter = Emitter(bus)
|
||||
return emitter
|
||||
168
backend/onyx/chat/turn/infra/session_sink.py
Normal file
168
backend/onyx/chat/turn/infra/session_sink.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# TODO: Figure out a way to persist information is robust to cancellation,
|
||||
# modular so easily testable in unit tests and evals [likely injecting some higher
|
||||
# level session manager and span sink], potentially has some robustness off the critical path,
|
||||
# and promotes clean separation of concerns.
|
||||
import re
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
|
||||
GeneratedImageFullResult,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
from onyx.chat.turn.models import ChatTurnContext
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.chat import create_search_doc_from_inference_section
|
||||
from onyx.db.chat import update_db_session_with_messages
|
||||
from onyx.db.models import ChatMessage__SearchDoc
|
||||
from onyx.db.models import ResearchAgentIteration
|
||||
from onyx.db.models import ResearchAgentIterationSubStep
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
|
||||
|
||||
def save_iteration(
|
||||
db_session: Session,
|
||||
message_id: int,
|
||||
chat_session_id: UUID,
|
||||
research_type: ResearchType,
|
||||
ctx: ChatTurnContext,
|
||||
final_answer: str,
|
||||
all_cited_documents: list[InferenceSection],
|
||||
) -> None:
|
||||
# first, insert the search_docs
|
||||
is_internet_marker_dict: dict[str, bool] = {}
|
||||
search_docs = [
|
||||
create_search_doc_from_inference_section(
|
||||
inference_section=inference_section,
|
||||
is_internet=is_internet_marker_dict.get(
|
||||
inference_section.center_chunk.document_id, False
|
||||
), # TODO: revisit
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
for inference_section in all_cited_documents
|
||||
]
|
||||
|
||||
# then, map_search_docs to message
|
||||
_insert_chat_message_search_doc_pair(
|
||||
message_id, [search_doc.id for search_doc in search_docs], db_session
|
||||
)
|
||||
|
||||
# lastly, insert the citations
|
||||
citation_dict: dict[int, int] = {}
|
||||
cited_doc_nrs = _extract_citation_numbers(final_answer)
|
||||
if search_docs:
|
||||
for cited_doc_nr in cited_doc_nrs:
|
||||
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=ctx.run_dependencies.llm.config.model_name,
|
||||
provider_type=ctx.run_dependencies.llm.config.model_provider,
|
||||
)
|
||||
num_tokens = len(llm_tokenizer.encode(final_answer or ""))
|
||||
# Update the chat message and its parent message in database
|
||||
update_db_session_with_messages(
|
||||
db_session=db_session,
|
||||
chat_message_id=message_id,
|
||||
chat_session_id=chat_session_id,
|
||||
is_agentic=research_type == ResearchType.DEEP,
|
||||
message=final_answer,
|
||||
citations=citation_dict,
|
||||
research_type=research_type,
|
||||
research_plan={},
|
||||
final_documents=search_docs,
|
||||
update_parent_message=True,
|
||||
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
|
||||
token_count=num_tokens,
|
||||
)
|
||||
|
||||
# TODO: I don't think this is the ideal schema for all use cases
|
||||
# find a better schema to store tool and reasoning calls
|
||||
for iteration_preparation in ctx.iteration_instructions:
|
||||
research_agent_iteration_step = ResearchAgentIteration(
|
||||
primary_question_id=message_id,
|
||||
reasoning=iteration_preparation.reasoning,
|
||||
purpose=iteration_preparation.purpose,
|
||||
iteration_nr=iteration_preparation.iteration_nr,
|
||||
)
|
||||
db_session.add(research_agent_iteration_step)
|
||||
|
||||
for iteration_answer in ctx.aggregated_context.global_iteration_responses:
|
||||
|
||||
retrieved_search_docs = convert_inference_sections_to_search_docs(
|
||||
list(iteration_answer.cited_documents.values())
|
||||
)
|
||||
|
||||
# Convert SavedSearchDoc objects to JSON-serializable format
|
||||
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
|
||||
|
||||
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
|
||||
primary_question_id=message_id,
|
||||
iteration_nr=iteration_answer.iteration_nr,
|
||||
iteration_sub_step_nr=iteration_answer.parallelization_nr,
|
||||
sub_step_instructions=iteration_answer.question,
|
||||
sub_step_tool_id=iteration_answer.tool_id,
|
||||
sub_answer=iteration_answer.answer,
|
||||
reasoning=iteration_answer.reasoning,
|
||||
claims=iteration_answer.claims,
|
||||
cited_doc_results=serialized_search_docs,
|
||||
generated_images=(
|
||||
GeneratedImageFullResult(images=iteration_answer.generated_images)
|
||||
if iteration_answer.generated_images
|
||||
else None
|
||||
),
|
||||
additional_data=iteration_answer.additional_data,
|
||||
)
|
||||
db_session.add(research_agent_iteration_sub_step)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _insert_chat_message_search_doc_pair(
|
||||
message_id: int, search_doc_ids: list[int], db_session: Session
|
||||
) -> None:
|
||||
"""
|
||||
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
|
||||
|
||||
Args:
|
||||
message_id: The ID of the chat message
|
||||
search_doc_id: The ID of the search document
|
||||
db_session: The database session
|
||||
"""
|
||||
for search_doc_id in search_doc_ids:
|
||||
chat_message_search_doc = ChatMessage__SearchDoc(
|
||||
chat_message_id=message_id, search_doc_id=search_doc_id
|
||||
)
|
||||
db_session.add(chat_message_search_doc)
|
||||
|
||||
|
||||
def _extract_citation_numbers(text: str) -> list[int]:
|
||||
"""
|
||||
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
|
||||
Returns a list of all unique citation numbers found.
|
||||
"""
|
||||
# Pattern to match [[number]] or [[number1, number2, ...]]
|
||||
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
|
||||
matches = re.findall(pattern, text)
|
||||
|
||||
cited_numbers = []
|
||||
for match in matches:
|
||||
# Split by comma and extract all numbers
|
||||
numbers = [int(num.strip()) for num in match.split(",")]
|
||||
cited_numbers.extend(numbers)
|
||||
|
||||
return list(set(cited_numbers)) # Return unique numbers
|
||||
|
||||
|
||||
def extract_final_answer_from_packets(packet_history: list[Packet]) -> str:
|
||||
"""Extract the final answer by concatenating all MessageDelta content."""
|
||||
final_answer = ""
|
||||
for packet in packet_history:
|
||||
if isinstance(packet.obj, MessageDelta) or isinstance(packet.obj, MessageStart):
|
||||
final_answer += packet.obj.content
|
||||
return final_answer
|
||||
177
backend/onyx/chat/turn/infra/sync_agent_stream_adapter.py
Normal file
177
backend/onyx/chat/turn/infra/sync_agent_stream_adapter.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import asyncio
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Iterator
|
||||
from typing import Generic
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
|
||||
from agents import Agent
|
||||
from agents import RunResultStreaming
|
||||
from agents.run import Runner
|
||||
|
||||
from onyx.chat.turn.models import ChatTurnContext
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class SyncAgentStream(Generic[T]):
|
||||
"""
|
||||
Convert an async streamed run into a sync iterator with cooperative cancellation.
|
||||
Runs the Agent in a background thread.
|
||||
|
||||
Usage:
|
||||
adapter = SyncStreamAdapter(
|
||||
agent=agent,
|
||||
input=input,
|
||||
context=context,
|
||||
max_turns=100,
|
||||
queue_maxsize=0, # optional backpressure
|
||||
)
|
||||
for ev in adapter: # sync iteration
|
||||
...
|
||||
# or cancel from elsewhere:
|
||||
adapter.cancel()
|
||||
"""
|
||||
|
||||
_SENTINEL = object()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
agent: Agent,
|
||||
input: list[dict],
|
||||
context: ChatTurnContext,
|
||||
max_turns: int = 100,
|
||||
queue_maxsize: int = 0,
|
||||
) -> None:
|
||||
self._agent = agent
|
||||
self._input = input
|
||||
self._context = context
|
||||
self._max_turns = max_turns
|
||||
|
||||
self._q: "queue.Queue[object]" = queue.Queue(maxsize=queue_maxsize)
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._streamed: RunResultStreaming | None = None
|
||||
self._exc: Optional[BaseException] = None
|
||||
self._cancel_requested = threading.Event()
|
||||
self._started = threading.Event()
|
||||
self._done = threading.Event()
|
||||
|
||||
self._start_thread()
|
||||
|
||||
# ---------- public sync API ----------
|
||||
|
||||
def __iter__(self) -> Iterator[T]:
|
||||
try:
|
||||
while True:
|
||||
item = self._q.get()
|
||||
if item is self._SENTINEL:
|
||||
# If the consumer thread raised, surface it now
|
||||
if self._exc is not None:
|
||||
raise self._exc
|
||||
# Normal completion
|
||||
return
|
||||
yield item # type: ignore[misc,return-value]
|
||||
finally:
|
||||
# Ensure we fully clean up whether we exited due to exception,
|
||||
# StopIteration, or external cancel.
|
||||
self.close()
|
||||
|
||||
def cancel(self) -> bool:
|
||||
"""
|
||||
Cooperatively cancel the underlying streamed run and shut down.
|
||||
Safe to call multiple times and from any thread.
|
||||
"""
|
||||
self._cancel_requested.set()
|
||||
loop = self._loop
|
||||
streamed = self._streamed
|
||||
if loop is not None and streamed is not None and not self._done.is_set():
|
||||
loop.call_soon_threadsafe(streamed.cancel)
|
||||
return True
|
||||
return False
|
||||
|
||||
def close(self, *, wait: bool = True) -> None:
|
||||
"""Idempotent shutdown."""
|
||||
self.cancel()
|
||||
# ask the loop to stop if it's still running
|
||||
loop = self._loop
|
||||
if loop is not None and loop.is_running():
|
||||
try:
|
||||
loop.call_soon_threadsafe(loop.stop)
|
||||
except Exception:
|
||||
pass
|
||||
# join the thread
|
||||
if wait and self._thread is not None and self._thread.is_alive():
|
||||
self._thread.join(timeout=5.0)
|
||||
|
||||
# ---------- internals ----------
|
||||
|
||||
def _start_thread(self) -> None:
|
||||
t = run_in_background(self._thread_main)
|
||||
self._thread = t
|
||||
# Optionally wait until the loop/worker is started so .cancel() is safe soon after init
|
||||
self._started.wait(timeout=1.0)
|
||||
|
||||
def _thread_main(self) -> None:
|
||||
loop = asyncio.new_event_loop()
|
||||
self._loop = loop
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
async def worker() -> None:
|
||||
try:
|
||||
# Start the streamed run inside the loop thread
|
||||
self._streamed = Runner.run_streamed(
|
||||
self._agent,
|
||||
self._input, # type: ignore[arg-type]
|
||||
context=self._context,
|
||||
max_turns=self._max_turns,
|
||||
)
|
||||
|
||||
# If cancel was requested before we created _streamed, honor it now
|
||||
if self._cancel_requested.is_set():
|
||||
await self._streamed.cancel() # type: ignore[func-returns-value]
|
||||
|
||||
# Consume async events and forward into the thread-safe queue
|
||||
async for ev in self._streamed.stream_events():
|
||||
# Early exit if a late cancel arrives
|
||||
if self._cancel_requested.is_set():
|
||||
# Try to cancel gracefully; don't break until cancel takes effect
|
||||
try:
|
||||
await self._streamed.cancel() # type: ignore[func-returns-value]
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
# This put() may block if queue_maxsize > 0 (backpressure)
|
||||
self._q.put(ev)
|
||||
|
||||
except BaseException as e:
|
||||
# Save exception to surface on the sync iterator side
|
||||
self._exc = e
|
||||
finally:
|
||||
# Signal end-of-stream
|
||||
self._q.put(self._SENTINEL)
|
||||
self._done.set()
|
||||
|
||||
# Mark started and run the worker to completion
|
||||
self._started.set()
|
||||
try:
|
||||
loop.run_until_complete(worker())
|
||||
finally:
|
||||
try:
|
||||
# Drain pending tasks/callbacks safely
|
||||
pending = asyncio.all_tasks(loop=loop)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
if pending:
|
||||
loop.run_until_complete(
|
||||
asyncio.gather(*pending, return_exceptions=True)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
loop.close()
|
||||
self._loop = None
|
||||
self._streamed = None
|
||||
50
backend/onyx/chat/turn/models.py
Normal file
50
backend/onyx/chat/turn/models.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import dataclasses
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from agents import FunctionTool
|
||||
from agents import Model
|
||||
from redis.client import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import AggregatedDRContext
|
||||
from onyx.agents.agent_search.dr.models import IterationInstructions
|
||||
from onyx.chat.turn.infra.emitter import Emitter
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
|
||||
OktaProfileTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatTurnDependencies:
|
||||
llm_model: Model
|
||||
llm: LLM
|
||||
db_session: Session
|
||||
tools: list[FunctionTool]
|
||||
redis_client: Redis
|
||||
emitter: Emitter
|
||||
search_pipeline: SearchTool | None = None
|
||||
image_generation_tool: ImageGenerationTool | None = None
|
||||
okta_profile_tool: OktaProfileTool | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatTurnContext:
|
||||
"""Context class to hold search tool and other dependencies"""
|
||||
|
||||
chat_session_id: UUID
|
||||
message_id: int
|
||||
research_type: ResearchType
|
||||
run_dependencies: ChatTurnDependencies
|
||||
aggregated_context: AggregatedDRContext
|
||||
current_run_step: int = 0
|
||||
iteration_instructions: list[IterationInstructions] = dataclasses.field(
|
||||
default_factory=list
|
||||
)
|
||||
web_fetch_results: list[dict] = dataclasses.field(default_factory=list)
|
||||
@@ -24,8 +24,6 @@ APP_PORT = 8080
|
||||
# prefix from requests directed towards the API server. In these cases, set this to `/api`
|
||||
APP_API_PREFIX = os.environ.get("API_PREFIX", "")
|
||||
|
||||
SKIP_WARM_UP = os.environ.get("SKIP_WARM_UP", "").lower() == "true"
|
||||
|
||||
#####
|
||||
# User Facing Features Configs
|
||||
#####
|
||||
@@ -351,10 +349,6 @@ except ValueError:
|
||||
CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT
|
||||
)
|
||||
|
||||
CELERY_WORKER_KG_PROCESSING_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_KG_PROCESSING_CONCURRENCY") or 4
|
||||
)
|
||||
|
||||
CELERY_WORKER_PRIMARY_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_PRIMARY_CONCURRENCY") or 4
|
||||
)
|
||||
@@ -362,18 +356,28 @@ CELERY_WORKER_PRIMARY_CONCURRENCY = int(
|
||||
CELERY_WORKER_PRIMARY_POOL_OVERFLOW = int(
|
||||
os.environ.get("CELERY_WORKER_PRIMARY_POOL_OVERFLOW") or 4
|
||||
)
|
||||
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY_DEFAULT = 4
|
||||
try:
|
||||
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY = int(
|
||||
os.environ.get(
|
||||
"CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY",
|
||||
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY_DEFAULT,
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY = (
|
||||
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY_DEFAULT
|
||||
)
|
||||
|
||||
# Consolidated background worker (merges heavy, kg_processing, monitoring, user_file_processing)
|
||||
CELERY_WORKER_BACKGROUND_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_BACKGROUND_CONCURRENCY") or 6
|
||||
)
|
||||
|
||||
# Individual worker concurrency settings (used when USE_LIGHTWEIGHT_BACKGROUND_WORKER is False or on Kuberenetes deployments)
|
||||
CELERY_WORKER_HEAVY_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_HEAVY_CONCURRENCY") or 4
|
||||
)
|
||||
|
||||
CELERY_WORKER_KG_PROCESSING_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_KG_PROCESSING_CONCURRENCY") or 2
|
||||
)
|
||||
|
||||
CELERY_WORKER_MONITORING_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_MONITORING_CONCURRENCY") or 1
|
||||
)
|
||||
|
||||
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY") or 2
|
||||
)
|
||||
|
||||
# The maximum number of tasks that can be queued up to sync to Vespa in a single pass
|
||||
VESPA_SYNC_MAX_TASKS = 8192
|
||||
@@ -511,6 +515,10 @@ SHAREPOINT_CONNECTOR_SIZE_THRESHOLD = int(
|
||||
os.environ.get("SHAREPOINT_CONNECTOR_SIZE_THRESHOLD", 20 * 1024 * 1024)
|
||||
)
|
||||
|
||||
BLOB_STORAGE_SIZE_THRESHOLD = int(
|
||||
os.environ.get("BLOB_STORAGE_SIZE_THRESHOLD", 20 * 1024 * 1024)
|
||||
)
|
||||
|
||||
JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
ignored_tag
|
||||
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
|
||||
@@ -756,7 +764,7 @@ MAX_FEDERATED_CHUNKS = int(
|
||||
# NOTE: this should only be enabled if you have purchased an enterprise license.
|
||||
# if you're interested in an enterprise license, please reach out to us at
|
||||
# founders@onyx.app OR message Chris Weaver or Yuhong Sun in the Onyx
|
||||
# Slack community (https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ)
|
||||
# Discord community https://discord.gg/4NA5SbzrWb
|
||||
ENTERPRISE_EDITION_ENABLED = (
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
|
||||
)
|
||||
|
||||
@@ -90,6 +90,7 @@ HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "").lower() == "true"
|
||||
|
||||
# Internet Search
|
||||
EXA_API_KEY = os.environ.get("EXA_API_KEY") or None
|
||||
SERPER_API_KEY = os.environ.get("SERPER_API_KEY") or None
|
||||
|
||||
NUM_INTERNET_SEARCH_RESULTS = int(os.environ.get("NUM_INTERNET_SEARCH_RESULTS") or 10)
|
||||
NUM_INTERNET_SEARCH_CHUNKS = int(os.environ.get("NUM_INTERNET_SEARCH_CHUNKS") or 50)
|
||||
|
||||
@@ -6,7 +6,7 @@ from enum import Enum
|
||||
|
||||
|
||||
ONYX_DEFAULT_APPLICATION_NAME = "Onyx"
|
||||
ONYX_SLACK_URL = "https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA"
|
||||
ONYX_DISCORD_URL = "https://discord.gg/4NA5SbzrWb"
|
||||
SLACK_USER_TOKEN_PREFIX = "xoxp-"
|
||||
SLACK_BOT_TOKEN_PREFIX = "xoxb-"
|
||||
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
|
||||
@@ -72,12 +72,13 @@ POSTGRES_CELERY_APP_NAME = "celery"
|
||||
POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
|
||||
POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
|
||||
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
|
||||
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
|
||||
POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME = "celery_worker_docprocessing"
|
||||
POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME = "celery_worker_docfetching"
|
||||
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
|
||||
POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME = "celery_worker_background"
|
||||
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
|
||||
POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME = "celery_worker_kg_processing"
|
||||
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
|
||||
POSTGRES_CELERY_WORKER_USER_FILE_PROCESSING_APP_NAME = (
|
||||
"celery_worker_user_file_processing"
|
||||
)
|
||||
@@ -108,7 +109,6 @@ KV_CUSTOMER_UUID_KEY = "customer_uuid"
|
||||
KV_INSTANCE_DOMAIN_KEY = "instance_domain"
|
||||
KV_ENTERPRISE_SETTINGS_KEY = "onyx_enterprise_settings"
|
||||
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
|
||||
KV_DOCUMENTS_SEEDED_KEY = "documents_seeded"
|
||||
KV_KG_CONFIG_KEY = "kg_config"
|
||||
|
||||
# NOTE: we use this timeout / 4 in various places to refresh a lock
|
||||
@@ -146,6 +146,13 @@ CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT = 3600 # 1 hour (in seconds)
|
||||
|
||||
CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
# Doc ID migration can be long-running; use a longer TTL and renew periodically
|
||||
CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT = 10 * 60 # 10 minutes (in seconds)
|
||||
|
||||
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
|
||||
|
||||
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
|
||||
|
||||
TMP_DRALPHA_PERSONA_NAME = "KG Beta"
|
||||
@@ -405,6 +412,7 @@ class OnyxRedisLocks:
|
||||
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
|
||||
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
|
||||
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
|
||||
USER_FILE_DOCID_MIGRATION_LOCK = "da_lock:user_file_docid_migration"
|
||||
|
||||
|
||||
class OnyxRedisSignals:
|
||||
|
||||
@@ -61,21 +61,16 @@ _BASE_EMBEDDING_MODELS = [
|
||||
dim=1536,
|
||||
index_name="danswer_chunk_text_embedding_3_small",
|
||||
),
|
||||
_BaseEmbeddingModel(
|
||||
name="google/gemini-embedding-001",
|
||||
dim=3072,
|
||||
index_name="danswer_chunk_google_gemini_embedding_001",
|
||||
),
|
||||
_BaseEmbeddingModel(
|
||||
name="google/text-embedding-005",
|
||||
dim=768,
|
||||
index_name="danswer_chunk_google_text_embedding_005",
|
||||
),
|
||||
_BaseEmbeddingModel(
|
||||
name="google/textembedding-gecko@003",
|
||||
dim=768,
|
||||
index_name="danswer_chunk_google_textembedding_gecko_003",
|
||||
),
|
||||
_BaseEmbeddingModel(
|
||||
name="google/textembedding-gecko@003",
|
||||
dim=768,
|
||||
index_name="danswer_chunk_textembedding_gecko_003",
|
||||
),
|
||||
_BaseEmbeddingModel(
|
||||
name="voyage/voyage-large-2-instruct",
|
||||
dim=1024,
|
||||
|
||||
@@ -41,7 +41,7 @@ All new connectors should have tests added to the `backend/tests/daily/connector
|
||||
|
||||
#### Implementing the new Connector
|
||||
|
||||
The connector must subclass one or more of LoadConnector, PollConnector, SlimConnector, or EventConnector.
|
||||
The connector must subclass one or more of LoadConnector, PollConnector, CheckpointedConnector, or CheckpointedConnectorWithPermSync
|
||||
|
||||
The `__init__` should take arguments for configuring what documents the connector will and where it finds those
|
||||
documents. For example, if you have a wiki site, it may include the configuration for the team, topic, folder, etc. of
|
||||
|
||||
@@ -25,7 +25,7 @@ from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
@@ -56,7 +56,7 @@ class BitbucketConnectorCheckpoint(ConnectorCheckpoint):
|
||||
|
||||
class BitbucketConnector(
|
||||
CheckpointedConnector[BitbucketConnectorCheckpoint],
|
||||
SlimConnector,
|
||||
SlimConnectorWithPermSync,
|
||||
):
|
||||
"""Connector for indexing Bitbucket Cloud pull requests.
|
||||
|
||||
@@ -266,7 +266,7 @@ class BitbucketConnector(
|
||||
"""Validate and deserialize a checkpoint instance from JSON."""
|
||||
return BitbucketConnectorCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from io import BytesIO
|
||||
from numbers import Integral
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
@@ -15,6 +17,7 @@ from botocore.exceptions import PartialCredentialsError
|
||||
from botocore.session import get_session
|
||||
from mypy_boto3_s3 import S3Client # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import BLOB_STORAGE_SIZE_THRESHOLD
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import BlobType
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -44,6 +47,10 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
DOWNLOAD_CHUNK_SIZE = 1024 * 1024
|
||||
SIZE_THRESHOLD_BUFFER = 64
|
||||
|
||||
|
||||
class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -58,6 +65,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
self.batch_size = batch_size
|
||||
self.s3_client: Optional[S3Client] = None
|
||||
self._allow_images: bool | None = None
|
||||
self.size_threshold: int | None = BLOB_STORAGE_SIZE_THRESHOLD
|
||||
|
||||
def set_allow_images(self, allow_images: bool) -> None:
|
||||
"""Set whether to process images in this connector."""
|
||||
@@ -195,11 +203,43 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
|
||||
return None
|
||||
|
||||
def _download_object(self, key: str) -> bytes:
|
||||
def _download_object(self, key: str) -> bytes | None:
|
||||
if self.s3_client is None:
|
||||
raise ConnectorMissingCredentialError("Blob storage")
|
||||
object = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
|
||||
return object["Body"].read()
|
||||
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
|
||||
body = response["Body"]
|
||||
|
||||
try:
|
||||
if self.size_threshold is None:
|
||||
return body.read()
|
||||
|
||||
return self._read_stream_with_limit(body, key)
|
||||
finally:
|
||||
body.close()
|
||||
|
||||
def _read_stream_with_limit(self, body: Any, key: str) -> bytes | None:
|
||||
if self.size_threshold is None:
|
||||
return body.read()
|
||||
|
||||
bytes_read = 0
|
||||
chunks: list[bytes] = []
|
||||
chunk_size = min(
|
||||
DOWNLOAD_CHUNK_SIZE, self.size_threshold + SIZE_THRESHOLD_BUFFER
|
||||
)
|
||||
|
||||
for chunk in body.iter_chunks(chunk_size=chunk_size):
|
||||
if not chunk:
|
||||
continue
|
||||
chunks.append(chunk)
|
||||
bytes_read += len(chunk)
|
||||
|
||||
if bytes_read > self.size_threshold + SIZE_THRESHOLD_BUFFER:
|
||||
logger.warning(
|
||||
f"{key} exceeds size threshold of {self.size_threshold}. Skipping."
|
||||
)
|
||||
return None
|
||||
|
||||
return b"".join(chunks)
|
||||
|
||||
# NOTE: Left in as may be useful for one-off access to documents and sharing across orgs.
|
||||
# def _get_presigned_url(self, key: str) -> str:
|
||||
@@ -236,6 +276,51 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
else:
|
||||
raise ValueError(f"Unsupported bucket type: {self.bucket_type}")
|
||||
|
||||
@staticmethod
|
||||
def _extract_size_bytes(obj: Mapping[str, Any]) -> int | None:
|
||||
"""Return the first numeric size field found on the object metadata."""
|
||||
|
||||
candidate_keys = (
|
||||
"Size",
|
||||
"size",
|
||||
"ContentLength",
|
||||
"content_length",
|
||||
"Content-Length",
|
||||
"contentLength",
|
||||
"bytes",
|
||||
"Bytes",
|
||||
)
|
||||
|
||||
def _normalize(value: Any) -> int | None:
|
||||
if value is None or isinstance(value, bool):
|
||||
return None
|
||||
if isinstance(value, Integral):
|
||||
return int(value)
|
||||
try:
|
||||
numeric = float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
if numeric >= 0 and numeric.is_integer():
|
||||
return int(numeric)
|
||||
return None
|
||||
|
||||
for key in candidate_keys:
|
||||
if key in obj:
|
||||
normalized = _normalize(obj.get(key))
|
||||
if normalized is not None:
|
||||
return normalized
|
||||
|
||||
for key, value in obj.items():
|
||||
if not isinstance(key, str):
|
||||
continue
|
||||
lowered_key = key.lower()
|
||||
if "size" in lowered_key or "length" in lowered_key:
|
||||
normalized = _normalize(value)
|
||||
if normalized is not None:
|
||||
return normalized
|
||||
|
||||
return None
|
||||
|
||||
def _yield_blob_objects(
|
||||
self,
|
||||
start: datetime,
|
||||
@@ -266,6 +351,18 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
key = obj["Key"]
|
||||
link = self._get_blob_link(key)
|
||||
|
||||
size_bytes = self._extract_size_bytes(obj)
|
||||
if (
|
||||
self.size_threshold is not None
|
||||
and isinstance(size_bytes, int)
|
||||
and self.size_threshold is not None
|
||||
and size_bytes > self.size_threshold
|
||||
):
|
||||
logger.warning(
|
||||
f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
# Handle image files
|
||||
if is_accepted_file_ext(file_ext, OnyxExtensionType.Multimedia):
|
||||
if not self._allow_images:
|
||||
@@ -277,6 +374,8 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
# Process the image file
|
||||
try:
|
||||
downloaded_file = self._download_object(key)
|
||||
if downloaded_file is None:
|
||||
continue
|
||||
|
||||
# TODO: Refactor to avoid direct DB access in connector
|
||||
# This will require broader refactoring across the codebase
|
||||
@@ -309,6 +408,8 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
# Handle text and document files
|
||||
try:
|
||||
downloaded_file = self._download_object(key)
|
||||
if downloaded_file is None:
|
||||
continue
|
||||
extraction_result = extract_text_and_images(
|
||||
BytesIO(downloaded_file), file_name=file_name
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
from atlassian.errors import ApiError # type: ignore
|
||||
from requests.exceptions import HTTPError
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -21,7 +22,6 @@ from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.confluence.utils import build_confluence_document_id
|
||||
from onyx.connectors.confluence.utils import convert_attachment_to_content
|
||||
from onyx.connectors.confluence.utils import datetime_from_string
|
||||
from onyx.connectors.confluence.utils import process_attachment
|
||||
from onyx.connectors.confluence.utils import update_param_in_path
|
||||
from onyx.connectors.confluence.utils import validate_attachment_filetype
|
||||
from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider
|
||||
@@ -41,6 +41,7 @@ from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
@@ -91,6 +92,7 @@ class ConfluenceCheckpoint(ConnectorCheckpoint):
|
||||
class ConfluenceConnector(
|
||||
CheckpointedConnector[ConfluenceCheckpoint],
|
||||
SlimConnector,
|
||||
SlimConnectorWithPermSync,
|
||||
CredentialsConnector,
|
||||
):
|
||||
def __init__(
|
||||
@@ -108,6 +110,7 @@ class ConfluenceConnector(
|
||||
# pages.
|
||||
labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP,
|
||||
timezone_offset: float = CONFLUENCE_TIMEZONE_OFFSET,
|
||||
scoped_token: bool = False,
|
||||
) -> None:
|
||||
self.wiki_base = wiki_base
|
||||
self.is_cloud = is_cloud
|
||||
@@ -118,6 +121,7 @@ class ConfluenceConnector(
|
||||
self.batch_size = batch_size
|
||||
self.labels_to_skip = labels_to_skip
|
||||
self.timezone_offset = timezone_offset
|
||||
self.scoped_token = scoped_token
|
||||
self._confluence_client: OnyxConfluence | None = None
|
||||
self._low_timeout_confluence_client: OnyxConfluence | None = None
|
||||
self._fetched_titles: set[str] = set()
|
||||
@@ -195,6 +199,7 @@ class ConfluenceConnector(
|
||||
is_cloud=self.is_cloud,
|
||||
url=self.wiki_base,
|
||||
credentials_provider=credentials_provider,
|
||||
scoped_token=self.scoped_token,
|
||||
)
|
||||
confluence_client._probe_connection(**self.probe_kwargs)
|
||||
confluence_client._initialize_connection(**self.final_kwargs)
|
||||
@@ -207,6 +212,7 @@ class ConfluenceConnector(
|
||||
url=self.wiki_base,
|
||||
credentials_provider=credentials_provider,
|
||||
timeout=3,
|
||||
scoped_token=self.scoped_token,
|
||||
)
|
||||
low_timeout_confluence_client._probe_connection(**self.probe_kwargs)
|
||||
low_timeout_confluence_client._initialize_connection(**self.final_kwargs)
|
||||
@@ -243,9 +249,26 @@ class ConfluenceConnector(
|
||||
page_query += " order by lastmodified asc"
|
||||
return page_query
|
||||
|
||||
def _construct_attachment_query(self, confluence_page_id: str) -> str:
|
||||
def _construct_attachment_query(
|
||||
self,
|
||||
confluence_page_id: str,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> str:
|
||||
attachment_query = f"type=attachment and container='{confluence_page_id}'"
|
||||
attachment_query += self.cql_label_filter
|
||||
# Add time filters to avoid reprocessing unchanged attachments during refresh
|
||||
if start:
|
||||
formatted_start_time = datetime.fromtimestamp(
|
||||
start, tz=self.timezone
|
||||
).strftime("%Y-%m-%d %H:%M")
|
||||
attachment_query += f" and lastmodified >= '{formatted_start_time}'"
|
||||
if end:
|
||||
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
attachment_query += f" and lastmodified <= '{formatted_end_time}'"
|
||||
attachment_query += " order by lastmodified asc"
|
||||
return attachment_query
|
||||
|
||||
def _get_comment_string_for_page_id(self, page_id: str) -> str:
|
||||
@@ -299,41 +322,8 @@ class ConfluenceConnector(
|
||||
sections.append(
|
||||
TextSection(text=comment_text, link=f"{page_url}#comments")
|
||||
)
|
||||
|
||||
# Process attachments
|
||||
if "children" in page and "attachment" in page["children"]:
|
||||
attachments = self.confluence_client.get_attachments_for_page(
|
||||
page_id, expand="metadata"
|
||||
)
|
||||
|
||||
for attachment in attachments.get("results", []):
|
||||
# Process each attachment
|
||||
result = process_attachment(
|
||||
self.confluence_client,
|
||||
attachment,
|
||||
page_id,
|
||||
self.allow_images,
|
||||
)
|
||||
|
||||
if result and result.text:
|
||||
# Create a section for the attachment text
|
||||
attachment_section = TextSection(
|
||||
text=result.text,
|
||||
link=f"{page_url}#attachment-{attachment['id']}",
|
||||
)
|
||||
sections.append(attachment_section)
|
||||
elif result and result.file_name:
|
||||
# Create an ImageSection for image attachments
|
||||
image_section = ImageSection(
|
||||
link=f"{page_url}#attachment-{attachment['id']}",
|
||||
image_file_id=result.file_name,
|
||||
)
|
||||
sections.append(image_section)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Error processing attachment '{attachment.get('title')}':",
|
||||
f"{result.error if result else 'Unknown error'}",
|
||||
)
|
||||
# Note: attachments are no longer merged into the page document.
|
||||
# They are indexed as separate documents downstream.
|
||||
|
||||
# Extract metadata
|
||||
metadata = {}
|
||||
@@ -382,9 +372,20 @@ class ConfluenceConnector(
|
||||
)
|
||||
|
||||
def _fetch_page_attachments(
|
||||
self, page: dict[str, Any], doc: Document
|
||||
) -> Document | ConnectorFailure:
|
||||
attachment_query = self._construct_attachment_query(page["id"])
|
||||
self,
|
||||
page: dict[str, Any],
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> tuple[list[Document], list[ConnectorFailure]]:
|
||||
"""
|
||||
Inline attachments are added directly to the document as text or image sections by
|
||||
this function. The returned documents/connectorfailures are for non-inline attachments
|
||||
and those at the end of the page.
|
||||
"""
|
||||
attachment_query = self._construct_attachment_query(page["id"], start, end)
|
||||
attachment_failures: list[ConnectorFailure] = []
|
||||
attachment_docs: list[Document] = []
|
||||
page_url = ""
|
||||
|
||||
for attachment in self.confluence_client.paginated_cql_retrieval(
|
||||
cql=attachment_query,
|
||||
@@ -413,11 +414,17 @@ class ConfluenceConnector(
|
||||
logger.info(
|
||||
f"Processing attachment: {attachment['title']} attached to page {page['title']}"
|
||||
)
|
||||
|
||||
# Attempt to get textual content or image summarization:
|
||||
object_url = build_confluence_document_id(
|
||||
self.wiki_base, attachment["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
# Attachment document id: use the download URL for stable identity
|
||||
try:
|
||||
object_url = build_confluence_document_id(
|
||||
self.wiki_base, attachment["_links"]["download"], self.is_cloud
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Invalid attachment url for id {attachment['id']}, skipping"
|
||||
)
|
||||
logger.debug(f"Error building attachment url: {e}")
|
||||
continue
|
||||
try:
|
||||
response = convert_attachment_to_content(
|
||||
confluence_client=self.confluence_client,
|
||||
@@ -430,38 +437,76 @@ class ConfluenceConnector(
|
||||
|
||||
content_text, file_storage_name = response
|
||||
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
if content_text:
|
||||
doc.sections.append(
|
||||
TextSection(
|
||||
text=content_text,
|
||||
link=object_url,
|
||||
)
|
||||
)
|
||||
sections.append(TextSection(text=content_text, link=object_url))
|
||||
elif file_storage_name:
|
||||
doc.sections.append(
|
||||
ImageSection(
|
||||
link=object_url,
|
||||
image_file_id=file_storage_name,
|
||||
)
|
||||
sections.append(
|
||||
ImageSection(link=object_url, image_file_id=file_storage_name)
|
||||
)
|
||||
|
||||
# Build attachment-specific metadata
|
||||
attachment_metadata: dict[str, str | list[str]] = {}
|
||||
if "space" in attachment:
|
||||
attachment_metadata["space"] = attachment["space"].get("name", "")
|
||||
labels: list[str] = []
|
||||
if "metadata" in attachment and "labels" in attachment["metadata"]:
|
||||
for label in attachment["metadata"]["labels"].get("results", []):
|
||||
labels.append(label.get("name", ""))
|
||||
if labels:
|
||||
attachment_metadata["labels"] = labels
|
||||
page_url = page_url or build_confluence_document_id(
|
||||
self.wiki_base, page["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
attachment_metadata["parent_page_id"] = page_url
|
||||
attachment_id = build_confluence_document_id(
|
||||
self.wiki_base, attachment["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
|
||||
primary_owners: list[BasicExpertInfo] | None = None
|
||||
if "version" in attachment and "by" in attachment["version"]:
|
||||
author = attachment["version"]["by"]
|
||||
display_name = author.get("displayName", "Unknown")
|
||||
email = author.get("email", "unknown@domain.invalid")
|
||||
primary_owners = [
|
||||
BasicExpertInfo(display_name=display_name, email=email)
|
||||
]
|
||||
|
||||
attachment_doc = Document(
|
||||
id=attachment_id,
|
||||
sections=sections,
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=attachment.get("title", object_url),
|
||||
metadata=attachment_metadata,
|
||||
doc_updated_at=(
|
||||
datetime_from_string(attachment["version"]["when"])
|
||||
if attachment.get("version")
|
||||
and attachment["version"].get("when")
|
||||
else None
|
||||
),
|
||||
primary_owners=primary_owners,
|
||||
)
|
||||
attachment_docs.append(attachment_doc)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to extract/summarize attachment {attachment['title']}",
|
||||
exc_info=e,
|
||||
)
|
||||
if is_atlassian_date_error(
|
||||
e
|
||||
): # propagate error to be caught and retried
|
||||
if is_atlassian_date_error(e):
|
||||
# propagate error to be caught and retried
|
||||
raise
|
||||
return ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=doc.id,
|
||||
document_link=object_url,
|
||||
),
|
||||
failure_message=f"Failed to extract/summarize attachment {attachment['title']} for doc {doc.id}",
|
||||
exception=e,
|
||||
attachment_failures.append(
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=object_url,
|
||||
document_link=object_url,
|
||||
),
|
||||
failure_message=f"Failed to extract/summarize attachment {attachment['title']} for doc {object_url}",
|
||||
exception=e,
|
||||
)
|
||||
)
|
||||
return doc
|
||||
|
||||
return attachment_docs, attachment_failures
|
||||
|
||||
def _fetch_document_batches(
|
||||
self,
|
||||
@@ -500,12 +545,18 @@ class ConfluenceConnector(
|
||||
if isinstance(doc_or_failure, ConnectorFailure):
|
||||
yield doc_or_failure
|
||||
continue
|
||||
# Now get attachments for that page:
|
||||
doc_or_failure = self._fetch_page_attachments(page, doc_or_failure)
|
||||
|
||||
# yield completed document (or failure)
|
||||
yield doc_or_failure
|
||||
|
||||
# Now get attachments for that page:
|
||||
attachment_docs, attachment_failures = self._fetch_page_attachments(
|
||||
page, start, end
|
||||
)
|
||||
# yield attached docs and failures
|
||||
yield from attachment_docs
|
||||
yield from attachment_failures
|
||||
|
||||
# Create checkpoint once a full page of results is returned
|
||||
if checkpoint.next_page_url and checkpoint.next_page_url != page_query_url:
|
||||
return checkpoint
|
||||
@@ -558,7 +609,21 @@ class ConfluenceConnector(
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> ConfluenceCheckpoint:
|
||||
return ConfluenceCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
@override
|
||||
def retrieve_all_slim_docs(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
return self._retrieve_all_slim_docs(
|
||||
start=start,
|
||||
end=end,
|
||||
callback=callback,
|
||||
include_permissions=False,
|
||||
)
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
@@ -568,12 +633,28 @@ class ConfluenceConnector(
|
||||
Return 'slim' docs (IDs + minimal permission data).
|
||||
Does not fetch actual text. Used primarily for incremental permission sync.
|
||||
"""
|
||||
return self._retrieve_all_slim_docs(
|
||||
start=start,
|
||||
end=end,
|
||||
callback=callback,
|
||||
include_permissions=True,
|
||||
)
|
||||
|
||||
def _retrieve_all_slim_docs(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
include_permissions: bool = True,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
doc_metadata_list: list[SlimDocument] = []
|
||||
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
|
||||
|
||||
space_level_access_info = get_all_space_permissions(
|
||||
self.confluence_client, self.is_cloud
|
||||
)
|
||||
space_level_access_info: dict[str, ExternalAccess] = {}
|
||||
if include_permissions:
|
||||
space_level_access_info = get_all_space_permissions(
|
||||
self.confluence_client, self.is_cloud
|
||||
)
|
||||
|
||||
def get_external_access(
|
||||
doc_id: str, restrictions: dict[str, Any], ancestors: list[dict[str, Any]]
|
||||
@@ -600,8 +681,10 @@ class ConfluenceConnector(
|
||||
doc_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=page_id,
|
||||
external_access=get_external_access(
|
||||
page_id, page_restrictions, page_ancestors
|
||||
external_access=(
|
||||
get_external_access(page_id, page_restrictions, page_ancestors)
|
||||
if include_permissions
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -636,8 +719,12 @@ class ConfluenceConnector(
|
||||
doc_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=attachment_id,
|
||||
external_access=get_external_access(
|
||||
attachment_id, attachment_restrictions, []
|
||||
external_access=(
|
||||
get_external_access(
|
||||
attachment_id, attachment_restrictions, []
|
||||
)
|
||||
if include_permissions
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -648,10 +735,10 @@ class ConfluenceConnector(
|
||||
|
||||
if callback and callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"retrieve_all_slim_documents: Stop signal detected"
|
||||
"retrieve_all_slim_docs_perm_sync: Stop signal detected"
|
||||
)
|
||||
if callback:
|
||||
callback.progress("retrieve_all_slim_documents", 1)
|
||||
callback.progress("retrieve_all_slim_docs_perm_sync", 1)
|
||||
|
||||
yield doc_metadata_list
|
||||
|
||||
@@ -676,6 +763,14 @@ class ConfluenceConnector(
|
||||
f"Unexpected error while validating Confluence settings: {e}"
|
||||
)
|
||||
|
||||
if self.space:
|
||||
try:
|
||||
self.low_timeout_confluence_client.get_space(self.space)
|
||||
except ApiError as e:
|
||||
raise ConnectorValidationError(
|
||||
"Invalid Confluence space key provided"
|
||||
) from e
|
||||
|
||||
if not spaces or not spaces.get("results"):
|
||||
raise ConnectorValidationError(
|
||||
"No Confluence spaces found. Either your credentials lack permissions, or "
|
||||
@@ -724,7 +819,7 @@ if __name__ == "__main__":
|
||||
end = datetime.now().timestamp()
|
||||
|
||||
# Fetch all `SlimDocuments`.
|
||||
for slim_doc in confluence_connector.retrieve_all_slim_documents():
|
||||
for slim_doc in confluence_connector.retrieve_all_slim_docs_perm_sync():
|
||||
print(slim_doc)
|
||||
|
||||
# Fetch all `Documents`.
|
||||
|
||||
@@ -41,6 +41,7 @@ from onyx.connectors.confluence.utils import _handle_http_error
|
||||
from onyx.connectors.confluence.utils import confluence_refresh_tokens
|
||||
from onyx.connectors.confluence.utils import get_start_param_from_url
|
||||
from onyx.connectors.confluence.utils import update_param_in_path
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import scoped_url
|
||||
from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.file_processing.html_utils import format_document_soup
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -87,16 +88,20 @@ class OnyxConfluence:
|
||||
url: str,
|
||||
credentials_provider: CredentialsProviderInterface,
|
||||
timeout: int | None = None,
|
||||
scoped_token: bool = False,
|
||||
# should generally not be passed in, but making it overridable for
|
||||
# easier testing
|
||||
confluence_user_profiles_override: list[dict[str, str]] | None = (
|
||||
CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE
|
||||
),
|
||||
) -> None:
|
||||
self.base_url = url #'/'.join(url.rstrip("/").split("/")[:-1])
|
||||
url = scoped_url(url, "confluence") if scoped_token else url
|
||||
|
||||
self._is_cloud = is_cloud
|
||||
self._url = url.rstrip("/")
|
||||
self._credentials_provider = credentials_provider
|
||||
|
||||
self.scoped_token = scoped_token
|
||||
self.redis_client: Redis | None = None
|
||||
self.static_credentials: dict[str, Any] | None = None
|
||||
if self._credentials_provider.is_dynamic():
|
||||
@@ -218,6 +223,34 @@ class OnyxConfluence:
|
||||
|
||||
with self._credentials_provider:
|
||||
credentials, _ = self._renew_credentials()
|
||||
if self.scoped_token:
|
||||
# v2 endpoint doesn't always work with scoped tokens, use v1
|
||||
token = credentials["confluence_access_token"]
|
||||
probe_url = f"{self.base_url}/rest/api/space?limit=1"
|
||||
import requests
|
||||
|
||||
logger.info(f"First and Last 5 of token: {token[:5]}...{token[-5:]}")
|
||||
|
||||
try:
|
||||
r = requests.get(
|
||||
probe_url,
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
timeout=10,
|
||||
)
|
||||
r.raise_for_status()
|
||||
except HTTPError as e:
|
||||
if e.response.status_code == 403:
|
||||
logger.warning(
|
||||
"scoped token authenticated but not valid for probe endpoint (spaces)"
|
||||
)
|
||||
else:
|
||||
if "WWW-Authenticate" in e.response.headers:
|
||||
logger.warning(
|
||||
f"WWW-Authenticate: {e.response.headers['WWW-Authenticate']}"
|
||||
)
|
||||
logger.warning(f"Full error: {e.response.text}")
|
||||
raise e
|
||||
return
|
||||
|
||||
# probe connection with direct client, no retries
|
||||
if "confluence_refresh_token" in credentials:
|
||||
@@ -236,6 +269,7 @@ class OnyxConfluence:
|
||||
logger.info("Probing Confluence with Personal Access Token.")
|
||||
url = self._url
|
||||
if self._is_cloud:
|
||||
logger.info("running with cloud client")
|
||||
confluence_client_with_minimal_retries = Confluence(
|
||||
url=url,
|
||||
username=credentials["confluence_username"],
|
||||
@@ -304,7 +338,9 @@ class OnyxConfluence:
|
||||
url = f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}"
|
||||
confluence = Confluence(url=url, oauth2=oauth2_dict, **kwargs)
|
||||
else:
|
||||
logger.info("Connecting to Confluence with Personal Access Token.")
|
||||
logger.info(
|
||||
f"Connecting to Confluence with Personal Access Token as user: {credentials['confluence_username']}"
|
||||
)
|
||||
if self._is_cloud:
|
||||
confluence = Confluence(
|
||||
url=self._url,
|
||||
@@ -930,6 +966,13 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
|
||||
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND
|
||||
|
||||
|
||||
def sanitize_attachment_title(title: str) -> str:
|
||||
"""
|
||||
Sanitize the attachment title to be a valid HTML attribute.
|
||||
"""
|
||||
return title.replace("<", "_").replace(">", "_").replace(" ", "_").replace(":", "_")
|
||||
|
||||
|
||||
def extract_text_from_confluence_html(
|
||||
confluence_client: OnyxConfluence,
|
||||
confluence_object: dict[str, Any],
|
||||
@@ -1032,6 +1075,16 @@ def extract_text_from_confluence_html(
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing ac:link-body: {e}")
|
||||
|
||||
for html_attachment in soup.findAll("ri:attachment"):
|
||||
# This extracts the text from inline attachments in the page so they can be
|
||||
# represented in the document text as plain text
|
||||
try:
|
||||
html_attachment.replaceWith(
|
||||
f"<attachment>{sanitize_attachment_title(html_attachment.attrs['ri:filename'])}</attachment>"
|
||||
) # to be replaced later
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing ac:attachment: {e}")
|
||||
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,10 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
from urllib.parse import urljoin
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from dateutil.parser import parse
|
||||
|
||||
from onyx.configs.app_configs import CONNECTOR_LOCALHOST_OVERRIDE
|
||||
@@ -148,3 +151,17 @@ def get_oauth_callback_uri(base_domain: str, connector_id: str) -> str:
|
||||
|
||||
def is_atlassian_date_error(e: Exception) -> bool:
|
||||
return "field 'updated' is invalid" in str(e)
|
||||
|
||||
|
||||
def get_cloudId(base_url: str) -> str:
|
||||
tenant_info_url = urljoin(base_url, "/_edge/tenant_info")
|
||||
response = requests.get(tenant_info_url, timeout=10)
|
||||
response.raise_for_status()
|
||||
return response.json()["cloudId"]
|
||||
|
||||
|
||||
def scoped_url(url: str, product: str) -> str:
|
||||
parsed = urlparse(url)
|
||||
base_url = parsed.scheme + "://" + parsed.netloc
|
||||
cloud_id = get_cloudId(base_url)
|
||||
return f"https://api.atlassian.com/ex/{product}/{cloud_id}{parsed.path}"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import importlib
|
||||
from typing import Any
|
||||
from typing import Type
|
||||
|
||||
@@ -6,60 +7,16 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||
from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
||||
from onyx.connectors.asana.connector import AsanaConnector
|
||||
from onyx.connectors.axero.connector import AxeroConnector
|
||||
from onyx.connectors.bitbucket.connector import BitbucketConnector
|
||||
from onyx.connectors.blob.connector import BlobStorageConnector
|
||||
from onyx.connectors.bookstack.connector import BookstackConnector
|
||||
from onyx.connectors.clickup.connector import ClickupConnector
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.discord.connector import DiscordConnector
|
||||
from onyx.connectors.discourse.connector import DiscourseConnector
|
||||
from onyx.connectors.document360.connector import Document360Connector
|
||||
from onyx.connectors.dropbox.connector import DropboxConnector
|
||||
from onyx.connectors.egnyte.connector import EgnyteConnector
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.fireflies.connector import FirefliesConnector
|
||||
from onyx.connectors.freshdesk.connector import FreshdeskConnector
|
||||
from onyx.connectors.gitbook.connector import GitbookConnector
|
||||
from onyx.connectors.github.connector import GithubConnector
|
||||
from onyx.connectors.gitlab.connector import GitlabConnector
|
||||
from onyx.connectors.gmail.connector import GmailConnector
|
||||
from onyx.connectors.gong.connector import GongConnector
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.google_site.connector import GoogleSitesConnector
|
||||
from onyx.connectors.guru.connector import GuruConnector
|
||||
from onyx.connectors.highspot.connector import HighspotConnector
|
||||
from onyx.connectors.hubspot.connector import HubSpotConnector
|
||||
from onyx.connectors.imap.connector import ImapConnector
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.interfaces import CredentialsConnector
|
||||
from onyx.connectors.interfaces import EventConnector
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.jira.connector import JiraConnector
|
||||
from onyx.connectors.linear.connector import LinearConnector
|
||||
from onyx.connectors.loopio.connector import LoopioConnector
|
||||
from onyx.connectors.mediawiki.wiki import MediaWikiConnector
|
||||
from onyx.connectors.mock_connector.connector import MockConnector
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.connectors.notion.connector import NotionConnector
|
||||
from onyx.connectors.outline.connector import OutlineConnector
|
||||
from onyx.connectors.productboard.connector import ProductboardConnector
|
||||
from onyx.connectors.salesforce.connector import SalesforceConnector
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from onyx.connectors.slab.connector import SlabConnector
|
||||
from onyx.connectors.slack.connector import SlackConnector
|
||||
from onyx.connectors.teams.connector import TeamsConnector
|
||||
from onyx.connectors.web.connector import WebConnector
|
||||
from onyx.connectors.wikipedia.connector import WikipediaConnector
|
||||
from onyx.connectors.xenforo.connector import XenforoConnector
|
||||
from onyx.connectors.zendesk.connector import ZendeskConnector
|
||||
from onyx.connectors.zulip.connector import ZulipConnector
|
||||
from onyx.connectors.registry import CONNECTOR_CLASS_MAP
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.credentials import backend_update_credential_json
|
||||
from onyx.db.credentials import fetch_credential_by_id
|
||||
@@ -72,101 +29,75 @@ class ConnectorMissingException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# Cache for already imported connector classes
|
||||
_connector_cache: dict[DocumentSource, Type[BaseConnector]] = {}
|
||||
|
||||
|
||||
def _load_connector_class(source: DocumentSource) -> Type[BaseConnector]:
|
||||
"""Dynamically load and cache a connector class."""
|
||||
if source in _connector_cache:
|
||||
return _connector_cache[source]
|
||||
|
||||
if source not in CONNECTOR_CLASS_MAP:
|
||||
raise ConnectorMissingException(f"Connector not found for source={source}")
|
||||
|
||||
mapping = CONNECTOR_CLASS_MAP[source]
|
||||
|
||||
try:
|
||||
module = importlib.import_module(mapping.module_path)
|
||||
connector_class = getattr(module, mapping.class_name)
|
||||
_connector_cache[source] = connector_class
|
||||
return connector_class
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ConnectorMissingException(
|
||||
f"Failed to import {mapping.class_name} from {mapping.module_path}: {e}"
|
||||
)
|
||||
|
||||
|
||||
def _validate_connector_supports_input_type(
|
||||
connector: Type[BaseConnector],
|
||||
input_type: InputType | None,
|
||||
source: DocumentSource,
|
||||
) -> None:
|
||||
"""Validate that a connector supports the requested input type."""
|
||||
if input_type is None:
|
||||
return
|
||||
|
||||
# Check each input type requirement separately for clarity
|
||||
load_state_unsupported = input_type == InputType.LOAD_STATE and not issubclass(
|
||||
connector, LoadConnector
|
||||
)
|
||||
|
||||
poll_unsupported = (
|
||||
input_type == InputType.POLL
|
||||
# Either poll or checkpoint works for this, in the future
|
||||
# all connectors should be checkpoint connectors
|
||||
and (
|
||||
not issubclass(connector, PollConnector)
|
||||
and not issubclass(connector, CheckpointedConnector)
|
||||
)
|
||||
)
|
||||
|
||||
event_unsupported = input_type == InputType.EVENT and not issubclass(
|
||||
connector, EventConnector
|
||||
)
|
||||
|
||||
if any([load_state_unsupported, poll_unsupported, event_unsupported]):
|
||||
raise ConnectorMissingException(
|
||||
f"Connector for source={source} does not accept input_type={input_type}"
|
||||
)
|
||||
|
||||
|
||||
def identify_connector_class(
|
||||
source: DocumentSource,
|
||||
input_type: InputType | None = None,
|
||||
) -> Type[BaseConnector]:
|
||||
connector_map = {
|
||||
DocumentSource.WEB: WebConnector,
|
||||
DocumentSource.FILE: LocalFileConnector,
|
||||
DocumentSource.SLACK: {
|
||||
InputType.POLL: SlackConnector,
|
||||
InputType.SLIM_RETRIEVAL: SlackConnector,
|
||||
},
|
||||
DocumentSource.GITHUB: GithubConnector,
|
||||
DocumentSource.GMAIL: GmailConnector,
|
||||
DocumentSource.GITLAB: GitlabConnector,
|
||||
DocumentSource.GITBOOK: GitbookConnector,
|
||||
DocumentSource.GOOGLE_DRIVE: GoogleDriveConnector,
|
||||
DocumentSource.BOOKSTACK: BookstackConnector,
|
||||
DocumentSource.OUTLINE: OutlineConnector,
|
||||
DocumentSource.CONFLUENCE: ConfluenceConnector,
|
||||
DocumentSource.JIRA: JiraConnector,
|
||||
DocumentSource.PRODUCTBOARD: ProductboardConnector,
|
||||
DocumentSource.SLAB: SlabConnector,
|
||||
DocumentSource.NOTION: NotionConnector,
|
||||
DocumentSource.ZULIP: ZulipConnector,
|
||||
DocumentSource.GURU: GuruConnector,
|
||||
DocumentSource.LINEAR: LinearConnector,
|
||||
DocumentSource.HUBSPOT: HubSpotConnector,
|
||||
DocumentSource.DOCUMENT360: Document360Connector,
|
||||
DocumentSource.GONG: GongConnector,
|
||||
DocumentSource.GOOGLE_SITES: GoogleSitesConnector,
|
||||
DocumentSource.ZENDESK: ZendeskConnector,
|
||||
DocumentSource.LOOPIO: LoopioConnector,
|
||||
DocumentSource.DROPBOX: DropboxConnector,
|
||||
DocumentSource.SHAREPOINT: SharepointConnector,
|
||||
DocumentSource.TEAMS: TeamsConnector,
|
||||
DocumentSource.SALESFORCE: SalesforceConnector,
|
||||
DocumentSource.DISCOURSE: DiscourseConnector,
|
||||
DocumentSource.AXERO: AxeroConnector,
|
||||
DocumentSource.CLICKUP: ClickupConnector,
|
||||
DocumentSource.MEDIAWIKI: MediaWikiConnector,
|
||||
DocumentSource.WIKIPEDIA: WikipediaConnector,
|
||||
DocumentSource.ASANA: AsanaConnector,
|
||||
DocumentSource.S3: BlobStorageConnector,
|
||||
DocumentSource.R2: BlobStorageConnector,
|
||||
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
|
||||
DocumentSource.OCI_STORAGE: BlobStorageConnector,
|
||||
DocumentSource.XENFORO: XenforoConnector,
|
||||
DocumentSource.DISCORD: DiscordConnector,
|
||||
DocumentSource.FRESHDESK: FreshdeskConnector,
|
||||
DocumentSource.FIREFLIES: FirefliesConnector,
|
||||
DocumentSource.EGNYTE: EgnyteConnector,
|
||||
DocumentSource.AIRTABLE: AirtableConnector,
|
||||
DocumentSource.HIGHSPOT: HighspotConnector,
|
||||
DocumentSource.IMAP: ImapConnector,
|
||||
DocumentSource.BITBUCKET: BitbucketConnector,
|
||||
# just for integration tests
|
||||
DocumentSource.MOCK_CONNECTOR: MockConnector,
|
||||
}
|
||||
connector_by_source = connector_map.get(source, {})
|
||||
# Load the connector class using lazy loading
|
||||
connector = _load_connector_class(source)
|
||||
|
||||
if isinstance(connector_by_source, dict):
|
||||
if input_type is None:
|
||||
# If not specified, default to most exhaustive update
|
||||
connector = connector_by_source.get(InputType.LOAD_STATE)
|
||||
else:
|
||||
connector = connector_by_source.get(input_type)
|
||||
else:
|
||||
connector = connector_by_source
|
||||
if connector is None:
|
||||
raise ConnectorMissingException(f"Connector not found for source={source}")
|
||||
# Validate connector supports the requested input_type
|
||||
_validate_connector_supports_input_type(connector, input_type, source)
|
||||
|
||||
if any(
|
||||
[
|
||||
(
|
||||
input_type == InputType.LOAD_STATE
|
||||
and not issubclass(connector, LoadConnector)
|
||||
),
|
||||
(
|
||||
input_type == InputType.POLL
|
||||
# either poll or checkpoint works for this, in the future
|
||||
# all connectors should be checkpoint connectors
|
||||
and (
|
||||
not issubclass(connector, PollConnector)
|
||||
and not issubclass(connector, CheckpointedConnector)
|
||||
)
|
||||
),
|
||||
(
|
||||
input_type == InputType.EVENT
|
||||
and not issubclass(connector, EventConnector)
|
||||
),
|
||||
]
|
||||
):
|
||||
raise ConnectorMissingException(
|
||||
f"Connector for source={source} does not accept input_type={input_type}"
|
||||
)
|
||||
return connector
|
||||
|
||||
|
||||
|
||||
@@ -219,12 +219,19 @@ def _get_batch_rate_limited(
|
||||
|
||||
|
||||
def _get_userinfo(user: NamedUser) -> dict[str, str]:
|
||||
def _safe_get(attr_name: str) -> str | None:
|
||||
try:
|
||||
return cast(str | None, getattr(user, attr_name))
|
||||
except GithubException:
|
||||
logger.debug(f"Error getting {attr_name} for user")
|
||||
return None
|
||||
|
||||
return {
|
||||
k: v
|
||||
for k, v in {
|
||||
"login": user.login,
|
||||
"name": user.name,
|
||||
"email": user.email,
|
||||
"login": _safe_get("login"),
|
||||
"name": _safe_get("name"),
|
||||
"email": _safe_get("email"),
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
@@ -232,7 +232,7 @@ def thread_to_document(
|
||||
)
|
||||
|
||||
|
||||
class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
|
||||
@@ -397,10 +397,10 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"retrieve_all_slim_documents: Stop signal detected"
|
||||
"retrieve_all_slim_docs_perm_sync: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("retrieve_all_slim_documents", 1)
|
||||
callback.progress("retrieve_all_slim_docs_perm_sync", 1)
|
||||
except HttpError as e:
|
||||
if _is_mail_service_disabled_error(e):
|
||||
logger.warning(
|
||||
@@ -431,7 +431,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
|
||||
@@ -64,7 +64,7 @@ from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
@@ -153,7 +153,7 @@ class DriveIdStatus(Enum):
|
||||
|
||||
|
||||
class GoogleDriveConnector(
|
||||
SlimConnector, CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint]
|
||||
SlimConnectorWithPermSync, CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1296,7 +1296,7 @@ class GoogleDriveConnector(
|
||||
callback.progress("_extract_slim_docs_from_google_drive", 1)
|
||||
yield slim_batch
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
|
||||
@@ -18,7 +18,7 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
@@ -38,7 +38,7 @@ class HighspotSpot(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class HighspotConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
"""
|
||||
Connector for loading data from Highspot.
|
||||
|
||||
@@ -362,7 +362,7 @@ class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
description = item_details.get("description", "")
|
||||
return title, description
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
import requests
|
||||
from hubspot import HubSpot # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.hubspot.rate_limit import HubSpotRateLimiter
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
@@ -25,6 +29,10 @@ HUBSPOT_API_URL = "https://api.hubapi.com/integrations/v1/me"
|
||||
# Available HubSpot object types
|
||||
AVAILABLE_OBJECT_TYPES = {"tickets", "companies", "deals", "contacts"}
|
||||
|
||||
HUBSPOT_PAGE_SIZE = 100
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -38,6 +46,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
self.batch_size = batch_size
|
||||
self._access_token = access_token
|
||||
self._portal_id: str | None = None
|
||||
self._rate_limiter = HubSpotRateLimiter()
|
||||
|
||||
# Set object types to fetch, default to all available types
|
||||
if object_types is None:
|
||||
@@ -77,6 +86,37 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
"""Set the portal ID."""
|
||||
self._portal_id = value
|
||||
|
||||
def _call_hubspot(self, func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
|
||||
return self._rate_limiter.call(func, *args, **kwargs)
|
||||
|
||||
def _paginated_results(
|
||||
self,
|
||||
fetch_page: Callable[..., Any],
|
||||
**kwargs: Any,
|
||||
) -> Generator[Any, None, None]:
|
||||
base_kwargs = dict(kwargs)
|
||||
base_kwargs.setdefault("limit", HUBSPOT_PAGE_SIZE)
|
||||
|
||||
after: str | None = None
|
||||
while True:
|
||||
page_kwargs = base_kwargs.copy()
|
||||
if after is not None:
|
||||
page_kwargs["after"] = after
|
||||
|
||||
page = self._call_hubspot(fetch_page, **page_kwargs)
|
||||
results = getattr(page, "results", [])
|
||||
for result in results:
|
||||
yield result
|
||||
|
||||
paging = getattr(page, "paging", None)
|
||||
next_page = getattr(paging, "next", None) if paging else None
|
||||
if next_page is None:
|
||||
break
|
||||
|
||||
after = getattr(next_page, "after", None)
|
||||
if after is None:
|
||||
break
|
||||
|
||||
def _clean_html_content(self, html_content: str) -> str:
|
||||
"""Clean HTML content and extract raw text"""
|
||||
if not html_content:
|
||||
@@ -150,78 +190,82 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get associated objects for a given object"""
|
||||
try:
|
||||
associations = api_client.crm.associations.v4.basic_api.get_page(
|
||||
associations_iter = self._paginated_results(
|
||||
api_client.crm.associations.v4.basic_api.get_page,
|
||||
object_type=from_object_type,
|
||||
object_id=object_id,
|
||||
to_object_type=to_object_type,
|
||||
)
|
||||
|
||||
associated_objects = []
|
||||
if associations.results:
|
||||
object_ids = [assoc.to_object_id for assoc in associations.results]
|
||||
object_ids = [assoc.to_object_id for assoc in associations_iter]
|
||||
|
||||
# Batch get the associated objects
|
||||
if to_object_type == "contacts":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = api_client.crm.contacts.basic_api.get_by_id(
|
||||
contact_id=obj_id,
|
||||
properties=[
|
||||
"firstname",
|
||||
"lastname",
|
||||
"email",
|
||||
"company",
|
||||
"jobtitle",
|
||||
],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch contact {obj_id}: {e}")
|
||||
associated_objects: list[dict[str, Any]] = []
|
||||
|
||||
elif to_object_type == "companies":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = api_client.crm.companies.basic_api.get_by_id(
|
||||
company_id=obj_id,
|
||||
properties=[
|
||||
"name",
|
||||
"domain",
|
||||
"industry",
|
||||
"city",
|
||||
"state",
|
||||
],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch company {obj_id}: {e}")
|
||||
if to_object_type == "contacts":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = self._call_hubspot(
|
||||
api_client.crm.contacts.basic_api.get_by_id,
|
||||
contact_id=obj_id,
|
||||
properties=[
|
||||
"firstname",
|
||||
"lastname",
|
||||
"email",
|
||||
"company",
|
||||
"jobtitle",
|
||||
],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch contact {obj_id}: {e}")
|
||||
|
||||
elif to_object_type == "deals":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = api_client.crm.deals.basic_api.get_by_id(
|
||||
deal_id=obj_id,
|
||||
properties=[
|
||||
"dealname",
|
||||
"amount",
|
||||
"dealstage",
|
||||
"closedate",
|
||||
"pipeline",
|
||||
],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch deal {obj_id}: {e}")
|
||||
elif to_object_type == "companies":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = self._call_hubspot(
|
||||
api_client.crm.companies.basic_api.get_by_id,
|
||||
company_id=obj_id,
|
||||
properties=[
|
||||
"name",
|
||||
"domain",
|
||||
"industry",
|
||||
"city",
|
||||
"state",
|
||||
],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch company {obj_id}: {e}")
|
||||
|
||||
elif to_object_type == "tickets":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = api_client.crm.tickets.basic_api.get_by_id(
|
||||
ticket_id=obj_id,
|
||||
properties=["subject", "content", "hs_ticket_priority"],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch ticket {obj_id}: {e}")
|
||||
elif to_object_type == "deals":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = self._call_hubspot(
|
||||
api_client.crm.deals.basic_api.get_by_id,
|
||||
deal_id=obj_id,
|
||||
properties=[
|
||||
"dealname",
|
||||
"amount",
|
||||
"dealstage",
|
||||
"closedate",
|
||||
"pipeline",
|
||||
],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch deal {obj_id}: {e}")
|
||||
|
||||
elif to_object_type == "tickets":
|
||||
for obj_id in object_ids:
|
||||
try:
|
||||
obj = self._call_hubspot(
|
||||
api_client.crm.tickets.basic_api.get_by_id,
|
||||
ticket_id=obj_id,
|
||||
properties=["subject", "content", "hs_ticket_priority"],
|
||||
)
|
||||
associated_objects.append(obj.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch ticket {obj_id}: {e}")
|
||||
|
||||
return associated_objects
|
||||
|
||||
@@ -239,33 +283,33 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get notes associated with a given object"""
|
||||
try:
|
||||
# Get associations to notes (engagement type)
|
||||
associations = api_client.crm.associations.v4.basic_api.get_page(
|
||||
associations_iter = self._paginated_results(
|
||||
api_client.crm.associations.v4.basic_api.get_page,
|
||||
object_type=object_type,
|
||||
object_id=object_id,
|
||||
to_object_type="notes",
|
||||
)
|
||||
|
||||
associated_notes = []
|
||||
if associations.results:
|
||||
note_ids = [assoc.to_object_id for assoc in associations.results]
|
||||
note_ids = [assoc.to_object_id for assoc in associations_iter]
|
||||
|
||||
# Batch get the associated notes
|
||||
for note_id in note_ids:
|
||||
try:
|
||||
# Notes are engagements in HubSpot, use the engagements API
|
||||
note = api_client.crm.objects.notes.basic_api.get_by_id(
|
||||
note_id=note_id,
|
||||
properties=[
|
||||
"hs_note_body",
|
||||
"hs_timestamp",
|
||||
"hs_created_by",
|
||||
"hubspot_owner_id",
|
||||
],
|
||||
)
|
||||
associated_notes.append(note.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch note {note_id}: {e}")
|
||||
associated_notes = []
|
||||
|
||||
for note_id in note_ids:
|
||||
try:
|
||||
# Notes are engagements in HubSpot, use the engagements API
|
||||
note = self._call_hubspot(
|
||||
api_client.crm.objects.notes.basic_api.get_by_id,
|
||||
note_id=note_id,
|
||||
properties=[
|
||||
"hs_note_body",
|
||||
"hs_timestamp",
|
||||
"hs_created_by",
|
||||
"hubspot_owner_id",
|
||||
],
|
||||
)
|
||||
associated_notes.append(note.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch note {note_id}: {e}")
|
||||
|
||||
return associated_notes
|
||||
|
||||
@@ -358,7 +402,9 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
api_client = HubSpot(access_token=self.access_token)
|
||||
all_tickets = api_client.crm.tickets.get_all(
|
||||
|
||||
tickets_iter = self._paginated_results(
|
||||
api_client.crm.tickets.basic_api.get_page,
|
||||
properties=[
|
||||
"subject",
|
||||
"content",
|
||||
@@ -371,7 +417,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
|
||||
for ticket in all_tickets:
|
||||
for ticket in tickets_iter:
|
||||
updated_at = ticket.updated_at.replace(tzinfo=None)
|
||||
if start is not None and updated_at < start.replace(tzinfo=None):
|
||||
continue
|
||||
@@ -459,7 +505,9 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
api_client = HubSpot(access_token=self.access_token)
|
||||
all_companies = api_client.crm.companies.get_all(
|
||||
|
||||
companies_iter = self._paginated_results(
|
||||
api_client.crm.companies.basic_api.get_page,
|
||||
properties=[
|
||||
"name",
|
||||
"domain",
|
||||
@@ -475,7 +523,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
|
||||
for company in all_companies:
|
||||
for company in companies_iter:
|
||||
updated_at = company.updated_at.replace(tzinfo=None)
|
||||
if start is not None and updated_at < start.replace(tzinfo=None):
|
||||
continue
|
||||
@@ -582,7 +630,9 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
api_client = HubSpot(access_token=self.access_token)
|
||||
all_deals = api_client.crm.deals.get_all(
|
||||
|
||||
deals_iter = self._paginated_results(
|
||||
api_client.crm.deals.basic_api.get_page,
|
||||
properties=[
|
||||
"dealname",
|
||||
"amount",
|
||||
@@ -598,7 +648,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
|
||||
for deal in all_deals:
|
||||
for deal in deals_iter:
|
||||
updated_at = deal.updated_at.replace(tzinfo=None)
|
||||
if start is not None and updated_at < start.replace(tzinfo=None):
|
||||
continue
|
||||
@@ -703,7 +753,9 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
api_client = HubSpot(access_token=self.access_token)
|
||||
all_contacts = api_client.crm.contacts.get_all(
|
||||
|
||||
contacts_iter = self._paginated_results(
|
||||
api_client.crm.contacts.basic_api.get_page,
|
||||
properties=[
|
||||
"firstname",
|
||||
"lastname",
|
||||
@@ -721,7 +773,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
|
||||
for contact in all_contacts:
|
||||
for contact in contacts_iter:
|
||||
updated_at = contact.updated_at.replace(tzinfo=None)
|
||||
if start is not None and updated_at < start.replace(tzinfo=None):
|
||||
continue
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user