Compare commits

..

1 Commits

Author SHA1 Message Date
Weves
18f13bd2ed Fix pre-commit to exclude .venv directory from lazy import check 2025-09-25 12:59:49 -07:00
666 changed files with 29451 additions and 30611 deletions

View File

@@ -8,9 +8,9 @@ on:
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
# tag nightly builds with "edge"
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
# don't tag cloud images with "latest"
LATEST_TAG: ${{ contains(github.ref_name, 'latest') && !contains(github.ref_name, 'cloud') }}
jobs:
build-and-push:
@@ -33,16 +33,7 @@ jobs:
run: |
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Check if stable release version
id: check_version
run: |
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]] && [[ "${{ github.ref_name }}" != *"cloud"* ]]; then
echo "is_stable=true" >> $GITHUB_OUTPUT
else
echo "is_stable=false" >> $GITHUB_OUTPUT
fi
- name: Checkout code
uses: actions/checkout@v4
@@ -55,8 +46,7 @@ jobs:
latest=false
tags: |
type=raw,value=${{ github.ref_name }}
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -129,8 +119,7 @@ jobs:
latest=false
tags: |
type=raw,value=${{ github.ref_name }}
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
- name: Login to Docker Hub
uses: docker/login-action@v3

View File

@@ -11,8 +11,8 @@ env:
BUILDKIT_PROGRESS: plain
DEPLOYMENT: ${{ contains(github.ref_name, 'cloud') && 'cloud' || 'standalone' }}
# tag nightly builds with "edge"
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
# don't tag cloud images with "latest"
LATEST_TAG: ${{ contains(github.ref_name, 'latest') && !contains(github.ref_name, 'cloud') }}
jobs:
@@ -145,15 +145,6 @@ jobs:
if: needs.check_model_server_changes.outputs.changed == 'true'
runs-on: ubuntu-latest
steps:
- name: Check if stable release version
id: check_version
run: |
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]] && [[ "${{ github.ref_name }}" != *"cloud"* ]]; then
echo "is_stable=true" >> $GITHUB_OUTPUT
else
echo "is_stable=false" >> $GITHUB_OUTPUT
fi
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
@@ -166,16 +157,11 @@ jobs:
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }} \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
if [[ "${{ steps.check_version.outputs.is_stable }}" == "true" ]]; then
if [[ "${{ env.LATEST_TAG }}" == "true" ]]; then
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:latest \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
fi
if [[ "${{ env.EDGE_TAG }}" == "true" ]]; then
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:edge \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
fi
- name: Run Trivy vulnerability scanner
uses: nick-fields/retry@v3

View File

@@ -7,10 +7,7 @@ on:
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
# tag nightly builds with "edge"
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
DEPLOYMENT: standalone
jobs:
@@ -48,15 +45,6 @@ jobs:
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Check if stable release version
id: check_version
run: |
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
echo "is_stable=true" >> $GITHUB_OUTPUT
else
echo "is_stable=false" >> $GITHUB_OUTPUT
fi
- name: Checkout
uses: actions/checkout@v4
@@ -69,8 +57,7 @@ jobs:
latest=false
tags: |
type=raw,value=${{ github.ref_name }}
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -139,8 +126,7 @@ jobs:
latest=false
tags: |
type=raw,value=${{ github.ref_name }}
type=raw,value=${{ steps.check_version.outputs.is_stable == 'true' && 'latest' || '' }}
type=raw,value=${{ env.EDGE_TAG == 'true' && 'edge' || '' }}
type=raw,value=${{ env.LATEST_TAG == 'true' && 'latest' || '' }}
- name: Login to Docker Hub
uses: docker/login-action@v3

View File

@@ -25,11 +25,9 @@ jobs:
- name: Add required Helm repositories
run: |
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
helm repo add bitnami https://charts.bitnami.com/bitnami
helm repo add onyx-vespa https://onyx-dot-app.github.io/vespa-helm-charts
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo add keda https://kedacore.github.io/charts
helm repo update
- name: Build chart dependencies

View File

@@ -20,7 +20,6 @@ env:
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
# LLMs
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@@ -80,7 +79,7 @@ jobs:
- name: Set up Standard Dependencies
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d minio relational_db cache index
docker compose up -d minio relational_db cache index
- name: Wait for services
run: |

View File

@@ -65,45 +65,35 @@ jobs:
if: steps.list-changed.outputs.changed == 'true'
run: |
echo "=== Adding Helm repositories ==="
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
helm repo add bitnami https://charts.bitnami.com/bitnami
helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo update
- name: Install Redis operator
if: steps.list-changed.outputs.changed == 'true'
shell: bash
run: |
echo "=== Installing redis-operator CRDs ==="
helm upgrade --install redis-operator ot-container-kit/redis-operator \
--namespace redis-operator --create-namespace --wait --timeout 300s
- name: Pre-pull required images
- name: Pre-pull critical images
if: steps.list-changed.outputs.changed == 'true'
run: |
echo "=== Pre-pulling required images to avoid timeout ==="
echo "=== Pre-pulling critical images to avoid timeout ==="
# Get kind cluster name
KIND_CLUSTER=$(kubectl config current-context | sed 's/kind-//')
echo "Kind cluster: $KIND_CLUSTER"
IMAGES=(
"ghcr.io/cloudnative-pg/cloudnative-pg:1.27.0"
"quay.io/opstree/redis:v7.0.15"
"docker.io/onyxdotapp/onyx-web-server:latest"
)
for image in "${IMAGES[@]}"; do
echo "Pre-pulling $image"
if docker pull "$image"; then
kind load docker-image "$image" --name "$KIND_CLUSTER" || echo "Failed to load $image into kind"
else
echo "Failed to pull $image"
fi
done
# Pre-pull images that are likely to be used
echo "Pre-pulling PostgreSQL image..."
docker pull postgres:15-alpine || echo "Failed to pull postgres:15-alpine"
kind load docker-image postgres:15-alpine --name $KIND_CLUSTER || echo "Failed to load postgres image"
echo "Pre-pulling Redis image..."
docker pull redis:7-alpine || echo "Failed to pull redis:7-alpine"
kind load docker-image redis:7-alpine --name $KIND_CLUSTER || echo "Failed to load redis image"
echo "Pre-pulling Onyx images..."
docker pull docker.io/onyxdotapp/onyx-web-server:latest || echo "Failed to pull onyx web server"
docker pull docker.io/onyxdotapp/onyx-backend:latest || echo "Failed to pull onyx backend"
kind load docker-image docker.io/onyxdotapp/onyx-web-server:latest --name $KIND_CLUSTER || echo "Failed to load onyx web server"
kind load docker-image docker.io/onyxdotapp/onyx-backend:latest --name $KIND_CLUSTER || echo "Failed to load onyx backend"
echo "=== Images loaded into Kind cluster ==="
docker exec "$KIND_CLUSTER"-control-plane crictl images | grep -E "(cloudnative-pg|redis|onyx)" || echo "Some images may still be loading..."
docker exec $KIND_CLUSTER-control-plane crictl images | grep -E "(postgres|redis|onyx)" || echo "Some images may still be loading..."
- name: Validate chart dependencies
if: steps.list-changed.outputs.changed == 'true'
@@ -159,7 +149,6 @@ jobs:
# Run the actual installation with detailed logging
echo "=== Starting ct install ==="
set +e
ct install --all \
--helm-extra-set-args="\
--set=nginx.enabled=false \
@@ -167,10 +156,8 @@ jobs:
--set=vespa.enabled=false \
--set=slackbot.enabled=false \
--set=postgresql.enabled=true \
--set=postgresql.nameOverride=cloudnative-pg \
--set=postgresql.cluster.storage.storageClass=standard \
--set=postgresql.primary.persistence.enabled=false \
--set=redis.enabled=true \
--set=redis.storageSpec.volumeClaimTemplate.spec.storageClassName=standard \
--set=webserver.replicaCount=1 \
--set=api.replicaCount=0 \
--set=inferenceCapability.replicaCount=0 \
@@ -182,20 +169,11 @@ jobs:
--set=celery_worker_light.replicaCount=0 \
--set=celery_worker_monitoring.replicaCount=0 \
--set=celery_worker_primary.replicaCount=0 \
--set=celery_worker_user_file_processing.replicaCount=0 \
--set=celery_worker_user_files_indexing.replicaCount=0" \
--helm-extra-args="--timeout 900s --debug" \
--debug --config ct.yaml
CT_EXIT=$?
set -e
if [[ $CT_EXIT -ne 0 ]]; then
echo "ct install failed with exit code $CT_EXIT"
exit $CT_EXIT
else
echo "=== Installation completed successfully ==="
fi
echo "=== Installation completed successfully ==="
kubectl get pods --all-namespaces
- name: Post-install verification
@@ -220,7 +198,7 @@ jobs:
echo "=== Recent logs for debugging ==="
kubectl logs --all-namespaces --tail=50 | grep -i "error\|timeout\|failed\|pull" || echo "No error logs found"
echo "=== Helm releases ==="
helm list --all-namespaces
# the following would install only changed charts, but we only have one chart so

View File

@@ -22,11 +22,9 @@ env:
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
@@ -264,7 +262,7 @@ jobs:
IMAGE_TAG=test \
INTEGRATION_TESTS_MODE=true \
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001 \
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
docker compose up \
relational_db \
index \
cache \
@@ -342,11 +340,9 @@ jobs:
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e CONFLUENCE_ACCESS_TOKEN_SCOPED=${CONFLUENCE_ACCESS_TOKEN_SCOPED} \
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
-e JIRA_API_TOKEN_SCOPED=${JIRA_API_TOKEN_SCOPED} \
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \

View File

@@ -19,11 +19,9 @@ env:
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
@@ -260,7 +258,7 @@ jobs:
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
INTEGRATION_TESTS_MODE=true \
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
docker compose up \
relational_db \
index \
cache \
@@ -339,11 +337,9 @@ jobs:
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e CONFLUENCE_ACCESS_TOKEN_SCOPED=${CONFLUENCE_ACCESS_TOKEN_SCOPED} \
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
-e JIRA_API_TOKEN_SCOPED=${JIRA_API_TOKEN_SCOPED} \
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \

View File

@@ -56,8 +56,6 @@ jobs:
provenance: false
sbom: false
push: true
outputs: type=registry
# no-cache: true
build-backend-image:
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
@@ -89,8 +87,6 @@ jobs:
provenance: false
sbom: false
push: true
outputs: type=registry
# no-cache: true
build-model-server-image:
runs-on: blacksmith-8vcpu-ubuntu-2404-arm
@@ -122,8 +118,6 @@ jobs:
provenance: false
sbom: false
push: true
outputs: type=registry
# no-cache: true
playwright-tests:
needs: [build-web-image, build-backend-image, build-model-server-image]
@@ -185,22 +179,17 @@ jobs:
working-directory: ./web
run: npx playwright install --with-deps
- name: Create .env file for Docker Compose
run: |
cat <<EOF > deployment/docker_compose/.env
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
AUTH_TYPE=basic
GEN_AI_API_KEY=${{ env.OPENAI_API_KEY }}
EXA_API_KEY=${{ env.EXA_API_KEY }}
REQUIRE_EMAIL_VERIFICATION=false
DISABLE_TELEMETRY=true
IMAGE_TAG=test
EOF
- name: Start Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
AUTH_TYPE=basic \
GEN_AI_API_KEY=${{ env.OPENAI_API_KEY }} \
EXA_API_KEY=${{ env.EXA_API_KEY }} \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
docker compose up -d
id: start_docker
- name: Wait for service to be ready
@@ -239,16 +228,14 @@ jobs:
- name: Run Playwright tests
working-directory: ./web
run: |
# Create test-results directory to ensure it exists for artifact upload
mkdir -p test-results
npx playwright test
run: npx playwright test
- uses: actions/upload-artifact@v4
if: always()
with:
# Includes test results and debug screenshots
name: playwright-test-results-${{ github.run_id }}
# Chromatic automatically defaults to the test-results directory.
# Replace with the path to your custom directory and adjust the CHROMATIC_ARCHIVE_LOCATION environment variable accordingly.
name: test-results
path: ./web/test-results
retention-days: 30

View File

@@ -20,13 +20,11 @@ env:
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
# Jira
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
# Gong
GONG_ACCESS_KEY: ${{ secrets.GONG_ACCESS_KEY }}
@@ -98,13 +96,6 @@ env:
TEAMS_DIRECTORY_ID: ${{ secrets.TEAMS_DIRECTORY_ID }}
TEAMS_SECRET: ${{ secrets.TEAMS_SECRET }}
# Bitbucket
BITBUCKET_WORKSPACE: ${{ secrets.BITBUCKET_WORKSPACE }}
BITBUCKET_REPOSITORIES: ${{ secrets.BITBUCKET_REPOSITORIES }}
BITBUCKET_PROJECTS: ${{ secrets.BITBUCKET_PROJECTS }}
BITBUCKET_EMAIL: ${{ secrets.BITBUCKET_EMAIL }}
BITBUCKET_API_TOKEN: ${{ secrets.BITBUCKET_API_TOKEN }}
jobs:
connectors-check:
# See https://runs-on.com/runners/linux/

View File

@@ -43,7 +43,7 @@ repos:
name: Check lazy imports are not directly imported
entry: python3 backend/scripts/check_lazy_imports.py
language: system
files: ^backend/.*\.py$
files: ^backend/(?!\.venv/).*\.py$
pass_filenames: false
# We would like to have a mypy pre-commit hook, but due to the fact that

View File

@@ -1,468 +1,444 @@
/* Copy this file into '.vscode/launch.json' or merge its contents into your existing configurations. */
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"compounds": [
{
// Dummy entry used to label the group
"name": "--- Compound ---",
"configurations": ["--- Individual ---"],
"presentation": {
"group": "1"
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"compounds": [
{
// Dummy entry used to label the group
"name": "--- Compound ---",
"configurations": ["--- Individual ---"],
"presentation": {
"group": "1"
}
},
{
"name": "Run All Onyx Services",
"configurations": [
"Web Server",
"Model Server",
"API Server",
"Slack Bot",
"Celery primary",
"Celery light",
"Celery heavy",
"Celery docfetching",
"Celery docprocessing",
"Celery beat",
"Celery monitoring"
],
"presentation": {
"group": "1"
},
"stopAll": true
},
{
"name": "Web / Model / API",
"configurations": ["Web Server", "Model Server", "API Server"],
"presentation": {
"group": "1"
},
"stopAll": true
},
{
"name": "Celery (all)",
"configurations": [
"Celery primary",
"Celery light",
"Celery heavy",
"Celery docfetching",
"Celery docprocessing",
"Celery beat",
"Celery monitoring"
],
"presentation": {
"group": "1"
},
"stopAll": true
}
],
"configurations": [
{
// Dummy entry used to label the group
"name": "--- Individual ---",
"type": "node",
"request": "launch",
"presentation": {
"group": "2",
"order": 0
}
},
{
"name": "Web Server",
"type": "node",
"request": "launch",
"cwd": "${workspaceRoot}/web",
"runtimeExecutable": "npm",
"envFile": "${workspaceFolder}/.vscode/.env",
"runtimeArgs": ["run", "dev"],
"presentation": {
"group": "2"
},
"console": "integratedTerminal",
"consoleTitle": "Web Server Console"
},
{
"name": "Model Server",
"consoleName": "Model Server",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
},
"args": ["model_server.main:app", "--reload", "--port", "9000"],
"presentation": {
"group": "2"
},
"consoleTitle": "Model Server Console"
},
{
"name": "API Server",
"consoleName": "API Server",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_ONYX_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
},
"args": ["onyx.main:app", "--reload", "--port", "8080"],
"presentation": {
"group": "2"
},
"consoleTitle": "API Server Console"
},
// For the listener to access the Slack API,
// ONYX_BOT_SLACK_APP_TOKEN & ONYX_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
{
"name": "Slack Bot",
"consoleName": "Slack Bot",
"type": "debugpy",
"request": "launch",
"program": "onyx/onyxbot/slack/listener.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"presentation": {
"group": "2"
},
"consoleTitle": "Slack Bot Console"
},
{
"name": "Celery primary",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.primary",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=primary@%n",
"-Q",
"celery"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery primary Console"
},
{
"name": "Celery light",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.light",
"worker",
"--pool=threads",
"--concurrency=64",
"--prefetch-multiplier=8",
"--loglevel=INFO",
"--hostname=light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,index_attempt_cleanup"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery light Console"
},
{
"name": "Celery heavy",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.heavy",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery heavy Console"
},
{
"name": "Celery docfetching",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.docfetching",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=docfetching@%n",
"-Q",
"connector_doc_fetching,user_files_indexing"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery docfetching Console",
"justMyCode": false
},
{
"name": "Run All Onyx Services",
"configurations": [
"Web Server",
"Model Server",
"API Server",
"Slack Bot",
"Celery primary",
"Celery light",
"Celery heavy",
"Celery docfetching",
"Celery docprocessing",
"Celery beat",
"Celery monitoring",
"Celery user file processing"
],
"presentation": {
"group": "1"
}
"name": "Celery docprocessing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"ENABLE_MULTIPASS_INDEXING": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.docprocessing",
"worker",
"--pool=threads",
"--concurrency=6",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=docprocessing@%n",
"-Q",
"docprocessing"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery docprocessing Console",
"justMyCode": false
},
{
"name": "Web / Model / API",
"configurations": ["Web Server", "Model Server", "API Server"],
"presentation": {
"group": "1"
}
{
"name": "Celery monitoring",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {},
"args": [
"-A",
"onyx.background.celery.versioned_apps.monitoring",
"worker",
"--pool=solo",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=monitoring@%n",
"-Q",
"monitoring"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery monitoring Console"
},
{
"name": "Celery beat",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.beat",
"beat",
"--loglevel=INFO"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery beat Console"
},
{
"name": "Pytest",
"consoleName": "Pytest",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-v"
// Specify a sepcific module/test to run or provide nothing to run all tests
//"tests/unit/onyx/llm/answering/test_prune_and_merge.py"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Pytest Console"
},
{
// Dummy entry used to label the group
"name": "--- Tasks ---",
"type": "node",
"request": "launch",
"presentation": {
"group": "3",
"order": 0
}
},
{
"name": "Clear and Restart External Volumes and Containers",
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": [
"${workspaceFolder}/backend/scripts/restart_containers.sh"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"stopOnEntry": true,
"presentation": {
"group": "3"
}
},
{
"name": "Eval CLI",
"type": "debugpy",
"request": "launch",
"program": "${workspaceFolder}/backend/onyx/evals/eval_cli.py",
"cwd": "${workspaceFolder}/backend",
"console": "integratedTerminal",
"justMyCode": false,
"envFile": "${workspaceFolder}/.vscode/.env",
"presentation": {
"group": "3"
},
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"--verbose"
],
"consoleTitle": "Eval CLI Console"
},
{
"name": "Celery (all)",
"configurations": [
"Celery primary",
"Celery light",
"Celery heavy",
"Celery docfetching",
"Celery docprocessing",
"Celery beat",
"Celery monitoring",
"Celery user file processing"
],
"presentation": {
"group": "1"
{
// Celery jobs launched through a single background script (legacy)
// Recommend using the "Celery (all)" compound launch instead.
"name": "Background Jobs",
"consoleName": "Background Jobs",
"type": "debugpy",
"request": "launch",
"program": "scripts/dev_run_background_jobs.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_ONYX_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
}
},
"stopAll": true
}
],
"configurations": [
{
// Dummy entry used to label the group
"name": "--- Individual ---",
"type": "node",
"request": "launch",
"presentation": {
"group": "2",
"order": 0
}
},
{
"name": "Web Server",
"type": "node",
"request": "launch",
"cwd": "${workspaceRoot}/web",
"runtimeExecutable": "npm",
"envFile": "${workspaceFolder}/.vscode/.env",
"runtimeArgs": ["run", "dev"],
"presentation": {
"group": "2"
{
"name": "Install Python Requirements",
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": [
"-c",
"pip install -r backend/requirements/default.txt && pip install -r backend/requirements/dev.txt && pip install -r backend/requirements/ee.txt && pip install -r backend/requirements/model_server.txt"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "3"
}
},
"console": "integratedTerminal",
"consoleTitle": "Web Server Console"
},
{
"name": "Model Server",
"consoleName": "Model Server",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
},
"args": ["model_server.main:app", "--reload", "--port", "9000"],
"presentation": {
"group": "2"
},
"consoleTitle": "Model Server Console"
},
{
"name": "API Server",
"consoleName": "API Server",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_ONYX_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
},
"args": ["onyx.main:app", "--reload", "--port", "8080"],
"presentation": {
"group": "2"
},
"consoleTitle": "API Server Console"
},
// For the listener to access the Slack API,
// ONYX_BOT_SLACK_APP_TOKEN & ONYX_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
{
"name": "Slack Bot",
"consoleName": "Slack Bot",
"type": "debugpy",
"request": "launch",
"program": "onyx/onyxbot/slack/listener.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"presentation": {
"group": "2"
},
"consoleTitle": "Slack Bot Console"
},
{
"name": "Celery primary",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.primary",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=primary@%n",
"-Q",
"celery"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery primary Console"
},
{
"name": "Celery light",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.light",
"worker",
"--pool=threads",
"--concurrency=64",
"--prefetch-multiplier=8",
"--loglevel=INFO",
"--hostname=light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,index_attempt_cleanup"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery light Console"
},
{
"name": "Celery heavy",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.heavy",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery heavy Console"
},
{
"name": "Celery docfetching",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.docfetching",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=docfetching@%n",
"-Q",
"connector_doc_fetching,user_files_indexing"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery docfetching Console",
"justMyCode": false
},
{
"name": "Celery docprocessing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"ENABLE_MULTIPASS_INDEXING": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.docprocessing",
"worker",
"--pool=threads",
"--concurrency=6",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=docprocessing@%n",
"-Q",
"docprocessing"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery docprocessing Console"
},
{
"name": "Celery beat",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.beat",
"beat",
"--loglevel=INFO"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery beat Console"
},
{
"name": "Celery monitoring",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {},
"args": [
"-A",
"onyx.background.celery.versioned_apps.monitoring",
"worker",
"--pool=solo",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=monitoring@%n",
"-Q",
"monitoring"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery monitoring Console"
},
{
"name": "Celery user file processing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"args": [
"-A",
"onyx.background.celery.versioned_apps.user_file_processing",
"worker",
"--loglevel=INFO",
"--hostname=user_file_processing@%n",
"--pool=threads",
"-Q",
"user_file_processing,user_file_project_sync"
],
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"presentation": {
"group": "2"
},
"consoleTitle": "Celery user file processing Console"
},
{
"name": "Pytest",
"consoleName": "Pytest",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-v"
// Specify a specific module/test to run or provide nothing to run all tests
// "tests/unit/onyx/llm/answering/test_prune_and_merge.py"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Pytest Console"
},
{
// Dummy entry used to label the group
"name": "--- Tasks ---",
"type": "node",
"request": "launch",
"presentation": {
"group": "3",
"order": 0
}
},
{
"name": "Clear and Restart External Volumes and Containers",
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": [
"${workspaceFolder}/backend/scripts/restart_containers.sh"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"stopOnEntry": true,
"presentation": {
"group": "3"
}
},
{
"name": "Eval CLI",
"type": "debugpy",
"request": "launch",
"program": "${workspaceFolder}/backend/onyx/evals/eval_cli.py",
"cwd": "${workspaceFolder}/backend",
"console": "integratedTerminal",
"justMyCode": false,
"envFile": "${workspaceFolder}/.vscode/.env",
"presentation": {
"group": "3"
},
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": ["--verbose"],
"consoleTitle": "Eval CLI Console"
},
{
// Celery jobs launched through a single background script (legacy)
// Recommend using the "Celery (all)" compound launch instead.
"name": "Background Jobs",
"consoleName": "Background Jobs",
"type": "debugpy",
"request": "launch",
"program": "scripts/dev_run_background_jobs.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_ONYX_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
}
},
{
"name": "Install Python Requirements",
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": [
"-c",
"pip install -r backend/requirements/default.txt && pip install -r backend/requirements/dev.txt && pip install -r backend/requirements/ee.txt && pip install -r backend/requirements/model_server.txt"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "3"
}
},
{
// script to generate the openapi schema
"name": "Onyx OpenAPI Schema Generator",
@@ -475,7 +451,10 @@
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": ["--filename", "generated/openapi.json"]
"args": [
"--filename",
"generated/openapi.json"
]
},
{
// script to debug multi tenant db issues
@@ -500,12 +479,13 @@
"generated/tenants_by_num_docs.csv"
]
},
{
"name": "Debug React Web App in Chrome",
"type": "chrome",
"request": "launch",
"url": "http://localhost:3000",
"webRoot": "${workspaceFolder}/web"
}
]
}
{
"name": "Debug React Web App in Chrome",
"type": "chrome",
"request": "launch",
"url": "http://localhost:3000",
"webRoot": "${workspaceFolder}/web"
}
]
}

View File

@@ -105,11 +105,6 @@ pip install -r backend/requirements/ee.txt
pip install -r backend/requirements/model_server.txt
```
Fix vscode/cursor auto-imports:
```bash
pip install -e .
```
Install Playwright for Python (headless browser required by the Web Connector)
In the activated Python virtualenv, install Playwright for Python by running:

View File

@@ -1,389 +0,0 @@
"""Migration 2: User file data preparation and backfill
Revision ID: 0cd424f32b1d
Revises: 9b66d3156fc6
Create Date: 2025-09-22 09:44:42.727034
This migration populates the new columns added in migration 1.
It prepares data for the UUID transition and relationship migration.
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy import text
import logging
logger = logging.getLogger("alembic.runtime.migration")
# revision identifiers, used by Alembic.
revision = "0cd424f32b1d"
down_revision = "9b66d3156fc6"
branch_labels = None
depends_on = None
def upgrade() -> None:
"""Populate new columns with data."""
bind = op.get_bind()
inspector = sa.inspect(bind)
# === Step 1: Populate user_file.new_id ===
user_file_columns = [col["name"] for col in inspector.get_columns("user_file")]
has_new_id = "new_id" in user_file_columns
if has_new_id:
logger.info("Populating user_file.new_id with UUIDs...")
# Count rows needing UUIDs
null_count = bind.execute(
text("SELECT COUNT(*) FROM user_file WHERE new_id IS NULL")
).scalar_one()
if null_count > 0:
logger.info(f"Generating UUIDs for {null_count} user_file records...")
# Populate in batches to avoid long locks
batch_size = 10000
total_updated = 0
while True:
result = bind.execute(
text(
"""
UPDATE user_file
SET new_id = gen_random_uuid()
WHERE new_id IS NULL
AND id IN (
SELECT id FROM user_file
WHERE new_id IS NULL
LIMIT :batch_size
)
"""
),
{"batch_size": batch_size},
)
updated = result.rowcount
total_updated += updated
if updated < batch_size:
break
logger.info(f" Updated {total_updated}/{null_count} records...")
logger.info(f"Generated UUIDs for {total_updated} user_file records")
# Verify all records have UUIDs
remaining_null = bind.execute(
text("SELECT COUNT(*) FROM user_file WHERE new_id IS NULL")
).scalar_one()
if remaining_null > 0:
raise Exception(
f"Failed to populate all user_file.new_id values ({remaining_null} NULL)"
)
# Lock down the column
op.alter_column("user_file", "new_id", nullable=False)
op.alter_column("user_file", "new_id", server_default=None)
logger.info("Locked down user_file.new_id column")
# === Step 2: Populate persona__user_file.user_file_id_uuid ===
persona_user_file_columns = [
col["name"] for col in inspector.get_columns("persona__user_file")
]
if has_new_id and "user_file_id_uuid" in persona_user_file_columns:
logger.info("Populating persona__user_file.user_file_id_uuid...")
# Count rows needing update
null_count = bind.execute(
text(
"""
SELECT COUNT(*) FROM persona__user_file
WHERE user_file_id IS NOT NULL AND user_file_id_uuid IS NULL
"""
)
).scalar_one()
if null_count > 0:
logger.info(f"Updating {null_count} persona__user_file records...")
# Update in batches
batch_size = 10000
total_updated = 0
while True:
result = bind.execute(
text(
"""
UPDATE persona__user_file p
SET user_file_id_uuid = uf.new_id
FROM user_file uf
WHERE p.user_file_id = uf.id
AND p.user_file_id_uuid IS NULL
AND p.persona_id IN (
SELECT persona_id
FROM persona__user_file
WHERE user_file_id_uuid IS NULL
LIMIT :batch_size
)
"""
),
{"batch_size": batch_size},
)
updated = result.rowcount
total_updated += updated
if updated < batch_size:
break
logger.info(f" Updated {total_updated}/{null_count} records...")
logger.info(f"Updated {total_updated} persona__user_file records")
# Verify all records are populated
remaining_null = bind.execute(
text(
"""
SELECT COUNT(*) FROM persona__user_file
WHERE user_file_id IS NOT NULL AND user_file_id_uuid IS NULL
"""
)
).scalar_one()
if remaining_null > 0:
raise Exception(
f"Failed to populate all persona__user_file.user_file_id_uuid values ({remaining_null} NULL)"
)
op.alter_column("persona__user_file", "user_file_id_uuid", nullable=False)
logger.info("Locked down persona__user_file.user_file_id_uuid column")
# === Step 3: Create user_project records from chat_folder ===
if "chat_folder" in inspector.get_table_names():
logger.info("Creating user_project records from chat_folder...")
result = bind.execute(
text(
"""
INSERT INTO user_project (user_id, name)
SELECT cf.user_id, cf.name
FROM chat_folder cf
WHERE NOT EXISTS (
SELECT 1
FROM user_project up
WHERE up.user_id = cf.user_id AND up.name = cf.name
)
"""
)
)
logger.info(f"Created {result.rowcount} user_project records from chat_folder")
# === Step 4: Populate chat_session.project_id ===
chat_session_columns = [
col["name"] for col in inspector.get_columns("chat_session")
]
if "folder_id" in chat_session_columns and "project_id" in chat_session_columns:
logger.info("Populating chat_session.project_id...")
# Count sessions needing update
null_count = bind.execute(
text(
"""
SELECT COUNT(*) FROM chat_session
WHERE project_id IS NULL AND folder_id IS NOT NULL
"""
)
).scalar_one()
if null_count > 0:
logger.info(f"Updating {null_count} chat_session records...")
result = bind.execute(
text(
"""
UPDATE chat_session cs
SET project_id = up.id
FROM chat_folder cf
JOIN user_project up ON up.user_id = cf.user_id AND up.name = cf.name
WHERE cs.folder_id = cf.id AND cs.project_id IS NULL
"""
)
)
logger.info(f"Updated {result.rowcount} chat_session records")
# Verify all records are populated
remaining_null = bind.execute(
text(
"""
SELECT COUNT(*) FROM chat_session
WHERE project_id IS NULL AND folder_id IS NOT NULL
"""
)
).scalar_one()
if remaining_null > 0:
logger.warning(
f"Warning: {remaining_null} chat_session records could not be mapped to projects"
)
# === Step 5: Update plaintext FileRecord IDs/display names to UUID scheme ===
# Prior to UUID migration, plaintext cache files were stored with file_id like 'plain_text_<int_id>'.
# After migration, we use 'plaintext_<uuid>' (note the name change to 'plaintext_').
# This step remaps existing FileRecord rows to the new naming while preserving object_key/bucket.
logger.info("Updating plaintext FileRecord ids and display names to UUID scheme...")
# Count legacy plaintext records that can be mapped to UUID user_file ids
count_query = text(
"""
SELECT COUNT(*)
FROM file_record fr
JOIN user_file uf ON fr.file_id = CONCAT('plaintext_', uf.id::text)
WHERE LOWER(fr.file_origin::text) = 'plaintext_cache'
"""
)
legacy_count = bind.execute(count_query).scalar_one()
if legacy_count and legacy_count > 0:
logger.info(f"Found {legacy_count} legacy plaintext file records to update")
# Update display_name first for readability (safe regardless of rename)
bind.execute(
text(
"""
UPDATE file_record fr
SET display_name = CONCAT('Plaintext for user file ', uf.new_id::text)
FROM user_file uf
WHERE LOWER(fr.file_origin::text) = 'plaintext_cache'
AND fr.file_id = CONCAT('plaintext_', uf.id::text)
"""
)
)
# Remap file_id from 'plaintext_<int>' -> 'plaintext_<uuid>' using transitional new_id
# Use a single UPDATE ... WHERE file_id LIKE 'plain_text_%'
# and ensure it aligns to existing user_file ids to avoid renaming unrelated rows
result = bind.execute(
text(
"""
UPDATE file_record fr
SET file_id = CONCAT('plaintext_', uf.new_id::text)
FROM user_file uf
WHERE LOWER(fr.file_origin::text) = 'plaintext_cache'
AND fr.file_id = CONCAT('plaintext_', uf.id::text)
"""
)
)
logger.info(
f"Updated {result.rowcount} plaintext file_record ids to UUID scheme"
)
# === Step 6: Ensure document_id_migrated default TRUE and backfill existing FALSE ===
# New records should default to migrated=True so the migration task won't run for them.
# Existing rows that had a legacy document_id should be marked as not migrated to be processed.
# Backfill existing records: if document_id is not null, set to FALSE
bind.execute(
text(
"""
UPDATE user_file
SET document_id_migrated = FALSE
WHERE document_id IS NOT NULL
"""
)
)
# === Step 7: Backfill user_file.status from index_attempt ===
logger.info("Backfilling user_file.status from index_attempt...")
# Update user_file status based on latest index attempt
# Using CTEs instead of temp tables for asyncpg compatibility
result = bind.execute(
text(
"""
WITH latest_attempt AS (
SELECT DISTINCT ON (ia.connector_credential_pair_id)
ia.connector_credential_pair_id,
ia.status
FROM index_attempt ia
ORDER BY ia.connector_credential_pair_id, ia.time_updated DESC
),
uf_to_ccp AS (
SELECT DISTINCT uf.id AS uf_id, ccp.id AS cc_pair_id
FROM user_file uf
JOIN document_by_connector_credential_pair dcc
ON dcc.id = REPLACE(uf.document_id, 'USER_FILE_CONNECTOR__', 'FILE_CONNECTOR__')
JOIN connector_credential_pair ccp
ON ccp.connector_id = dcc.connector_id
AND ccp.credential_id = dcc.credential_id
)
UPDATE user_file uf
SET status = CASE
WHEN la.status IN ('NOT_STARTED', 'IN_PROGRESS') THEN 'PROCESSING'
WHEN la.status = 'SUCCESS' THEN 'COMPLETED'
ELSE 'FAILED'
END
FROM uf_to_ccp ufc
LEFT JOIN latest_attempt la
ON la.connector_credential_pair_id = ufc.cc_pair_id
WHERE uf.id = ufc.uf_id
AND uf.status = 'PROCESSING'
"""
)
)
logger.info(f"Updated status for {result.rowcount} user_file records")
logger.info("Migration 2 (data preparation) completed successfully")
def downgrade() -> None:
"""Reset populated data to allow clean downgrade of schema."""
bind = op.get_bind()
inspector = sa.inspect(bind)
logger.info("Starting downgrade of data preparation...")
# Reset user_file columns to allow nulls before data removal
if "user_file" in inspector.get_table_names():
columns = [col["name"] for col in inspector.get_columns("user_file")]
if "new_id" in columns:
op.alter_column(
"user_file",
"new_id",
nullable=True,
server_default=sa.text("gen_random_uuid()"),
)
# Optionally clear the data
# bind.execute(text("UPDATE user_file SET new_id = NULL"))
logger.info("Reset user_file.new_id to nullable")
# Reset persona__user_file.user_file_id_uuid
if "persona__user_file" in inspector.get_table_names():
columns = [col["name"] for col in inspector.get_columns("persona__user_file")]
if "user_file_id_uuid" in columns:
op.alter_column("persona__user_file", "user_file_id_uuid", nullable=True)
# Optionally clear the data
# bind.execute(text("UPDATE persona__user_file SET user_file_id_uuid = NULL"))
logger.info("Reset persona__user_file.user_file_id_uuid to nullable")
# Note: We don't delete user_project records or reset chat_session.project_id
# as these might be in use and can be handled by the schema downgrade
# Reset user_file.status to default
if "user_file" in inspector.get_table_names():
columns = [col["name"] for col in inspector.get_columns("user_file")]
if "status" in columns:
bind.execute(text("UPDATE user_file SET status = 'PROCESSING'"))
logger.info("Reset user_file.status to default")
logger.info("Downgrade completed successfully")

View File

@@ -1,261 +0,0 @@
"""Migration 3: User file relationship migration
Revision ID: 16c37a30adf2
Revises: 0cd424f32b1d
Create Date: 2025-09-22 09:47:34.175596
This migration converts folder-based relationships to project-based relationships.
It migrates persona__user_folder to persona__user_file and populates project__user_file.
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy import text
import logging
logger = logging.getLogger("alembic.runtime.migration")
# revision identifiers, used by Alembic.
revision = "16c37a30adf2"
down_revision = "0cd424f32b1d"
branch_labels = None
depends_on = None
def upgrade() -> None:
"""Migrate folder-based relationships to project-based relationships."""
bind = op.get_bind()
inspector = sa.inspect(bind)
# === Step 1: Migrate persona__user_folder to persona__user_file ===
table_names = inspector.get_table_names()
if "persona__user_folder" in table_names and "user_file" in table_names:
user_file_columns = [col["name"] for col in inspector.get_columns("user_file")]
has_new_id = "new_id" in user_file_columns
if has_new_id and "folder_id" in user_file_columns:
logger.info(
"Migrating persona__user_folder relationships to persona__user_file..."
)
# Count relationships to migrate (asyncpg-compatible)
count_query = text(
"""
SELECT COUNT(*)
FROM (
SELECT DISTINCT puf.persona_id, uf.id
FROM persona__user_folder puf
JOIN user_file uf ON uf.folder_id = puf.user_folder_id
WHERE NOT EXISTS (
SELECT 1
FROM persona__user_file p2
WHERE p2.persona_id = puf.persona_id
AND p2.user_file_id = uf.id
)
) AS distinct_pairs
"""
)
to_migrate = bind.execute(count_query).scalar_one()
if to_migrate > 0:
logger.info(f"Creating {to_migrate} persona-file relationships...")
# Migrate in batches to avoid memory issues
batch_size = 10000
total_inserted = 0
while True:
# Insert batch directly using subquery (asyncpg compatible)
result = bind.execute(
text(
"""
INSERT INTO persona__user_file (persona_id, user_file_id, user_file_id_uuid)
SELECT DISTINCT puf.persona_id, uf.id as file_id, uf.new_id
FROM persona__user_folder puf
JOIN user_file uf ON uf.folder_id = puf.user_folder_id
WHERE NOT EXISTS (
SELECT 1
FROM persona__user_file p2
WHERE p2.persona_id = puf.persona_id
AND p2.user_file_id = uf.id
)
LIMIT :batch_size
"""
),
{"batch_size": batch_size},
)
inserted = result.rowcount
total_inserted += inserted
if inserted < batch_size:
break
logger.info(
f" Migrated {total_inserted}/{to_migrate} relationships..."
)
logger.info(
f"Created {total_inserted} persona__user_file relationships"
)
# === Step 2: Add foreign key for chat_session.project_id ===
chat_session_fks = inspector.get_foreign_keys("chat_session")
fk_exists = any(
fk["name"] == "fk_chat_session_project_id" for fk in chat_session_fks
)
if not fk_exists:
logger.info("Adding foreign key constraint for chat_session.project_id...")
op.create_foreign_key(
"fk_chat_session_project_id",
"chat_session",
"user_project",
["project_id"],
["id"],
)
logger.info("Added foreign key constraint")
# === Step 3: Populate project__user_file from user_file.folder_id ===
user_file_columns = [col["name"] for col in inspector.get_columns("user_file")]
has_new_id = "new_id" in user_file_columns
if has_new_id and "folder_id" in user_file_columns:
logger.info("Populating project__user_file from folder relationships...")
# Count relationships to create
count_query = text(
"""
SELECT COUNT(*)
FROM user_file uf
WHERE uf.folder_id IS NOT NULL
AND NOT EXISTS (
SELECT 1
FROM project__user_file puf
WHERE puf.project_id = uf.folder_id
AND puf.user_file_id = uf.new_id
)
"""
)
to_create = bind.execute(count_query).scalar_one()
if to_create > 0:
logger.info(f"Creating {to_create} project-file relationships...")
# Insert in batches
batch_size = 10000
total_inserted = 0
while True:
result = bind.execute(
text(
"""
INSERT INTO project__user_file (project_id, user_file_id)
SELECT uf.folder_id, uf.new_id
FROM user_file uf
WHERE uf.folder_id IS NOT NULL
AND NOT EXISTS (
SELECT 1
FROM project__user_file puf
WHERE puf.project_id = uf.folder_id
AND puf.user_file_id = uf.new_id
)
LIMIT :batch_size
ON CONFLICT (project_id, user_file_id) DO NOTHING
"""
),
{"batch_size": batch_size},
)
inserted = result.rowcount
total_inserted += inserted
if inserted < batch_size:
break
logger.info(f" Created {total_inserted}/{to_create} relationships...")
logger.info(f"Created {total_inserted} project__user_file relationships")
# === Step 4: Create index on chat_session.project_id ===
try:
indexes = [ix.get("name") for ix in inspector.get_indexes("chat_session")]
except Exception:
indexes = []
if "ix_chat_session_project_id" not in indexes:
logger.info("Creating index on chat_session.project_id...")
op.create_index(
"ix_chat_session_project_id", "chat_session", ["project_id"], unique=False
)
logger.info("Created index")
logger.info("Migration 3 (relationship migration) completed successfully")
def downgrade() -> None:
"""Remove migrated relationships and constraints."""
bind = op.get_bind()
inspector = sa.inspect(bind)
logger.info("Starting downgrade of relationship migration...")
# Drop index on chat_session.project_id
try:
indexes = [ix.get("name") for ix in inspector.get_indexes("chat_session")]
if "ix_chat_session_project_id" in indexes:
op.drop_index("ix_chat_session_project_id", "chat_session")
logger.info("Dropped index on chat_session.project_id")
except Exception:
pass
# Drop foreign key constraint
try:
chat_session_fks = inspector.get_foreign_keys("chat_session")
fk_exists = any(
fk["name"] == "fk_chat_session_project_id" for fk in chat_session_fks
)
if fk_exists:
op.drop_constraint(
"fk_chat_session_project_id", "chat_session", type_="foreignkey"
)
logger.info("Dropped foreign key constraint on chat_session.project_id")
except Exception:
pass
# Clear project__user_file relationships (but keep the table for migration 1 to handle)
if "project__user_file" in inspector.get_table_names():
result = bind.execute(text("DELETE FROM project__user_file"))
logger.info(f"Cleared {result.rowcount} records from project__user_file")
# Remove migrated persona__user_file relationships
# Only remove those that came from folder relationships
if all(
table in inspector.get_table_names()
for table in ["persona__user_file", "persona__user_folder", "user_file"]
):
user_file_columns = [col["name"] for col in inspector.get_columns("user_file")]
if "folder_id" in user_file_columns:
result = bind.execute(
text(
"""
DELETE FROM persona__user_file puf
WHERE EXISTS (
SELECT 1
FROM user_file uf
JOIN persona__user_folder puf2
ON puf2.user_folder_id = uf.folder_id
WHERE puf.persona_id = puf2.persona_id
AND puf.user_file_id = uf.id
)
"""
)
)
logger.info(
f"Removed {result.rowcount} migrated persona__user_file relationships"
)
logger.info("Downgrade completed successfully")

View File

@@ -1,218 +0,0 @@
"""Migration 6: User file schema cleanup
Revision ID: 2b75d0a8ffcb
Revises: 3a78dba1080a
Create Date: 2025-09-22 10:09:26.375377
This migration removes legacy columns and tables after data migration is complete.
It should only be run after verifying all data has been successfully migrated.
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy import text
import logging
logger = logging.getLogger("alembic.runtime.migration")
# revision identifiers, used by Alembic.
revision = "2b75d0a8ffcb"
down_revision = "3a78dba1080a"
branch_labels = None
depends_on = None
def upgrade() -> None:
"""Remove legacy columns and tables."""
bind = op.get_bind()
inspector = sa.inspect(bind)
logger.info("Starting schema cleanup...")
# === Step 1: Verify data migration is complete ===
logger.info("Verifying data migration completion...")
# Check if any chat sessions still have folder_id references
chat_session_columns = [
col["name"] for col in inspector.get_columns("chat_session")
]
if "folder_id" in chat_session_columns:
orphaned_count = bind.execute(
text(
"""
SELECT COUNT(*) FROM chat_session
WHERE folder_id IS NOT NULL AND project_id IS NULL
"""
)
).scalar_one()
if orphaned_count > 0:
logger.warning(
f"WARNING: {orphaned_count} chat_session records still have "
f"folder_id without project_id. Proceeding anyway."
)
# === Step 2: Drop chat_session.folder_id ===
if "folder_id" in chat_session_columns:
logger.info("Dropping chat_session.folder_id...")
# Drop foreign key constraint first
op.execute(
"ALTER TABLE chat_session DROP CONSTRAINT IF EXISTS chat_session_folder_fk"
)
# Drop the column
op.drop_column("chat_session", "folder_id")
logger.info("Dropped chat_session.folder_id")
# === Step 3: Drop persona__user_folder table ===
if "persona__user_folder" in inspector.get_table_names():
logger.info("Dropping persona__user_folder table...")
# Check for any remaining data
remaining = bind.execute(
text("SELECT COUNT(*) FROM persona__user_folder")
).scalar_one()
if remaining > 0:
logger.warning(
f"WARNING: Dropping persona__user_folder with {remaining} records"
)
op.drop_table("persona__user_folder")
logger.info("Dropped persona__user_folder table")
# === Step 4: Drop chat_folder table ===
if "chat_folder" in inspector.get_table_names():
logger.info("Dropping chat_folder table...")
# Check for any remaining data
remaining = bind.execute(text("SELECT COUNT(*) FROM chat_folder")).scalar_one()
if remaining > 0:
logger.warning(f"WARNING: Dropping chat_folder with {remaining} records")
op.drop_table("chat_folder")
logger.info("Dropped chat_folder table")
# === Step 5: Drop user_file legacy columns ===
user_file_columns = [col["name"] for col in inspector.get_columns("user_file")]
# Drop folder_id
if "folder_id" in user_file_columns:
logger.info("Dropping user_file.folder_id...")
op.drop_column("user_file", "folder_id")
logger.info("Dropped user_file.folder_id")
# Drop cc_pair_id (already handled in migration 5, but be sure)
if "cc_pair_id" in user_file_columns:
logger.info("Dropping user_file.cc_pair_id...")
# Drop any remaining foreign key constraints
bind.execute(
text(
"""
DO $$
DECLARE r RECORD;
BEGIN
FOR r IN (
SELECT conname
FROM pg_constraint c
JOIN pg_class t ON c.conrelid = t.oid
WHERE c.contype = 'f'
AND t.relname = 'user_file'
AND EXISTS (
SELECT 1 FROM pg_attribute a
WHERE a.attrelid = t.oid
AND a.attname = 'cc_pair_id'
)
) LOOP
EXECUTE format('ALTER TABLE user_file DROP CONSTRAINT IF EXISTS %I', r.conname);
END LOOP;
END$$;
"""
)
)
op.drop_column("user_file", "cc_pair_id")
logger.info("Dropped user_file.cc_pair_id")
# === Step 6: Clean up any remaining constraints ===
logger.info("Cleaning up remaining constraints...")
# Drop any unique constraints on removed columns
op.execute(
"ALTER TABLE user_file DROP CONSTRAINT IF EXISTS user_file_cc_pair_id_key"
)
logger.info("Migration 6 (schema cleanup) completed successfully")
logger.info("Legacy schema has been fully removed")
def downgrade() -> None:
"""Recreate dropped columns and tables (structure only, no data)."""
bind = op.get_bind()
inspector = sa.inspect(bind)
logger.warning("Downgrading schema cleanup - recreating structure only, no data!")
# Recreate user_file columns
if "user_file" in inspector.get_table_names():
columns = [col["name"] for col in inspector.get_columns("user_file")]
if "cc_pair_id" not in columns:
op.add_column(
"user_file", sa.Column("cc_pair_id", sa.Integer(), nullable=True)
)
if "folder_id" not in columns:
op.add_column(
"user_file", sa.Column("folder_id", sa.Integer(), nullable=True)
)
# Recreate chat_folder table
if "chat_folder" not in inspector.get_table_names():
op.create_table(
"chat_folder",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.UUID(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
["user_id"], ["user.id"], name="chat_folder_user_fk"
),
)
# Recreate persona__user_folder table
if "persona__user_folder" not in inspector.get_table_names():
op.create_table(
"persona__user_folder",
sa.Column("persona_id", sa.Integer(), nullable=False),
sa.Column("user_folder_id", sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint("persona_id", "user_folder_id"),
sa.ForeignKeyConstraint(["persona_id"], ["persona.id"]),
sa.ForeignKeyConstraint(["user_folder_id"], ["user_project.id"]),
)
# Add folder_id back to chat_session
if "chat_session" in inspector.get_table_names():
columns = [col["name"] for col in inspector.get_columns("chat_session")]
if "folder_id" not in columns:
op.add_column(
"chat_session", sa.Column("folder_id", sa.Integer(), nullable=True)
)
# Add foreign key if chat_folder exists
if "chat_folder" in inspector.get_table_names():
op.create_foreign_key(
"chat_session_folder_fk",
"chat_session",
"chat_folder",
["folder_id"],
["id"],
)
logger.info("Downgrade completed - structure recreated but data is lost")

View File

@@ -1,298 +0,0 @@
"""Migration 5: User file legacy data cleanup
Revision ID: 3a78dba1080a
Revises: 7cc3fcc116c1
Create Date: 2025-09-22 10:04:27.986294
This migration removes legacy user-file documents and connector_credential_pairs.
It performs bulk deletions of obsolete data after the UUID migration.
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql as psql
from sqlalchemy import text
import logging
from typing import List
import uuid
logger = logging.getLogger("alembic.runtime.migration")
# revision identifiers, used by Alembic.
revision = "3a78dba1080a"
down_revision = "7cc3fcc116c1"
branch_labels = None
depends_on = None
def batch_delete(
bind: sa.engine.Connection,
table_name: str,
id_column: str,
ids: List[str | int | uuid.UUID],
batch_size: int = 1000,
id_type: str = "int",
) -> int:
"""Delete records in batches to avoid memory issues and timeouts."""
total_count = len(ids)
if total_count == 0:
return 0
logger.info(
f"Starting batch deletion of {total_count} records from {table_name}..."
)
# Determine appropriate ARRAY type
if id_type == "uuid":
array_type = psql.ARRAY(psql.UUID(as_uuid=True))
elif id_type == "int":
array_type = psql.ARRAY(sa.Integer())
else:
array_type = psql.ARRAY(sa.String())
total_deleted = 0
failed_batches = []
for i in range(0, total_count, batch_size):
batch_ids = ids[i : i + batch_size]
try:
stmt = text(
f"DELETE FROM {table_name} WHERE {id_column} = ANY(:ids)"
).bindparams(sa.bindparam("ids", value=batch_ids, type_=array_type))
result = bind.execute(stmt)
total_deleted += result.rowcount
# Log progress every 10 batches or at completion
batch_num = (i // batch_size) + 1
if batch_num % 10 == 0 or i + batch_size >= total_count:
logger.info(
f" Deleted {min(i + batch_size, total_count)}/{total_count} records "
f"({total_deleted} actual) from {table_name}"
)
except Exception as e:
logger.error(f"Failed to delete batch {(i // batch_size) + 1}: {e}")
failed_batches.append((i, min(i + batch_size, total_count)))
if failed_batches:
logger.warning(
f"Failed to delete {len(failed_batches)} batches from {table_name}. "
f"Total deleted: {total_deleted}/{total_count}"
)
# Fail the migration to avoid silently succeeding on partial cleanup
raise RuntimeError(
f"Batch deletion failed for {table_name}: "
f"{len(failed_batches)} failed batches out of "
f"{(total_count + batch_size - 1) // batch_size}."
)
return total_deleted
def upgrade() -> None:
"""Remove legacy user-file documents and connector_credential_pairs."""
bind = op.get_bind()
inspector = sa.inspect(bind)
logger.info("Starting legacy data cleanup...")
# === Step 1: Identify and delete user-file documents ===
logger.info("Identifying user-file documents to delete...")
# Get document IDs to delete
doc_rows = bind.execute(
text(
"""
SELECT DISTINCT dcc.id AS document_id
FROM document_by_connector_credential_pair dcc
JOIN connector_credential_pair u
ON u.connector_id = dcc.connector_id
AND u.credential_id = dcc.credential_id
WHERE u.is_user_file IS TRUE
"""
)
).fetchall()
doc_ids = [r[0] for r in doc_rows]
if doc_ids:
logger.info(f"Found {len(doc_ids)} user-file documents to delete")
# Delete dependent rows first
tables_to_clean = [
("document_retrieval_feedback", "document_id"),
("document__tag", "document_id"),
("chunk_stats", "document_id"),
]
for table_name, column_name in tables_to_clean:
if table_name in inspector.get_table_names():
# document_id is a string in these tables
deleted = batch_delete(
bind, table_name, column_name, doc_ids, id_type="str"
)
logger.info(f"Deleted {deleted} records from {table_name}")
# Delete document_by_connector_credential_pair entries
deleted = batch_delete(
bind, "document_by_connector_credential_pair", "id", doc_ids, id_type="str"
)
logger.info(f"Deleted {deleted} document_by_connector_credential_pair records")
# Delete documents themselves
deleted = batch_delete(bind, "document", "id", doc_ids, id_type="str")
logger.info(f"Deleted {deleted} document records")
else:
logger.info("No user-file documents found to delete")
# === Step 2: Clean up user-file connector_credential_pairs ===
logger.info("Cleaning up user-file connector_credential_pairs...")
# Get cc_pair IDs
cc_pair_rows = bind.execute(
text(
"""
SELECT id AS cc_pair_id
FROM connector_credential_pair
WHERE is_user_file IS TRUE
"""
)
).fetchall()
cc_pair_ids = [r[0] for r in cc_pair_rows]
if cc_pair_ids:
logger.info(
f"Found {len(cc_pair_ids)} user-file connector_credential_pairs to clean up"
)
# Delete related records
# Clean child tables first to satisfy foreign key constraints,
# then the parent tables
tables_to_clean = [
("index_attempt_errors", "connector_credential_pair_id"),
("index_attempt", "connector_credential_pair_id"),
("background_error", "cc_pair_id"),
("document_set__connector_credential_pair", "connector_credential_pair_id"),
("user_group__connector_credential_pair", "cc_pair_id"),
]
for table_name, column_name in tables_to_clean:
if table_name in inspector.get_table_names():
deleted = batch_delete(
bind, table_name, column_name, cc_pair_ids, id_type="int"
)
logger.info(f"Deleted {deleted} records from {table_name}")
# === Step 3: Identify connectors and credentials to delete ===
logger.info("Identifying orphaned connectors and credentials...")
# Get connectors used only by user-file cc_pairs
connector_rows = bind.execute(
text(
"""
SELECT DISTINCT ccp.connector_id
FROM connector_credential_pair ccp
WHERE ccp.is_user_file IS TRUE
AND ccp.connector_id != 0 -- Exclude system default
AND NOT EXISTS (
SELECT 1
FROM connector_credential_pair c2
WHERE c2.connector_id = ccp.connector_id
AND c2.is_user_file IS NOT TRUE
)
"""
)
).fetchall()
userfile_only_connector_ids = [r[0] for r in connector_rows]
# Get credentials used only by user-file cc_pairs
credential_rows = bind.execute(
text(
"""
SELECT DISTINCT ccp.credential_id
FROM connector_credential_pair ccp
WHERE ccp.is_user_file IS TRUE
AND ccp.credential_id != 0 -- Exclude public/default
AND NOT EXISTS (
SELECT 1
FROM connector_credential_pair c2
WHERE c2.credential_id = ccp.credential_id
AND c2.is_user_file IS NOT TRUE
)
"""
)
).fetchall()
userfile_only_credential_ids = [r[0] for r in credential_rows]
# === Step 4: Delete the cc_pairs themselves ===
if cc_pair_ids:
# Remove FK dependency from user_file first
bind.execute(
text(
"""
DO $$
DECLARE r RECORD;
BEGIN
FOR r IN (
SELECT conname
FROM pg_constraint c
JOIN pg_class t ON c.conrelid = t.oid
JOIN pg_class ft ON c.confrelid = ft.oid
WHERE c.contype = 'f'
AND t.relname = 'user_file'
AND ft.relname = 'connector_credential_pair'
) LOOP
EXECUTE format('ALTER TABLE user_file DROP CONSTRAINT IF EXISTS %I', r.conname);
END LOOP;
END$$;
"""
)
)
# Delete cc_pairs
deleted = batch_delete(
bind, "connector_credential_pair", "id", cc_pair_ids, id_type="int"
)
logger.info(f"Deleted {deleted} connector_credential_pair records")
# === Step 5: Delete orphaned connectors ===
if userfile_only_connector_ids:
deleted = batch_delete(
bind, "connector", "id", userfile_only_connector_ids, id_type="int"
)
logger.info(f"Deleted {deleted} orphaned connector records")
# === Step 6: Delete orphaned credentials ===
if userfile_only_credential_ids:
# Clean up credential__user_group mappings first
deleted = batch_delete(
bind,
"credential__user_group",
"credential_id",
userfile_only_credential_ids,
id_type="int",
)
logger.info(f"Deleted {deleted} credential__user_group records")
# Delete credentials
deleted = batch_delete(
bind, "credential", "id", userfile_only_credential_ids, id_type="int"
)
logger.info(f"Deleted {deleted} orphaned credential records")
logger.info("Migration 5 (legacy data cleanup) completed successfully")
def downgrade() -> None:
"""Cannot restore deleted data - requires backup restoration."""
logger.error("CRITICAL: Downgrading data cleanup cannot restore deleted data!")
logger.error("Data restoration requires backup files or database backup.")
raise NotImplementedError(
"Downgrade of legacy data cleanup is not supported. "
"Deleted data must be restored from backups."
)

View File

@@ -1,37 +0,0 @@
"""Add image input support to model config
Revision ID: 64bd5677aeb6
Revises: b30353be4eec
Create Date: 2025-09-28 15:48:12.003612
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "64bd5677aeb6"
down_revision = "b30353be4eec"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"model_configuration",
sa.Column("supports_image_input", sa.Boolean(), nullable=True),
)
# Seems to be left over from when model visibility was introduced and a nullable field.
# Set any null is_visible values to False
connection = op.get_bind()
connection.execute(
sa.text(
"UPDATE model_configuration SET is_visible = false WHERE is_visible IS NULL"
)
)
def downgrade() -> None:
op.drop_column("model_configuration", "supports_image_input")

View File

@@ -1,193 +0,0 @@
"""Migration 4: User file UUID primary key swap
Revision ID: 7cc3fcc116c1
Revises: 16c37a30adf2
Create Date: 2025-09-22 09:54:38.292952
This migration performs the critical UUID primary key swap on user_file table.
It updates all foreign key references to use UUIDs instead of integers.
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql as psql
import logging
logger = logging.getLogger("alembic.runtime.migration")
# revision identifiers, used by Alembic.
revision = "7cc3fcc116c1"
down_revision = "16c37a30adf2"
branch_labels = None
depends_on = None
def upgrade() -> None:
"""Swap user_file primary key from integer to UUID."""
bind = op.get_bind()
inspector = sa.inspect(bind)
# Verify we're in the expected state
user_file_columns = [col["name"] for col in inspector.get_columns("user_file")]
if "new_id" not in user_file_columns:
logger.warning(
"user_file.new_id not found - migration may have already been applied"
)
return
logger.info("Starting UUID primary key swap...")
# === Step 1: Update persona__user_file foreign key to UUID ===
logger.info("Updating persona__user_file foreign key...")
# Drop existing foreign key constraints
op.execute(
"ALTER TABLE persona__user_file DROP CONSTRAINT IF EXISTS persona__user_file_user_file_id_uuid_fkey"
)
op.execute(
"ALTER TABLE persona__user_file DROP CONSTRAINT IF EXISTS persona__user_file_user_file_id_fkey"
)
# Create new foreign key to user_file.new_id
op.create_foreign_key(
"persona__user_file_user_file_id_fkey",
"persona__user_file",
"user_file",
local_cols=["user_file_id_uuid"],
remote_cols=["new_id"],
)
# Drop the old integer column and rename UUID column
op.execute("ALTER TABLE persona__user_file DROP COLUMN IF EXISTS user_file_id")
op.alter_column(
"persona__user_file",
"user_file_id_uuid",
new_column_name="user_file_id",
existing_type=psql.UUID(as_uuid=True),
nullable=False,
)
# Recreate composite primary key
op.execute(
"ALTER TABLE persona__user_file DROP CONSTRAINT IF EXISTS persona__user_file_pkey"
)
op.execute(
"ALTER TABLE persona__user_file ADD PRIMARY KEY (persona_id, user_file_id)"
)
logger.info("Updated persona__user_file to use UUID foreign key")
# === Step 2: Perform the primary key swap on user_file ===
logger.info("Swapping user_file primary key to UUID...")
# Drop the primary key constraint
op.execute("ALTER TABLE user_file DROP CONSTRAINT IF EXISTS user_file_pkey")
# Drop the old id column and rename new_id to id
op.execute("ALTER TABLE user_file DROP COLUMN IF EXISTS id")
op.alter_column(
"user_file",
"new_id",
new_column_name="id",
existing_type=psql.UUID(as_uuid=True),
nullable=False,
)
# Set default for new inserts
op.alter_column(
"user_file",
"id",
existing_type=psql.UUID(as_uuid=True),
server_default=sa.text("gen_random_uuid()"),
)
# Create new primary key
op.execute("ALTER TABLE user_file ADD PRIMARY KEY (id)")
logger.info("Swapped user_file primary key to UUID")
# === Step 3: Update foreign key constraints ===
logger.info("Updating foreign key constraints...")
# Recreate persona__user_file foreign key to point to user_file.id
# Drop existing FK first to break dependency on the unique constraint
op.execute(
"ALTER TABLE persona__user_file DROP CONSTRAINT IF EXISTS persona__user_file_user_file_id_fkey"
)
# Drop the unique constraint on (formerly) new_id BEFORE recreating the FK,
# so the FK will bind to the primary key instead of the unique index.
op.execute("ALTER TABLE user_file DROP CONSTRAINT IF EXISTS uq_user_file_new_id")
# Now recreate FK to the primary key column
op.create_foreign_key(
"persona__user_file_user_file_id_fkey",
"persona__user_file",
"user_file",
local_cols=["user_file_id"],
remote_cols=["id"],
)
# Add foreign keys for project__user_file
existing_fks = inspector.get_foreign_keys("project__user_file")
has_user_file_fk = any(
fk.get("referred_table") == "user_file"
and fk.get("constrained_columns") == ["user_file_id"]
for fk in existing_fks
)
if not has_user_file_fk:
op.create_foreign_key(
"fk_project__user_file_user_file_id",
"project__user_file",
"user_file",
["user_file_id"],
["id"],
)
logger.info("Added project__user_file -> user_file foreign key")
has_project_fk = any(
fk.get("referred_table") == "user_project"
and fk.get("constrained_columns") == ["project_id"]
for fk in existing_fks
)
if not has_project_fk:
op.create_foreign_key(
"fk_project__user_file_project_id",
"project__user_file",
"user_project",
["project_id"],
["id"],
)
logger.info("Added project__user_file -> user_project foreign key")
# === Step 4: Mark files for document_id migration ===
logger.info("Marking files for background document_id migration...")
logger.info("Migration 4 (UUID primary key swap) completed successfully")
logger.info(
"NOTE: Background task will update document IDs in Vespa and search_doc"
)
def downgrade() -> None:
"""Revert UUID primary key back to integer (data destructive!)."""
logger.error("CRITICAL: Downgrading UUID primary key swap is data destructive!")
logger.error(
"This will break all UUID-based references created after the migration."
)
logger.error("Only proceed if absolutely necessary and have backups.")
# The downgrade would need to:
# 1. Add back integer columns
# 2. Generate new sequential IDs
# 3. Update all foreign key references
# 4. Swap primary keys back
# This is complex and risky, so we raise an error instead
raise NotImplementedError(
"Downgrade of UUID primary key swap is not supported due to data loss risk. "
"Manual intervention with data backup/restore is required."
)

View File

@@ -1,257 +0,0 @@
"""Migration 1: User file schema additions
Revision ID: 9b66d3156fc6
Revises: b4ef3ae0bf6e
Create Date: 2025-09-22 09:42:06.086732
This migration adds new columns and tables without modifying existing data.
It is safe to run and can be easily rolled back.
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql as psql
import logging
logger = logging.getLogger("alembic.runtime.migration")
# revision identifiers, used by Alembic.
revision = "9b66d3156fc6"
down_revision = "b4ef3ae0bf6e"
branch_labels = None
depends_on = None
def upgrade() -> None:
"""Add new columns and tables without modifying existing data."""
# Enable pgcrypto for UUID generation
op.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto")
bind = op.get_bind()
inspector = sa.inspect(bind)
# === USER_FILE: Add new columns ===
logger.info("Adding new columns to user_file table...")
user_file_columns = [col["name"] for col in inspector.get_columns("user_file")]
# Check if ID is already UUID (in case of re-run after partial migration)
id_is_uuid = any(
col["name"] == "id" and "uuid" in str(col["type"]).lower()
for col in inspector.get_columns("user_file")
)
# Add transitional UUID column only if ID is not already UUID
if "new_id" not in user_file_columns and not id_is_uuid:
op.add_column(
"user_file",
sa.Column(
"new_id",
psql.UUID(as_uuid=True),
nullable=True,
server_default=sa.text("gen_random_uuid()"),
),
)
op.create_unique_constraint("uq_user_file_new_id", "user_file", ["new_id"])
logger.info("Added new_id column to user_file")
# Add status column
if "status" not in user_file_columns:
op.add_column(
"user_file",
sa.Column(
"status",
sa.Enum(
"PROCESSING",
"COMPLETED",
"FAILED",
"CANCELED",
name="userfilestatus",
native_enum=False,
),
nullable=False,
server_default="PROCESSING",
),
)
logger.info("Added status column to user_file")
# Add other tracking columns
if "chunk_count" not in user_file_columns:
op.add_column(
"user_file", sa.Column("chunk_count", sa.Integer(), nullable=True)
)
logger.info("Added chunk_count column to user_file")
if "last_accessed_at" not in user_file_columns:
op.add_column(
"user_file",
sa.Column("last_accessed_at", sa.DateTime(timezone=True), nullable=True),
)
logger.info("Added last_accessed_at column to user_file")
if "needs_project_sync" not in user_file_columns:
op.add_column(
"user_file",
sa.Column(
"needs_project_sync",
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
),
)
logger.info("Added needs_project_sync column to user_file")
if "last_project_sync_at" not in user_file_columns:
op.add_column(
"user_file",
sa.Column(
"last_project_sync_at", sa.DateTime(timezone=True), nullable=True
),
)
logger.info("Added last_project_sync_at column to user_file")
if "document_id_migrated" not in user_file_columns:
op.add_column(
"user_file",
sa.Column(
"document_id_migrated",
sa.Boolean(),
nullable=False,
server_default=sa.text("true"),
),
)
logger.info("Added document_id_migrated column to user_file")
# === USER_FOLDER -> USER_PROJECT rename ===
table_names = set(inspector.get_table_names())
if "user_folder" in table_names:
logger.info("Updating user_folder table...")
# Make description nullable first
op.alter_column("user_folder", "description", nullable=True)
# Rename table if user_project doesn't exist
if "user_project" not in table_names:
op.execute("ALTER TABLE user_folder RENAME TO user_project")
logger.info("Renamed user_folder to user_project")
elif "user_project" in table_names:
# If already renamed, ensure column nullability
project_cols = [col["name"] for col in inspector.get_columns("user_project")]
if "description" in project_cols:
op.alter_column("user_project", "description", nullable=True)
# Add instructions column to user_project
inspector = sa.inspect(bind) # Refresh after rename
if "user_project" in inspector.get_table_names():
project_columns = [col["name"] for col in inspector.get_columns("user_project")]
if "instructions" not in project_columns:
op.add_column(
"user_project",
sa.Column("instructions", sa.String(), nullable=True),
)
logger.info("Added instructions column to user_project")
# === CHAT_SESSION: Add project_id ===
chat_session_columns = [
col["name"] for col in inspector.get_columns("chat_session")
]
if "project_id" not in chat_session_columns:
op.add_column(
"chat_session",
sa.Column("project_id", sa.Integer(), nullable=True),
)
logger.info("Added project_id column to chat_session")
# === PERSONA__USER_FILE: Add UUID column ===
persona_user_file_columns = [
col["name"] for col in inspector.get_columns("persona__user_file")
]
if "user_file_id_uuid" not in persona_user_file_columns:
op.add_column(
"persona__user_file",
sa.Column("user_file_id_uuid", psql.UUID(as_uuid=True), nullable=True),
)
logger.info("Added user_file_id_uuid column to persona__user_file")
# === PROJECT__USER_FILE: Create new table ===
if "project__user_file" not in inspector.get_table_names():
op.create_table(
"project__user_file",
sa.Column("project_id", sa.Integer(), nullable=False),
sa.Column("user_file_id", psql.UUID(as_uuid=True), nullable=False),
sa.PrimaryKeyConstraint("project_id", "user_file_id"),
)
op.create_index(
"idx_project__user_file_user_file_id",
"project__user_file",
["user_file_id"],
)
logger.info("Created project__user_file table")
logger.info("Migration 1 (schema additions) completed successfully")
def downgrade() -> None:
"""Remove added columns and tables."""
bind = op.get_bind()
inspector = sa.inspect(bind)
logger.info("Starting downgrade of schema additions...")
# Drop project__user_file table
if "project__user_file" in inspector.get_table_names():
op.drop_index("idx_project__user_file_user_file_id", "project__user_file")
op.drop_table("project__user_file")
logger.info("Dropped project__user_file table")
# Remove columns from persona__user_file
if "persona__user_file" in inspector.get_table_names():
columns = [col["name"] for col in inspector.get_columns("persona__user_file")]
if "user_file_id_uuid" in columns:
op.drop_column("persona__user_file", "user_file_id_uuid")
logger.info("Dropped user_file_id_uuid from persona__user_file")
# Remove columns from chat_session
if "chat_session" in inspector.get_table_names():
columns = [col["name"] for col in inspector.get_columns("chat_session")]
if "project_id" in columns:
op.drop_column("chat_session", "project_id")
logger.info("Dropped project_id from chat_session")
# Rename user_project back to user_folder and remove instructions
if "user_project" in inspector.get_table_names():
columns = [col["name"] for col in inspector.get_columns("user_project")]
if "instructions" in columns:
op.drop_column("user_project", "instructions")
op.execute("ALTER TABLE user_project RENAME TO user_folder")
op.alter_column("user_folder", "description", nullable=False)
logger.info("Renamed user_project back to user_folder")
# Remove columns from user_file
if "user_file" in inspector.get_table_names():
columns = [col["name"] for col in inspector.get_columns("user_file")]
columns_to_drop = [
"document_id_migrated",
"last_project_sync_at",
"needs_project_sync",
"last_accessed_at",
"chunk_count",
"status",
]
for col in columns_to_drop:
if col in columns:
op.drop_column("user_file", col)
logger.info(f"Dropped {col} from user_file")
if "new_id" in columns:
op.drop_constraint("uq_user_file_new_id", "user_file", type_="unique")
op.drop_column("user_file", "new_id")
logger.info("Dropped new_id from user_file")
# Drop enum type if no columns use it
bind.execute(sa.text("DROP TYPE IF EXISTS userfilestatus"))
logger.info("Downgrade completed successfully")

View File

@@ -1,123 +0,0 @@
"""add_mcp_auth_performer
Revision ID: b30353be4eec
Revises: 2b75d0a8ffcb
Create Date: 2025-09-13 14:58:08.413534
"""
from alembic import op
import sqlalchemy as sa
from onyx.db.enums import MCPAuthenticationPerformer, MCPTransport
# revision identifiers, used by Alembic.
revision = "b30353be4eec"
down_revision = "2b75d0a8ffcb"
branch_labels = None
depends_on = None
def upgrade() -> None:
"""moving to a better way of handling auth performer and transport"""
# Add nullable column first for backward compatibility
op.add_column(
"mcp_server",
sa.Column(
"auth_performer",
sa.Enum(MCPAuthenticationPerformer, native_enum=False),
nullable=True,
),
)
op.add_column(
"mcp_server",
sa.Column(
"transport",
sa.Enum(MCPTransport, native_enum=False),
nullable=True,
),
)
# # Backfill values using existing data and inference rules
bind = op.get_bind()
# 1) OAUTH servers are always PER_USER
bind.execute(
sa.text(
"""
UPDATE mcp_server
SET auth_performer = 'PER_USER'
WHERE auth_type = 'OAUTH'
"""
)
)
# 2) If there is no admin connection config, mark as ADMIN (and not set yet)
bind.execute(
sa.text(
"""
UPDATE mcp_server
SET auth_performer = 'ADMIN'
WHERE admin_connection_config_id IS NULL
AND auth_performer IS NULL
"""
)
)
# 3) If there exists any user-specific connection config (user_email != ''), mark as PER_USER
bind.execute(
sa.text(
"""
UPDATE mcp_server AS ms
SET auth_performer = 'PER_USER'
FROM mcp_connection_config AS mcc
WHERE mcc.mcp_server_id = ms.id
AND COALESCE(mcc.user_email, '') <> ''
AND ms.auth_performer IS NULL
"""
)
)
# 4) Default any remaining nulls to ADMIN (covers API_TOKEN admin-managed and NONE)
bind.execute(
sa.text(
"""
UPDATE mcp_server
SET auth_performer = 'ADMIN'
WHERE auth_performer IS NULL
"""
)
)
# Finally, make the column non-nullable
op.alter_column(
"mcp_server",
"auth_performer",
existing_type=sa.Enum(MCPAuthenticationPerformer, native_enum=False),
nullable=False,
)
# Backfill transport for existing rows to STREAMABLE_HTTP, then make non-nullable
bind.execute(
sa.text(
"""
UPDATE mcp_server
SET transport = 'STREAMABLE_HTTP'
WHERE transport IS NULL
"""
)
)
op.alter_column(
"mcp_server",
"transport",
existing_type=sa.Enum(MCPTransport, native_enum=False),
nullable=False,
)
def downgrade() -> None:
"""remove cols"""
op.drop_column("mcp_server", "transport")
op.drop_column("mcp_server", "auth_performer")

View File

@@ -124,9 +124,9 @@ def get_space_permission(
and not space_permissions.external_user_group_ids
):
logger.warning(
f"No permissions found for space '{space_key}'. This is very unlikely "
"to be correct and is more likely caused by an access token with "
"insufficient permissions. Make sure that the access token has Admin "
f"No permissions found for space '{space_key}'. This is very unlikely"
"to be correct and is more likely caused by an access token with"
"insufficient permissions. Make sure that the access token has Admin"
f"permissions for space '{space_key}'"
)

View File

@@ -26,7 +26,7 @@ def _get_slim_doc_generator(
else 0.0
)
return gmail_connector.retrieve_all_slim_docs_perm_sync(
return gmail_connector.retrieve_all_slim_documents(
start=start_time,
end=current_time.timestamp(),
callback=callback,

View File

@@ -34,7 +34,7 @@ def _get_slim_doc_generator(
else 0.0
)
return google_drive_connector.retrieve_all_slim_docs_perm_sync(
return google_drive_connector.retrieve_all_slim_documents(
start=start_time,
end=current_time.timestamp(),
callback=callback,

View File

@@ -59,7 +59,7 @@ def _build_holder_map(permissions: list[dict]) -> dict[str, list[Holder]]:
for raw_perm in permissions:
if not hasattr(raw_perm, "raw"):
logger.warning(f"Expected a 'raw' field, but none was found: {raw_perm=}")
logger.warn(f"Expected a 'raw' field, but none was found: {raw_perm=}")
continue
permission = Permission(**raw_perm.raw)
@@ -71,14 +71,14 @@ def _build_holder_map(permissions: list[dict]) -> dict[str, list[Holder]]:
# In order to associate this permission to some Atlassian entity, we need the "Holder".
# If this doesn't exist, then we cannot associate this permission to anyone; just skip.
if not permission.holder:
logger.warning(
logger.warn(
f"Expected to find a permission holder, but none was found: {permission=}"
)
continue
type = permission.holder.get("type")
if not type:
logger.warning(
logger.warn(
f"Expected to find the type of permission holder, but none was found: {permission=}"
)
continue

View File

@@ -105,9 +105,7 @@ def _get_slack_document_access(
channel_permissions: dict[str, ExternalAccess],
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
slim_doc_generator = slack_connector.retrieve_all_slim_docs_perm_sync(
callback=callback
)
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
for doc_metadata_batch in slim_doc_generator:
for doc_metadata in doc_metadata_batch:

View File

@@ -4,7 +4,7 @@ from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFun
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.interfaces import SlimConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -17,7 +17,7 @@ def generic_doc_sync(
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None,
doc_source: DocumentSource,
slim_connector: SlimConnectorWithPermSync,
slim_connector: SlimConnector,
label: str,
) -> Generator[DocExternalAccess, None, None]:
"""
@@ -40,7 +40,7 @@ def generic_doc_sync(
newly_fetched_doc_ids: set[str] = set()
logger.info(f"Fetching all slim documents from {doc_source}")
for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync(callback=callback):
for doc_batch in slim_connector.retrieve_all_slim_documents(callback=callback):
logger.info(f"Got {len(doc_batch)} slim documents from {doc_source}")
if callback:

View File

@@ -16,7 +16,6 @@ EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [
# saml
("/auth/saml/authorize", {"GET"}),
("/auth/saml/callback", {"POST"}),
("/auth/saml/callback", {"GET"}),
("/auth/saml/logout", {"POST"}),
]

View File

@@ -182,6 +182,7 @@ def admin_get_chat_sessions(
time_created=chat.time_created.isoformat(),
time_updated=chat.time_updated.isoformat(),
shared_status=chat.shared_status,
folder_id=chat.folder_id,
current_alternate_model=chat.current_alternate_model,
)
for chat in chat_sessions

View File

@@ -110,6 +110,7 @@ async def upsert_saml_user(email: str) -> User:
async def prepare_from_fastapi_request(request: Request) -> dict[str, Any]:
form_data = await request.form()
if request.client is None:
raise ValueError("Invalid request for SAML")
@@ -124,27 +125,14 @@ async def prepare_from_fastapi_request(request: Request) -> dict[str, Any]:
"post_data": {},
"get_data": {},
}
# Handle query parameters (for GET requests)
if request.query_params:
rv["get_data"] = dict(request.query_params)
# Handle form data (for POST requests)
if request.method == "POST":
form_data = await request.form()
if "SAMLResponse" in form_data:
SAMLResponse = form_data["SAMLResponse"]
rv["post_data"]["SAMLResponse"] = SAMLResponse
if "RelayState" in form_data:
RelayState = form_data["RelayState"]
rv["post_data"]["RelayState"] = RelayState
else:
# For GET requests, check if SAMLResponse is in query params
if "SAMLResponse" in request.query_params:
rv["get_data"]["SAMLResponse"] = request.query_params["SAMLResponse"]
if "RelayState" in request.query_params:
rv["get_data"]["RelayState"] = request.query_params["RelayState"]
rv["get_data"] = (request.query_params,)
if "SAMLResponse" in form_data:
SAMLResponse = form_data["SAMLResponse"]
rv["post_data"]["SAMLResponse"] = SAMLResponse
if "RelayState" in form_data:
RelayState = form_data["RelayState"]
rv["post_data"]["RelayState"] = RelayState
return rv
@@ -160,27 +148,10 @@ async def saml_login(request: Request) -> SAMLAuthorizeResponse:
return SAMLAuthorizeResponse(authorization_url=callback_url)
@router.get("/callback")
async def saml_login_callback_get(
request: Request,
db_session: Session = Depends(get_session),
) -> Response:
"""Handle SAML callback via HTTP-Redirect binding (GET request)"""
return await _process_saml_callback(request, db_session)
@router.post("/callback")
async def saml_login_callback(
request: Request,
db_session: Session = Depends(get_session),
) -> Response:
"""Handle SAML callback via HTTP-POST binding (POST request)"""
return await _process_saml_callback(request, db_session)
async def _process_saml_callback(
request: Request,
db_session: Session,
) -> Response:
req = await prepare_from_fastapi_request(request)
auth = OneLogin_Saml2_Auth(req, custom_base_path=SAML_CONF_DIR)

View File

@@ -6,6 +6,7 @@ from typing import Optional
from fastapi import APIRouter
from fastapi import HTTPException
from fastapi import Request
from litellm.exceptions import RateLimitError
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
@@ -206,8 +207,6 @@ async def route_bi_encoder_embed(
async def process_embed_request(
embed_request: EmbedRequest, gpu_type: str = "UNKNOWN"
) -> EmbedResponse:
from litellm.exceptions import RateLimitError
# Only local models should use this endpoint - API providers should make direct API calls
if embed_request.provider_type is not None:
raise ValueError(

View File

@@ -1,7 +1,6 @@
from collections.abc import Callable
from typing import cast
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from onyx.access.models import DocumentAccess
@@ -11,7 +10,6 @@ from onyx.configs.constants import PUBLIC_DOC_PAT
from onyx.db.document import get_access_info_for_document
from onyx.db.document import get_access_info_for_documents
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from onyx.utils.variable_functionality import fetch_versioned_implementation
@@ -126,25 +124,3 @@ def source_should_fetch_permissions_during_indexing(source: DocumentSource) -> b
),
)
return _source_should_fetch_permissions_during_indexing_func(source)
def get_access_for_user_files(
user_file_ids: list[str],
db_session: Session,
) -> dict[str, DocumentAccess]:
user_files = (
db_session.query(UserFile)
.options(joinedload(UserFile.user)) # Eager load the user relationship
.filter(UserFile.id.in_(user_file_ids))
.all()
)
return {
str(user_file.id): DocumentAccess.build(
user_emails=[user_file.user.email] if user_file.user else [],
user_groups=[],
is_public=True if user_file.user is None else False,
external_user_emails=[],
external_user_group_ids=[],
)
for user_file in user_files
}

View File

@@ -24,8 +24,6 @@ def decision_router(state: MainState) -> list[Send | Hashable] | DRPath | str:
return END
elif next_tool_name == DRPath.LOGGER.value:
return DRPath.LOGGER
elif next_tool_name == DRPath.CLOSER.value:
return DRPath.CLOSER
else:
return DRPath.ORCHESTRATOR

View File

@@ -35,24 +35,14 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
from onyx.agents.agent_search.shared_graph_utils.utils import run_with_timeout
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.agents.agent_search.utils import create_question_prompt
from onyx.chat.chat_utils import build_citation_map_from_numbers
from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.citations_prompt import build_citations_system_message
from onyx.chat.prompt_builder.citations_prompt import build_citations_user_message
from onyx.chat.stream_processing.citation_processing import (
normalize_square_bracket_citations_to_double_with_links,
)
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import DocumentSourceDescription
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
from onyx.db.chat import create_search_doc_from_saved_search_doc
from onyx.db.chat import update_db_session_with_messages
from onyx.db.connector import fetch_unique_document_sources
from onyx.db.kg_config import get_kg_config_settings
from onyx.db.models import SearchDoc
from onyx.db.models import Tool
from onyx.db.tools import get_tools
from onyx.file_store.models import ChatFileType
@@ -62,7 +52,6 @@ from onyx.kg.utils.extraction_utils import get_relationship_types_str
from onyx.llm.utils import check_number_of_tokens
from onyx.llm.utils import get_max_input_tokens
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.prompts.chat_prompts import PROJECT_INSTRUCTIONS_SEPARATOR
from onyx.prompts.dr_prompts import ANSWER_PROMPT_WO_TOOL_CALLING
from onyx.prompts.dr_prompts import DECISION_PROMPT_W_TOOL_CALLING
from onyx.prompts.dr_prompts import DECISION_PROMPT_WO_TOOL_CALLING
@@ -321,52 +310,6 @@ def _get_existing_clarification_request(
return clarification, original_question, chat_history_string
def _persist_final_docs_and_citations(
db_session: Session,
context_llm_docs: list[Any] | None,
full_answer: str | None,
) -> tuple[list[SearchDoc], dict[int, int] | None]:
"""Persist final documents from in-context docs and derive citation mapping.
Returns the list of persisted `SearchDoc` records and an optional
citation map translating inline [[n]] references to DB doc indices.
"""
final_documents_db: list[SearchDoc] = []
citations_map: dict[int, int] | None = None
if not context_llm_docs:
return final_documents_db, citations_map
saved_search_docs = saved_search_docs_from_llm_docs(context_llm_docs)
for saved_doc in saved_search_docs:
db_doc = create_search_doc_from_saved_search_doc(saved_doc)
db_session.add(db_doc)
final_documents_db.append(db_doc)
db_session.flush()
cited_numbers: set[int] = set()
try:
# Match [[1]] or [[1, 2]] optionally followed by a link like ([[1]](http...))
matches = re.findall(
r"\[\[(\d+(?:,\s*\d+)*)\]\](?:\([^)]*\))?", full_answer or ""
)
for match in matches:
for num_str in match.split(","):
num = int(num_str.strip())
cited_numbers.add(num)
except Exception:
cited_numbers = set()
if cited_numbers and final_documents_db:
translations = build_citation_map_from_numbers(
cited_numbers=cited_numbers,
db_docs=final_documents_db,
)
citations_map = translations or None
return final_documents_db, citations_map
_ARTIFICIAL_ALL_ENCOMPASSING_TOOL = {
"type": "function",
"function": {
@@ -478,13 +421,6 @@ def clarifier(
assistant_system_prompt = PromptTemplate(DEFAULT_DR_SYSTEM_PROMPT).build()
assistant_task_prompt = ""
if graph_config.inputs.project_instructions:
assistant_system_prompt = (
assistant_system_prompt
+ PROJECT_INSTRUCTIONS_SEPARATOR
+ graph_config.inputs.project_instructions
)
chat_history_string = (
get_chat_history_string(
graph_config.inputs.prompt_builder.message_history,
@@ -513,11 +449,6 @@ def clarifier(
graph_config.inputs.files
)
# Use project/search context docs if available to enable citation mapping
context_llm_docs = getattr(
graph_config.inputs.prompt_builder, "context_llm_docs", None
)
if not (force_use_tool and force_use_tool.force_use):
if not use_tool_calling_llm or len(available_tools) == 1:
@@ -632,44 +563,10 @@ def clarifier(
active_source_type_descriptions_str=active_source_type_descriptions_str,
)
if context_llm_docs:
persona = graph_config.inputs.persona
if persona is not None:
prompt_config = PromptConfig.from_model(persona)
else:
prompt_config = PromptConfig(
system_prompt=assistant_system_prompt,
task_prompt="",
datetime_aware=True,
)
system_prompt_to_use_content = build_citations_system_message(
prompt_config
).content
system_prompt_to_use: str = cast(str, system_prompt_to_use_content)
if graph_config.inputs.project_instructions:
system_prompt_to_use = (
system_prompt_to_use
+ PROJECT_INSTRUCTIONS_SEPARATOR
+ graph_config.inputs.project_instructions
)
user_prompt_to_use = build_citations_user_message(
user_query=original_question,
files=[],
prompt_config=prompt_config,
context_docs=context_llm_docs,
all_doc_useful=False,
history_message=chat_history_string,
context_type="user files",
).content
else:
system_prompt_to_use = assistant_system_prompt
user_prompt_to_use = decision_prompt + assistant_task_prompt
stream = graph_config.tooling.primary_llm.stream(
prompt=create_question_prompt(
cast(str, system_prompt_to_use),
cast(str, user_prompt_to_use),
assistant_system_prompt,
decision_prompt + assistant_task_prompt,
uploaded_image_context=uploaded_image_context,
),
tools=([_ARTIFICIAL_ALL_ENCOMPASSING_TOOL]),
@@ -682,8 +579,6 @@ def clarifier(
should_stream_answer=True,
writer=writer,
ind=0,
final_search_results=context_llm_docs,
displayed_search_results=context_llm_docs,
generate_final_answer=True,
chat_message_id=str(graph_config.persistence.chat_session_id),
)
@@ -691,32 +586,19 @@ def clarifier(
if len(full_response.ai_message_chunk.tool_calls) == 0:
if isinstance(full_response.full_answer, str):
full_answer = (
normalize_square_bracket_citations_to_double_with_links(
full_response.full_answer
)
)
full_answer = full_response.full_answer
else:
full_answer = None
# Persist final documents and derive citations when using in-context docs
final_documents_db, citations_map = _persist_final_docs_and_citations(
db_session=db_session,
context_llm_docs=context_llm_docs,
full_answer=full_answer,
)
update_db_session_with_messages(
db_session=db_session,
chat_message_id=message_id,
chat_session_id=graph_config.persistence.chat_session_id,
is_agentic=graph_config.behavior.use_agentic_search,
message=full_answer,
token_count=len(llm_tokenizer.encode(full_answer or "")),
citations=citations_map,
final_documents=final_documents_db or None,
update_parent_message=True,
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
token_count=len(llm_tokenizer.encode(full_answer or "")),
)
db_session.commit()

View File

@@ -181,15 +181,6 @@ def orchestrator(
remaining_time_budget = DR_TIME_BUDGET_BY_TYPE[research_type]
elif remaining_time_budget <= 0:
write_custom_event(
current_step_nr,
SectionEnd(),
writer,
)
current_step_nr += 1
return OrchestrationUpdate(
tools_used=[DRPath.CLOSER.value],
current_step_nr=current_step_nr,

View File

@@ -42,7 +42,6 @@ from onyx.db.models import ResearchAgentIteration
from onyx.db.models import ResearchAgentIterationSubStep
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.llm.utils import check_number_of_tokens
from onyx.prompts.chat_prompts import PROJECT_INSTRUCTIONS_SEPARATOR
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
from onyx.prompts.dr_prompts import TEST_INFO_COMPLETE_PROMPT
@@ -226,7 +225,7 @@ def closer(
research_type = graph_config.behavior.research_type
assistant_system_prompt: str = state.assistant_system_prompt or ""
assistant_system_prompt = state.assistant_system_prompt
assistant_task_prompt = state.assistant_task_prompt
uploaded_context = state.uploaded_test_context or ""
@@ -350,13 +349,6 @@ def closer(
uploaded_context=uploaded_context,
)
if graph_config.inputs.project_instructions:
assistant_system_prompt = (
assistant_system_prompt
+ PROJECT_INSTRUCTIONS_SEPARATOR
+ (graph_config.inputs.project_instructions or "")
)
all_context_llmdocs = [
llm_doc_from_inference_section(inference_section)
for inference_section in all_cited_documents

View File

@@ -9,7 +9,6 @@ from pydantic import BaseModel
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import CitationInfo
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
@@ -19,8 +18,6 @@ from onyx.chat.stream_processing.answer_response_handler import (
)
from onyx.chat.stream_processing.utils import map_document_id_order
from onyx.context.search.models import InferenceSection
from onyx.server.query_and_chat.streaming_models import CitationDelta
from onyx.server.query_and_chat.streaming_models import CitationStart
from onyx.server.query_and_chat.streaming_models import MessageDelta
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
@@ -59,9 +56,6 @@ def process_llm_stream(
full_answer = ""
start_final_answer_streaming_set = False
# Accumulate citation infos if handler emits them
collected_citation_infos: list[CitationInfo] = []
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
# the stream will contain AIMessageChunks with tool call information.
for message in messages:
@@ -108,9 +102,6 @@ def process_llm_stream(
MessageDelta(content=response_part.answer_piece),
writer,
)
# collect citation info objects
elif isinstance(response_part, CitationInfo):
collected_citation_infos.append(response_part)
if generate_final_answer and start_final_answer_streaming_set:
# start_final_answer_streaming_set is only set if the answer is verbal and not a tool call
@@ -120,14 +111,6 @@ def process_llm_stream(
writer,
)
# Emit citations section if any were collected
if collected_citation_infos:
write_custom_event(ind, CitationStart(), writer)
write_custom_event(
ind, CitationDelta(citations=collected_citation_infos), writer
)
write_custom_event(ind, SectionEnd(), writer)
logger.debug(f"Full answer: {full_answer}")
return BasicSearchProcessedStreamResults(
ai_message_chunk=cast(AIMessageChunk, tool_call_chunk), full_answer=full_answer

View File

@@ -1,7 +1,6 @@
import re
from datetime import datetime
from typing import cast
from uuid import UUID
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
@@ -74,7 +73,6 @@ def basic_search(
search_tool_info = state.available_tools[state.tools_used[-1]]
search_tool = cast(SearchTool, search_tool_info.tool_object)
force_use_tool = graph_config.tooling.force_use_tool
# sanity check
if search_tool != graph_config.tooling.search_tool:
@@ -143,15 +141,6 @@ def basic_search(
retrieved_docs: list[InferenceSection] = []
callback_container: list[list[InferenceSection]] = []
user_file_ids: list[UUID] | None = None
project_id: int | None = None
if force_use_tool.override_kwargs and isinstance(
force_use_tool.override_kwargs, SearchToolOverrideKwargs
):
override_kwargs = force_use_tool.override_kwargs
user_file_ids = override_kwargs.user_file_ids
project_id = override_kwargs.project_id
# new db session to avoid concurrency issues
with get_session_with_current_tenant() as search_db_session:
for tool_response in search_tool.run(
@@ -164,8 +153,6 @@ def basic_search(
retrieved_sections_callback=callback_container.append,
skip_query_analysis=True,
original_query=rewritten_query,
user_file_ids=user_file_ids,
project_id=project_id,
),
):
# get retrieved docs to send to the rest of the graph

View File

@@ -5,12 +5,12 @@ from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
from onyx.agents.agent_search.dr.utils import chunks_or_sections_to_search_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.context.search.models import SavedSearchDoc
from onyx.context.search.models import SearchDoc
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.utils.logger import setup_logger
@@ -47,7 +47,7 @@ def is_reducer(
doc_list.append(x)
# Convert InferenceSections to SavedSearchDocs
search_docs = SearchDoc.from_chunks_or_sections(doc_list)
search_docs = chunks_or_sections_to_search_docs(doc_list)
retrieved_saved_search_docs = [
SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
for search_doc in search_docs

View File

@@ -1,147 +0,0 @@
import json
from concurrent.futures import ThreadPoolExecutor
import requests
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetContent,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetSearchProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetSearchResult,
)
from onyx.configs.chat_configs import SERPER_API_KEY
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.utils.retry_wrapper import retry_builder
SERPER_SEARCH_URL = "https://google.serper.dev/search"
SERPER_CONTENTS_URL = "https://scrape.serper.dev"
class SerperClient(InternetSearchProvider):
def __init__(self, api_key: str | None = SERPER_API_KEY) -> None:
self.headers = {
"X-API-KEY": api_key,
"Content-Type": "application/json",
}
@retry_builder(tries=3, delay=1, backoff=2)
def search(self, query: str) -> list[InternetSearchResult]:
payload = {
"q": query,
}
response = requests.post(
SERPER_SEARCH_URL,
headers=self.headers,
data=json.dumps(payload),
)
response.raise_for_status()
results = response.json()
organic_results = results["organic"]
return [
InternetSearchResult(
title=result["title"],
link=result["link"],
snippet=result["snippet"],
author=None,
published_date=None,
)
for result in organic_results
]
def contents(self, urls: list[str]) -> list[InternetContent]:
if not urls:
return []
# Serper can responds with 500s regularly. We want to retry,
# but in the event of failure, return an unsuccesful scrape.
def safe_get_webpage_content(url: str) -> InternetContent:
try:
return self._get_webpage_content(url)
except Exception:
return InternetContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
with ThreadPoolExecutor(max_workers=min(8, len(urls))) as e:
return list(e.map(safe_get_webpage_content, urls))
@retry_builder(tries=3, delay=1, backoff=2)
def _get_webpage_content(self, url: str) -> InternetContent:
payload = {
"url": url,
}
response = requests.post(
SERPER_CONTENTS_URL,
headers=self.headers,
data=json.dumps(payload),
)
# 400 returned when serper cannot scrape
if response.status_code == 400:
return InternetContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
response.raise_for_status()
response_json = response.json()
# Response only guarantees text
text = response_json["text"]
# metadata & jsonld is not guaranteed to be present
metadata = response_json.get("metadata", {})
jsonld = response_json.get("jsonld", {})
title = extract_title_from_metadata(metadata)
# Serper does not provide a reliable mechanism to extract the url
response_url = url
published_date_str = extract_published_date_from_jsonld(jsonld)
published_date = None
if published_date_str:
try:
published_date = time_str_to_utc(published_date_str)
except Exception:
published_date = None
return InternetContent(
title=title or "",
link=response_url,
full_content=text or "",
published_date=published_date,
)
def extract_title_from_metadata(metadata: dict[str, str]) -> str | None:
keys = ["title", "og:title"]
return extract_value_from_dict(metadata, keys)
def extract_published_date_from_jsonld(jsonld: dict[str, str]) -> str | None:
keys = ["dateModified"]
return extract_value_from_dict(jsonld, keys)
def extract_value_from_dict(data: dict[str, str], keys: list[str]) -> str | None:
for key in keys:
if key in data:
return data[key]
return None

View File

@@ -26,7 +26,6 @@ class InternetContent(BaseModel):
link: str
full_content: str
published_date: datetime | None = None
scrape_successful: bool = True
class InternetSearchProvider(ABC):

View File

@@ -1,19 +1,13 @@
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.exa_client import (
ExaClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.clients.serper_client import (
SerperClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetSearchProvider,
)
from onyx.configs.chat_configs import EXA_API_KEY
from onyx.configs.chat_configs import SERPER_API_KEY
def get_default_provider() -> InternetSearchProvider | None:
if EXA_API_KEY:
return ExaClient()
if SERPER_API_KEY:
return SerperClient()
return None

View File

@@ -34,7 +34,7 @@ def dummy_inference_section_from_internet_content(
boost=1,
recency_bias=1.0,
score=1.0,
hidden=(not result.scrape_successful),
hidden=False,
metadata={},
match_highlights=[],
doc_summary=truncated_content,

View File

@@ -13,7 +13,7 @@ from onyx.agents.agent_search.shared_graph_utils.operators import (
)
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import SavedSearchDoc
from onyx.context.search.models import SearchDoc
from onyx.context.search.utils import chunks_or_sections_to_search_docs
from onyx.tools.tool_implementations.web_search.web_search_tool import (
WebSearchTool,
)
@@ -266,7 +266,7 @@ def convert_inference_sections_to_search_docs(
is_internet: bool = False,
) -> list[SavedSearchDoc]:
# Convert InferenceSections to SavedSearchDocs
search_docs = SearchDoc.from_chunks_or_sections(inference_sections)
search_docs = chunks_or_sections_to_search_docs(inference_sections)
for search_doc in search_docs:
search_doc.is_internet = is_internet

View File

@@ -24,7 +24,6 @@ class GraphInputs(BaseModel):
prompt_builder: AnswerPromptBuilder
files: list[InMemoryChatFile] | None = None
structured_response_format: dict | None = None
project_instructions: str | None = None
class Config:
arbitrary_types_allowed = True

View File

@@ -1,6 +1,6 @@
from pydantic import BaseModel
from onyx.chat.prompt_builder.schemas import PromptSnapshot
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolCallFinalResult

View File

@@ -8,6 +8,8 @@ from typing import TypeVar
from langchain.schema.language_model import LanguageModelInput
from langchain_core.messages import HumanMessage
from langgraph.types import StreamWriter
from litellm import get_supported_openai_params
from litellm import supports_response_schema
from pydantic import BaseModel
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
@@ -145,7 +147,6 @@ def invoke_llm_json(
Invoke an LLM, forcing it to respond in a specified JSON format if possible,
and return an object of that schema.
"""
from litellm.utils import get_supported_openai_params, supports_response_schema
# check if the model supports response_format: json_schema
supports_json = "response_format" in (

View File

@@ -115,6 +115,7 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.user_file_folder_sync",
"onyx.background.celery.tasks.docprocessing",
]
)

View File

@@ -323,6 +323,7 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.llm_model_update",
"onyx.background.celery.tasks.user_file_folder_sync",
"onyx.background.celery.tasks.kg_processing",
]
)

View File

@@ -1,113 +0,0 @@
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_process_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import POSTGRES_CELERY_WORKER_USER_FILE_PROCESSING_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.user_file_processing")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
@celeryd_init.connect
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_USER_FILE_PROCESSING_APP_NAME)
# rkuo: Transient errors keep happening in the indexing watchdog threads.
# "SSL connection has been closed unexpectedly"
# actually setting the spawn method in the cloud fixes 95% of these.
# setting pre ping might help even more, but not worrying about that yet
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
app_base.on_secondary_worker_init(sender, **kwargs)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_ready(sender, **kwargs)
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@worker_process_init.connect
def init_worker(**kwargs: Any) -> None:
SqlEngine.reset_engine()
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
base_bootsteps = app_base.get_bootsteps()
for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.user_file_processing",
]
)

View File

@@ -19,9 +19,7 @@ from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -32,7 +30,7 @@ PRUNING_CHECKPOINTED_BATCH_SIZE = 32
def document_batch_to_ids(
doc_batch: Iterator[list[Document]] | Iterator[list[SlimDocument]],
doc_batch: Iterator[list[Document]],
) -> Generator[set[str], None, None]:
for doc_list in doc_batch:
yield {doc.id for doc in doc_list}
@@ -43,24 +41,20 @@ def extract_ids_from_runnable_connector(
callback: IndexingHeartbeatInterface | None = None,
) -> set[str]:
"""
If the given connector is neither a SlimConnector nor a SlimConnectorWithPermSync, just pull
If the SlimConnector hasnt been implemented for the given connector, just pull
all docs using the load_from_state and grab out the IDs.
Optionally, a callback can be passed to handle the length of each document batch.
"""
all_connector_doc_ids: set[str] = set()
doc_batch_id_generator = None
if isinstance(runnable_connector, SlimConnector):
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.retrieve_all_slim_docs()
)
elif isinstance(runnable_connector, SlimConnectorWithPermSync):
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.retrieve_all_slim_docs_perm_sync()
)
# If the connector isn't slim, fall back to running it normally to get ids
elif isinstance(runnable_connector, LoadConnector):
for metadata_batch in runnable_connector.retrieve_all_slim_documents():
all_connector_doc_ids.update({doc.id for doc in metadata_batch})
doc_batch_id_generator = None
if isinstance(runnable_connector, LoadConnector):
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.load_from_state()
)
@@ -84,14 +78,13 @@ def extract_ids_from_runnable_connector(
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
# this function is called per batch for rate limiting
doc_batch_processing_func = (
rate_limit_builder(
def doc_batch_processing_func(doc_batch_ids: set[str]) -> set[str]:
return doc_batch_ids
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
doc_batch_processing_func = rate_limit_builder(
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
)(lambda x: x)
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
else lambda x: x
)
for doc_batch_ids in doc_batch_id_generator:
if callback:
if callback.should_stop():

View File

@@ -1,22 +0,0 @@
import onyx.background.celery.configs.base as shared_config
from onyx.configs.app_configs import CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
# User file processing worker configuration
worker_concurrency = CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY
worker_pool = "threads"
worker_prefetch_multiplier = 1

View File

@@ -26,26 +26,6 @@ CLOUD_DOC_PERMISSION_SYNC_MULTIPLIER_DEFAULT = 1.0
# tasks that run in either self-hosted on cloud
beat_task_templates: list[dict] = [
{
"name": "check-for-user-file-processing",
"task": OnyxCeleryTask.CHECK_FOR_USER_FILE_PROCESSING,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.USER_FILE_PROCESSING,
},
},
{
"name": "user-file-docid-migration",
"task": OnyxCeleryTask.USER_FILE_DOCID_MIGRATION,
"schedule": timedelta(minutes=1),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.USER_FILE_PROCESSING,
},
},
{
"name": "check-for-kg-processing",
"task": OnyxCeleryTask.CHECK_KG_PROCESSING,
@@ -85,9 +65,9 @@ beat_task_templates: list[dict] = [
{
"name": "check-for-index-attempt-cleanup",
"task": OnyxCeleryTask.CHECK_FOR_INDEX_ATTEMPT_CLEANUP,
"schedule": timedelta(minutes=30),
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
@@ -109,6 +89,17 @@ beat_task_templates: list[dict] = [
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-user-file-folder-sync",
"task": OnyxCeleryTask.CHECK_FOR_USER_FILE_FOLDER_SYNC,
"schedule": timedelta(
days=1
), # This should essentially always be triggered manually for user folder updates.
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-pruning",
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,

View File

@@ -28,6 +28,9 @@ from onyx.db.connector_credential_pair import add_deletion_failure_message
from onyx.db.connector_credential_pair import (
delete_connector_credential_pair__no_commit,
)
from onyx.db.connector_credential_pair import (
delete_userfiles_for_cc_pair__no_commit,
)
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.document import (
@@ -481,6 +484,12 @@ def monitor_connector_deletion_taskset(
# related to the deleted DocumentByConnectorCredentialPair during commit
db_session.expire(cc_pair)
# delete all userfiles for the cc_pair
delete_userfiles_for_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,

View File

@@ -85,9 +85,6 @@ from onyx.document_index.factory import get_default_document_index
from onyx.file_store.document_batch_storage import DocumentBatchStorage
from onyx.file_store.document_batch_storage import get_document_batch_storage
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.adapters.document_indexing_adapter import (
DocumentIndexingBatchAdapter,
)
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
@@ -1372,14 +1369,6 @@ def _docprocessing_task(
f"Processing {len(documents)} documents through indexing pipeline"
)
adapter = DocumentIndexingBatchAdapter(
db_session=db_session,
connector_id=index_attempt.connector_credential_pair.connector.id,
credential_id=index_attempt.connector_credential_pair.credential.id,
tenant_id=tenant_id,
index_attempt_metadata=index_attempt_metadata,
)
# real work happens here!
index_pipeline_result = run_indexing_pipeline(
embedder=embedding_model,
@@ -1389,8 +1378,7 @@ def _docprocessing_task(
db_session=db_session,
tenant_id=tenant_id,
document_batch=documents,
request_id=index_attempt_metadata.request_id,
adapter=adapter,
index_attempt_metadata=index_attempt_metadata,
)
# Update batch completion and document counts atomically using database coordination

View File

@@ -889,12 +889,6 @@ def monitor_celery_queues_helper(
n_user_files_indexing = celery_get_queue_length(
OnyxCeleryQueues.USER_FILES_INDEXING, r_celery
)
n_user_file_processing = celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
)
n_user_file_project_sync = celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, r_celery
)
n_sync = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
n_deletion = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
n_pruning = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery)
@@ -922,8 +916,6 @@ def monitor_celery_queues_helper(
f"docprocessing={n_docprocessing} "
f"docprocessing_prefetched={len(n_docprocessing_prefetched)} "
f"user_files_indexing={n_user_files_indexing} "
f"user_file_processing={n_user_file_processing} "
f"user_file_project_sync={n_user_file_project_sync} "
f"sync={n_sync} "
f"deletion={n_deletion} "
f"pruning={n_pruning} "

View File

@@ -0,0 +1,266 @@
import time
from typing import List
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from tenacity import RetryError
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_FOLDER_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.connector_credential_pair import (
get_connector_credential_pairs_with_user_files,
)
from onyx.db.document import get_document
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Document
from onyx.db.models import DocumentByConnectorCredentialPair
from onyx.db.search_settings import get_active_search_settings
from onyx.db.user_documents import fetch_user_files_for_documents
from onyx.db.user_documents import fetch_user_folders_for_documents
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaDocumentUserFields
from onyx.httpx.httpx_pool import HttpxPool
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
logger = setup_logger()
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_USER_FILE_FOLDER_SYNC,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,
)
def check_for_user_file_folder_sync(self: Task, *, tenant_id: str) -> bool | None:
"""Runs periodically to check for documents that need user file folder metadata updates.
This task fetches all connector credential pairs with user files, gets the documents
associated with them, and updates the user file and folder metadata in Vespa.
"""
time_start = time.monotonic()
r = get_redis_client()
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_USER_FILE_FOLDER_SYNC_BEAT_LOCK,
timeout=CELERY_USER_FILE_FOLDER_SYNC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
try:
with get_session_with_current_tenant() as db_session:
# Get all connector credential pairs that have user files
cc_pairs = get_connector_credential_pairs_with_user_files(db_session)
if not cc_pairs:
task_logger.info("No connector credential pairs with user files found")
return True
# Get all documents associated with these cc_pairs
document_ids = get_documents_for_cc_pairs(cc_pairs, db_session)
if not document_ids:
task_logger.info(
"No documents found for connector credential pairs with user files"
)
return True
# Fetch current user file and folder IDs for these documents
doc_id_to_user_file_id = fetch_user_files_for_documents(
document_ids=document_ids, db_session=db_session
)
doc_id_to_user_folder_id = fetch_user_folders_for_documents(
document_ids=document_ids, db_session=db_session
)
# Update Vespa metadata for each document
for doc_id in document_ids:
user_file_id = doc_id_to_user_file_id.get(doc_id)
user_folder_id = doc_id_to_user_folder_id.get(doc_id)
if user_file_id is not None or user_folder_id is not None:
# Schedule a task to update the document metadata
update_user_file_folder_metadata.apply_async(
args=(doc_id,), # Use tuple instead of list for args
kwargs={
"tenant_id": tenant_id,
"user_file_id": user_file_id,
"user_folder_id": user_folder_id,
},
queue="vespa_metadata_sync",
)
task_logger.info(
f"Scheduled metadata updates for {len(document_ids)} documents. "
f"Elapsed time: {time.monotonic() - time_start:.2f}s"
)
return True
except Exception as e:
task_logger.exception(f"Error in check_for_user_file_folder_sync: {e}")
return False
finally:
lock_beat.release()
def get_documents_for_cc_pairs(
cc_pairs: List[ConnectorCredentialPair], db_session: Session
) -> List[str]:
"""Get all document IDs associated with the given connector credential pairs."""
if not cc_pairs:
return []
cc_pair_ids = [cc_pair.id for cc_pair in cc_pairs]
# Query to get document IDs from DocumentByConnectorCredentialPair
# Note: DocumentByConnectorCredentialPair uses connector_id and credential_id, not cc_pair_id
doc_cc_pairs = (
db_session.query(Document.id)
.join(
DocumentByConnectorCredentialPair,
Document.id == DocumentByConnectorCredentialPair.id,
)
.filter(
db_session.query(ConnectorCredentialPair)
.filter(
ConnectorCredentialPair.id.in_(cc_pair_ids),
ConnectorCredentialPair.connector_id
== DocumentByConnectorCredentialPair.connector_id,
ConnectorCredentialPair.credential_id
== DocumentByConnectorCredentialPair.credential_id,
)
.exists()
)
.all()
)
return [doc_id for (doc_id,) in doc_cc_pairs]
@shared_task(
name=OnyxCeleryTask.UPDATE_USER_FILE_FOLDER_METADATA,
bind=True,
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
time_limit=LIGHT_TIME_LIMIT,
max_retries=3,
)
def update_user_file_folder_metadata(
self: Task,
document_id: str,
*,
tenant_id: str,
user_file_id: int | None,
user_folder_id: int | None,
) -> bool:
"""Updates the user file and folder metadata for a document in Vespa."""
start = time.monotonic()
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
try:
with get_session_with_current_tenant() as db_session:
active_search_settings = get_active_search_settings(db_session)
doc_index = get_default_document_index(
search_settings=active_search_settings.primary,
secondary_search_settings=active_search_settings.secondary,
httpx_client=HttpxPool.get("vespa"),
)
retry_index = RetryDocumentIndex(doc_index)
doc = get_document(document_id, db_session)
if not doc:
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action=no_operation "
f"elapsed={elapsed:.2f}"
)
completion_status = OnyxCeleryTaskCompletionStatus.SKIPPED
return False
# Create user fields object with file and folder IDs
user_fields = VespaDocumentUserFields(
user_file_id=str(user_file_id) if user_file_id is not None else None,
user_folder_id=(
str(user_folder_id) if user_folder_id is not None else None
),
)
# Update Vespa. OK if doc doesn't exist. Raises exception otherwise.
chunks_affected = retry_index.update_single(
document_id,
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
fields=None, # We're only updating user fields
user_fields=user_fields,
)
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action=user_file_folder_sync "
f"user_file_id={user_file_id} "
f"user_folder_id={user_folder_id} "
f"chunks={chunks_affected} "
f"elapsed={elapsed:.2f}"
)
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
return True
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
completion_status = OnyxCeleryTaskCompletionStatus.SOFT_TIME_LIMIT
except Exception as ex:
e: Exception | None = None
while True:
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
)
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()
if isinstance(e_temp, Exception):
e = e_temp
else:
e = ex
task_logger.exception(
f"update_user_file_folder_metadata exceptioned: doc={document_id}"
)
completion_status = OnyxCeleryTaskCompletionStatus.RETRYABLE_EXCEPTION
if (
self.max_retries is not None
and self.request.retries >= self.max_retries
):
completion_status = (
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
)
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown) # this will raise a celery exception
break # we won't hit this, but it looks weird not to have it
finally:
task_logger.info(
f"update_user_file_folder_metadata completed: status={completion_status.value} doc={document_id}"
)
return False

View File

@@ -1,699 +0,0 @@
import datetime
import time
from collections.abc import Sequence
from typing import Any
from uuid import UUID
import httpx
import sqlalchemy as sa
from celery import shared_task
from celery import Task
from redis.lock import Lock as RedisLock
from sqlalchemy import select
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.connectors.file.connector import LocalFileConnector
from onyx.connectors.models import Document
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import UserFileStatus
from onyx.db.models import FileRecord
from onyx.db.models import SearchDoc
from onyx.db.models import UserFile
from onyx.db.search_settings import get_active_search_settings
from onyx.db.search_settings import get_active_search_settings_list
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaDocumentUserFields
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
from onyx.document_index.vespa.shared_utils.utils import (
replace_invalid_doc_id_characters,
)
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.document_index.vespa_constants import USER_PROJECT
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.file_store import S3BackedFileStore
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.adapters.user_file_indexing_adapter import UserFileIndexingAdapter
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
from onyx.redis.redis_pool import get_redis_client
def _as_uuid(value: str | UUID) -> UUID:
"""Return a UUID, accepting either a UUID or a string-like value."""
return value if isinstance(value, UUID) else UUID(str(value))
def _user_file_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_USER_FILE_PROCESSING,
soft_time_limit=300,
bind=True,
ignore_result=True,
)
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
Uses direct Redis locks to avoid overlapping runs.
"""
task_logger.info("check_user_file_processing - Starting")
redis_client = get_redis_client(tenant_id=tenant_id)
lock: RedisLock = redis_client.lock(
OnyxRedisLocks.USER_FILE_PROCESSING_BEAT_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# Do not overlap generator runs
if not lock.acquire(blocking=False):
return None
enqueued = 0
try:
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
select(UserFile.id).where(
UserFile.status == UserFileStatus.PROCESSING
)
)
.scalars()
.all()
)
for user_file_id in user_file_ids:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
)
enqueued += 1
finally:
if lock.owned():
lock.release()
task_logger.info(
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
)
return None
@shared_task(
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
bind=True,
ignore_result=True,
)
def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -> None:
task_logger.info(f"process_single_user_file - Starting id={user_file_id}")
start = time.monotonic()
redis_client = get_redis_client(tenant_id=tenant_id)
file_lock: RedisLock = redis_client.lock(
_user_file_lock_key(user_file_id), timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT
)
if not file_lock.acquire(blocking=False):
task_logger.info(
f"process_single_user_file - Lock held, skipping user_file_id={user_file_id}"
)
return None
documents: list[Document] = []
try:
with get_session_with_current_tenant() as db_session:
uf = db_session.get(UserFile, _as_uuid(user_file_id))
if not uf:
task_logger.warning(
f"process_single_user_file - UserFile not found id={user_file_id}"
)
return None
if uf.status != UserFileStatus.PROCESSING:
task_logger.info(
f"process_single_user_file - Skipping id={user_file_id} status={uf.status}"
)
return None
connector = LocalFileConnector(
file_locations=[uf.file_id],
file_names=[uf.name] if uf.name else None,
zip_metadata={},
)
connector.load_credentials({})
# 20 is the documented default for httpx max_keepalive_connections
if MANAGED_VESPA:
httpx_init_vespa_pool(
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
)
else:
httpx_init_vespa_pool(20)
search_settings_list = get_active_search_settings_list(db_session)
current_search_settings = next(
(
search_settings_instance
for search_settings_instance in search_settings_list
if search_settings_instance.status.is_current()
),
None,
)
if current_search_settings is None:
raise RuntimeError(
f"process_single_user_file - No current search settings found for tenant={tenant_id}"
)
try:
for batch in connector.load_from_state():
documents.extend(batch)
adapter = UserFileIndexingAdapter(
tenant_id=tenant_id,
db_session=db_session,
)
# Set up indexing pipeline components
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=current_search_settings,
)
information_content_classification_model = (
InformationContentClassificationModel()
)
document_index = get_default_document_index(
current_search_settings,
None,
httpx_client=HttpxPool.get("vespa"),
)
# update the doument id to userfile id in the documents
for document in documents:
document.id = str(user_file_id)
document.source = DocumentSource.USER_FILE
# real work happens here!
index_pipeline_result = run_indexing_pipeline(
embedder=embedding_model,
information_content_classification_model=information_content_classification_model,
document_index=document_index,
ignore_time_skip=True,
db_session=db_session,
tenant_id=tenant_id,
document_batch=documents,
request_id=None,
adapter=adapter,
)
task_logger.info(
f"process_single_user_file - Indexing pipeline completed ={index_pipeline_result}"
)
if (
index_pipeline_result.failures
or index_pipeline_result.total_docs != len(documents)
or index_pipeline_result.total_chunks == 0
):
task_logger.error(
f"process_single_user_file - Indexing pipeline failed id={user_file_id}"
)
uf.status = UserFileStatus.FAILED
db_session.add(uf)
db_session.commit()
return None
except Exception as e:
task_logger.exception(
f"process_single_user_file - Error processing file id={user_file_id} - {e.__class__.__name__}"
)
uf.status = UserFileStatus.FAILED
db_session.add(uf)
db_session.commit()
return None
elapsed = time.monotonic() - start
task_logger.info(
f"process_single_user_file - Finished id={user_file_id} docs={len(documents)} elapsed={elapsed:.2f}s"
)
return None
except Exception as e:
# Attempt to mark the file as failed
with get_session_with_current_tenant() as db_session:
uf = db_session.get(UserFile, _as_uuid(user_file_id))
if uf:
uf.status = UserFileStatus.FAILED
db_session.add(uf)
db_session.commit()
task_logger.exception(
f"process_single_user_file - Error processing file id={user_file_id} - {e.__class__.__name__}"
)
return None
finally:
if file_lock.owned():
file_lock.release()
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_USER_FILE_PROJECT_SYNC,
soft_time_limit=300,
bind=True,
ignore_result=True,
)
def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
"""Scan for user files with PROJECT_SYNC status and enqueue per-file tasks."""
task_logger.info("check_for_user_file_project_sync - Starting")
redis_client = get_redis_client(tenant_id=tenant_id)
lock: RedisLock = redis_client.lock(
OnyxRedisLocks.USER_FILE_PROJECT_SYNC_BEAT_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
if not lock.acquire(blocking=False):
return None
enqueued = 0
try:
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
select(UserFile.id).where(
UserFile.needs_project_sync.is_(True)
and UserFile.status == UserFileStatus.COMPLETED
)
)
.scalars()
.all()
)
for user_file_id in user_file_ids:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
priority=OnyxCeleryPriority.HIGH,
)
enqueued += 1
finally:
if lock.owned():
lock.release()
task_logger.info(
f"check_for_user_file_project_sync - Enqueued {enqueued} tasks for tenant={tenant_id}"
)
return None
@shared_task(
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
bind=True,
ignore_result=True,
)
def process_single_user_file_project_sync(
self: Task, *, user_file_id: str, tenant_id: str
) -> None:
"""Process a single user file project sync."""
task_logger.info(
f"process_single_user_file_project_sync - Starting id={user_file_id}"
)
redis_client = get_redis_client(tenant_id=tenant_id)
file_lock: RedisLock = redis_client.lock(
_user_file_project_sync_lock_key(user_file_id),
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
if not file_lock.acquire(blocking=False):
task_logger.info(
f"process_single_user_file_project_sync - Lock held, skipping user_file_id={user_file_id}"
)
return None
try:
with get_session_with_current_tenant() as db_session:
active_search_settings = get_active_search_settings(db_session)
doc_index = get_default_document_index(
search_settings=active_search_settings.primary,
secondary_search_settings=active_search_settings.secondary,
httpx_client=HttpxPool.get("vespa"),
)
retry_index = RetryDocumentIndex(doc_index)
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
if not user_file:
task_logger.info(
f"process_single_user_file_project_sync - User file not found id={user_file_id}"
)
return None
project_ids = [project.id for project in user_file.projects]
chunks_affected = retry_index.update_single(
doc_id=str(user_file.id),
tenant_id=tenant_id,
chunk_count=user_file.chunk_count,
fields=None,
user_fields=VespaDocumentUserFields(user_projects=project_ids),
)
task_logger.info(
f"process_single_user_file_project_sync - Chunks affected id={user_file_id} chunks={chunks_affected}"
)
user_file.needs_project_sync = False
user_file.last_project_sync_at = datetime.datetime.now(
datetime.timezone.utc
)
db_session.add(user_file)
db_session.commit()
except Exception as e:
task_logger.exception(
f"process_single_user_file_project_sync - Error syncing project for file id={user_file_id} - {e.__class__.__name__}"
)
return None
finally:
if file_lock.owned():
file_lock.release()
return None
def _normalize_legacy_user_file_doc_id(old_id: str) -> str:
# Convert USER_FILE_CONNECTOR__<uuid> -> FILE_CONNECTOR__<uuid> for legacy values
user_prefix = "USER_FILE_CONNECTOR__"
file_prefix = "FILE_CONNECTOR__"
if old_id.startswith(user_prefix):
remainder = old_id[len(user_prefix) :]
return file_prefix + remainder
return old_id
def _visit_chunks(
*,
http_client: httpx.Client,
index_name: str,
selection: str,
continuation: str | None = None,
) -> tuple[list[dict[str, Any]], str | None]:
base_url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
params: dict[str, str] = {
"selection": selection,
"wantedDocumentCount": "1000",
}
if continuation:
params["continuation"] = continuation
resp = http_client.get(base_url, params=params, timeout=None)
resp.raise_for_status()
payload = resp.json()
return payload.get("documents", []), payload.get("continuation")
def _update_document_id_in_vespa(
*,
index_name: str,
old_doc_id: str,
new_doc_id: str,
user_project_ids: list[int] | None = None,
) -> None:
clean_new_doc_id = replace_invalid_doc_id_characters(new_doc_id)
normalized_old = _normalize_legacy_user_file_doc_id(old_doc_id)
clean_old_doc_id = replace_invalid_doc_id_characters(normalized_old)
selection = f"{index_name}.document_id=='{clean_old_doc_id}'"
task_logger.debug(f"Vespa selection: {selection}")
with get_vespa_http_client() as http_client:
continuation: str | None = None
while True:
docs, continuation = _visit_chunks(
http_client=http_client,
index_name=index_name,
selection=selection,
continuation=continuation,
)
if not docs:
break
for doc in docs:
vespa_full_id = doc.get("id")
if not vespa_full_id:
continue
vespa_doc_uuid = vespa_full_id.split("::")[-1]
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_doc_uuid}"
update_request: dict[str, Any] = {
"fields": {"document_id": {"assign": clean_new_doc_id}}
}
if user_project_ids is not None:
update_request["fields"][USER_PROJECT] = {
"assign": user_project_ids
}
r = http_client.put(vespa_url, json=update_request)
r.raise_for_status()
if not continuation:
break
@shared_task(
name=OnyxCeleryTask.USER_FILE_DOCID_MIGRATION,
ignore_result=True,
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
time_limit=LIGHT_TIME_LIMIT,
bind=True,
)
def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
"""Per-tenant job to update Vespa and search_doc document_id values for user files.
- For each user_file with a legacy document_id, set Vespa `document_id` to the UUID `user_file.id`.
- Update `search_doc.document_id` to the same UUID string.
"""
try:
with get_session_with_current_tenant() as db_session:
active_settings = get_active_search_settings(db_session)
document_index = get_default_document_index(
active_settings.primary,
active_settings.secondary,
)
if hasattr(document_index, "index_name"):
index_name = document_index.index_name
else:
index_name = "danswer_index"
# Fetch mappings of legacy -> new ids
rows = db_session.execute(
sa.select(
UserFile.document_id.label("document_id"),
UserFile.id.label("id"),
).where(
UserFile.document_id.is_not(None),
UserFile.document_id_migrated.is_(False),
)
).all()
# dedupe by old document_id
seen: set[str] = set()
for row in rows:
old_doc_id = str(row.document_id)
new_uuid = str(row.id)
if not old_doc_id or not new_uuid or old_doc_id in seen:
continue
seen.add(old_doc_id)
# collect user project ids for a combined Vespa update
user_project_ids: list[int] | None = None
try:
uf = db_session.get(UserFile, UUID(new_uuid))
if uf is not None:
user_project_ids = [project.id for project in uf.projects]
except Exception as e:
task_logger.warning(
f"Tenant={tenant_id} failed fetching projects for doc_id={new_uuid} - {e.__class__.__name__}"
)
try:
_update_document_id_in_vespa(
index_name=index_name,
old_doc_id=old_doc_id,
new_doc_id=new_uuid,
user_project_ids=user_project_ids,
)
except Exception as e:
task_logger.warning(
f"Tenant={tenant_id} failed Vespa update for doc_id={new_uuid} - {e.__class__.__name__}"
)
# Update search_doc records to refer to the UUID string
# we are not using document_id_migrated = false because if the migration already completed,
# it will not run again and we will not update the search_doc records because of the issue currently fixed
user_files = (
db_session.execute(
sa.select(UserFile).where(UserFile.document_id.is_not(None))
)
.scalars()
.all()
)
# Query all SearchDocs that need updating
search_docs = (
db_session.execute(
sa.select(SearchDoc).where(
SearchDoc.document_id.like("%FILE_CONNECTOR__%")
)
)
.scalars()
.all()
)
task_logger.info(f"Found {len(user_files)} user files to update")
task_logger.info(f"Found {len(search_docs)} search docs to update")
# Build a map of normalized doc IDs to SearchDocs
search_doc_map: dict[str, list[SearchDoc]] = {}
for sd in search_docs:
doc_id = sd.document_id
if search_doc_map.get(doc_id) is None:
search_doc_map[doc_id] = []
search_doc_map[doc_id].append(sd)
task_logger.debug(
f"Built search doc map with {len(search_doc_map)} entries"
)
ids_preview = list(search_doc_map.keys())[:5]
task_logger.debug(
f"First few search_doc_map ids: {ids_preview if ids_preview else 'No ids found'}"
)
task_logger.debug(
f"search_doc_map total items: {sum(len(docs) for docs in search_doc_map.values())}"
)
# Process each UserFile and update matching SearchDocs
updated_count = 0
for uf in user_files:
doc_id = uf.document_id
if doc_id.startswith("USER_FILE_CONNECTOR__"):
doc_id = "FILE_CONNECTOR__" + doc_id[len("USER_FILE_CONNECTOR__") :]
task_logger.debug(f"Processing user file {uf.id} with doc_id {doc_id}")
task_logger.debug(
f"doc_id in search_doc_map: {doc_id in search_doc_map}"
)
if doc_id in search_doc_map:
search_docs = search_doc_map[doc_id]
task_logger.debug(
f"Found {len(search_docs)} search docs to update for user file {uf.id}"
)
# Update the SearchDoc to use the UserFile's UUID
for search_doc in search_docs:
search_doc.document_id = str(uf.id)
db_session.add(search_doc)
# Mark UserFile as migrated
uf.document_id_migrated = True
db_session.add(uf)
updated_count += 1
task_logger.info(
f"Updated {updated_count} SearchDoc records with new UUIDs"
)
db_session.commit()
# Normalize plaintext FileRecord blobs: ensure S3 object key aligns with current file_id
try:
store = get_default_file_store()
# Only supported for S3-backed stores where we can manipulate object keys
if isinstance(store, S3BackedFileStore):
s3_client = store._get_s3_client()
bucket_name = store._get_bucket_name()
plaintext_records: Sequence[FileRecord] = (
db_session.execute(
sa.select(FileRecord).where(
FileRecord.file_origin == FileOrigin.PLAINTEXT_CACHE,
FileRecord.file_id.like("plaintext_%"),
)
)
.scalars()
.all()
)
normalized = 0
for fr in plaintext_records:
try:
expected_key = store._get_s3_key(fr.file_id)
if fr.object_key == expected_key:
continue
# Copy old object to new key
copy_source = f"{fr.bucket_name}/{fr.object_key}"
s3_client.copy_object(
CopySource=copy_source,
Bucket=bucket_name,
Key=expected_key,
MetadataDirective="COPY",
)
# Delete old object (best-effort)
try:
s3_client.delete_object(
Bucket=fr.bucket_name, Key=fr.object_key
)
except Exception:
pass
# Update DB record with new key
fr.object_key = expected_key
db_session.add(fr)
normalized += 1
except Exception as e:
task_logger.warning(
f"Tenant={tenant_id} failed plaintext object normalize for "
f"id={fr.file_id} - {e.__class__.__name__}"
)
if normalized:
db_session.commit()
task_logger.info(
f"user_file_docid_migration_task normalized {normalized} plaintext objects for tenant={tenant_id}"
)
else:
task_logger.info(
"user_file_docid_migration_task skipping plaintext object normalization (non-S3 store)"
)
except Exception:
task_logger.exception(
f"user_file_docid_migration_task - Error during plaintext normalization for tenant={tenant_id}"
)
task_logger.info(
f"user_file_docid_migration_task completed for tenant={tenant_id} (rows={len(rows)})"
)
return True
except Exception:
task_logger.exception(
f"user_file_docid_migration_task - Error during execution for tenant={tenant_id}"
)
return False

View File

@@ -1,16 +0,0 @@
"""Factory stub for running the user file processing Celery worker."""
from celery import Celery
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
def get_app() -> Celery:
from onyx.background.celery.apps.user_file_processing import celery_app
return celery_app
app = get_app()

View File

@@ -5,7 +5,6 @@ from sqlalchemy.orm import Session
from onyx.configs.constants import NUM_DAYS_TO_KEEP_INDEX_ATTEMPTS
from onyx.db.engine.time_utils import get_db_current_time
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
def get_old_index_attempts(
@@ -22,10 +21,6 @@ def get_old_index_attempts(
def cleanup_index_attempts(db_session: Session, index_attempt_ids: list[int]) -> None:
"""Clean up multiple index attempts"""
db_session.query(IndexAttemptError).filter(
IndexAttemptError.index_attempt_id.in_(index_attempt_ids)
).delete(synchronize_session=False)
db_session.query(IndexAttempt).filter(
IndexAttempt.id.in_(index_attempt_ids)
).delete(synchronize_session=False)

View File

@@ -64,11 +64,9 @@ from onyx.document_index.factory import get_default_document_index
from onyx.file_store.document_batch_storage import DocumentBatchStorage
from onyx.file_store.document_batch_storage import get_document_batch_storage
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.adapters.document_indexing_adapter import (
DocumentIndexingBatchAdapter,
)
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
@@ -102,7 +100,6 @@ def _get_connector_runner(
are the complete list of existing documents of the connector. If the task
of type LOAD_STATE, the list will be considered complete and otherwise incomplete.
"""
task = attempt.connector_credential_pair.connector.input_type
try:
@@ -286,8 +283,6 @@ def _run_indexing(
2. Embed and index these documents into the chosen datastore (vespa)
3. Updates Postgres to record the indexed documents + the outcome of this run
"""
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
start_time = time.monotonic() # jsut used for logging
with get_session_with_current_tenant() as db_session_temp:
@@ -572,13 +567,6 @@ def _run_indexing(
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
# real work happens here!
adapter = DocumentIndexingBatchAdapter(
db_session=db_session,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
tenant_id=tenant_id,
index_attempt_metadata=index_attempt_md,
)
index_pipeline_result = run_indexing_pipeline(
embedder=embedding_model,
information_content_classification_model=information_content_classification_model,
@@ -590,8 +578,7 @@ def _run_indexing(
db_session=db_session,
tenant_id=tenant_id,
document_batch=doc_batch_cleaned,
request_id=index_attempt_md.request_id,
adapter=adapter,
index_attempt_metadata=index_attempt_md,
)
batch_num += 1

View File

@@ -62,7 +62,6 @@ class Answer:
use_agentic_search: bool = False,
research_type: ResearchType | None = None,
research_plan: dict[str, Any] | None = None,
project_instructions: str | None = None,
) -> None:
self.is_connected: Callable[[], bool] | None = is_connected
self._processed_stream: list[AnswerStreamPart] | None = None
@@ -98,7 +97,6 @@ class Answer:
prompt_builder=prompt_builder,
files=latest_query_files,
structured_response_format=answer_style_config.structured_response_format,
project_instructions=project_instructions,
)
self.graph_tooling = GraphTooling(
primary_llm=llm,

View File

@@ -32,7 +32,6 @@ from onyx.db.llm import fetch_existing_doc_sets
from onyx.db.llm import fetch_existing_tools
from onyx.db.models import ChatMessage
from onyx.db.models import Persona
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.db.models import Tool
from onyx.db.models import User
from onyx.db.search_settings import get_current_search_settings
@@ -344,45 +343,6 @@ def reorganize_citations(
return new_answer, list(new_citation_info.values())
def build_citation_map_from_infos(
citations_list: list[CitationInfo], db_docs: list[DbSearchDoc]
) -> dict[int, int]:
"""Translate a list of streaming CitationInfo objects into a mapping of
citation number -> saved search doc DB id.
Always cites the first instance of a document_id and assumes db_docs are
ordered as shown to the user (display order).
"""
doc_id_to_saved_doc_id_map: dict[str, int] = {}
for db_doc in db_docs:
if db_doc.document_id not in doc_id_to_saved_doc_id_map:
doc_id_to_saved_doc_id_map[db_doc.document_id] = db_doc.id
citation_to_saved_doc_id_map: dict[int, int] = {}
for citation in citations_list:
if citation.citation_num not in citation_to_saved_doc_id_map:
saved_id = doc_id_to_saved_doc_id_map.get(citation.document_id)
if saved_id is not None:
citation_to_saved_doc_id_map[citation.citation_num] = saved_id
return citation_to_saved_doc_id_map
def build_citation_map_from_numbers(
cited_numbers: list[int] | set[int], db_docs: list[DbSearchDoc]
) -> dict[int, int]:
"""Translate parsed citation numbers (e.g., from [[n]]) into a mapping of
citation number -> saved search doc DB id by positional index.
"""
citation_to_saved_doc_id_map: dict[int, int] = {}
for num in sorted(set(cited_numbers)):
idx = num - 1
if 0 <= idx < len(db_docs):
citation_to_saved_doc_id_map[num] = db_docs[idx].id
return citation_to_saved_doc_id_map
def extract_headers(
headers: dict[str, str] | Headers, pass_through_headers: list[str] | None
) -> dict[str, str]:

View File

@@ -5,7 +5,6 @@ from collections.abc import Callable
from collections.abc import Iterator
from typing import cast
from typing import Protocol
from uuid import UUID
from sqlalchemy.orm import Session
@@ -19,7 +18,6 @@ from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import ChatBasicResponse
from onyx.chat.models import CitationConfig
from onyx.chat.models import DocumentPruningConfig
from onyx.chat.models import LlmDoc
from onyx.chat.models import MessageResponseIDInfo
from onyx.chat.models import MessageSpecificCitations
from onyx.chat.models import PromptConfig
@@ -37,7 +35,6 @@ from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.chat_configs import SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MessageType
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import NO_AUTH_USER_ID
@@ -66,13 +63,9 @@ from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.db.models import ToolCall
from onyx.db.models import User
from onyx.db.persona import get_persona_by_id
from onyx.db.projects import get_project_instructions
from onyx.db.projects import get_user_files_from_project
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.file_store.models import FileDescriptor
from onyx.file_store.models import InMemoryChatFile
from onyx.file_store.utils import build_frontend_file_url
from onyx.file_store.utils import load_all_chat_files
from onyx.kg.models import KGException
from onyx.llm.exceptions import GenAIDisabledException
@@ -108,7 +101,6 @@ from onyx.utils.timing import log_function_time
from onyx.utils.timing import log_generator_function_time
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
ERROR_TYPE_CANCELLED = "cancelled"
@@ -127,66 +119,6 @@ class PartialResponse(Protocol):
) -> ChatMessage: ...
def _build_project_llm_docs(
project_file_ids: list[str] | None,
in_memory_user_files: list[InMemoryChatFile] | None,
) -> list[LlmDoc]:
"""Construct `LlmDoc` objects for project-scoped user files for citation flow."""
project_llm_docs: list[LlmDoc] = []
if not project_file_ids or not in_memory_user_files:
return project_llm_docs
project_file_id_set = set(project_file_ids)
for f in in_memory_user_files:
if project_file_id_set and (f.file_id in project_file_id_set):
def _strip_nuls(s: str) -> str:
return s.replace("\x00", "") if s else s
cleaned_filename = _strip_nuls(f.filename or str(f.file_id))
if f.file_type.is_text_file():
try:
text_content = f.content.decode("utf-8", errors="ignore")
text_content = _strip_nuls(text_content)
except Exception:
text_content = ""
# Build a short blurb from the file content for better UI display
blurb = (
(text_content[:200] + "...")
if len(text_content) > 200
else text_content
)
else:
# Non-text (e.g., images): do not decode bytes; keep empty content but allow citation
text_content = ""
blurb = f"[{f.file_type.value}] {cleaned_filename}"
# Provide basic metadata to improve SavedSearchDoc display
file_metadata: dict[str, str | list[str]] = {
"filename": cleaned_filename,
"file_type": f.file_type.value,
}
project_llm_docs.append(
LlmDoc(
document_id=str(f.file_id),
content=text_content,
blurb=blurb,
semantic_identifier=cleaned_filename,
source_type=DocumentSource.USER_FILE,
metadata=file_metadata,
updated_at=None,
link=build_frontend_file_url(str(f.file_id)),
source_links=None,
match_highlights=None,
)
)
return project_llm_docs
def _translate_citations(
citations_list: list[CitationInfo], db_docs: list[DbSearchDoc]
) -> MessageSpecificCitations:
@@ -504,29 +436,26 @@ def stream_chat_message_objects(
files = load_all_chat_files(history_msgs, new_msg_req.file_descriptors)
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
latest_query_files = [file for file in files if file.file_id in req_file_ids]
user_file_ids: list[UUID] = []
user_file_ids = new_msg_req.user_file_ids or []
user_folder_ids = new_msg_req.user_folder_ids or []
if persona.user_files:
for uf in persona.user_files:
user_file_ids.append(uf.id)
if new_msg_req.current_message_files:
for fd in new_msg_req.current_message_files:
uid = fd.get("user_file_id")
if uid is not None:
user_file_id = UUID(uid)
user_file_ids.append(user_file_id)
for file in persona.user_files:
user_file_ids.append(file.id)
if persona.user_folders:
for folder in persona.user_folders:
user_folder_ids.append(folder.id)
# Load in user files into memory and create search tool override kwargs if needed
# if we have enough tokens, we don't need to use search
# if we have enough tokens and no folders, we don't need to use search
# we can just pass them into the prompt directly
(
in_memory_user_files,
user_file_models,
search_tool_override_kwargs_for_user_files,
) = parse_user_files(
user_file_ids=user_file_ids or [],
project_id=chat_session.project_id,
user_file_ids=user_file_ids,
user_folder_ids=user_folder_ids,
db_session=db_session,
persona=persona,
actual_user_input=message_text,
@@ -535,37 +464,16 @@ def stream_chat_message_objects(
if not search_tool_override_kwargs_for_user_files:
latest_query_files.extend(in_memory_user_files)
project_file_ids = []
if chat_session.project_id:
project_file_ids.extend(
[
file.file_id
for file in get_user_files_from_project(
chat_session.project_id, user_id, db_session
)
]
)
# we don't want to attach project files to the user message
if user_message:
attach_files_to_chat_message(
chat_message=user_message,
files=[
new_file.to_file_descriptor()
for new_file in latest_query_files
if project_file_ids is not None
and (new_file.file_id not in project_file_ids)
new_file.to_file_descriptor() for new_file in latest_query_files
],
db_session=db_session,
commit=False,
)
# Build project context docs for citation flow if project files are present
project_llm_docs: list[LlmDoc] = _build_project_llm_docs(
project_file_ids=project_file_ids,
in_memory_user_files=in_memory_user_files,
)
selected_db_search_docs = None
selected_sections: list[InferenceSection] | None = None
if reference_doc_ids:
@@ -651,22 +559,12 @@ def stream_chat_message_objects(
else:
prompt_config = PromptConfig.from_model(persona)
# Retrieve project-specific instructions if this chat session is associated with a project.
project_instructions: str | None = (
get_project_instructions(
db_session=db_session, project_id=chat_session.project_id
)
if persona.is_default_persona
else None
) # if the persona is not default, we don't want to use the project instructions
answer_style_config = AnswerStyleConfig(
citation_config=CitationConfig(
all_docs_useful=selected_db_search_docs is not None
),
structured_response_format=new_msg_req.structured_response_format,
)
has_project_files = project_file_ids is not None and len(project_file_ids) > 0
tool_dict = construct_tools(
persona=persona,
@@ -676,17 +574,9 @@ def stream_chat_message_objects(
llm=llm,
fast_llm=fast_llm,
run_search_setting=(
OptionalSearchSetting.NEVER
if (
chat_session.project_id
and not has_project_files
and persona.is_default_persona
)
else (
retrieval_options.run_search
if retrieval_options
else OptionalSearchSetting.AUTO
)
retrieval_options.run_search
if retrieval_options
else OptionalSearchSetting.AUTO
),
search_tool_config=SearchToolConfig(
answer_style_config=answer_style_config,
@@ -728,7 +618,6 @@ def stream_chat_message_objects(
message_history = [
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
]
if not search_tool_override_kwargs_for_user_files and in_memory_user_files:
yield UserKnowledgeFilePacket(
user_files=[
@@ -736,8 +625,6 @@ def stream_chat_message_objects(
id=str(file.file_id), type=file.file_type, name=file.filename
)
for file in in_memory_user_files
if project_file_ids is not None
and (file.file_id not in project_file_ids)
]
)
@@ -756,10 +643,6 @@ def stream_chat_message_objects(
single_message_history=single_message_history,
)
if project_llm_docs and not search_tool_override_kwargs_for_user_files:
# Store for downstream streaming to wire citations and final_documents
prompt_builder.context_llm_docs = project_llm_docs
# LLM prompt building, response capturing, etc.
answer = Answer(
prompt_builder=prompt_builder,
@@ -788,7 +671,6 @@ def stream_chat_message_objects(
db_session=db_session,
use_agentic_search=new_msg_req.use_agentic_search,
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation,
project_instructions=project_instructions,
)
# Process streamed packets using the new packet processing module

View File

@@ -4,9 +4,9 @@ from typing import cast
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from pydantic import BaseModel
from pydantic.v1 import BaseModel as BaseModel__v1
from onyx.chat.models import LlmDoc
from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens
from onyx.chat.prompt_builder.utils import translate_history_to_basemessages
@@ -76,7 +76,6 @@ def default_build_user_message(
if prompt_config.task_prompt
else user_query
)
user_prompt = user_prompt.strip()
tag_handled_prompt = handle_onyx_date_awareness(user_prompt, prompt_config)
user_msg = HumanMessage(
@@ -133,10 +132,6 @@ class AnswerPromptBuilder:
self.raw_user_uploaded_files = raw_user_uploaded_files
self.single_message_history = single_message_history
# Optional: if the prompt includes explicit context documents (e.g., project files),
# store them here so downstream streaming can reference them for citation mapping.
self.context_llm_docs: list[LlmDoc] | None = None
def update_system_prompt(self, system_message: SystemMessage | None) -> None:
if not system_message:
self.system_message_and_token_cnt = None
@@ -201,6 +196,10 @@ class AnswerPromptBuilder:
# Stores some parts of a prompt builder as needed for tool calls
class PromptSnapshot(BaseModel):
raw_message_history: list[PreviousMessage]
raw_user_query: str
built_prompt: list[BaseMessage]
# TODO: rename this? AnswerConfig maybe?

View File

@@ -1,10 +0,0 @@
from langchain_core.messages import BaseMessage
from pydantic import BaseModel
from onyx.llm.models import PreviousMessage
class PromptSnapshot(BaseModel):
raw_message_history: list[PreviousMessage]
raw_user_query: str
built_prompt: list[BaseMessage]

View File

@@ -12,35 +12,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
def normalize_square_bracket_citations_to_double_with_links(text: str) -> str:
"""
Normalize citation markers in the text:
- Convert bare double-bracket citations without links `[[n]]` to `[[n]]()`
- Convert single-bracket citations `[n]` to `[[n]]()`
Leaves existing linked citations like `[[n]](http...)` unchanged.
"""
if not text:
return ""
# Add empty parens to bare double-bracket citations without a link: [[n]] -> [[n]]()
pattern_double_no_link = re.compile(r"\[\[(\d+)\]\](?!\()")
def _repl_double(match: re.Match[str]) -> str:
num = match.group(1)
return f"[[{num}]]()"
text = pattern_double_no_link.sub(_repl_double, text)
# Convert single [n] not already [[n]] to [[n]]()
pattern_single = re.compile(r"(?<!\[)\[(\d+)\](?!\])")
def _repl_single(match: re.Match[str]) -> str:
num = match.group(1)
return f"[[{num}]]()"
return pattern_single.sub(_repl_single, text)
def in_code_block(llm_text: str) -> bool:
count = llm_text.count(TRIPLE_BACKTICK)
return count % 2 != 0

View File

@@ -7,7 +7,7 @@ from langchain_core.messages import ToolCall
from onyx.chat.models import ResponsePart
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
from onyx.chat.prompt_builder.schemas import PromptSnapshot
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool
from onyx.tools.message import build_tool_message

View File

@@ -4,8 +4,6 @@ from sqlalchemy.orm import Session
from onyx.db.models import Persona
from onyx.db.models import UserFile
from onyx.db.projects import get_user_files_from_project
from onyx.db.user_file import update_last_accessed_at_for_user_files
from onyx.file_store.models import InMemoryChatFile
from onyx.file_store.utils import get_user_files_as_user
from onyx.file_store.utils import load_in_memory_chat_files
@@ -17,24 +15,24 @@ logger = setup_logger()
def parse_user_files(
user_file_ids: list[UUID],
user_file_ids: list[int],
user_folder_ids: list[int],
db_session: Session,
persona: Persona,
actual_user_input: str,
project_id: int | None,
# should only be None if auth is disabled
user_id: UUID | None,
) -> tuple[list[InMemoryChatFile], list[UserFile], SearchToolOverrideKwargs | None]:
"""
Parse user files and project into in-memory chat files and create search tool override kwargs.
Only creates SearchToolOverrideKwargs if token overflow occurs.
Parse user files and folders into in-memory chat files and create search tool override kwargs.
Only creates SearchToolOverrideKwargs if token overflow occurs or folders are present.
Args:
user_file_ids: List of user file IDs to load
user_folder_ids: List of user folder IDs to load
db_session: Database session
persona: Persona to calculate available tokens
actual_user_input: User's input message for token calculation
project_id: Project ID to validate file ownership
user_id: User ID to validate file ownership
Returns:
@@ -42,56 +40,37 @@ def parse_user_files(
loaded user files,
user file models,
search tool override kwargs if token
overflow
overflow or folders present
)
"""
# Return empty results if no files or project specified
if not user_file_ids and not project_id:
# Return empty results if no files or folders specified
if not user_file_ids and not user_folder_ids:
return [], [], None
project_user_file_ids = []
if project_id:
project_user_file_ids.extend(
[
file.id
for file in get_user_files_from_project(project_id, user_id, db_session)
]
)
# Combine user-provided and project-derived user file IDs
combined_user_file_ids = user_file_ids + project_user_file_ids or []
# Load user files from the database into memory
user_files = load_in_memory_chat_files(
combined_user_file_ids,
user_file_ids or [],
user_folder_ids or [],
db_session,
)
user_file_models = get_user_files_as_user(
combined_user_file_ids,
user_file_ids or [],
user_folder_ids or [],
user_id,
db_session,
)
# Update last accessed at for the user files which are used in the chat
if user_file_ids or project_user_file_ids:
# update_last_accessed_at_for_user_files expects list[UUID]
update_last_accessed_at_for_user_files(
combined_user_file_ids,
db_session,
)
# Calculate token count for the files, need to import here to avoid circular import
# TODO: fix this
from onyx.db.user_file import calculate_user_files_token_count
from onyx.db.user_documents import calculate_user_files_token_count
from onyx.chat.prompt_builder.citations_prompt import (
compute_max_document_tokens_for_persona,
)
# calculate_user_files_token_count now expects list[UUID]
total_tokens = calculate_user_files_token_count(
combined_user_file_ids,
user_file_ids or [],
user_folder_ids or [],
db_session,
)
@@ -100,31 +79,27 @@ def parse_user_files(
persona=persona,
actual_user_input=actual_user_input,
)
uploaded_context_cap = int(available_tokens * 0.5)
logger.debug(
f"Total file tokens: {total_tokens}, Available tokens: {available_tokens},"
f"Allowed uploaded context tokens: {uploaded_context_cap}"
f"Total file tokens: {total_tokens}, Available tokens: {available_tokens}"
)
have_enough_tokens = total_tokens <= uploaded_context_cap
have_enough_tokens = total_tokens <= available_tokens
# If we have enough tokens, we don't need search
# If we have enough tokens and no folders, we don't need search
# we can just pass them into the prompt directly
if have_enough_tokens:
if have_enough_tokens and not user_folder_ids:
# No search tool override needed - files can be passed directly
return user_files, user_file_models, None
# Token overflow - need to use search tool
# Token overflow or folders present - need to use search tool
override_kwargs = SearchToolOverrideKwargs(
force_no_rerank=have_enough_tokens,
alternate_db_session=None,
retrieved_sections_callback=None,
skip_query_analysis=have_enough_tokens,
user_file_ids=user_file_ids or [],
project_id=(
project_id if persona.is_default_persona else None
), # if the persona is not default, we don't want to use the project files
user_file_ids=user_file_ids,
user_folder_ids=user_folder_ids,
)
return user_files, user_file_models, override_kwargs

View File

@@ -362,18 +362,6 @@ CELERY_WORKER_PRIMARY_CONCURRENCY = int(
CELERY_WORKER_PRIMARY_POOL_OVERFLOW = int(
os.environ.get("CELERY_WORKER_PRIMARY_POOL_OVERFLOW") or 4
)
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY_DEFAULT = 4
try:
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY = int(
os.environ.get(
"CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY",
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY_DEFAULT,
)
)
except ValueError:
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY = (
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY_DEFAULT
)
# The maximum number of tasks that can be queued up to sync to Vespa in a single pass
VESPA_SYNC_MAX_TASKS = 8192

View File

@@ -3,6 +3,7 @@ import os
INPUT_PROMPT_YAML = "./onyx/seeding/input_prompts.yaml"
PROMPTS_YAML = "./onyx/seeding/prompts.yaml"
PERSONAS_YAML = "./onyx/seeding/personas.yaml"
USER_FOLDERS_YAML = "./onyx/seeding/user_folders.yaml"
NUM_RETURNED_HITS = 50
# Used for LLM filtering and reranking
# We want this to be approximately the number of results we want to show on the first page
@@ -90,7 +91,6 @@ HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "").lower() == "true"
# Internet Search
EXA_API_KEY = os.environ.get("EXA_API_KEY") or None
SERPER_API_KEY = os.environ.get("SERPER_API_KEY") or None
NUM_INTERNET_SEARCH_RESULTS = int(os.environ.get("NUM_INTERNET_SEARCH_RESULTS") or 10)
NUM_INTERNET_SEARCH_CHUNKS = int(os.environ.get("NUM_INTERNET_SEARCH_CHUNKS") or 50)

View File

@@ -78,9 +78,6 @@ POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME = "celery_worker_docfetching"
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME = "celery_worker_kg_processing"
POSTGRES_CELERY_WORKER_USER_FILE_PROCESSING_APP_NAME = (
"celery_worker_user_file_processing"
)
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
POSTGRES_UNKNOWN_APP_NAME = "unknown"
@@ -117,6 +114,7 @@ CELERY_GENERIC_BEAT_LOCK_TIMEOUT = 120
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 120
CELERY_USER_FILE_FOLDER_SYNC_BEAT_LOCK_TIMEOUT = 120
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
@@ -207,8 +205,6 @@ class DocumentSource(str, Enum):
# Special case just for integration tests
MOCK_CONNECTOR = "mock_connector"
# Special case for user files
USER_FILE = "user_file"
class FederatedConnectorSource(str, Enum):
@@ -304,7 +300,6 @@ class FileOrigin(str, Enum):
PLAINTEXT_CACHE = "plaintext_cache"
OTHER = "other"
QUERY_HISTORY_CSV = "query_history_csv"
USER_FILE = "user_file"
class FileType(str, Enum):
@@ -350,9 +345,6 @@ class OnyxCeleryQueues:
# Indexing queue
USER_FILES_INDEXING = "user_files_indexing"
# User file processing queue
USER_FILE_PROCESSING = "user_file_processing"
USER_FILE_PROJECT_SYNC = "user_file_project_sync"
# Document processing pipeline queue
DOCPROCESSING = "docprocessing"
CONNECTOR_DOC_FETCHING = "connector_doc_fetching"
@@ -378,7 +370,7 @@ class OnyxRedisLocks:
CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK = (
"da_lock:check_connector_external_group_sync_beat"
)
CHECK_USER_FILE_FOLDER_SYNC_BEAT_LOCK = "da_lock:check_user_file_folder_sync_beat"
MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes"
CHECK_AVAILABLE_TENANTS_LOCK = "da_lock:check_available_tenants"
CLOUD_PRE_PROVISION_TENANT_LOCK = "da_lock:pre_provision_tenant"
@@ -400,12 +392,6 @@ class OnyxRedisLocks:
# KG processing
KG_PROCESSING_LOCK = "da_lock:kg_processing"
# User file processing
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
class OnyxRedisSignals:
BLOCK_VALIDATE_INDEXING_FENCES = "signal:block_validate_indexing_fences"
@@ -464,6 +450,8 @@ class OnyxCeleryTask:
f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor_celery_pidbox"
)
UPDATE_USER_FILE_FOLDER_METADATA = "update_user_file_folder_metadata"
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
CHECK_FOR_INDEXING = "check_for_indexing"
@@ -471,12 +459,7 @@ class OnyxCeleryTask:
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
# User file processing
CHECK_FOR_USER_FILE_PROCESSING = "check_for_user_file_processing"
PROCESS_SINGLE_USER_FILE = "process_single_user_file"
CHECK_FOR_USER_FILE_PROJECT_SYNC = "check_for_user_file_project_sync"
PROCESS_SINGLE_USER_FILE_PROJECT_SYNC = "process_single_user_file_project_sync"
CHECK_FOR_USER_FILE_FOLDER_SYNC = "check_for_user_file_folder_sync"
# Connector checkpoint cleanup
CHECK_FOR_CHECKPOINT_CLEANUP = "check_for_checkpoint_cleanup"
@@ -509,7 +492,6 @@ class OnyxCeleryTask:
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"
USER_FILE_DOCID_MIGRATION = "user_file_docid_migration"
# chat retention
CHECK_TTL_MANAGEMENT_TASK = "check_ttl_management_task"

View File

@@ -41,7 +41,7 @@ All new connectors should have tests added to the `backend/tests/daily/connector
#### Implementing the new Connector
The connector must subclass one or more of LoadConnector, PollConnector, CheckpointedConnector, or CheckpointedConnectorWithPermSync
The connector must subclass one or more of LoadConnector, PollConnector, SlimConnector, or EventConnector.
The `__init__` should take arguments for configuring what documents the connector will and where it finds those
documents. For example, if you have a wiki site, it may include the configuration for the team, topic, folder, etc. of

View File

@@ -25,7 +25,7 @@ from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
@@ -56,7 +56,7 @@ class BitbucketConnectorCheckpoint(ConnectorCheckpoint):
class BitbucketConnector(
CheckpointedConnector[BitbucketConnectorCheckpoint],
SlimConnectorWithPermSync,
SlimConnector,
):
"""Connector for indexing Bitbucket Cloud pull requests.
@@ -266,7 +266,7 @@ class BitbucketConnector(
"""Validate and deserialize a checkpoint instance from JSON."""
return BitbucketConnectorCheckpoint.model_validate_json(checkpoint_json)
def retrieve_all_slim_docs_perm_sync(
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,

View File

@@ -5,7 +5,6 @@ from datetime import timezone
from typing import Any
from urllib.parse import quote
from atlassian.errors import ApiError # type: ignore
from requests.exceptions import HTTPError
from typing_extensions import override
@@ -42,7 +41,6 @@ from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
@@ -93,7 +91,6 @@ class ConfluenceCheckpoint(ConnectorCheckpoint):
class ConfluenceConnector(
CheckpointedConnector[ConfluenceCheckpoint],
SlimConnector,
SlimConnectorWithPermSync,
CredentialsConnector,
):
def __init__(
@@ -111,7 +108,6 @@ class ConfluenceConnector(
# pages.
labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP,
timezone_offset: float = CONFLUENCE_TIMEZONE_OFFSET,
scoped_token: bool = False,
) -> None:
self.wiki_base = wiki_base
self.is_cloud = is_cloud
@@ -122,7 +118,6 @@ class ConfluenceConnector(
self.batch_size = batch_size
self.labels_to_skip = labels_to_skip
self.timezone_offset = timezone_offset
self.scoped_token = scoped_token
self._confluence_client: OnyxConfluence | None = None
self._low_timeout_confluence_client: OnyxConfluence | None = None
self._fetched_titles: set[str] = set()
@@ -200,7 +195,6 @@ class ConfluenceConnector(
is_cloud=self.is_cloud,
url=self.wiki_base,
credentials_provider=credentials_provider,
scoped_token=self.scoped_token,
)
confluence_client._probe_connection(**self.probe_kwargs)
confluence_client._initialize_connection(**self.final_kwargs)
@@ -213,7 +207,6 @@ class ConfluenceConnector(
url=self.wiki_base,
credentials_provider=credentials_provider,
timeout=3,
scoped_token=self.scoped_token,
)
low_timeout_confluence_client._probe_connection(**self.probe_kwargs)
low_timeout_confluence_client._initialize_connection(**self.final_kwargs)
@@ -565,21 +558,7 @@ class ConfluenceConnector(
def validate_checkpoint_json(self, checkpoint_json: str) -> ConfluenceCheckpoint:
return ConfluenceCheckpoint.model_validate_json(checkpoint_json)
@override
def retrieve_all_slim_docs(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
return self._retrieve_all_slim_docs(
start=start,
end=end,
callback=callback,
include_permissions=False,
)
def retrieve_all_slim_docs_perm_sync(
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
@@ -589,28 +568,12 @@ class ConfluenceConnector(
Return 'slim' docs (IDs + minimal permission data).
Does not fetch actual text. Used primarily for incremental permission sync.
"""
return self._retrieve_all_slim_docs(
start=start,
end=end,
callback=callback,
include_permissions=True,
)
def _retrieve_all_slim_docs(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
include_permissions: bool = True,
) -> GenerateSlimDocumentOutput:
doc_metadata_list: list[SlimDocument] = []
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
space_level_access_info: dict[str, ExternalAccess] = {}
if include_permissions:
space_level_access_info = get_all_space_permissions(
self.confluence_client, self.is_cloud
)
space_level_access_info = get_all_space_permissions(
self.confluence_client, self.is_cloud
)
def get_external_access(
doc_id: str, restrictions: dict[str, Any], ancestors: list[dict[str, Any]]
@@ -637,10 +600,8 @@ class ConfluenceConnector(
doc_metadata_list.append(
SlimDocument(
id=page_id,
external_access=(
get_external_access(page_id, page_restrictions, page_ancestors)
if include_permissions
else None
external_access=get_external_access(
page_id, page_restrictions, page_ancestors
),
)
)
@@ -675,12 +636,8 @@ class ConfluenceConnector(
doc_metadata_list.append(
SlimDocument(
id=attachment_id,
external_access=(
get_external_access(
attachment_id, attachment_restrictions, []
)
if include_permissions
else None
external_access=get_external_access(
attachment_id, attachment_restrictions, []
),
)
)
@@ -691,10 +648,10 @@ class ConfluenceConnector(
if callback and callback.should_stop():
raise RuntimeError(
"retrieve_all_slim_docs_perm_sync: Stop signal detected"
"retrieve_all_slim_documents: Stop signal detected"
)
if callback:
callback.progress("retrieve_all_slim_docs_perm_sync", 1)
callback.progress("retrieve_all_slim_documents", 1)
yield doc_metadata_list
@@ -719,14 +676,6 @@ class ConfluenceConnector(
f"Unexpected error while validating Confluence settings: {e}"
)
if self.space:
try:
self.low_timeout_confluence_client.get_space(self.space)
except ApiError as e:
raise ConnectorValidationError(
"Invalid Confluence space key provided"
) from e
if not spaces or not spaces.get("results"):
raise ConnectorValidationError(
"No Confluence spaces found. Either your credentials lack permissions, or "
@@ -775,7 +724,7 @@ if __name__ == "__main__":
end = datetime.now().timestamp()
# Fetch all `SlimDocuments`.
for slim_doc in confluence_connector.retrieve_all_slim_docs_perm_sync():
for slim_doc in confluence_connector.retrieve_all_slim_documents():
print(slim_doc)
# Fetch all `Documents`.

View File

@@ -41,7 +41,6 @@ from onyx.connectors.confluence.utils import _handle_http_error
from onyx.connectors.confluence.utils import confluence_refresh_tokens
from onyx.connectors.confluence.utils import get_start_param_from_url
from onyx.connectors.confluence.utils import update_param_in_path
from onyx.connectors.cross_connector_utils.miscellaneous_utils import scoped_url
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.file_processing.html_utils import format_document_soup
from onyx.redis.redis_pool import get_redis_client
@@ -88,20 +87,16 @@ class OnyxConfluence:
url: str,
credentials_provider: CredentialsProviderInterface,
timeout: int | None = None,
scoped_token: bool = False,
# should generally not be passed in, but making it overridable for
# easier testing
confluence_user_profiles_override: list[dict[str, str]] | None = (
CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE
),
) -> None:
self.base_url = url #'/'.join(url.rstrip("/").split("/")[:-1])
url = scoped_url(url, "confluence") if scoped_token else url
self._is_cloud = is_cloud
self._url = url.rstrip("/")
self._credentials_provider = credentials_provider
self.scoped_token = scoped_token
self.redis_client: Redis | None = None
self.static_credentials: dict[str, Any] | None = None
if self._credentials_provider.is_dynamic():
@@ -223,34 +218,6 @@ class OnyxConfluence:
with self._credentials_provider:
credentials, _ = self._renew_credentials()
if self.scoped_token:
# v2 endpoint doesn't always work with scoped tokens, use v1
token = credentials["confluence_access_token"]
probe_url = f"{self.base_url}/rest/api/space?limit=1"
import requests
logger.info(f"First and Last 5 of token: {token[:5]}...{token[-5:]}")
try:
r = requests.get(
probe_url,
headers={"Authorization": f"Bearer {token}"},
timeout=10,
)
r.raise_for_status()
except HTTPError as e:
if e.response.status_code == 403:
logger.warning(
"scoped token authenticated but not valid for probe endpoint (spaces)"
)
else:
if "WWW-Authenticate" in e.response.headers:
logger.warning(
f"WWW-Authenticate: {e.response.headers['WWW-Authenticate']}"
)
logger.warning(f"Full error: {e.response.text}")
raise e
return
# probe connection with direct client, no retries
if "confluence_refresh_token" in credentials:
@@ -269,7 +236,6 @@ class OnyxConfluence:
logger.info("Probing Confluence with Personal Access Token.")
url = self._url
if self._is_cloud:
logger.info("running with cloud client")
confluence_client_with_minimal_retries = Confluence(
url=url,
username=credentials["confluence_username"],
@@ -338,9 +304,7 @@ class OnyxConfluence:
url = f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}"
confluence = Confluence(url=url, oauth2=oauth2_dict, **kwargs)
else:
logger.info(
f"Connecting to Confluence with Personal Access Token as user: {credentials['confluence_username']}"
)
logger.info("Connecting to Confluence with Personal Access Token.")
if self._is_cloud:
confluence = Confluence(
url=self._url,

View File

@@ -5,10 +5,7 @@ from datetime import datetime
from datetime import timezone
from typing import Any
from typing import TypeVar
from urllib.parse import urljoin
from urllib.parse import urlparse
import requests
from dateutil.parser import parse
from onyx.configs.app_configs import CONNECTOR_LOCALHOST_OVERRIDE
@@ -151,17 +148,3 @@ def get_oauth_callback_uri(base_domain: str, connector_id: str) -> str:
def is_atlassian_date_error(e: Exception) -> bool:
return "field 'updated' is invalid" in str(e)
def get_cloudId(base_url: str) -> str:
tenant_info_url = urljoin(base_url, "/_edge/tenant_info")
response = requests.get(tenant_info_url, timeout=10)
response.raise_for_status()
return response.json()["cloudId"]
def scoped_url(url: str, product: str) -> str:
parsed = urlparse(url)
base_url = parsed.scheme + "://" + parsed.netloc
cloud_id = get_cloudId(base_url)
return f"https://api.atlassian.com/ex/{product}/{cloud_id}{parsed.path}"

View File

@@ -1,4 +1,3 @@
import importlib
from typing import Any
from typing import Type
@@ -7,16 +6,60 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.constants import DocumentSource
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.connectors.airtable.airtable_connector import AirtableConnector
from onyx.connectors.asana.connector import AsanaConnector
from onyx.connectors.axero.connector import AxeroConnector
from onyx.connectors.bitbucket.connector import BitbucketConnector
from onyx.connectors.blob.connector import BlobStorageConnector
from onyx.connectors.bookstack.connector import BookstackConnector
from onyx.connectors.clickup.connector import ClickupConnector
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.connectors.discord.connector import DiscordConnector
from onyx.connectors.discourse.connector import DiscourseConnector
from onyx.connectors.document360.connector import Document360Connector
from onyx.connectors.dropbox.connector import DropboxConnector
from onyx.connectors.egnyte.connector import EgnyteConnector
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.file.connector import LocalFileConnector
from onyx.connectors.fireflies.connector import FirefliesConnector
from onyx.connectors.freshdesk.connector import FreshdeskConnector
from onyx.connectors.gitbook.connector import GitbookConnector
from onyx.connectors.github.connector import GithubConnector
from onyx.connectors.gitlab.connector import GitlabConnector
from onyx.connectors.gmail.connector import GmailConnector
from onyx.connectors.gong.connector import GongConnector
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.google_site.connector import GoogleSitesConnector
from onyx.connectors.guru.connector import GuruConnector
from onyx.connectors.highspot.connector import HighspotConnector
from onyx.connectors.hubspot.connector import HubSpotConnector
from onyx.connectors.imap.connector import ImapConnector
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import CredentialsConnector
from onyx.connectors.interfaces import EventConnector
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.jira.connector import JiraConnector
from onyx.connectors.linear.connector import LinearConnector
from onyx.connectors.loopio.connector import LoopioConnector
from onyx.connectors.mediawiki.wiki import MediaWikiConnector
from onyx.connectors.mock_connector.connector import MockConnector
from onyx.connectors.models import InputType
from onyx.connectors.registry import CONNECTOR_CLASS_MAP
from onyx.connectors.notion.connector import NotionConnector
from onyx.connectors.outline.connector import OutlineConnector
from onyx.connectors.productboard.connector import ProductboardConnector
from onyx.connectors.salesforce.connector import SalesforceConnector
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.connectors.slab.connector import SlabConnector
from onyx.connectors.slack.connector import SlackConnector
from onyx.connectors.teams.connector import TeamsConnector
from onyx.connectors.web.connector import WebConnector
from onyx.connectors.wikipedia.connector import WikipediaConnector
from onyx.connectors.xenforo.connector import XenforoConnector
from onyx.connectors.zendesk.connector import ZendeskConnector
from onyx.connectors.zulip.connector import ZulipConnector
from onyx.db.connector import fetch_connector_by_id
from onyx.db.credentials import backend_update_credential_json
from onyx.db.credentials import fetch_credential_by_id
@@ -29,75 +72,101 @@ class ConnectorMissingException(Exception):
pass
# Cache for already imported connector classes
_connector_cache: dict[DocumentSource, Type[BaseConnector]] = {}
def _load_connector_class(source: DocumentSource) -> Type[BaseConnector]:
"""Dynamically load and cache a connector class."""
if source in _connector_cache:
return _connector_cache[source]
if source not in CONNECTOR_CLASS_MAP:
raise ConnectorMissingException(f"Connector not found for source={source}")
mapping = CONNECTOR_CLASS_MAP[source]
try:
module = importlib.import_module(mapping.module_path)
connector_class = getattr(module, mapping.class_name)
_connector_cache[source] = connector_class
return connector_class
except (ImportError, AttributeError) as e:
raise ConnectorMissingException(
f"Failed to import {mapping.class_name} from {mapping.module_path}: {e}"
)
def _validate_connector_supports_input_type(
connector: Type[BaseConnector],
input_type: InputType | None,
source: DocumentSource,
) -> None:
"""Validate that a connector supports the requested input type."""
if input_type is None:
return
# Check each input type requirement separately for clarity
load_state_unsupported = input_type == InputType.LOAD_STATE and not issubclass(
connector, LoadConnector
)
poll_unsupported = (
input_type == InputType.POLL
# Either poll or checkpoint works for this, in the future
# all connectors should be checkpoint connectors
and (
not issubclass(connector, PollConnector)
and not issubclass(connector, CheckpointedConnector)
)
)
event_unsupported = input_type == InputType.EVENT and not issubclass(
connector, EventConnector
)
if any([load_state_unsupported, poll_unsupported, event_unsupported]):
raise ConnectorMissingException(
f"Connector for source={source} does not accept input_type={input_type}"
)
def identify_connector_class(
source: DocumentSource,
input_type: InputType | None = None,
) -> Type[BaseConnector]:
# Load the connector class using lazy loading
connector = _load_connector_class(source)
connector_map = {
DocumentSource.WEB: WebConnector,
DocumentSource.FILE: LocalFileConnector,
DocumentSource.SLACK: {
InputType.POLL: SlackConnector,
InputType.SLIM_RETRIEVAL: SlackConnector,
},
DocumentSource.GITHUB: GithubConnector,
DocumentSource.GMAIL: GmailConnector,
DocumentSource.GITLAB: GitlabConnector,
DocumentSource.GITBOOK: GitbookConnector,
DocumentSource.GOOGLE_DRIVE: GoogleDriveConnector,
DocumentSource.BOOKSTACK: BookstackConnector,
DocumentSource.OUTLINE: OutlineConnector,
DocumentSource.CONFLUENCE: ConfluenceConnector,
DocumentSource.JIRA: JiraConnector,
DocumentSource.PRODUCTBOARD: ProductboardConnector,
DocumentSource.SLAB: SlabConnector,
DocumentSource.NOTION: NotionConnector,
DocumentSource.ZULIP: ZulipConnector,
DocumentSource.GURU: GuruConnector,
DocumentSource.LINEAR: LinearConnector,
DocumentSource.HUBSPOT: HubSpotConnector,
DocumentSource.DOCUMENT360: Document360Connector,
DocumentSource.GONG: GongConnector,
DocumentSource.GOOGLE_SITES: GoogleSitesConnector,
DocumentSource.ZENDESK: ZendeskConnector,
DocumentSource.LOOPIO: LoopioConnector,
DocumentSource.DROPBOX: DropboxConnector,
DocumentSource.SHAREPOINT: SharepointConnector,
DocumentSource.TEAMS: TeamsConnector,
DocumentSource.SALESFORCE: SalesforceConnector,
DocumentSource.DISCOURSE: DiscourseConnector,
DocumentSource.AXERO: AxeroConnector,
DocumentSource.CLICKUP: ClickupConnector,
DocumentSource.MEDIAWIKI: MediaWikiConnector,
DocumentSource.WIKIPEDIA: WikipediaConnector,
DocumentSource.ASANA: AsanaConnector,
DocumentSource.S3: BlobStorageConnector,
DocumentSource.R2: BlobStorageConnector,
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
DocumentSource.OCI_STORAGE: BlobStorageConnector,
DocumentSource.XENFORO: XenforoConnector,
DocumentSource.DISCORD: DiscordConnector,
DocumentSource.FRESHDESK: FreshdeskConnector,
DocumentSource.FIREFLIES: FirefliesConnector,
DocumentSource.EGNYTE: EgnyteConnector,
DocumentSource.AIRTABLE: AirtableConnector,
DocumentSource.HIGHSPOT: HighspotConnector,
DocumentSource.IMAP: ImapConnector,
DocumentSource.BITBUCKET: BitbucketConnector,
# just for integration tests
DocumentSource.MOCK_CONNECTOR: MockConnector,
}
connector_by_source = connector_map.get(source, {})
# Validate connector supports the requested input_type
_validate_connector_supports_input_type(connector, input_type, source)
if isinstance(connector_by_source, dict):
if input_type is None:
# If not specified, default to most exhaustive update
connector = connector_by_source.get(InputType.LOAD_STATE)
else:
connector = connector_by_source.get(input_type)
else:
connector = connector_by_source
if connector is None:
raise ConnectorMissingException(f"Connector not found for source={source}")
if any(
[
(
input_type == InputType.LOAD_STATE
and not issubclass(connector, LoadConnector)
),
(
input_type == InputType.POLL
# either poll or checkpoint works for this, in the future
# all connectors should be checkpoint connectors
and (
not issubclass(connector, PollConnector)
and not issubclass(connector, CheckpointedConnector)
)
),
(
input_type == InputType.EVENT
and not issubclass(connector, EventConnector)
),
]
):
raise ConnectorMissingException(
f"Connector for source={source} does not accept input_type={input_type}"
)
return connector

View File

@@ -32,7 +32,6 @@ def _create_image_section(
image_data: bytes,
parent_file_name: str,
display_name: str,
media_type: str | None = None,
link: str | None = None,
idx: int = 0,
) -> tuple[ImageSection, str | None]:
@@ -59,9 +58,6 @@ def _create_image_section(
image_data=image_data,
file_id=file_id,
display_name=display_name,
media_type=(
media_type if media_type is not None else "application/octet-stream"
),
link=link,
file_origin=FileOrigin.CONNECTOR,
)
@@ -127,7 +123,6 @@ def _process_file(
image_data=image_data,
parent_file_name=file_id,
display_name=title,
media_type=file_type,
)
return [
@@ -199,7 +194,6 @@ def _process_file(
image_data=img_data,
parent_file_name=file_id,
display_name=f"{title} - image {idx}",
media_type="application/octet-stream", # Default media type for embedded images
idx=idx,
)
sections.append(image_section)

View File

@@ -219,19 +219,12 @@ def _get_batch_rate_limited(
def _get_userinfo(user: NamedUser) -> dict[str, str]:
def _safe_get(attr_name: str) -> str | None:
try:
return cast(str | None, getattr(user, attr_name))
except GithubException:
logger.debug(f"Error getting {attr_name} for user")
return None
return {
k: v
for k, v in {
"login": _safe_get("login"),
"name": _safe_get("name"),
"email": _safe_get("email"),
"login": user.login,
"name": user.name,
"email": user.email,
}.items()
if v is not None
}

View File

@@ -28,7 +28,7 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
@@ -61,19 +61,6 @@ EMAIL_FIELDS = [
add_retries = retry_builder(tries=50, max_delay=30)
def _is_mail_service_disabled_error(error: HttpError) -> bool:
"""Detect if the Gmail API is telling us the mailbox is not provisioned."""
if error.resp.status != 400:
return False
error_message = str(error)
return (
"Mail service not enabled" in error_message
or "failedPrecondition" in error_message
)
def _build_time_range_query(
time_range_start: SecondsSinceUnixEpoch | None = None,
time_range_end: SecondsSinceUnixEpoch | None = None,
@@ -232,7 +219,7 @@ def thread_to_document(
)
class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
class GmailConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
@@ -320,42 +307,33 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
doc_batch = []
for user_email in self._get_all_user_emails():
gmail_service = get_gmail_service(self.creds, user_email)
try:
for thread in execute_paginated_retrieval(
retrieval_function=gmail_service.users().threads().list,
list_key="threads",
for thread in execute_paginated_retrieval(
retrieval_function=gmail_service.users().threads().list,
list_key="threads",
userId=user_email,
fields=THREAD_LIST_FIELDS,
q=query,
continue_on_404_or_403=True,
):
full_threads = execute_single_retrieval(
retrieval_function=gmail_service.users().threads().get,
list_key=None,
userId=user_email,
fields=THREAD_LIST_FIELDS,
q=query,
fields=THREAD_FIELDS,
id=thread["id"],
continue_on_404_or_403=True,
):
full_threads = execute_single_retrieval(
retrieval_function=gmail_service.users().threads().get,
list_key=None,
userId=user_email,
fields=THREAD_FIELDS,
id=thread["id"],
continue_on_404_or_403=True,
)
# full_threads is an iterator containing a single thread
# so we need to convert it to a list and grab the first element
full_thread = list(full_threads)[0]
doc = thread_to_document(full_thread, user_email)
if doc is None:
continue
doc_batch.append(doc)
if len(doc_batch) > self.batch_size:
yield doc_batch
doc_batch = []
except HttpError as e:
if _is_mail_service_disabled_error(e):
logger.warning(
"Skipping Gmail sync for %s because the mailbox is disabled.",
user_email,
)
)
# full_threads is an iterator containing a single thread
# so we need to convert it to a list and grab the first element
full_thread = list(full_threads)[0]
doc = thread_to_document(full_thread, user_email)
if doc is None:
continue
raise
doc_batch.append(doc)
if len(doc_batch) > self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
@@ -371,44 +349,35 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
for user_email in self._get_all_user_emails():
logger.info(f"Fetching slim threads for user: {user_email}")
gmail_service = get_gmail_service(self.creds, user_email)
try:
for thread in execute_paginated_retrieval(
retrieval_function=gmail_service.users().threads().list,
list_key="threads",
userId=user_email,
fields=THREAD_LIST_FIELDS,
q=query,
continue_on_404_or_403=True,
):
doc_batch.append(
SlimDocument(
id=thread["id"],
external_access=ExternalAccess(
external_user_emails={user_email},
external_user_group_ids=set(),
is_public=False,
),
)
for thread in execute_paginated_retrieval(
retrieval_function=gmail_service.users().threads().list,
list_key="threads",
userId=user_email,
fields=THREAD_LIST_FIELDS,
q=query,
continue_on_404_or_403=True,
):
doc_batch.append(
SlimDocument(
id=thread["id"],
external_access=ExternalAccess(
external_user_emails={user_email},
external_user_group_ids=set(),
is_public=False,
),
)
if len(doc_batch) > SLIM_BATCH_SIZE:
yield doc_batch
doc_batch = []
)
if len(doc_batch) > SLIM_BATCH_SIZE:
yield doc_batch
doc_batch = []
if callback:
if callback.should_stop():
raise RuntimeError(
"retrieve_all_slim_docs_perm_sync: Stop signal detected"
)
if callback:
if callback.should_stop():
raise RuntimeError(
"retrieve_all_slim_documents: Stop signal detected"
)
callback.progress("retrieve_all_slim_docs_perm_sync", 1)
except HttpError as e:
if _is_mail_service_disabled_error(e):
logger.warning(
"Skipping slim Gmail sync for %s because the mailbox is disabled.",
user_email,
)
continue
raise
callback.progress("retrieve_all_slim_documents", 1)
if doc_batch:
yield doc_batch
@@ -431,7 +400,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
def retrieve_all_slim_docs_perm_sync(
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,

View File

@@ -64,7 +64,7 @@ from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
@@ -153,7 +153,7 @@ class DriveIdStatus(Enum):
class GoogleDriveConnector(
SlimConnectorWithPermSync, CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint]
SlimConnector, CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint]
):
def __init__(
self,
@@ -1137,9 +1137,7 @@ class GoogleDriveConnector(
convert_func,
(
[file.user_email, self.primary_admin_email]
+ get_file_owners(
file.drive_file, self.primary_admin_email
),
+ get_file_owners(file.drive_file),
file.drive_file,
),
)
@@ -1296,7 +1294,7 @@ class GoogleDriveConnector(
callback.progress("_extract_slim_docs_from_google_drive", 1)
yield slim_batch
def retrieve_all_slim_docs_perm_sync(
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,

View File

@@ -97,15 +97,14 @@ def _execute_with_retry(request: Any) -> Any:
raise Exception(f"Failed to execute request after {max_attempts} attempts")
def get_file_owners(file: GoogleDriveFileType, primary_admin_email: str) -> list[str]:
def get_file_owners(file: GoogleDriveFileType) -> list[str]:
"""
Get the owners of a file if the attribute is present.
"""
return [
email
owner.get("emailAddress")
for owner in file.get("owners", [])
if (email := owner.get("emailAddress"))
and email.split("@")[-1] == primary_admin_email.split("@")[-1]
if owner.get("emailAddress")
]

View File

@@ -18,7 +18,7 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
@@ -38,7 +38,7 @@ class HighspotSpot(BaseModel):
name: str
class HighspotConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
"""
Connector for loading data from Highspot.
@@ -362,7 +362,7 @@ class HighspotConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync)
description = item_details.get("description", "")
return title, description
def retrieve_all_slim_docs_perm_sync(
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,

View File

@@ -1,18 +1,14 @@
import re
from collections.abc import Callable
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from typing import TypeVar
import requests
from hubspot import HubSpot # type: ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.hubspot.rate_limit import HubSpotRateLimiter
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
@@ -29,10 +25,6 @@ HUBSPOT_API_URL = "https://api.hubapi.com/integrations/v1/me"
# Available HubSpot object types
AVAILABLE_OBJECT_TYPES = {"tickets", "companies", "deals", "contacts"}
HUBSPOT_PAGE_SIZE = 100
T = TypeVar("T")
logger = setup_logger()
@@ -46,7 +38,6 @@ class HubSpotConnector(LoadConnector, PollConnector):
self.batch_size = batch_size
self._access_token = access_token
self._portal_id: str | None = None
self._rate_limiter = HubSpotRateLimiter()
# Set object types to fetch, default to all available types
if object_types is None:
@@ -86,37 +77,6 @@ class HubSpotConnector(LoadConnector, PollConnector):
"""Set the portal ID."""
self._portal_id = value
def _call_hubspot(self, func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
return self._rate_limiter.call(func, *args, **kwargs)
def _paginated_results(
self,
fetch_page: Callable[..., Any],
**kwargs: Any,
) -> Generator[Any, None, None]:
base_kwargs = dict(kwargs)
base_kwargs.setdefault("limit", HUBSPOT_PAGE_SIZE)
after: str | None = None
while True:
page_kwargs = base_kwargs.copy()
if after is not None:
page_kwargs["after"] = after
page = self._call_hubspot(fetch_page, **page_kwargs)
results = getattr(page, "results", [])
for result in results:
yield result
paging = getattr(page, "paging", None)
next_page = getattr(paging, "next", None) if paging else None
if next_page is None:
break
after = getattr(next_page, "after", None)
if after is None:
break
def _clean_html_content(self, html_content: str) -> str:
"""Clean HTML content and extract raw text"""
if not html_content:
@@ -190,82 +150,78 @@ class HubSpotConnector(LoadConnector, PollConnector):
) -> list[dict[str, Any]]:
"""Get associated objects for a given object"""
try:
associations_iter = self._paginated_results(
api_client.crm.associations.v4.basic_api.get_page,
associations = api_client.crm.associations.v4.basic_api.get_page(
object_type=from_object_type,
object_id=object_id,
to_object_type=to_object_type,
)
object_ids = [assoc.to_object_id for assoc in associations_iter]
associated_objects = []
if associations.results:
object_ids = [assoc.to_object_id for assoc in associations.results]
associated_objects: list[dict[str, Any]] = []
# Batch get the associated objects
if to_object_type == "contacts":
for obj_id in object_ids:
try:
obj = api_client.crm.contacts.basic_api.get_by_id(
contact_id=obj_id,
properties=[
"firstname",
"lastname",
"email",
"company",
"jobtitle",
],
)
associated_objects.append(obj.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch contact {obj_id}: {e}")
if to_object_type == "contacts":
for obj_id in object_ids:
try:
obj = self._call_hubspot(
api_client.crm.contacts.basic_api.get_by_id,
contact_id=obj_id,
properties=[
"firstname",
"lastname",
"email",
"company",
"jobtitle",
],
)
associated_objects.append(obj.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch contact {obj_id}: {e}")
elif to_object_type == "companies":
for obj_id in object_ids:
try:
obj = api_client.crm.companies.basic_api.get_by_id(
company_id=obj_id,
properties=[
"name",
"domain",
"industry",
"city",
"state",
],
)
associated_objects.append(obj.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch company {obj_id}: {e}")
elif to_object_type == "companies":
for obj_id in object_ids:
try:
obj = self._call_hubspot(
api_client.crm.companies.basic_api.get_by_id,
company_id=obj_id,
properties=[
"name",
"domain",
"industry",
"city",
"state",
],
)
associated_objects.append(obj.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch company {obj_id}: {e}")
elif to_object_type == "deals":
for obj_id in object_ids:
try:
obj = api_client.crm.deals.basic_api.get_by_id(
deal_id=obj_id,
properties=[
"dealname",
"amount",
"dealstage",
"closedate",
"pipeline",
],
)
associated_objects.append(obj.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch deal {obj_id}: {e}")
elif to_object_type == "deals":
for obj_id in object_ids:
try:
obj = self._call_hubspot(
api_client.crm.deals.basic_api.get_by_id,
deal_id=obj_id,
properties=[
"dealname",
"amount",
"dealstage",
"closedate",
"pipeline",
],
)
associated_objects.append(obj.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch deal {obj_id}: {e}")
elif to_object_type == "tickets":
for obj_id in object_ids:
try:
obj = self._call_hubspot(
api_client.crm.tickets.basic_api.get_by_id,
ticket_id=obj_id,
properties=["subject", "content", "hs_ticket_priority"],
)
associated_objects.append(obj.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch ticket {obj_id}: {e}")
elif to_object_type == "tickets":
for obj_id in object_ids:
try:
obj = api_client.crm.tickets.basic_api.get_by_id(
ticket_id=obj_id,
properties=["subject", "content", "hs_ticket_priority"],
)
associated_objects.append(obj.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch ticket {obj_id}: {e}")
return associated_objects
@@ -283,33 +239,33 @@ class HubSpotConnector(LoadConnector, PollConnector):
) -> list[dict[str, Any]]:
"""Get notes associated with a given object"""
try:
associations_iter = self._paginated_results(
api_client.crm.associations.v4.basic_api.get_page,
# Get associations to notes (engagement type)
associations = api_client.crm.associations.v4.basic_api.get_page(
object_type=object_type,
object_id=object_id,
to_object_type="notes",
)
note_ids = [assoc.to_object_id for assoc in associations_iter]
associated_notes = []
if associations.results:
note_ids = [assoc.to_object_id for assoc in associations.results]
for note_id in note_ids:
try:
# Notes are engagements in HubSpot, use the engagements API
note = self._call_hubspot(
api_client.crm.objects.notes.basic_api.get_by_id,
note_id=note_id,
properties=[
"hs_note_body",
"hs_timestamp",
"hs_created_by",
"hubspot_owner_id",
],
)
associated_notes.append(note.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch note {note_id}: {e}")
# Batch get the associated notes
for note_id in note_ids:
try:
# Notes are engagements in HubSpot, use the engagements API
note = api_client.crm.objects.notes.basic_api.get_by_id(
note_id=note_id,
properties=[
"hs_note_body",
"hs_timestamp",
"hs_created_by",
"hubspot_owner_id",
],
)
associated_notes.append(note.to_dict())
except Exception as e:
logger.warning(f"Failed to fetch note {note_id}: {e}")
return associated_notes
@@ -402,9 +358,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
api_client = HubSpot(access_token=self.access_token)
tickets_iter = self._paginated_results(
api_client.crm.tickets.basic_api.get_page,
all_tickets = api_client.crm.tickets.get_all(
properties=[
"subject",
"content",
@@ -417,7 +371,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
doc_batch: list[Document] = []
for ticket in tickets_iter:
for ticket in all_tickets:
updated_at = ticket.updated_at.replace(tzinfo=None)
if start is not None and updated_at < start.replace(tzinfo=None):
continue
@@ -505,9 +459,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
api_client = HubSpot(access_token=self.access_token)
companies_iter = self._paginated_results(
api_client.crm.companies.basic_api.get_page,
all_companies = api_client.crm.companies.get_all(
properties=[
"name",
"domain",
@@ -523,7 +475,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
doc_batch: list[Document] = []
for company in companies_iter:
for company in all_companies:
updated_at = company.updated_at.replace(tzinfo=None)
if start is not None and updated_at < start.replace(tzinfo=None):
continue
@@ -630,9 +582,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
api_client = HubSpot(access_token=self.access_token)
deals_iter = self._paginated_results(
api_client.crm.deals.basic_api.get_page,
all_deals = api_client.crm.deals.get_all(
properties=[
"dealname",
"amount",
@@ -648,7 +598,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
doc_batch: list[Document] = []
for deal in deals_iter:
for deal in all_deals:
updated_at = deal.updated_at.replace(tzinfo=None)
if start is not None and updated_at < start.replace(tzinfo=None):
continue
@@ -753,9 +703,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
api_client = HubSpot(access_token=self.access_token)
contacts_iter = self._paginated_results(
api_client.crm.contacts.basic_api.get_page,
all_contacts = api_client.crm.contacts.get_all(
properties=[
"firstname",
"lastname",
@@ -773,7 +721,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
doc_batch: list[Document] = []
for contact in contacts_iter:
for contact in all_contacts:
updated_at = contact.updated_at.replace(tzinfo=None)
if start is not None and updated_at < start.replace(tzinfo=None):
continue

View File

@@ -1,145 +0,0 @@
from __future__ import annotations
import time
from collections.abc import Callable
from typing import Any
from typing import TypeVar
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
RateLimitTriedTooManyTimesError,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
T = TypeVar("T")
# HubSpot exposes a ten second rolling window (x-hubspot-ratelimit-interval-milliseconds)
# with a maximum of 190 requests, and a per-second limit of 19 requests.
_HUBSPOT_TEN_SECOND_LIMIT = 190
_HUBSPOT_TEN_SECOND_PERIOD = 10 # seconds
_HUBSPOT_SECONDLY_LIMIT = 19
_HUBSPOT_SECONDLY_PERIOD = 1 # second
_DEFAULT_SLEEP_SECONDS = 10
_SLEEP_PADDING_SECONDS = 1.0
_MAX_RATE_LIMIT_RETRIES = 5
def _extract_header(headers: Any, key: str) -> str | None:
if headers is None:
return None
getter = getattr(headers, "get", None)
if callable(getter):
value = getter(key)
if value is not None:
return value
if isinstance(headers, dict):
value = headers.get(key)
if value is not None:
return value
return None
def is_rate_limit_error(exception: Exception) -> bool:
status = getattr(exception, "status", None)
if status == 429:
return True
headers = getattr(exception, "headers", None)
if headers is not None:
remaining = _extract_header(headers, "x-hubspot-ratelimit-remaining")
if remaining == "0":
return True
secondly_remaining = _extract_header(
headers, "x-hubspot-ratelimit-secondly-remaining"
)
if secondly_remaining == "0":
return True
message = str(exception)
return "RATE_LIMIT" in message or "Too Many Requests" in message
def get_rate_limit_retry_delay_seconds(exception: Exception) -> float:
headers = getattr(exception, "headers", None)
retry_after = _extract_header(headers, "Retry-After")
if retry_after:
try:
return float(retry_after) + _SLEEP_PADDING_SECONDS
except ValueError:
logger.debug(
"Failed to parse Retry-After header '%s' as float", retry_after
)
interval_ms = _extract_header(headers, "x-hubspot-ratelimit-interval-milliseconds")
if interval_ms:
try:
return float(interval_ms) / 1000.0 + _SLEEP_PADDING_SECONDS
except ValueError:
logger.debug(
"Failed to parse x-hubspot-ratelimit-interval-milliseconds '%s' as float",
interval_ms,
)
secondly_limit = _extract_header(headers, "x-hubspot-ratelimit-secondly")
if secondly_limit:
try:
per_second = max(float(secondly_limit), 1.0)
return (1.0 / per_second) + _SLEEP_PADDING_SECONDS
except ValueError:
logger.debug(
"Failed to parse x-hubspot-ratelimit-secondly '%s' as float",
secondly_limit,
)
return _DEFAULT_SLEEP_SECONDS + _SLEEP_PADDING_SECONDS
class HubSpotRateLimiter:
def __init__(
self,
*,
ten_second_limit: int = _HUBSPOT_TEN_SECOND_LIMIT,
ten_second_period: int = _HUBSPOT_TEN_SECOND_PERIOD,
secondly_limit: int = _HUBSPOT_SECONDLY_LIMIT,
secondly_period: int = _HUBSPOT_SECONDLY_PERIOD,
max_retries: int = _MAX_RATE_LIMIT_RETRIES,
) -> None:
self._max_retries = max_retries
@rate_limit_builder(max_calls=secondly_limit, period=secondly_period)
@rate_limit_builder(max_calls=ten_second_limit, period=ten_second_period)
def _execute(callable_: Callable[[], T]) -> T:
return callable_()
self._execute = _execute
def call(self, func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
attempts = 0
while True:
try:
return self._execute(lambda: func(*args, **kwargs))
except Exception as exc: # pylint: disable=broad-except
if not is_rate_limit_error(exc):
raise
attempts += 1
if attempts > self._max_retries:
raise RateLimitTriedTooManyTimesError(
"Exceeded configured HubSpot rate limit retries"
) from exc
wait_time = get_rate_limit_retry_delay_seconds(exc)
logger.notice(
"HubSpot rate limit reached. Sleeping %.2f seconds before retrying.",
wait_time,
)
time.sleep(wait_time)

View File

@@ -97,20 +97,11 @@ class PollConnector(BaseConnector):
raise NotImplementedError
# Slim connectors retrieve just the ids of documents
# Slim connectors can retrieve just the ids and
# permission syncing information for connected documents
class SlimConnector(BaseConnector):
@abc.abstractmethod
def retrieve_all_slim_docs(
self,
) -> GenerateSlimDocumentOutput:
raise NotImplementedError
# Slim connectors retrieve both the ids AND
# permission syncing information for connected documents
class SlimConnectorWithPermSync(BaseConnector):
@abc.abstractmethod
def retrieve_all_slim_docs_perm_sync(
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,

View File

@@ -25,11 +25,11 @@ from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.jira.access import get_project_permissions
from onyx.connectors.jira.utils import best_effort_basic_expert_info
from onyx.connectors.jira.utils import best_effort_get_field_from_issue
@@ -247,7 +247,7 @@ def _perform_jql_search_v2(
def process_jira_issue(
jira_base_url: str,
jira_client: JIRA,
issue: Issue,
comment_email_blacklist: tuple[str, ...] = (),
labels_to_skip: set[str] | None = None,
@@ -281,7 +281,7 @@ def process_jira_issue(
)
return None
page_url = build_jira_url(jira_base_url, issue.key)
page_url = build_jira_url(jira_client, issue.key)
metadata_dict: dict[str, str | list[str]] = {}
people = set()
@@ -359,10 +359,7 @@ class JiraConnectorCheckpoint(ConnectorCheckpoint):
offset: int | None = None
class JiraConnector(
CheckpointedConnectorWithPermSync[JiraConnectorCheckpoint],
SlimConnectorWithPermSync,
):
class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnector):
def __init__(
self,
jira_base_url: str,
@@ -375,23 +372,15 @@ class JiraConnector(
labels_to_skip: list[str] = JIRA_CONNECTOR_LABELS_TO_SKIP,
# Custom JQL query to filter Jira issues
jql_query: str | None = None,
scoped_token: bool = False,
) -> None:
self.batch_size = batch_size
# dealing with scoped tokens is a bit tricky becasue we need to hit api.atlassian.net
# when making jira requests but still want correct links to issues in the UI.
# So, the user's base url is stored here, but converted to a scoped url when passed
# to the jira client.
self.jira_base = jira_base_url.rstrip("/") # Remove trailing slash if present
self.jira_project = project_key
self._comment_email_blacklist = comment_email_blacklist or []
self.labels_to_skip = set(labels_to_skip)
self.jql_query = jql_query
self.scoped_token = scoped_token
self._jira_client: JIRA | None = None
# Cache project permissions to avoid fetching them repeatedly across runs
self._project_permissions_cache: dict[str, Any] = {}
@property
def comment_email_blacklist(self) -> tuple:
@@ -410,26 +399,10 @@ class JiraConnector(
return ""
return f'"{self.jira_project}"'
def _get_project_permissions(self, project_key: str) -> Any:
"""Get project permissions with caching.
Args:
project_key: The Jira project key
Returns:
The external access permissions for the project
"""
if project_key not in self._project_permissions_cache:
self._project_permissions_cache[project_key] = get_project_permissions(
jira_client=self.jira_client, jira_project=project_key
)
return self._project_permissions_cache[project_key]
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self._jira_client = build_jira_client(
credentials=credentials,
jira_base=self.jira_base,
scoped_token=self.scoped_token,
)
return None
@@ -469,37 +442,15 @@ class JiraConnector(
) -> CheckpointOutput[JiraConnectorCheckpoint]:
jql = self._get_jql_query(start, end)
try:
return self._load_from_checkpoint(
jql, checkpoint, include_permissions=False
)
return self._load_from_checkpoint(jql, checkpoint)
except Exception as e:
if is_atlassian_date_error(e):
jql = self._get_jql_query(start - ONE_HOUR, end)
return self._load_from_checkpoint(
jql, checkpoint, include_permissions=False
)
raise e
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: JiraConnectorCheckpoint,
) -> CheckpointOutput[JiraConnectorCheckpoint]:
"""Load documents from checkpoint with permission information included."""
jql = self._get_jql_query(start, end)
try:
return self._load_from_checkpoint(jql, checkpoint, include_permissions=True)
except Exception as e:
if is_atlassian_date_error(e):
jql = self._get_jql_query(start - ONE_HOUR, end)
return self._load_from_checkpoint(
jql, checkpoint, include_permissions=True
)
return self._load_from_checkpoint(jql, checkpoint)
raise e
def _load_from_checkpoint(
self, jql: str, checkpoint: JiraConnectorCheckpoint, include_permissions: bool
self, jql: str, checkpoint: JiraConnectorCheckpoint
) -> CheckpointOutput[JiraConnectorCheckpoint]:
# Get the current offset from checkpoint or start at 0
starting_offset = checkpoint.offset or 0
@@ -521,25 +472,18 @@ class JiraConnector(
issue_key = issue.key
try:
if document := process_jira_issue(
jira_base_url=self.jira_base,
jira_client=self.jira_client,
issue=issue,
comment_email_blacklist=self.comment_email_blacklist,
labels_to_skip=self.labels_to_skip,
):
# Add permission information to the document if requested
if include_permissions:
project_key = get_jira_project_key_from_issue(issue=issue)
if project_key:
document.external_access = self._get_project_permissions(
project_key
)
yield document
except Exception as e:
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=issue_key,
document_link=build_jira_url(self.jira_base, issue_key),
document_link=build_jira_url(self.jira_client, issue_key),
),
failure_message=f"Failed to process Jira issue: {str(e)}",
exception=e,
@@ -571,7 +515,7 @@ class JiraConnector(
# if we didn't retrieve a full batch, we're done
checkpoint.has_more = current_offset - starting_offset == page_size
def retrieve_all_slim_docs_perm_sync(
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
@@ -590,7 +534,6 @@ class JiraConnector(
prev_offset = 0
current_offset = 0
slim_doc_batch = []
while checkpoint.has_more:
for issue in _perform_jql_search(
jira_client=self.jira_client,
@@ -607,12 +550,13 @@ class JiraConnector(
continue
issue_key = best_effort_get_field_from_issue(issue, _FIELD_KEY)
id = build_jira_url(self.jira_base, issue_key)
id = build_jira_url(self.jira_client, issue_key)
slim_doc_batch.append(
SlimDocument(
id=id,
external_access=self._get_project_permissions(project_key),
external_access=get_project_permissions(
jira_client=self.jira_client, jira_project=project_key
),
)
)
current_offset += 1
@@ -757,7 +701,7 @@ if __name__ == "__main__":
start = 0
end = datetime.now().timestamp()
for slim_doc in connector.retrieve_all_slim_docs_perm_sync(
for slim_doc in connector.retrieve_all_slim_documents(
start=start,
end=end,
):

View File

@@ -10,7 +10,6 @@ from jira.resources import CustomFieldOption
from jira.resources import Issue
from jira.resources import User
from onyx.connectors.cross_connector_utils.miscellaneous_utils import scoped_url
from onyx.connectors.models import BasicExpertInfo
from onyx.utils.logger import setup_logger
@@ -75,18 +74,11 @@ def extract_text_from_adf(adf: dict | None) -> str:
return " ".join(texts)
def build_jira_url(jira_base_url: str, issue_key: str) -> str:
"""
Get the url used to access an issue in the UI.
"""
return f"{jira_base_url}/browse/{issue_key}"
def build_jira_url(jira_client: JIRA, issue_key: str) -> str:
return f"{jira_client.client_info()}/browse/{issue_key}"
def build_jira_client(
credentials: dict[str, Any], jira_base: str, scoped_token: bool = False
) -> JIRA:
jira_base = scoped_url(jira_base, "jira") if scoped_token else jira_base
def build_jira_client(credentials: dict[str, Any], jira_base: str) -> JIRA:
api_token = credentials["jira_api_token"]
# if user provide an email we assume it's cloud
if "jira_user_email" in credentials:

View File

@@ -1,208 +0,0 @@
"""Registry mapping for connector classes."""
from pydantic import BaseModel
from onyx.configs.constants import DocumentSource
class ConnectorMapping(BaseModel):
module_path: str
class_name: str
# Mapping of DocumentSource to connector details for lazy loading
CONNECTOR_CLASS_MAP = {
DocumentSource.WEB: ConnectorMapping(
module_path="onyx.connectors.web.connector",
class_name="WebConnector",
),
DocumentSource.FILE: ConnectorMapping(
module_path="onyx.connectors.file.connector",
class_name="LocalFileConnector",
),
DocumentSource.SLACK: ConnectorMapping(
module_path="onyx.connectors.slack.connector",
class_name="SlackConnector",
),
DocumentSource.GITHUB: ConnectorMapping(
module_path="onyx.connectors.github.connector",
class_name="GithubConnector",
),
DocumentSource.GMAIL: ConnectorMapping(
module_path="onyx.connectors.gmail.connector",
class_name="GmailConnector",
),
DocumentSource.GITLAB: ConnectorMapping(
module_path="onyx.connectors.gitlab.connector",
class_name="GitlabConnector",
),
DocumentSource.GITBOOK: ConnectorMapping(
module_path="onyx.connectors.gitbook.connector",
class_name="GitbookConnector",
),
DocumentSource.GOOGLE_DRIVE: ConnectorMapping(
module_path="onyx.connectors.google_drive.connector",
class_name="GoogleDriveConnector",
),
DocumentSource.BOOKSTACK: ConnectorMapping(
module_path="onyx.connectors.bookstack.connector",
class_name="BookstackConnector",
),
DocumentSource.OUTLINE: ConnectorMapping(
module_path="onyx.connectors.outline.connector",
class_name="OutlineConnector",
),
DocumentSource.CONFLUENCE: ConnectorMapping(
module_path="onyx.connectors.confluence.connector",
class_name="ConfluenceConnector",
),
DocumentSource.JIRA: ConnectorMapping(
module_path="onyx.connectors.jira.connector",
class_name="JiraConnector",
),
DocumentSource.PRODUCTBOARD: ConnectorMapping(
module_path="onyx.connectors.productboard.connector",
class_name="ProductboardConnector",
),
DocumentSource.SLAB: ConnectorMapping(
module_path="onyx.connectors.slab.connector",
class_name="SlabConnector",
),
DocumentSource.NOTION: ConnectorMapping(
module_path="onyx.connectors.notion.connector",
class_name="NotionConnector",
),
DocumentSource.ZULIP: ConnectorMapping(
module_path="onyx.connectors.zulip.connector",
class_name="ZulipConnector",
),
DocumentSource.GURU: ConnectorMapping(
module_path="onyx.connectors.guru.connector",
class_name="GuruConnector",
),
DocumentSource.LINEAR: ConnectorMapping(
module_path="onyx.connectors.linear.connector",
class_name="LinearConnector",
),
DocumentSource.HUBSPOT: ConnectorMapping(
module_path="onyx.connectors.hubspot.connector",
class_name="HubSpotConnector",
),
DocumentSource.DOCUMENT360: ConnectorMapping(
module_path="onyx.connectors.document360.connector",
class_name="Document360Connector",
),
DocumentSource.GONG: ConnectorMapping(
module_path="onyx.connectors.gong.connector",
class_name="GongConnector",
),
DocumentSource.GOOGLE_SITES: ConnectorMapping(
module_path="onyx.connectors.google_site.connector",
class_name="GoogleSitesConnector",
),
DocumentSource.ZENDESK: ConnectorMapping(
module_path="onyx.connectors.zendesk.connector",
class_name="ZendeskConnector",
),
DocumentSource.LOOPIO: ConnectorMapping(
module_path="onyx.connectors.loopio.connector",
class_name="LoopioConnector",
),
DocumentSource.DROPBOX: ConnectorMapping(
module_path="onyx.connectors.dropbox.connector",
class_name="DropboxConnector",
),
DocumentSource.SHAREPOINT: ConnectorMapping(
module_path="onyx.connectors.sharepoint.connector",
class_name="SharepointConnector",
),
DocumentSource.TEAMS: ConnectorMapping(
module_path="onyx.connectors.teams.connector",
class_name="TeamsConnector",
),
DocumentSource.SALESFORCE: ConnectorMapping(
module_path="onyx.connectors.salesforce.connector",
class_name="SalesforceConnector",
),
DocumentSource.DISCOURSE: ConnectorMapping(
module_path="onyx.connectors.discourse.connector",
class_name="DiscourseConnector",
),
DocumentSource.AXERO: ConnectorMapping(
module_path="onyx.connectors.axero.connector",
class_name="AxeroConnector",
),
DocumentSource.CLICKUP: ConnectorMapping(
module_path="onyx.connectors.clickup.connector",
class_name="ClickupConnector",
),
DocumentSource.MEDIAWIKI: ConnectorMapping(
module_path="onyx.connectors.mediawiki.wiki",
class_name="MediaWikiConnector",
),
DocumentSource.WIKIPEDIA: ConnectorMapping(
module_path="onyx.connectors.wikipedia.connector",
class_name="WikipediaConnector",
),
DocumentSource.ASANA: ConnectorMapping(
module_path="onyx.connectors.asana.connector",
class_name="AsanaConnector",
),
DocumentSource.S3: ConnectorMapping(
module_path="onyx.connectors.blob.connector",
class_name="BlobStorageConnector",
),
DocumentSource.R2: ConnectorMapping(
module_path="onyx.connectors.blob.connector",
class_name="BlobStorageConnector",
),
DocumentSource.GOOGLE_CLOUD_STORAGE: ConnectorMapping(
module_path="onyx.connectors.blob.connector",
class_name="BlobStorageConnector",
),
DocumentSource.OCI_STORAGE: ConnectorMapping(
module_path="onyx.connectors.blob.connector",
class_name="BlobStorageConnector",
),
DocumentSource.XENFORO: ConnectorMapping(
module_path="onyx.connectors.xenforo.connector",
class_name="XenforoConnector",
),
DocumentSource.DISCORD: ConnectorMapping(
module_path="onyx.connectors.discord.connector",
class_name="DiscordConnector",
),
DocumentSource.FRESHDESK: ConnectorMapping(
module_path="onyx.connectors.freshdesk.connector",
class_name="FreshdeskConnector",
),
DocumentSource.FIREFLIES: ConnectorMapping(
module_path="onyx.connectors.fireflies.connector",
class_name="FirefliesConnector",
),
DocumentSource.EGNYTE: ConnectorMapping(
module_path="onyx.connectors.egnyte.connector",
class_name="EgnyteConnector",
),
DocumentSource.AIRTABLE: ConnectorMapping(
module_path="onyx.connectors.airtable.airtable_connector",
class_name="AirtableConnector",
),
DocumentSource.HIGHSPOT: ConnectorMapping(
module_path="onyx.connectors.highspot.connector",
class_name="HighspotConnector",
),
DocumentSource.IMAP: ConnectorMapping(
module_path="onyx.connectors.imap.connector",
class_name="ImapConnector",
),
DocumentSource.BITBUCKET: ConnectorMapping(
module_path="onyx.connectors.bitbucket.connector",
class_name="BitbucketConnector",
),
# just for integration tests
DocumentSource.MOCK_CONNECTOR: ConnectorMapping(
module_path="onyx.connectors.mock_connector.connector",
class_name="MockConnector",
),
}

View File

@@ -16,7 +16,7 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorMissingCredentialError
@@ -151,7 +151,7 @@ def _validate_custom_query_config(config: dict[str, Any]) -> None:
)
class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
"""Approach outline
Goal
@@ -1119,7 +1119,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
with tempfile.TemporaryDirectory() as temp_dir:
return self._delta_sync(temp_dir, start, end)
def retrieve_all_slim_docs_perm_sync(
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,

View File

@@ -41,7 +41,7 @@ from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import IndexingHeartbeatInterface
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
@@ -57,6 +57,8 @@ from onyx.connectors.sharepoint.connector_utils import get_sharepoint_external_a
from onyx.file_processing.extract_file_text import ACCEPTED_IMAGE_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext
from onyx.file_processing.extract_file_text import OnyxExtensionType
from onyx.file_processing.file_validation import EXCLUDED_IMAGE_TYPES
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.utils.b64 import get_image_type_from_bytes
@@ -73,8 +75,7 @@ class SiteDescriptor(BaseModel):
"""Data class for storing SharePoint site information.
Args:
url: The base site URL (e.g. https://danswerai.sharepoint.com/sites/sharepoint-tests
or https://danswerai.sharepoint.com/teams/team-name)
url: The base site URL (e.g. https://danswerai.sharepoint.com/sites/sharepoint-tests)
drive_name: The name of the drive to access (e.g. "Shared Documents", "Other Library")
If None, all drives will be accessed.
folder_path: The folder path within the drive to access (e.g. "test/nested with spaces")
@@ -673,7 +674,7 @@ def _convert_sitepage_to_slim_document(
class SharepointConnector(
SlimConnectorWithPermSync,
SlimConnector,
CheckpointedConnectorWithPermSync[SharepointConnectorCheckpoint],
):
def __init__(
@@ -704,11 +705,9 @@ class SharepointConnector(
# Ensure sites are sharepoint urls
for site_url in self.sites:
if not site_url.startswith("https://") or not (
"/sites/" in site_url or "/teams/" in site_url
):
if not site_url.startswith("https://") or "/sites/" not in site_url:
raise ConnectorValidationError(
"Site URLs must be full Sharepoint URLs (e.g. https://your-tenant.sharepoint.com/sites/your-site or https://your-tenant.sharepoint.com/teams/your-team)"
"Site URLs must be full Sharepoint URLs (e.g. https://your-tenant.sharepoint.com/sites/your-site)"
)
@property
@@ -723,17 +722,10 @@ class SharepointConnector(
site_data_list = []
for url in site_urls:
parts = url.strip().split("/")
site_type_index = None
if "sites" in parts:
site_type_index = parts.index("sites")
elif "teams" in parts:
site_type_index = parts.index("teams")
if site_type_index is not None:
# Extract the base site URL (up to and including the site/team name)
site_url = "/".join(parts[: site_type_index + 2])
remaining_parts = parts[site_type_index + 2 :]
sites_index = parts.index("sites")
site_url = "/".join(parts[: sites_index + 2])
remaining_parts = parts[sites_index + 2 :]
# Extract drive name and folder path
if remaining_parts:
@@ -755,9 +747,7 @@ class SharepointConnector(
)
)
else:
logger.warning(
f"Site URL '{url}' is not a valid Sharepoint URL (must contain /sites/ or /teams/)"
)
logger.warning(f"Site URL '{url}' is not a valid Sharepoint URL")
return site_data_list
def _get_drive_items_for_drive_name(
@@ -1453,6 +1443,12 @@ class SharepointConnector(
)
for driveitem in driveitems:
driveitem_extension = get_file_ext(driveitem.name)
if not is_accepted_file_ext(driveitem_extension, OnyxExtensionType.All):
logger.warning(
f"Skipping {driveitem.web_url} as it is not a supported file type"
)
continue
# Only yield empty documents if they are PDFs or images
should_yield_if_empty = (
driveitem_extension in ACCEPTED_IMAGE_FILE_EXTENSIONS
@@ -1476,6 +1472,10 @@ class SharepointConnector(
TextSection(link=driveitem.web_url, text="")
]
yield doc
else:
logger.warning(
f"Skipping {driveitem.web_url} as it is empty and not a PDF or image"
)
except Exception as e:
logger.warning(
f"Failed to process driveitem {driveitem.web_url}: {e}"
@@ -1609,7 +1609,7 @@ class SharepointConnector(
) -> SharepointConnectorCheckpoint:
return SharepointConnectorCheckpoint.model_validate_json(checkpoint_json)
def retrieve_all_slim_docs_perm_sync(
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,

View File

@@ -16,7 +16,7 @@ from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
@@ -164,7 +164,7 @@ def get_slab_url_from_title_id(base_url: str, title: str, page_id: str) -> str:
return urljoin(urljoin(base_url, "posts/"), url_id)
class SlabConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
class SlabConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
self,
base_url: str,
@@ -239,7 +239,7 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
time_filter=lambda t: start_time <= t <= end_time
)
def retrieve_all_slim_docs_perm_sync(
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,

View File

@@ -42,7 +42,7 @@ from onyx.connectors.interfaces import CredentialsConnector
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
@@ -581,7 +581,7 @@ def _process_message(
class SlackConnector(
SlimConnectorWithPermSync,
SlimConnector,
CredentialsConnector,
CheckpointedConnectorWithPermSync[SlackCheckpoint],
):
@@ -732,7 +732,7 @@ class SlackConnector(
self.text_cleaner = SlackTextCleaner(client=self.client)
self.credentials_provider = credentials_provider
def retrieve_all_slim_docs_perm_sync(
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,

View File

@@ -22,7 +22,7 @@ from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
@@ -51,7 +51,7 @@ class TeamsCheckpoint(ConnectorCheckpoint):
class TeamsConnector(
CheckpointedConnector[TeamsCheckpoint],
SlimConnectorWithPermSync,
SlimConnector,
):
MAX_WORKERS = 10
AUTHORITY_URL_PREFIX = "https://login.microsoftonline.com/"
@@ -228,9 +228,9 @@ class TeamsConnector(
has_more=bool(todos),
)
# impls for SlimConnectorWithPermSync
# impls for SlimConnector
def retrieve_all_slim_docs_perm_sync(
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
@@ -572,7 +572,7 @@ if __name__ == "__main__":
)
teams_connector.validate_connector_settings()
for slim_doc in teams_connector.retrieve_all_slim_docs_perm_sync():
for slim_doc in teams_connector.retrieve_all_slim_documents():
...
for doc in load_everything_from_checkpoint_connector(

View File

@@ -219,25 +219,6 @@ def is_valid_url(url: str) -> bool:
return False
def _same_site(base_url: str, candidate_url: str) -> bool:
base, candidate = urlparse(base_url), urlparse(candidate_url)
base_netloc = base.netloc.lower().removeprefix("www.")
candidate_netloc = candidate.netloc.lower().removeprefix("www.")
if base_netloc != candidate_netloc:
return False
base_path = (base.path or "/").rstrip("/")
if base_path in ("", "/"):
return True
candidate_path = candidate.path or "/"
if candidate_path == base_path:
return True
boundary = f"{base_path}/"
return candidate_path.startswith(boundary)
def get_internal_links(
base_url: str, url: str, soup: BeautifulSoup, should_ignore_pound: bool = True
) -> set[str]:
@@ -258,7 +239,7 @@ def get_internal_links(
# Relative path handling
href = urljoin(url, href)
if _same_site(base_url, href):
if urlparse(href).netloc == urlparse(url).netloc and base_url in href:
internal_links.add(href)
return internal_links

View File

@@ -26,7 +26,7 @@ from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import ConnectorFailure
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import Document
@@ -376,7 +376,7 @@ class ZendeskConnectorCheckpoint(ConnectorCheckpoint):
class ZendeskConnector(
SlimConnectorWithPermSync, CheckpointedConnector[ZendeskConnectorCheckpoint]
SlimConnector, CheckpointedConnector[ZendeskConnectorCheckpoint]
):
def __init__(
self,
@@ -565,7 +565,7 @@ class ZendeskConnector(
)
return checkpoint
def retrieve_all_slim_docs_perm_sync(
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,

View File

@@ -1,7 +1,5 @@
from collections.abc import Sequence
from datetime import datetime
from typing import Any
from uuid import UUID
from pydantic import BaseModel
from pydantic import ConfigDict
@@ -121,8 +119,8 @@ class BaseFilters(BaseModel):
class UserFileFilters(BaseModel):
user_file_ids: list[UUID] | None = None
project_id: int | None = None
user_file_ids: list[int] | None = None
user_folder_ids: list[int] | None = None
class IndexFilters(BaseFilters, UserFileFilters):
@@ -357,44 +355,6 @@ class SearchDoc(BaseModel):
secondary_owners: list[str] | None = None
is_internet: bool = False
@classmethod
def from_chunks_or_sections(
cls,
items: "Sequence[InferenceChunk | InferenceSection] | None",
) -> list["SearchDoc"]:
"""Convert a sequence of InferenceChunk or InferenceSection objects to SearchDoc objects."""
if not items:
return []
search_docs = [
cls(
document_id=(
chunk := (
item.center_chunk
if isinstance(item, InferenceSection)
else item
)
).document_id,
chunk_ind=chunk.chunk_id,
semantic_identifier=chunk.semantic_identifier or "Unknown",
link=chunk.source_links[0] if chunk.source_links else None,
blurb=chunk.blurb,
source_type=chunk.source_type,
boost=chunk.boost,
hidden=chunk.hidden,
metadata=chunk.metadata,
score=chunk.score,
match_highlights=chunk.match_highlights,
updated_at=chunk.updated_at,
primary_owners=chunk.primary_owners,
secondary_owners=chunk.secondary_owners,
is_internet=False,
)
for item in items
]
return search_docs
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().model_dump(*args, **kwargs) # type: ignore
initial_dict["updated_at"] = (

View File

@@ -166,6 +166,9 @@ def retrieval_preprocessing(
)
user_file_filters = search_request.user_file_filters
user_file_ids = (user_file_filters.user_file_ids or []) if user_file_filters else []
user_folder_ids = (
(user_file_filters.user_folder_ids or []) if user_file_filters else []
)
if persona and persona.user_files:
user_file_ids = list(
set(user_file_ids) | set([file.id for file in persona.user_files])
@@ -173,7 +176,7 @@ def retrieval_preprocessing(
final_filters = IndexFilters(
user_file_ids=user_file_ids,
project_id=user_file_filters.project_id if user_file_filters else None,
user_folder_ids=user_folder_ids,
source_type=preset_filters.source_type or predicted_source_filters,
document_set=preset_filters.document_set,
time_cutoff=time_filter or predicted_time_cutoff,

View File

@@ -118,6 +118,40 @@ def inference_section_from_chunks(
)
def chunks_or_sections_to_search_docs(
items: Sequence[InferenceChunk | InferenceSection] | None,
) -> list[SearchDoc]:
if not items:
return []
search_docs = [
SearchDoc(
document_id=(
chunk := (
item.center_chunk if isinstance(item, InferenceSection) else item
)
).document_id,
chunk_ind=chunk.chunk_id,
semantic_identifier=chunk.semantic_identifier or "Unknown",
link=chunk.source_links[0] if chunk.source_links else None,
blurb=chunk.blurb,
source_type=chunk.source_type,
boost=chunk.boost,
hidden=chunk.hidden,
metadata=chunk.metadata,
score=chunk.score,
match_highlights=chunk.match_highlights,
updated_at=chunk.updated_at,
primary_owners=chunk.primary_owners,
secondary_owners=chunk.secondary_owners,
is_internet=False,
)
for item in items
]
return search_docs
def remove_stop_words_and_punctuation(keywords: list[str]) -> list[str]:
try:
# Re-tokenize using the NLTK tokenizer for better matching

Some files were not shown because too many files have changed in this diff Show More