Compare commits

..

1 Commits

Author SHA1 Message Date
pablonyx
7a9db4753a add some debug logging 2025-03-21 14:50:07 -07:00
450 changed files with 12051 additions and 33646 deletions

View File

@@ -1,209 +0,0 @@
name: Run MIT Integration Tests v2
concurrency:
group: Run-MIT-Integration-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
merge_group:
pull_request:
branches:
- main
- "release/**"
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
jobs:
integration-tests-mit:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=32cpu-linux-x64, "run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# tag every docker image with "test" so that we can spin up the correct set
# of images during testing
# We don't need to build the Web Docker image since it's not yet used
# in the integration tests. We have a separate action to verify that it builds
# successfully.
- name: Pull Web Docker image
run: |
docker pull onyxdotapp/onyx-web-server:latest
docker tag onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:test
# we use the runs-on cache for docker builds
# in conjunction with runs-on runners, it has better speed and unlimited caching
# https://runs-on.com/caching/s3-cache-for-github-actions/
# https://runs-on.com/caching/docker/
# https://github.com/moby/buildkit#s3-cache-experimental
# images are built and run locally for testing purposes. Not pushed.
- name: Build Backend Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64
tags: onyxdotapp/onyx-backend:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build Model Server Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64
tags: onyxdotapp/onyx-model-server:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build integration test Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/tests/integration/Dockerfile
platforms: linux/amd64
tags: onyxdotapp/onyx-integration:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
- name: Start Docker containers
run: |
cd deployment/docker_compose
AUTH_TYPE=basic \
POSTGRES_POOL_PRE_PING=true \
POSTGRES_USE_NULL_POOL=true \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
INTEGRATION_TESTS_MODE=true \
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
docker logs -f onyx-stack-api_server-1 &
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. Service did not become ready in 5 minutes."
exit 1
fi
# Use curl with error handling to ignore specific exit code 56
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
if [ "$response" = "200" ]; then
echo "Service is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
else
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
sleep 5
done
echo "Finished waiting for service."
- name: Start Mock Services
run: |
cd backend/tests/integration/mock_services
docker compose -f docker-compose.mock-it-services.yml \
-p mock-it-services-stack up -d
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
- name: Run Standard Integration Tests
run: |
echo "Running integration tests..."
docker run --rm --network onyx-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
onyxdotapp/onyx-integration:test \
/app/tests/integration/tests \
/app/tests/integration/connector_job_tests
continue-on-error: true
id: run_tests
- name: Check test results
run: |
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
echo "Integration tests failed. Exiting with error."
exit 1
else
echo "All integration tests passed successfully."
fi
# ------------------------------------------------------------
# Always gather logs BEFORE "down":
- name: Dump API server logs
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
- name: Dump all-container logs (optional)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
uses: actions/upload-artifact@v4
with:
name: docker-all-logs
path: ${{ github.workspace }}/docker-compose.log
# ------------------------------------------------------------
- name: Stop Docker containers
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack down -v

View File

@@ -9,10 +9,6 @@ on:
- cron: "0 16 * * *"
env:
# AWS
AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS }}
# Confluence
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
@@ -23,10 +19,6 @@ env:
# Jira
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
GONG_ACCESS_KEY: ${{ secrets.GONG_ACCESS_KEY }}
GONG_ACCESS_KEY_SECRET: ${{ secrets.GONG_ACCESS_KEY_SECRET }}
# Google
GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR }}
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1 }}
@@ -53,8 +45,6 @@ env:
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
# Github
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}
# Gitbook
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}

View File

@@ -6,419 +6,396 @@
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"compounds": [
{
// Dummy entry used to label the group
"name": "--- Compound ---",
"configurations": ["--- Individual ---"],
"presentation": {
"group": "1"
}
},
{
"name": "Run All Onyx Services",
"configurations": [
"Web Server",
"Model Server",
"API Server",
"Slack Bot",
"Celery primary",
"Celery light",
"Celery heavy",
"Celery indexing",
"Celery user files indexing",
"Celery beat",
"Celery monitoring"
],
"presentation": {
"group": "1"
}
},
{
"name": "Web / Model / API",
"configurations": ["Web Server", "Model Server", "API Server"],
"presentation": {
"group": "1"
}
},
{
"name": "Celery (all)",
"configurations": [
"Celery primary",
"Celery light",
"Celery heavy",
"Celery indexing",
"Celery user files indexing",
"Celery beat",
"Celery monitoring"
],
"presentation": {
"group": "1"
}
}
{
// Dummy entry used to label the group
"name": "--- Compound ---",
"configurations": [
"--- Individual ---"
],
"presentation": {
"group": "1",
}
},
{
"name": "Run All Onyx Services",
"configurations": [
"Web Server",
"Model Server",
"API Server",
"Slack Bot",
"Celery primary",
"Celery light",
"Celery heavy",
"Celery indexing",
"Celery beat",
"Celery monitoring",
],
"presentation": {
"group": "1",
}
},
{
"name": "Web / Model / API",
"configurations": [
"Web Server",
"Model Server",
"API Server",
],
"presentation": {
"group": "1",
}
},
{
"name": "Celery (all)",
"configurations": [
"Celery primary",
"Celery light",
"Celery heavy",
"Celery indexing",
"Celery beat",
"Celery monitoring",
],
"presentation": {
"group": "1",
}
}
],
"configurations": [
{
// Dummy entry used to label the group
"name": "--- Individual ---",
"type": "node",
"request": "launch",
"presentation": {
"group": "2",
"order": 0
}
},
{
"name": "Web Server",
"type": "node",
"request": "launch",
"cwd": "${workspaceRoot}/web",
"runtimeExecutable": "npm",
"envFile": "${workspaceFolder}/.vscode/.env",
"runtimeArgs": ["run", "dev"],
"presentation": {
"group": "2"
{
// Dummy entry used to label the group
"name": "--- Individual ---",
"type": "node",
"request": "launch",
"presentation": {
"group": "2",
"order": 0
}
},
{
"name": "Web Server",
"type": "node",
"request": "launch",
"cwd": "${workspaceRoot}/web",
"runtimeExecutable": "npm",
"envFile": "${workspaceFolder}/.vscode/.env",
"runtimeArgs": [
"run", "dev"
],
"presentation": {
"group": "2",
},
"console": "integratedTerminal",
"consoleTitle": "Web Server Console"
},
"console": "integratedTerminal",
"consoleTitle": "Web Server Console"
},
{
"name": "Model Server",
"consoleName": "Model Server",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
{
"name": "Model Server",
"consoleName": "Model Server",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
},
"args": [
"model_server.main:app",
"--reload",
"--port",
"9000"
],
"presentation": {
"group": "2",
},
"consoleTitle": "Model Server Console"
},
"args": ["model_server.main:app", "--reload", "--port", "9000"],
"presentation": {
"group": "2"
{
"name": "API Server",
"consoleName": "API Server",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
},
"args": [
"onyx.main:app",
"--reload",
"--port",
"8080"
],
"presentation": {
"group": "2",
},
"consoleTitle": "API Server Console"
},
"consoleTitle": "Model Server Console"
},
{
"name": "API Server",
"consoleName": "API Server",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1"
// For the listener to access the Slack API,
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
{
"name": "Slack Bot",
"consoleName": "Slack Bot",
"type": "debugpy",
"request": "launch",
"program": "onyx/onyxbot/slack/listener.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"presentation": {
"group": "2",
},
"consoleTitle": "Slack Bot Console"
},
"args": ["onyx.main:app", "--reload", "--port", "8080"],
"presentation": {
"group": "2"
{
"name": "Celery primary",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.primary",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=primary@%n",
"-Q",
"celery",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery primary Console"
},
"consoleTitle": "API Server Console"
},
// For the listener to access the Slack API,
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
{
"name": "Slack Bot",
"consoleName": "Slack Bot",
"type": "debugpy",
"request": "launch",
"program": "onyx/onyxbot/slack/listener.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
{
"name": "Celery light",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.light",
"worker",
"--pool=threads",
"--concurrency=64",
"--prefetch-multiplier=8",
"--loglevel=INFO",
"--hostname=light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery light Console"
},
"presentation": {
"group": "2"
{
"name": "Celery heavy",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.heavy",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery heavy Console"
},
"consoleTitle": "Slack Bot Console"
},
{
"name": "Celery primary",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
{
"name": "Celery indexing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"ENABLE_MULTIPASS_INDEXING": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.indexing",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=indexing@%n",
"-Q",
"connector_indexing",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery indexing Console"
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.primary",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=primary@%n",
"-Q",
"celery"
],
"presentation": {
"group": "2"
{
"name": "Celery monitoring",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {},
"args": [
"-A",
"onyx.background.celery.versioned_apps.monitoring",
"worker",
"--pool=solo",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=monitoring@%n",
"-Q",
"monitoring",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery monitoring Console"
},
"consoleTitle": "Celery primary Console"
},
{
"name": "Celery light",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
{
"name": "Celery beat",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.beat",
"beat",
"--loglevel=INFO",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery beat Console"
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.light",
"worker",
"--pool=threads",
"--concurrency=64",
"--prefetch-multiplier=8",
"--loglevel=INFO",
"--hostname=light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert"
],
"presentation": {
"group": "2"
{
"name": "Pytest",
"consoleName": "Pytest",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-v"
// Specify a sepcific module/test to run or provide nothing to run all tests
//"tests/unit/onyx/llm/answering/test_prune_and_merge.py"
],
"presentation": {
"group": "2",
},
"consoleTitle": "Pytest Console"
},
"consoleTitle": "Celery light Console"
},
{
"name": "Celery heavy",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
{
// Dummy entry used to label the group
"name": "--- Tasks ---",
"type": "node",
"request": "launch",
"presentation": {
"group": "3",
"order": 0
}
},
{
"name": "Clear and Restart External Volumes and Containers",
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": ["${workspaceFolder}/backend/scripts/restart_containers.sh"],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"stopOnEntry": true,
"presentation": {
"group": "3",
},
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.heavy",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync"
],
"presentation": {
"group": "2"
{
// Celery jobs launched through a single background script (legacy)
// Recommend using the "Celery (all)" compound launch instead.
"name": "Background Jobs",
"consoleName": "Background Jobs",
"type": "debugpy",
"request": "launch",
"program": "scripts/dev_run_background_jobs.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
},
"consoleTitle": "Celery heavy Console"
},
{
"name": "Celery indexing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"ENABLE_MULTIPASS_INDEXING": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
{
"name": "Install Python Requirements",
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": [
"-c",
"pip install -r backend/requirements/default.txt && pip install -r backend/requirements/dev.txt && pip install -r backend/requirements/ee.txt && pip install -r backend/requirements/model_server.txt"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "3"
}
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.indexing",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=indexing@%n",
"-Q",
"connector_indexing"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery indexing Console"
},
{
"name": "Celery monitoring",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {},
"args": [
"-A",
"onyx.background.celery.versioned_apps.monitoring",
"worker",
"--pool=solo",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=monitoring@%n",
"-Q",
"monitoring"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery monitoring Console"
},
{
"name": "Celery beat",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.beat",
"beat",
"--loglevel=INFO"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery beat Console"
},
{
"name": "Celery user files indexing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.indexing",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=user_files_indexing@%n",
"-Q",
"user_files_indexing"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery user files indexing Console"
},
{
"name": "Pytest",
"consoleName": "Pytest",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-v"
// Specify a sepcific module/test to run or provide nothing to run all tests
//"tests/unit/onyx/llm/answering/test_prune_and_merge.py"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Pytest Console"
},
{
// Dummy entry used to label the group
"name": "--- Tasks ---",
"type": "node",
"request": "launch",
"presentation": {
"group": "3",
"order": 0
}
},
{
"name": "Clear and Restart External Volumes and Containers",
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": [
"${workspaceFolder}/backend/scripts/restart_containers.sh"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"stopOnEntry": true,
"presentation": {
"group": "3"
}
},
{
// Celery jobs launched through a single background script (legacy)
// Recommend using the "Celery (all)" compound launch instead.
"name": "Background Jobs",
"consoleName": "Background Jobs",
"type": "debugpy",
"request": "launch",
"program": "scripts/dev_run_background_jobs.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
}
},
{
"name": "Install Python Requirements",
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": [
"-c",
"pip install -r backend/requirements/default.txt && pip install -r backend/requirements/dev.txt && pip install -r backend/requirements/ee.txt && pip install -r backend/requirements/model_server.txt"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "3"
}
},
{
"name": "Debug React Web App in Chrome",
"type": "chrome",
"request": "launch",
"url": "http://localhost:3000",
"webRoot": "${workspaceFolder}/web"
}
]
}
}

View File

@@ -102,7 +102,6 @@ COPY ./alembic /app/alembic
COPY ./alembic_tenants /app/alembic_tenants
COPY ./alembic.ini /app/alembic.ini
COPY supervisord.conf /usr/etc/supervisord.conf
COPY ./static /app/static
# Escape hatch scripts
COPY ./scripts/debugging /app/scripts/debugging

View File

@@ -46,7 +46,6 @@ WORKDIR /app
# Utils used by model server
COPY ./onyx/utils/logger.py /app/onyx/utils/logger.py
COPY ./onyx/utils/middleware.py /app/onyx/utils/middleware.py
# Place to fetch version information
COPY ./onyx/__init__.py /app/onyx/__init__.py

View File

@@ -84,7 +84,7 @@ keys = console
keys = generic
[logger_root]
level = INFO
level = WARN
handlers = console
qualname =

View File

@@ -25,9 +25,6 @@ from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
from onyx.db.models import Base
from celery.backends.database.session import ResultModelBase # type: ignore
# Make sure in alembic.ini [logger_root] level=INFO is set or most logging will be
# hidden! (defaults to level=WARN)
# Alembic Config object
config = context.config
@@ -39,7 +36,6 @@ if config.config_file_name is not None and config.attributes.get(
target_metadata = [Base.metadata, ResultModelBase.metadata]
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
logger = logging.getLogger(__name__)
ssl_context: ssl.SSLContext | None = None
@@ -68,7 +64,7 @@ def include_object(
return True
def get_schema_options() -> tuple[str, bool, bool, bool]:
def get_schema_options() -> tuple[str, bool, bool]:
x_args_raw = context.get_x_argument()
x_args = {}
for arg in x_args_raw:
@@ -80,10 +76,6 @@ def get_schema_options() -> tuple[str, bool, bool, bool]:
create_schema = x_args.get("create_schema", "true").lower() == "true"
upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true"
# continue on error with individual tenant
# only applies to online migrations
continue_on_error = x_args.get("continue", "false").lower() == "true"
if (
MULTI_TENANT
and schema_name == POSTGRES_DEFAULT_SCHEMA
@@ -94,12 +86,14 @@ def get_schema_options() -> tuple[str, bool, bool, bool]:
"Please specify a tenant-specific schema."
)
return schema_name, create_schema, upgrade_all_tenants, continue_on_error
return schema_name, create_schema, upgrade_all_tenants
def do_run_migrations(
connection: Connection, schema_name: str, create_schema: bool
) -> None:
logger.info(f"About to migrate schema: {schema_name}")
if create_schema:
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
connection.execute(text("COMMIT"))
@@ -140,12 +134,7 @@ def provide_iam_token_for_alembic(
async def run_async_migrations() -> None:
(
schema_name,
create_schema,
upgrade_all_tenants,
continue_on_error,
) = get_schema_options()
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
engine = create_async_engine(
build_connection_string(),
@@ -162,15 +151,9 @@ async def run_async_migrations() -> None:
if upgrade_all_tenants:
tenant_schemas = get_all_tenant_ids()
i_tenant = 0
num_tenants = len(tenant_schemas)
for schema in tenant_schemas:
i_tenant += 1
logger.info(
f"Migrating schema: index={i_tenant} num_tenants={num_tenants} schema={schema}"
)
try:
logger.info(f"Migrating schema: {schema}")
async with engine.connect() as connection:
await connection.run_sync(
do_run_migrations,
@@ -179,12 +162,7 @@ async def run_async_migrations() -> None:
)
except Exception as e:
logger.error(f"Error migrating schema {schema}: {e}")
if not continue_on_error:
logger.error("--continue is not set, raising exception!")
raise
logger.warning("--continue is set, continuing to next schema.")
raise
else:
try:
logger.info(f"Migrating schema: {schema_name}")
@@ -202,11 +180,7 @@ async def run_async_migrations() -> None:
def run_migrations_offline() -> None:
"""This doesn't really get used when we migrate in the cloud."""
logger.info("run_migrations_offline starting.")
schema_name, _, upgrade_all_tenants, continue_on_error = get_schema_options()
schema_name, _, upgrade_all_tenants = get_schema_options()
url = build_connection_string()
if upgrade_all_tenants:
@@ -256,7 +230,6 @@ def run_migrations_offline() -> None:
def run_migrations_online() -> None:
logger.info("run_migrations_online starting.")
asyncio.run(run_async_migrations())

View File

@@ -28,20 +28,6 @@ depends_on = None
def upgrade() -> None:
# First, drop any existing indexes to avoid conflicts
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
op.execute("COMMIT")
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")
# Drop existing columns if they exist
op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;")
op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS description_tsv;")
# Create a GIN index for full-text search on chat_message.message
op.execute(
"""

View File

@@ -1,50 +0,0 @@
"""update prompt length
Revision ID: 4794bc13e484
Revises: f7505c5b0284
Create Date: 2025-04-02 11:26:36.180328
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4794bc13e484"
down_revision = "f7505c5b0284"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column(
"prompt",
"system_prompt",
existing_type=sa.TEXT(),
type_=sa.String(length=5000000),
existing_nullable=False,
)
op.alter_column(
"prompt",
"task_prompt",
existing_type=sa.TEXT(),
type_=sa.String(length=5000000),
existing_nullable=False,
)
def downgrade() -> None:
op.alter_column(
"prompt",
"system_prompt",
existing_type=sa.String(length=5000000),
type_=sa.TEXT(),
existing_nullable=False,
)
op.alter_column(
"prompt",
"task_prompt",
existing_type=sa.String(length=5000000),
type_=sa.TEXT(),
existing_nullable=False,
)

View File

@@ -1,117 +0,0 @@
"""duplicated no-harm user file migration
Revision ID: 6a804aeb4830
Revises: 8e1ac4f39a9f
Create Date: 2025-04-01 07:26:10.539362
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy import inspect
import datetime
# revision identifiers, used by Alembic.
revision = "6a804aeb4830"
down_revision = "8e1ac4f39a9f"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Check if user_file table already exists
conn = op.get_bind()
inspector = inspect(conn)
if not inspector.has_table("user_file"):
# Create user_folder table without parent_id
op.create_table(
"user_folder",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
sa.Column("name", sa.String(length=255), nullable=True),
sa.Column("description", sa.String(length=255), nullable=True),
sa.Column("display_priority", sa.Integer(), nullable=True, default=0),
sa.Column(
"created_at", sa.DateTime(timezone=True), server_default=sa.func.now()
),
)
# Create user_file table with folder_id instead of parent_folder_id
op.create_table(
"user_file",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
sa.Column(
"folder_id",
sa.Integer(),
sa.ForeignKey("user_folder.id"),
nullable=True,
),
sa.Column("link_url", sa.String(), nullable=True),
sa.Column("token_count", sa.Integer(), nullable=True),
sa.Column("file_type", sa.String(), nullable=True),
sa.Column("file_id", sa.String(length=255), nullable=False),
sa.Column("document_id", sa.String(length=255), nullable=False),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column(
"created_at",
sa.DateTime(),
default=datetime.datetime.utcnow,
),
sa.Column(
"cc_pair_id",
sa.Integer(),
sa.ForeignKey("connector_credential_pair.id"),
nullable=True,
unique=True,
),
)
# Create persona__user_file table
op.create_table(
"persona__user_file",
sa.Column(
"persona_id",
sa.Integer(),
sa.ForeignKey("persona.id"),
primary_key=True,
),
sa.Column(
"user_file_id",
sa.Integer(),
sa.ForeignKey("user_file.id"),
primary_key=True,
),
)
# Create persona__user_folder table
op.create_table(
"persona__user_folder",
sa.Column(
"persona_id",
sa.Integer(),
sa.ForeignKey("persona.id"),
primary_key=True,
),
sa.Column(
"user_folder_id",
sa.Integer(),
sa.ForeignKey("user_folder.id"),
primary_key=True,
),
)
op.add_column(
"connector_credential_pair",
sa.Column("is_user_file", sa.Boolean(), nullable=True, default=False),
)
# Update existing records to have is_user_file=False instead of NULL
op.execute(
"UPDATE connector_credential_pair SET is_user_file = FALSE WHERE is_user_file IS NULL"
)
def downgrade() -> None:
pass

View File

@@ -1,50 +0,0 @@
"""enable contextual retrieval
Revision ID: 8e1ac4f39a9f
Revises: 9aadf32dfeb4
Create Date: 2024-12-20 13:29:09.918661
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "8e1ac4f39a9f"
down_revision = "9aadf32dfeb4"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"search_settings",
sa.Column(
"enable_contextual_rag",
sa.Boolean(),
nullable=False,
server_default="false",
),
)
op.add_column(
"search_settings",
sa.Column(
"contextual_rag_llm_name",
sa.String(),
nullable=True,
),
)
op.add_column(
"search_settings",
sa.Column(
"contextual_rag_llm_provider",
sa.String(),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("search_settings", "enable_contextual_rag")
op.drop_column("search_settings", "contextual_rag_llm_name")
op.drop_column("search_settings", "contextual_rag_llm_provider")

View File

@@ -1,113 +0,0 @@
"""add user files
Revision ID: 9aadf32dfeb4
Revises: 3781a5eb12cb
Create Date: 2025-01-26 16:08:21.551022
"""
import sqlalchemy as sa
import datetime
from alembic import op
# revision identifiers, used by Alembic.
revision = "9aadf32dfeb4"
down_revision = "3781a5eb12cb"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create user_folder table without parent_id
op.create_table(
"user_folder",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
sa.Column("name", sa.String(length=255), nullable=True),
sa.Column("description", sa.String(length=255), nullable=True),
sa.Column("display_priority", sa.Integer(), nullable=True, default=0),
sa.Column(
"created_at", sa.DateTime(timezone=True), server_default=sa.func.now()
),
)
# Create user_file table with folder_id instead of parent_folder_id
op.create_table(
"user_file",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
sa.Column(
"folder_id",
sa.Integer(),
sa.ForeignKey("user_folder.id"),
nullable=True,
),
sa.Column("link_url", sa.String(), nullable=True),
sa.Column("token_count", sa.Integer(), nullable=True),
sa.Column("file_type", sa.String(), nullable=True),
sa.Column("file_id", sa.String(length=255), nullable=False),
sa.Column("document_id", sa.String(length=255), nullable=False),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column(
"created_at",
sa.DateTime(),
default=datetime.datetime.utcnow,
),
sa.Column(
"cc_pair_id",
sa.Integer(),
sa.ForeignKey("connector_credential_pair.id"),
nullable=True,
unique=True,
),
)
# Create persona__user_file table
op.create_table(
"persona__user_file",
sa.Column(
"persona_id", sa.Integer(), sa.ForeignKey("persona.id"), primary_key=True
),
sa.Column(
"user_file_id",
sa.Integer(),
sa.ForeignKey("user_file.id"),
primary_key=True,
),
)
# Create persona__user_folder table
op.create_table(
"persona__user_folder",
sa.Column(
"persona_id", sa.Integer(), sa.ForeignKey("persona.id"), primary_key=True
),
sa.Column(
"user_folder_id",
sa.Integer(),
sa.ForeignKey("user_folder.id"),
primary_key=True,
),
)
op.add_column(
"connector_credential_pair",
sa.Column("is_user_file", sa.Boolean(), nullable=True, default=False),
)
# Update existing records to have is_user_file=False instead of NULL
op.execute(
"UPDATE connector_credential_pair SET is_user_file = FALSE WHERE is_user_file IS NULL"
)
def downgrade() -> None:
# Drop the persona__user_folder table
op.drop_table("persona__user_folder")
# Drop the persona__user_file table
op.drop_table("persona__user_file")
# Drop the user_file table
op.drop_table("user_file")
# Drop the user_folder table
op.drop_table("user_folder")
op.drop_column("connector_credential_pair", "is_user_file")

View File

@@ -1,50 +0,0 @@
"""add prompt length limit
Revision ID: f71470ba9274
Revises: 6a804aeb4830
Create Date: 2025-04-01 15:07:14.977435
"""
# revision identifiers, used by Alembic.
revision = "f71470ba9274"
down_revision = "6a804aeb4830"
branch_labels = None
depends_on = None
def upgrade() -> None:
# op.alter_column(
# "prompt",
# "system_prompt",
# existing_type=sa.TEXT(),
# type_=sa.String(length=8000),
# existing_nullable=False,
# )
# op.alter_column(
# "prompt",
# "task_prompt",
# existing_type=sa.TEXT(),
# type_=sa.String(length=8000),
# existing_nullable=False,
# )
pass
def downgrade() -> None:
# op.alter_column(
# "prompt",
# "system_prompt",
# existing_type=sa.String(length=8000),
# type_=sa.TEXT(),
# existing_nullable=False,
# )
# op.alter_column(
# "prompt",
# "task_prompt",
# existing_type=sa.String(length=8000),
# type_=sa.TEXT(),
# existing_nullable=False,
# )
pass

View File

@@ -1,77 +0,0 @@
"""updated constraints for ccpairs
Revision ID: f7505c5b0284
Revises: f71470ba9274
Create Date: 2025-04-01 17:50:42.504818
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "f7505c5b0284"
down_revision = "f71470ba9274"
branch_labels = None
depends_on = None
def upgrade() -> None:
# 1) Drop the old foreign-key constraints
op.drop_constraint(
"document_by_connector_credential_pair_connector_id_fkey",
"document_by_connector_credential_pair",
type_="foreignkey",
)
op.drop_constraint(
"document_by_connector_credential_pair_credential_id_fkey",
"document_by_connector_credential_pair",
type_="foreignkey",
)
# 2) Re-add them with ondelete='CASCADE'
op.create_foreign_key(
"document_by_connector_credential_pair_connector_id_fkey",
source_table="document_by_connector_credential_pair",
referent_table="connector",
local_cols=["connector_id"],
remote_cols=["id"],
ondelete="CASCADE",
)
op.create_foreign_key(
"document_by_connector_credential_pair_credential_id_fkey",
source_table="document_by_connector_credential_pair",
referent_table="credential",
local_cols=["credential_id"],
remote_cols=["id"],
ondelete="CASCADE",
)
def downgrade() -> None:
# Reverse the changes for rollback
op.drop_constraint(
"document_by_connector_credential_pair_connector_id_fkey",
"document_by_connector_credential_pair",
type_="foreignkey",
)
op.drop_constraint(
"document_by_connector_credential_pair_credential_id_fkey",
"document_by_connector_credential_pair",
type_="foreignkey",
)
# Recreate without CASCADE
op.create_foreign_key(
"document_by_connector_credential_pair_connector_id_fkey",
"document_by_connector_credential_pair",
"connector",
["connector_id"],
["id"],
)
op.create_foreign_key(
"document_by_connector_credential_pair_credential_id_fkey",
"document_by_connector_credential_pair",
"credential",
["credential_id"],
["id"],
)

View File

@@ -93,12 +93,12 @@ def _get_access_for_documents(
)
# To avoid collisions of group namings between connectors, they need to be prefixed
access_map[document_id] = DocumentAccess.build(
user_emails=list(non_ee_access.user_emails),
user_groups=user_group_info.get(document_id, []),
access_map[document_id] = DocumentAccess(
user_emails=non_ee_access.user_emails,
user_groups=set(user_group_info.get(document_id, [])),
is_public=is_public_anywhere,
external_user_emails=list(ext_u_emails),
external_user_group_ids=list(ext_u_groups),
external_user_emails=ext_u_emails,
external_user_group_ids=ext_u_groups,
)
return access_map

View File

@@ -2,6 +2,7 @@ from ee.onyx.server.query_and_chat.models import OneShotQAResponse
from onyx.chat.models import AllCitations
from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import OnyxContexts
from onyx.chat.models import QADocsResponse
from onyx.chat.models import StreamingError
from onyx.chat.process_message import ChatPacketStream
@@ -31,6 +32,8 @@ def gather_stream_for_answer_api(
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
elif isinstance(packet, AllCitations):
response.citations = packet.citations
elif isinstance(packet, OnyxContexts):
response.contexts = packet
if answer:
response.answer = answer

View File

@@ -25,10 +25,6 @@ SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_co
#####
# Auto Permission Sync
#####
DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
# In seconds, default is 5 minutes
CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
os.environ.get("CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
@@ -43,7 +39,6 @@ CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC = (
CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)
@@ -77,13 +72,6 @@ OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
)
GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
os.environ.get("GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
)
SLACK_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("SLACK_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
# The posthog client does not accept empty API keys or hosts however it fails silently
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app

View File

@@ -159,9 +159,6 @@ def _get_space_permissions(
# Stores the permissions for each space
space_permissions_by_space_key[space_key] = space_permissions
logger.info(
f"Found space permissions for space '{space_key}': {space_permissions}"
)
return space_permissions_by_space_key

View File

@@ -55,7 +55,7 @@ def _post_query_chunk_censoring(
# if user is None, permissions are not enforced
return chunks
final_chunk_dict: dict[str, InferenceChunk] = {}
chunks_to_keep = []
chunks_to_process: dict[DocumentSource, list[InferenceChunk]] = {}
sources_to_censor = _get_all_censoring_enabled_sources()
@@ -64,7 +64,7 @@ def _post_query_chunk_censoring(
if chunk.source_type in sources_to_censor:
chunks_to_process.setdefault(chunk.source_type, []).append(chunk)
else:
final_chunk_dict[chunk.unique_id] = chunk
chunks_to_keep.append(chunk)
# For each source, filter out the chunks using the permission
# check function for that source
@@ -79,16 +79,6 @@ def _post_query_chunk_censoring(
f" chunks for this source and continuing: {e}"
)
continue
chunks_to_keep.extend(censored_chunks)
for censored_chunk in censored_chunks:
final_chunk_dict[censored_chunk.unique_id] = censored_chunk
# IMPORTANT: make sure to retain the same ordering as the original `chunks` passed in
final_chunk_list: list[InferenceChunk] = []
for chunk in chunks:
# only if the chunk is in the final censored chunks, add it to the final list
# if it is missing, that means it was intentionally left out
if chunk.unique_id in final_chunk_dict:
final_chunk_list.append(final_chunk_dict[chunk.unique_id])
return final_chunk_list
return chunks_to_keep

View File

@@ -51,14 +51,13 @@ def _get_objects_access_for_user_email_from_salesforce(
# This is cached in the function so the first query takes an extra 0.1-0.3 seconds
# but subsequent queries by the same user are essentially instant
start_time = time.monotonic()
start_time = time.time()
user_id = get_salesforce_user_id_from_email(salesforce_client, user_email)
end_time = time.monotonic()
end_time = time.time()
logger.info(
f"Time taken to get Salesforce user ID: {end_time - start_time} seconds"
)
if user_id is None:
logger.warning(f"User '{user_email}' not found in Salesforce")
return None
# This is the only query that is not cached in the function
@@ -66,7 +65,6 @@ def _get_objects_access_for_user_email_from_salesforce(
object_id_to_access = get_objects_access_for_user_id(
salesforce_client, user_id, list(object_ids)
)
logger.debug(f"Object ID to access: {object_id_to_access}")
return object_id_to_access

View File

@@ -1,6 +1,10 @@
from simple_salesforce import Salesforce
from sqlalchemy.orm import Session
from onyx.connectors.salesforce.sqlite_functions import get_user_id_by_email
from onyx.connectors.salesforce.sqlite_functions import init_db
from onyx.connectors.salesforce.sqlite_functions import NULL_ID_STRING
from onyx.connectors.salesforce.sqlite_functions import update_email_to_id_table
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.document import get_cc_pairs_for_document
from onyx.utils.logger import setup_logger
@@ -24,8 +28,6 @@ def get_any_salesforce_client_for_doc_id(
E.g. there are 2 different credential sets for 2 different salesforce cc_pairs
but only one has the permissions to access the permissions needed for the query.
"""
# NOTE: this global seems very very bad
global _ANY_SALESFORCE_CLIENT
if _ANY_SALESFORCE_CLIENT is None:
cc_pairs = get_cc_pairs_for_document(db_session, doc_id)
@@ -40,18 +42,11 @@ def get_any_salesforce_client_for_doc_id(
def _query_salesforce_user_id(sf_client: Salesforce, user_email: str) -> str | None:
query = f"SELECT Id FROM User WHERE Username = '{user_email}' AND IsActive = true"
query = f"SELECT Id FROM User WHERE Email = '{user_email}'"
result = sf_client.query(query)
if len(result["records"]) > 0:
return result["records"][0]["Id"]
# try emails
query = f"SELECT Id FROM User WHERE Email = '{user_email}' AND IsActive = true"
result = sf_client.query(query)
if len(result["records"]) > 0:
return result["records"][0]["Id"]
return None
if len(result["records"]) == 0:
return None
return result["records"][0]["Id"]
# This contains only the user_ids that we have found in Salesforce.
@@ -82,21 +77,35 @@ def get_salesforce_user_id_from_email(
salesforce database. (Around 0.1-0.3 seconds)
If it's cached or stored in the local salesforce database, it's fast (<0.001 seconds).
"""
# NOTE: this global seems bad
global _CACHED_SF_EMAIL_TO_ID_MAP
if user_email in _CACHED_SF_EMAIL_TO_ID_MAP:
if _CACHED_SF_EMAIL_TO_ID_MAP[user_email] is not None:
return _CACHED_SF_EMAIL_TO_ID_MAP[user_email]
# some caching via sqlite existed here before ... check history if interested
# ...query Salesforce and store the result in the database
user_id = _query_salesforce_user_id(sf_client, user_email)
db_exists = True
try:
# Check if the user is already in the database
user_id = get_user_id_by_email(user_email)
except Exception:
init_db()
try:
user_id = get_user_id_by_email(user_email)
except Exception as e:
logger.error(f"Error checking if user is in database: {e}")
user_id = None
db_exists = False
# If no entry is found in the database (indicated by user_id being None)...
if user_id is None:
# ...query Salesforce and store the result in the database
user_id = _query_salesforce_user_id(sf_client, user_email)
if db_exists:
update_email_to_id_table(user_email, user_id)
return user_id
elif user_id is None:
return None
elif user_id == NULL_ID_STRING:
return None
# If the found user_id is real, cache it
_CACHED_SF_EMAIL_TO_ID_MAP[user_email] = user_id
return user_id

View File

@@ -3,8 +3,6 @@ from collections.abc import Generator
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import SLACK_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.confluence.doc_sync import confluence_doc_sync
from ee.onyx.external_permissions.confluence.group_sync import confluence_group_sync
@@ -68,13 +66,13 @@ GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC: set[DocumentSource] = {
DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all doc permissions every 5 minutes
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY,
DocumentSource.SLACK: SLACK_PERMISSION_DOC_SYNC_FREQUENCY,
DocumentSource.SLACK: 5 * 60,
}
# If nothing is specified here, we run the doc_sync every time the celery beat runs
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all group permissions every 30 minutes
DocumentSource.GOOGLE_DRIVE: GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY,
DocumentSource.GOOGLE_DRIVE: 5 * 60,
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY,
}

View File

@@ -64,15 +64,7 @@ def get_application() -> FastAPI:
add_tenant_id_middleware(application, logger)
if AUTH_TYPE == AuthType.CLOUD:
# For Google OAuth, refresh tokens are requested by:
# 1. Adding the right scopes
# 2. Properly configuring OAuth in Google Cloud Console to allow offline access
oauth_client = GoogleOAuth2(
OAUTH_CLIENT_ID,
OAUTH_CLIENT_SECRET,
# Use standard scopes that include profile and email
scopes=["openid", "email", "profile"],
)
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
@@ -95,16 +87,6 @@ def get_application() -> FastAPI:
)
if AUTH_TYPE == AuthType.OIDC:
# Ensure we request offline_access for refresh tokens
try:
oidc_scopes = list(OIDC_SCOPE_OVERRIDE or BASE_SCOPES)
if "offline_access" not in oidc_scopes:
oidc_scopes.append("offline_access")
except Exception as e:
logger.warning(f"Error configuring OIDC scopes: {e}")
# Fall back to default scopes if there's an error
oidc_scopes = BASE_SCOPES
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
@@ -112,8 +94,8 @@ def get_application() -> FastAPI:
OAUTH_CLIENT_ID,
OAUTH_CLIENT_SECRET,
OPENID_CONFIG_URL,
# Use the configured scopes
base_scopes=oidc_scopes,
# BASE_SCOPES is the same as not setting this
base_scopes=OIDC_SCOPE_OVERRIDE or BASE_SCOPES,
),
auth_backend,
USER_AUTH_SECRET,

View File

@@ -44,7 +44,7 @@ async def _get_tenant_id_from_request(
Attempt to extract tenant_id from:
1) The API key header
2) The Redis-based token (stored in Cookie: fastapiusersauth)
3) The anonymous user cookie
3) Reset token cookie
Fallback: POSTGRES_DEFAULT_SCHEMA
"""
# Check for API key
@@ -52,55 +52,41 @@ async def _get_tenant_id_from_request(
if tenant_id is not None:
return tenant_id
# Check for anonymous user cookie
anonymous_user_cookie = request.cookies.get(ANONYMOUS_USER_COOKIE_NAME)
if anonymous_user_cookie:
try:
anonymous_user_data = decode_anonymous_user_jwt_token(anonymous_user_cookie)
return anonymous_user_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
except Exception as e:
logger.error(f"Error decoding anonymous user cookie: {str(e)}")
# Continue and attempt to authenticate
try:
# Look up token data in Redis
token_data = await retrieve_auth_token_data_from_redis(request)
if token_data:
tenant_id_from_payload = token_data.get(
"tenant_id", POSTGRES_DEFAULT_SCHEMA
if not token_data:
logger.debug(
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
)
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
# so we maintain consistency by returning it here when no valid tenant is found.
return POSTGRES_DEFAULT_SCHEMA
tenant_id = (
str(tenant_id_from_payload)
if tenant_id_from_payload is not None
else None
)
tenant_id_from_payload = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
if tenant_id and not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
# Check for anonymous user cookie
anonymous_user_cookie = request.cookies.get(ANONYMOUS_USER_COOKIE_NAME)
if anonymous_user_cookie:
try:
anonymous_user_data = decode_anonymous_user_jwt_token(
anonymous_user_cookie
)
tenant_id = anonymous_user_data.get(
"tenant_id", POSTGRES_DEFAULT_SCHEMA
)
if not tenant_id or not is_valid_schema_name(tenant_id):
raise HTTPException(
status_code=400, detail="Invalid tenant ID format"
)
return tenant_id
except Exception as e:
logger.error(f"Error decoding anonymous user cookie: {str(e)}")
# Continue and attempt to authenticate
logger.debug(
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
# Since token_data.get() can return None, ensure we have a string
tenant_id = (
str(tenant_id_from_payload)
if tenant_id_from_payload is not None
else POSTGRES_DEFAULT_SCHEMA
)
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
# so we maintain consistency by returning it here when no valid tenant is found.
return POSTGRES_DEFAULT_SCHEMA
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
except Exception as e:
logger.error(f"Unexpected error in _get_tenant_id_from_request: {str(e)}")

View File

@@ -14,6 +14,7 @@ from ee.onyx.server.query_and_chat.models import (
BasicCreateChatMessageWithHistoryRequest,
)
from ee.onyx.server.query_and_chat.models import ChatBasicResponse
from ee.onyx.server.query_and_chat.models import SimpleDoc
from onyx.auth.users import current_user
from onyx.chat.chat_utils import combine_message_thread
from onyx.chat.chat_utils import create_chat_chain
@@ -55,6 +56,25 @@ logger = setup_logger()
router = APIRouter(prefix="/chat")
def _translate_doc_response_to_simple_doc(
doc_response: QADocsResponse,
) -> list[SimpleDoc]:
return [
SimpleDoc(
id=doc.document_id,
semantic_identifier=doc.semantic_identifier,
link=doc.link,
blurb=doc.blurb,
match_highlights=[
highlight for highlight in doc.match_highlights if highlight
],
source_type=doc.source_type,
metadata=doc.metadata,
)
for doc in doc_response.top_documents
]
def _get_final_context_doc_indices(
final_context_docs: list[LlmDoc] | None,
top_docs: list[SavedSearchDoc] | None,
@@ -91,6 +111,9 @@ def _convert_packet_stream_to_response(
elif isinstance(packet, QADocsResponse):
response.top_documents = packet.top_documents
# TODO: deprecate `simple_search_docs`
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
# This is a no-op if agent_sub_questions hasn't already been filled
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)

View File

@@ -8,6 +8,7 @@ from pydantic import model_validator
from ee.onyx.server.manage.models import StandardAnswer
from onyx.chat.models import CitationInfo
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import SubQuestionIdentifier
@@ -163,6 +164,8 @@ class ChatBasicResponse(BaseModel):
cited_documents: dict[int, str] | None = None
# FOR BACKWARDS COMPATIBILITY
# TODO: deprecate both of these
simple_search_docs: list[SimpleDoc] | None = None
llm_chunks_indices: list[int] | None = None
# agentic fields
@@ -217,3 +220,4 @@ class OneShotQAResponse(BaseModel):
llm_selected_doc_indices: list[int] | None = None
error_msg: str | None = None
chat_message_id: int | None = None
contexts: OnyxContexts | None = None

View File

@@ -1,6 +1,5 @@
import json
from collections.abc import Generator
from typing import Optional
from fastapi import APIRouter
from fastapi import Depends
@@ -24,12 +23,8 @@ from onyx.chat.chat_utils import prepare_chat_message_request
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.process_message import ChatPacketStream
from onyx.chat.process_message import stream_chat_message_objects
from onyx.configs.app_configs import FAST_SEARCH_MAX_HITS
from onyx.configs.onyxbot_configs import MAX_THREAD_CONTEXT_PERCENTAGE
from onyx.context.search.enums import LLMEvaluationType
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import SavedSearchDocWithContent
from onyx.context.search.models import SearchDoc
from onyx.context.search.models import SearchRequest
from onyx.context.search.pipeline import SearchPipeline
from onyx.context.search.utils import dedupe_documents
@@ -235,26 +230,6 @@ def get_answer_with_citation(
raise HTTPException(status_code=500, detail="An internal server error occurred")
@basic_router.post("/search")
def get_search_response(
request: OneShotQARequest,
db_session: Session = Depends(get_session),
user: User | None = Depends(current_user),
) -> StreamingResponse:
def stream_generator() -> Generator[str, None, None]:
try:
for packet in get_answer_stream(request, user, db_session):
print("packet is")
print(packet.__dict__)
serialized = get_json_line(packet.model_dump())
yield serialized
except Exception as e:
logger.exception("Error in answer streaming")
yield json.dumps({"error": str(e)})
return StreamingResponse(stream_generator(), media_type="application/json")
@basic_router.post("/stream-answer-with-citation")
def stream_answer_with_citation(
request: OneShotQARequest,
@@ -289,112 +264,3 @@ def get_standard_answer(
except Exception as e:
logger.error(f"Error in get_standard_answer: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail="An internal server error occurred")
class FastSearchRequest(BaseModel):
"""Request for fast search endpoint that returns raw search results without section merging."""
query: str
filters: BaseFilters | None = (
None # Direct filter options instead of retrieval_options
)
max_results: Optional[
int
] = None # If not provided, defaults to FAST_SEARCH_MAX_HITS
class FastSearchResult(BaseModel):
"""A search result without section expansion or merging."""
document_id: str
chunk_id: int
content: str
source_links: dict[int, str] | None = None
score: Optional[float] = None
metadata: Optional[dict] = None
class FastSearchResponse(BaseModel):
"""Response from the fast search endpoint."""
results: list[SearchDoc]
total_found: int
@basic_router.post("/fast-search")
def get_fast_search_response(
request: FastSearchRequest,
db_session: Session = Depends(get_session),
user: User | None = Depends(current_user),
) -> FastSearchResponse:
"""Endpoint for fast search that returns up to 300 results without section merging.
This is optimized for quickly returning a large number of search results without the overhead
of section expansion, reranking, relevance evaluation, and merging.
"""
try:
# Set up the search request with optimized settings
max_results = request.max_results or FAST_SEARCH_MAX_HITS
# Create a search request with optimized settings
search_request = SearchRequest(
query=request.query,
human_selected_filters=request.filters,
# Skip section expansion
chunks_above=0,
chunks_below=0,
# Skip LLM evaluation
evaluation_type=LLMEvaluationType.SKIP,
# Limit the number of results
limit=max_results,
)
# Set up the LLM instances
llm, fast_llm = get_default_llms()
# Create the search pipeline with optimized settings
search_pipeline = SearchPipeline(
search_request=search_request,
user=user,
llm=llm,
fast_llm=fast_llm,
skip_query_analysis=True, # Skip expensive query analysis
db_session=db_session,
bypass_acl=False,
)
# Only retrieve chunks without further processing
chunks = search_pipeline._get_chunks()
# Convert chunks to response format
results = [
SearchDoc(
document_id=chunk.document_id,
chunk_ind=chunk.chunk_id,
semantic_identifier=chunk.semantic_identifier,
link=None, # Assuming source_links might be used for link
blurb=chunk.content,
source_type=chunk.source_type, # Default source type
boost=0, # Default boost value
hidden=False, # Default visibility
metadata=chunk.metadata,
score=chunk.score,
is_relevant=None,
relevance_explanation=None,
match_highlights=[],
updated_at=chunk.updated_at,
primary_owners=None,
secondary_owners=None,
is_internet=False,
)
for chunk in chunks
]
return FastSearchResponse(
results=results,
total_found=len(results),
)
except Exception as e:
logger.exception("Error in fast search")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -38,7 +38,6 @@ router = APIRouter(prefix="/auth/saml")
async def upsert_saml_user(email: str) -> User:
logger.debug(f"Attempting to upsert SAML user with email: {email}")
get_async_session_context = contextlib.asynccontextmanager(
get_async_session
) # type:ignore
@@ -49,13 +48,9 @@ async def upsert_saml_user(email: str) -> User:
async with get_user_db_context(session) as user_db:
async with get_user_manager_context(user_db) as user_manager:
try:
user = await user_manager.get_by_email(email)
# If user has a non-authenticated role, treat as non-existent
if not user.role.is_web_login():
raise exceptions.UserNotExists()
return user
return await user_manager.get_by_email(email)
except exceptions.UserNotExists:
logger.info("Creating user from SAML login")
logger.notice("Creating user from SAML login")
user_count = await get_user_count()
role = UserRole.ADMIN if user_count == 0 else UserRole.BASIC
@@ -64,10 +59,11 @@ async def upsert_saml_user(email: str) -> User:
password = fastapi_users_pw_helper.generate()
hashed_pass = fastapi_users_pw_helper.hash(password)
user = await user_manager.create(
user: User = await user_manager.create(
UserCreate(
email=email,
password=hashed_pass,
is_verified=True,
role=role,
)
)

View File

@@ -87,15 +87,11 @@ async def get_or_provision_tenant(
# If we have a pre-provisioned tenant, assign it to the user
await assign_tenant_to_user(tenant_id, email, referral_source)
logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}")
return tenant_id
else:
# If no pre-provisioned tenant is available, create a new one on-demand
tenant_id = await create_tenant(email, referral_source)
# Notify control plane if we have created / assigned a new tenant
if not DEV_MODE:
await notify_control_plane(tenant_id, email, referral_source)
return tenant_id
return tenant_id
except Exception as e:
# If we've encountered an error, log and raise an exception
@@ -120,6 +116,10 @@ async def create_tenant(email: str, referral_source: str | None = None) -> str:
# Provision tenant on data plane
await provision_tenant(tenant_id, email)
# Notify control plane if not already done in provision_tenant
if not DEV_MODE and referral_source:
await notify_control_plane(tenant_id, email, referral_source)
except Exception as e:
logger.exception(f"Tenant provisioning failed: {str(e)}")
# Attempt to rollback the tenant provisioning
@@ -506,11 +506,8 @@ async def setup_tenant(tenant_id: str) -> None:
try:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Run Alembic migrations in a way that isolates it from the current event loop
# Create a new event loop for this synchronous operation
loop = asyncio.get_event_loop()
# Use run_in_executor which properly isolates the thread execution
await loop.run_in_executor(None, lambda: run_alembic_migrations(tenant_id))
# Run Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
# Configure the tenant with default settings
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
@@ -564,3 +561,7 @@ async def assign_tenant_to_user(
except Exception:
logger.exception(f"Failed to assign tenant {tenant_id} to user {email}")
raise Exception("Failed to assign tenant to user")
# Notify control plane with retry logic
if not DEV_MODE:
await notify_control_plane(tenant_id, email, referral_source)

View File

@@ -70,7 +70,6 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
"""
Add users to a tenant with proper transaction handling.
Checks if users already have a tenant mapping to avoid duplicates.
If a user already has an active mapping to any tenant, the new mapping will be added as inactive.
"""
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
@@ -89,25 +88,9 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
.first()
)
# If user already has an active mapping, add this one as inactive
if not existing_mapping:
# Check if the user already has an active mapping to any tenant
has_active_mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
)
.first()
)
db_session.add(
UserTenantMapping(
email=email,
tenant_id=tenant_id,
active=False if has_active_mapping else True,
)
)
# Only add if mapping doesn't exist
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
# Commit the transaction
db_session.commit()

Binary file not shown.

View File

@@ -1,4 +1,3 @@
import logging
import os
import shutil
from collections.abc import AsyncGenerator
@@ -9,7 +8,6 @@ import sentry_sdk
import torch
import uvicorn
from fastapi import FastAPI
from prometheus_fastapi_instrumentator import Instrumentator
from sentry_sdk.integrations.fastapi import FastApiIntegration
from sentry_sdk.integrations.starlette import StarletteIntegration
from transformers import logging as transformer_logging # type:ignore
@@ -22,8 +20,6 @@ from model_server.management_endpoints import router as management_router
from model_server.utils import get_gpu_type
from onyx import __version__
from onyx.utils.logger import setup_logger
from onyx.utils.logger import setup_uvicorn_logger
from onyx.utils.middleware import add_onyx_request_id_middleware
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import MIN_THREADS_ML_MODELS
from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
@@ -40,12 +36,6 @@ transformer_logging.set_verbosity_error()
logger = setup_logger()
file_handlers = [
h for h in logger.logger.handlers if isinstance(h, logging.FileHandler)
]
setup_uvicorn_logger(shared_file_handlers=file_handlers)
def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -> None:
"""
@@ -122,15 +112,6 @@ def get_model_app() -> FastAPI:
application.include_router(encoders_router)
application.include_router(custom_models_router)
request_id_prefix = "INF"
if INDEXING_ONLY:
request_id_prefix = "IDX"
add_onyx_request_id_middleware(application, request_id_prefix, logger)
# Initialize and instrument the app
Instrumentator().instrument(application).expose(application)
return application

View File

@@ -18,7 +18,7 @@ def _get_access_for_document(
document_id=document_id,
)
doc_access = DocumentAccess.build(
return DocumentAccess.build(
user_emails=info[1] if info and info[1] else [],
user_groups=[],
external_user_emails=[],
@@ -26,8 +26,6 @@ def _get_access_for_document(
is_public=info[2] if info else False,
)
return doc_access
def get_access_for_document(
document_id: str,
@@ -40,12 +38,12 @@ def get_access_for_document(
def get_null_document_access() -> DocumentAccess:
return DocumentAccess.build(
user_emails=[],
user_groups=[],
return DocumentAccess(
user_emails=set(),
user_groups=set(),
is_public=False,
external_user_emails=[],
external_user_group_ids=[],
external_user_emails=set(),
external_user_group_ids=set(),
)
@@ -57,18 +55,19 @@ def _get_access_for_documents(
db_session=db_session,
document_ids=document_ids,
)
doc_access = {}
for document_id, user_emails, is_public in document_access_info:
doc_access[document_id] = DocumentAccess.build(
user_emails=[email for email in user_emails if email],
doc_access = {
document_id: DocumentAccess(
user_emails=set([email for email in user_emails if email]),
# MIT version will wipe all groups and external groups on update
user_groups=[],
user_groups=set(),
is_public=is_public,
external_user_emails=[],
external_user_group_ids=[],
external_user_emails=set(),
external_user_group_ids=set(),
)
for document_id, user_emails, is_public in document_access_info
}
# Sometimes the document has not been indexed by the indexing job yet, in those cases
# Sometimes the document has not be indexed by the indexing job yet, in those cases
# the document does not exist and so we use least permissive. Specifically the EE version
# checks the MIT version permissions and creates a superset. This ensures that this flow
# does not fail even if the Document has not yet been indexed.

View File

@@ -15,22 +15,6 @@ class ExternalAccess:
# Whether the document is public in the external system or Onyx
is_public: bool
def __str__(self) -> str:
"""Prevent extremely long logs"""
def truncate_set(s: set[str], max_len: int = 100) -> str:
s_str = str(s)
if len(s_str) > max_len:
return f"{s_str[:max_len]}... ({len(s)} items)"
return s_str
return (
f"ExternalAccess("
f"external_user_emails={truncate_set(self.external_user_emails)}, "
f"external_user_group_ids={truncate_set(self.external_user_group_ids)}, "
f"is_public={self.is_public})"
)
@dataclass(frozen=True)
class DocExternalAccess:
@@ -72,45 +56,33 @@ class DocExternalAccess:
)
@dataclass(frozen=True, init=False)
@dataclass(frozen=True)
class DocumentAccess(ExternalAccess):
# User emails for Onyx users, None indicates admin
user_emails: set[str | None]
# Names of user groups associated with this document
user_groups: set[str]
external_user_emails: set[str]
external_user_group_ids: set[str]
is_public: bool
def __init__(self) -> None:
raise TypeError(
"Use `DocumentAccess.build(...)` instead of creating an instance directly."
)
def to_acl(self) -> set[str]:
# the acl's emitted by this function are prefixed by type
# to get the native objects, access the member variables directly
acl_set: set[str] = set()
for user_email in self.user_emails:
if user_email:
acl_set.add(prefix_user_email(user_email))
for group_name in self.user_groups:
acl_set.add(prefix_user_group(group_name))
for external_user_email in self.external_user_emails:
acl_set.add(prefix_user_email(external_user_email))
for external_group_id in self.external_user_group_ids:
acl_set.add(prefix_external_group(external_group_id))
if self.is_public:
acl_set.add(PUBLIC_DOC_PAT)
return acl_set
return set(
[
prefix_user_email(user_email)
for user_email in self.user_emails
if user_email
]
+ [prefix_user_group(group_name) for group_name in self.user_groups]
+ [
prefix_user_email(user_email)
for user_email in self.external_user_emails
]
+ [
# The group names are already prefixed by the source type
# This adds an additional prefix of "external_group:"
prefix_external_group(group_name)
for group_name in self.external_user_group_ids
]
+ ([PUBLIC_DOC_PAT] if self.is_public else [])
)
@classmethod
def build(
@@ -121,32 +93,29 @@ class DocumentAccess(ExternalAccess):
external_user_group_ids: list[str],
is_public: bool,
) -> "DocumentAccess":
"""Don't prefix incoming data wth acl type, prefix on read from to_acl!"""
obj = object.__new__(cls)
object.__setattr__(
obj, "user_emails", {user_email for user_email in user_emails if user_email}
return cls(
external_user_emails={
prefix_user_email(external_email)
for external_email in external_user_emails
},
external_user_group_ids={
prefix_external_group(external_group_id)
for external_group_id in external_user_group_ids
},
user_emails={
prefix_user_email(user_email)
for user_email in user_emails
if user_email
},
user_groups=set(user_groups),
is_public=is_public,
)
object.__setattr__(obj, "user_groups", set(user_groups))
object.__setattr__(
obj,
"external_user_emails",
{external_email for external_email in external_user_emails},
)
object.__setattr__(
obj,
"external_user_group_ids",
{external_group_id for external_group_id in external_user_group_ids},
)
object.__setattr__(obj, "is_public", is_public)
return obj
default_public_access = DocumentAccess.build(
external_user_emails=[],
external_user_group_ids=[],
user_emails=[],
user_groups=[],
default_public_access = DocumentAccess(
external_user_emails=set(),
external_user_group_ids=set(),
user_emails=set(),
user_groups=set(),
is_public=True,
)

View File

@@ -7,6 +7,7 @@ from langgraph.types import StreamWriter
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
from onyx.chat.stream_processing.answer_response_handler import (
@@ -23,7 +24,7 @@ def process_llm_stream(
should_stream_answer: bool,
writer: StreamWriter,
final_search_results: list[LlmDoc] | None = None,
displayed_search_results: list[LlmDoc] | None = None,
displayed_search_results: list[OnyxContext] | list[LlmDoc] | None = None,
) -> AIMessageChunk:
tool_call_chunk = AIMessageChunk(content="")

View File

@@ -1,62 +0,0 @@
from collections.abc import Hashable
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import Send
from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
from onyx.agents.agent_search.dc_search_analysis.states import (
ObjectResearchInformationUpdate,
)
from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
from onyx.agents.agent_search.dc_search_analysis.states import (
SearchSourcesObjectsUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
def parallel_object_source_research_edge(
state: SearchSourcesObjectsUpdate, config: RunnableConfig
) -> list[Send | Hashable]:
"""
LangGraph edge to parallelize the research for an individual object and source
"""
search_objects = state.analysis_objects
search_sources = state.analysis_sources
object_source_combinations = [
(object, source) for object in search_objects for source in search_sources
]
return [
Send(
"research_object_source",
ObjectSourceInput(
object_source_combination=object_source_combination,
log_messages=[],
),
)
for object_source_combination in object_source_combinations
]
def parallel_object_research_consolidation_edge(
state: ObjectResearchInformationUpdate, config: RunnableConfig
) -> list[Send | Hashable]:
"""
LangGraph edge to parallelize the research for an individual object and source
"""
cast(GraphConfig, config["metadata"]["config"])
object_research_information_results = state.object_research_information_results
return [
Send(
"consolidate_object_research",
ObjectInformationInput(
object_information=object_information,
log_messages=[],
),
)
for object_information in object_research_information_results
]

View File

@@ -1,103 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.dc_search_analysis.edges import (
parallel_object_research_consolidation_edge,
)
from onyx.agents.agent_search.dc_search_analysis.edges import (
parallel_object_source_research_edge,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a1_search_objects import (
search_objects,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a2_research_object_source import (
research_object_source,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a3_structure_research_by_object import (
structure_research_by_object,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a4_consolidate_object_research import (
consolidate_object_research,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a5_consolidate_research import (
consolidate_research,
)
from onyx.agents.agent_search.dc_search_analysis.states import MainInput
from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.utils.logger import setup_logger
logger = setup_logger()
test_mode = False
def divide_and_conquer_graph_builder(test_mode: bool = False) -> StateGraph:
"""
LangGraph graph builder for the knowledge graph search process.
"""
graph = StateGraph(
state_schema=MainState,
input=MainInput,
)
### Add nodes ###
graph.add_node(
"search_objects",
search_objects,
)
graph.add_node(
"structure_research_by_source",
structure_research_by_object,
)
graph.add_node(
"research_object_source",
research_object_source,
)
graph.add_node(
"consolidate_object_research",
consolidate_object_research,
)
graph.add_node(
"consolidate_research",
consolidate_research,
)
### Add edges ###
graph.add_edge(start_key=START, end_key="search_objects")
graph.add_conditional_edges(
source="search_objects",
path=parallel_object_source_research_edge,
path_map=["research_object_source"],
)
graph.add_edge(
start_key="research_object_source",
end_key="structure_research_by_source",
)
graph.add_conditional_edges(
source="structure_research_by_source",
path=parallel_object_research_consolidation_edge,
path_map=["consolidate_object_research"],
)
graph.add_edge(
start_key="consolidate_object_research",
end_key="consolidate_research",
)
graph.add_edge(
start_key="consolidate_research",
end_key=END,
)
return graph

View File

@@ -1,159 +0,0 @@
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
from onyx.agents.agent_search.dc_search_analysis.ops import research
from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.agents.agent_search.dc_search_analysis.states import (
SearchSourcesObjectsUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.configs.constants import DocumentSource
from onyx.prompts.agents.dc_prompts import DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT
from onyx.prompts.agents.dc_prompts import DC_OBJECT_SEPARATOR
from onyx.prompts.agents.dc_prompts import DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def search_objects(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> SearchSourcesObjectsUpdate:
"""
LangGraph node to start the agentic search process.
"""
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
search_tool = graph_config.tooling.search_tool
if search_tool is None or graph_config.inputs.search_request.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
try:
instructions = graph_config.inputs.search_request.persona.prompts[
0
].system_prompt
agent_1_instructions = extract_section(
instructions, "Agent Step 1:", "Agent Step 2:"
)
if agent_1_instructions is None:
raise ValueError("Agent 1 instructions not found")
agent_1_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
agent_1_task = extract_section(
agent_1_instructions, "Task:", "Independent Research Sources:"
)
if agent_1_task is None:
raise ValueError("Agent 1 task not found")
agent_1_independent_sources_str = extract_section(
agent_1_instructions, "Independent Research Sources:", "Output Objective:"
)
if agent_1_independent_sources_str is None:
raise ValueError("Agent 1 Independent Research Sources not found")
document_sources = [
DocumentSource(x.strip().lower())
for x in agent_1_independent_sources_str.split(DC_OBJECT_SEPARATOR)
]
agent_1_output_objective = extract_section(
agent_1_instructions, "Output Objective:"
)
if agent_1_output_objective is None:
raise ValueError("Agent 1 output objective not found")
except Exception as e:
raise ValueError(
f"Agent 1 instructions not found or not formatted correctly: {e}"
)
# Extract objects
if agent_1_base_data is None:
# Retrieve chunks for objects
retrieved_docs = research(question, search_tool)[:10]
document_texts_list = []
for doc_num, doc in enumerate(retrieved_docs):
chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
document_texts_list.append(chunk_text)
document_texts = "\n\n".join(document_texts_list)
dc_object_extraction_prompt = DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT.format(
question=question,
task=agent_1_task,
document_text=document_texts,
objects_of_interest=agent_1_output_objective,
)
else:
dc_object_extraction_prompt = DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT.format(
question=question,
task=agent_1_task,
base_data=agent_1_base_data,
objects_of_interest=agent_1_output_objective,
)
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=dc_object_extraction_prompt,
reserved_str="",
),
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
30,
primary_llm.invoke,
prompt=msg,
timeout_override=30,
max_tokens=300,
)
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
cleaned_response = cleaned_response.split("OBJECTS:")[1]
object_list = [x.strip() for x in cleaned_response.split(";")]
except Exception as e:
raise ValueError(f"Error in search_objects: {e}")
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=" Researching the individual objects for each source type... ",
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
return SearchSourcesObjectsUpdate(
analysis_objects=object_list,
analysis_sources=document_sources,
log_messages=["Agent 1 Task done"],
)

View File

@@ -1,185 +0,0 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
from onyx.agents.agent_search.dc_search_analysis.ops import research
from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
from onyx.agents.agent_search.dc_search_analysis.states import (
ObjectSourceResearchUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.prompts.agents.dc_prompts import DC_OBJECT_SOURCE_RESEARCH_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def research_object_source(
state: ObjectSourceInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ObjectSourceResearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config.inputs.search_request.query
search_tool = graph_config.tooling.search_tool
question = graph_config.inputs.search_request.query
object, document_source = state.object_source_combination
if search_tool is None or graph_config.inputs.search_request.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
try:
instructions = graph_config.inputs.search_request.persona.prompts[
0
].system_prompt
agent_2_instructions = extract_section(
instructions, "Agent Step 2:", "Agent Step 3:"
)
if agent_2_instructions is None:
raise ValueError("Agent 2 instructions not found")
agent_2_task = extract_section(
agent_2_instructions, "Task:", "Independent Research Sources:"
)
if agent_2_task is None:
raise ValueError("Agent 2 task not found")
agent_2_time_cutoff = extract_section(
agent_2_instructions, "Time Cutoff:", "Research Topics:"
)
agent_2_research_topics = extract_section(
agent_2_instructions, "Research Topics:", "Output Objective"
)
agent_2_output_objective = extract_section(
agent_2_instructions, "Output Objective:"
)
if agent_2_output_objective is None:
raise ValueError("Agent 2 output objective not found")
except Exception:
raise ValueError(
"Agent 1 instructions not found or not formatted correctly: {e}"
)
# Populate prompt
# Retrieve chunks for objects
if agent_2_time_cutoff is not None and agent_2_time_cutoff.strip() != "":
if agent_2_time_cutoff.strip().endswith("d"):
try:
days = int(agent_2_time_cutoff.strip()[:-1])
agent_2_source_start_time = datetime.now(timezone.utc) - timedelta(
days=days
)
except ValueError:
raise ValueError(
f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
)
else:
raise ValueError(
f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
)
else:
agent_2_source_start_time = None
document_sources = [document_source] if document_source else None
if len(question.strip()) > 0:
research_area = f"{question} for {object}"
elif agent_2_research_topics and len(agent_2_research_topics.strip()) > 0:
research_area = f"{agent_2_research_topics} for {object}"
else:
research_area = object
retrieved_docs = research(
question=research_area,
search_tool=search_tool,
document_sources=document_sources,
time_cutoff=agent_2_source_start_time,
)
# Generate document text
document_texts_list = []
for doc_num, doc in enumerate(retrieved_docs):
chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
document_texts_list.append(chunk_text)
document_texts = "\n\n".join(document_texts_list)
# Built prompt
today = datetime.now().strftime("%A, %Y-%m-%d")
dc_object_source_research_prompt = (
DC_OBJECT_SOURCE_RESEARCH_PROMPT.format(
today=today,
question=question,
task=agent_2_task,
document_text=document_texts,
format=agent_2_output_objective,
)
.replace("---object---", object)
.replace("---source---", document_source.value)
)
# Run LLM
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=dc_object_source_research_prompt,
reserved_str="",
),
)
]
# fast_llm = graph_config.tooling.fast_llm
primary_llm = graph_config.tooling.primary_llm
llm = primary_llm
# Grader
try:
llm_response = run_with_timeout(
30,
llm.invoke,
prompt=msg,
timeout_override=30,
max_tokens=300,
)
cleaned_response = str(llm_response.content).replace("```json\n", "")
cleaned_response = cleaned_response.split("RESEARCH RESULTS:")[1]
object_research_results = {
"object": object,
"source": document_source.value,
"research_result": cleaned_response,
}
except Exception as e:
raise ValueError(f"Error in research_object_source: {e}")
logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
return ObjectSourceResearchUpdate(
object_source_research_results=[object_research_results],
log_messages=["Agent Step 2 done for one object"],
)

View File

@@ -1,68 +0,0 @@
from collections import defaultdict
from datetime import datetime
from typing import cast
from typing import Dict
from typing import List
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.agents.agent_search.dc_search_analysis.states import (
ObjectResearchInformationUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.utils.logger import setup_logger
logger = setup_logger()
def structure_research_by_object(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ObjectResearchInformationUpdate:
"""
LangGraph node to start the agentic search process.
"""
datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config.inputs.search_request.query
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=" consolidating the information across source types for each object...",
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
object_source_research_results = state.object_source_research_results
object_research_information_results: List[Dict[str, str]] = []
object_research_information_results_list: Dict[str, List[str]] = defaultdict(list)
for object_source_research in object_source_research_results:
object = object_source_research["object"]
source = object_source_research["source"]
research_result = object_source_research["research_result"]
object_research_information_results_list[object].append(
f"Source: {source}\n{research_result}"
)
for object, information in object_research_information_results_list.items():
object_research_information_results.append(
{"object": object, "information": "\n".join(information)}
)
logger.debug("DivCon Step A3 - Object Research Information Structuring - completed")
return ObjectResearchInformationUpdate(
object_research_information_results=object_research_information_results,
log_messages=["A3 - Object Research Information structured"],
)

View File

@@ -1,107 +0,0 @@
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
from onyx.agents.agent_search.dc_search_analysis.states import ObjectResearchUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.prompts.agents.dc_prompts import DC_OBJECT_CONSOLIDATION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def consolidate_object_research(
state: ObjectInformationInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ObjectResearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config.inputs.search_request.query
search_tool = graph_config.tooling.search_tool
question = graph_config.inputs.search_request.query
if search_tool is None or graph_config.inputs.search_request.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
instructions = graph_config.inputs.search_request.persona.prompts[0].system_prompt
agent_4_instructions = extract_section(
instructions, "Agent Step 4:", "Agent Step 5:"
)
if agent_4_instructions is None:
raise ValueError("Agent 4 instructions not found")
agent_4_output_objective = extract_section(
agent_4_instructions, "Output Objective:"
)
if agent_4_output_objective is None:
raise ValueError("Agent 4 output objective not found")
object_information = state.object_information
object = object_information["object"]
information = object_information["information"]
# Create a prompt for the object consolidation
dc_object_consolidation_prompt = DC_OBJECT_CONSOLIDATION_PROMPT.format(
question=question,
object=object,
information=information,
format=agent_4_output_objective,
)
# Run LLM
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=dc_object_consolidation_prompt,
reserved_str="",
),
)
]
graph_config.tooling.primary_llm
# fast_llm = graph_config.tooling.fast_llm
primary_llm = graph_config.tooling.primary_llm
llm = primary_llm
# Grader
try:
llm_response = run_with_timeout(
30,
llm.invoke,
prompt=msg,
timeout_override=30,
max_tokens=300,
)
cleaned_response = str(llm_response.content).replace("```json\n", "")
consolidated_information = cleaned_response.split("INFORMATION:")[1]
except Exception as e:
raise ValueError(f"Error in consolidate_object_research: {e}")
object_research_results = {
"object": object,
"research_result": consolidated_information,
}
logger.debug(
"DivCon Step A4 - Object Research Consolidation - completed for an object"
)
return ObjectResearchUpdate(
object_research_results=[object_research_results],
log_messages=["Agent Source Consilidation done"],
)

View File

@@ -1,164 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.agents.agent_search.dc_search_analysis.states import ResearchUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_NO_BASE_DATA_PROMPT
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_WITH_BASE_DATA_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def consolidate_research(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ResearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config.inputs.search_request.query
search_tool = graph_config.tooling.search_tool
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=" generating the answer\n\n\n",
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
if search_tool is None or graph_config.inputs.search_request.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
# Populate prompt
instructions = graph_config.inputs.search_request.persona.prompts[0].system_prompt
try:
agent_5_instructions = extract_section(
instructions, "Agent Step 5:", "Agent End"
)
if agent_5_instructions is None:
raise ValueError("Agent 5 instructions not found")
agent_5_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
agent_5_task = extract_section(
agent_5_instructions, "Task:", "Independent Research Sources:"
)
if agent_5_task is None:
raise ValueError("Agent 5 task not found")
agent_5_output_objective = extract_section(
agent_5_instructions, "Output Objective:"
)
if agent_5_output_objective is None:
raise ValueError("Agent 5 output objective not found")
except ValueError as e:
raise ValueError(
f"Instructions for Agent Step 5 were not properly formatted: {e}"
)
research_result_list = []
if agent_5_task.strip() == "*concatenate*":
object_research_results = state.object_research_results
for object_research_result in object_research_results:
object = object_research_result["object"]
research_result = object_research_result["research_result"]
research_result_list.append(f"Object: {object}\n\n{research_result}")
research_results = "\n\n".join(research_result_list)
else:
raise NotImplementedError("Only '*concatenate*' is currently supported")
# Create a prompt for the object consolidation
if agent_5_base_data is None:
dc_formatting_prompt = DC_FORMATTING_NO_BASE_DATA_PROMPT.format(
text=research_results,
format=agent_5_output_objective,
)
else:
dc_formatting_prompt = DC_FORMATTING_WITH_BASE_DATA_PROMPT.format(
base_data=agent_5_base_data,
text=research_results,
format=agent_5_output_objective,
)
# Run LLM
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=dc_formatting_prompt,
reserved_str="",
),
)
]
dispatch_timings: list[float] = []
primary_model = graph_config.tooling.primary_llm
def stream_initial_answer() -> list[str]:
response: list[str] = []
for message in primary_model.stream(msg, timeout_override=30, max_tokens=None):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
response.append(content)
return response
try:
_ = run_with_timeout(
60,
stream_initial_answer,
)
except Exception as e:
raise ValueError(f"Error in consolidate_research: {e}")
logger.debug("DivCon Step A5 - Final Generation - completed")
return ResearchUpdate(
research_results=research_results,
log_messages=["Agent Source Consilidation done"],
)

View File

@@ -1,61 +0,0 @@
from datetime import datetime
from typing import cast
from onyx.chat.models import LlmDoc
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceSection
from onyx.db.engine import get_session_with_current_tenant
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
def research(
question: str,
search_tool: SearchTool,
document_sources: list[DocumentSource] | None = None,
time_cutoff: datetime | None = None,
) -> list[LlmDoc]:
# new db session to avoid concurrency issues
callback_container: list[list[InferenceSection]] = []
retrieved_docs: list[LlmDoc] = []
with get_session_with_current_tenant() as db_session:
for tool_response in search_tool.run(
query=question,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=False,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=True,
document_sources=document_sources,
time_cutoff=time_cutoff,
),
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
retrieved_docs = cast(list[LlmDoc], tool_response.response)[:10]
break
return retrieved_docs
def extract_section(
text: str, start_marker: str, end_marker: str | None = None
) -> str | None:
"""Extract text between markers, returning None if markers not found"""
parts = text.split(start_marker)
if len(parts) == 1:
return None
after_start = parts[1].strip()
if not end_marker:
return after_start
extract = after_start.split(end_marker)[0]
return extract.strip()

View File

@@ -1,72 +0,0 @@
from operator import add
from typing import Annotated
from typing import Dict
from typing import TypedDict
from pydantic import BaseModel
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.configs.constants import DocumentSource
### States ###
class LoggerUpdate(BaseModel):
log_messages: Annotated[list[str], add] = []
class SearchSourcesObjectsUpdate(LoggerUpdate):
analysis_objects: list[str] = []
analysis_sources: list[DocumentSource] = []
class ObjectSourceInput(LoggerUpdate):
object_source_combination: tuple[str, DocumentSource]
class ObjectSourceResearchUpdate(LoggerUpdate):
object_source_research_results: Annotated[list[Dict[str, str]], add] = []
class ObjectInformationInput(LoggerUpdate):
object_information: Dict[str, str]
class ObjectResearchInformationUpdate(LoggerUpdate):
object_research_information_results: Annotated[list[Dict[str, str]], add] = []
class ObjectResearchUpdate(LoggerUpdate):
object_research_results: Annotated[list[Dict[str, str]], add] = []
class ResearchUpdate(LoggerUpdate):
research_results: str | None = None
## Graph Input State
class MainInput(CoreState):
pass
## Graph State
class MainState(
# This includes the core state
MainInput,
ToolChoiceInput,
ToolCallUpdate,
ToolChoiceUpdate,
SearchSourcesObjectsUpdate,
ObjectSourceResearchUpdate,
ObjectResearchInformationUpdate,
ObjectResearchUpdate,
ResearchUpdate,
):
pass
## Graph Output State - presently not used
class MainOutput(TypedDict):
log_messages: list[str]

View File

@@ -156,6 +156,7 @@ def generate_initial_answer(
for tool_response in yield_search_responses(
query=question,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,

View File

@@ -183,6 +183,7 @@ def generate_validate_refined_answer(
for tool_response in yield_search_responses(
query=question,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,

View File

@@ -57,6 +57,7 @@ def format_results(
for tool_response in yield_search_responses(
query=state.question,
get_retrieved_sections=lambda: reranked_documents,
get_reranked_sections=lambda: state.retrieved_documents,
get_final_context_sections=lambda: reranked_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,

View File

@@ -13,7 +13,9 @@ from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_utils import section_to_llm_doc
from onyx.tools.tool_implementations.search.search_utils import (
context_from_inference_section,
)
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
@@ -57,7 +59,9 @@ def basic_use_tool_response(
search_response_summary = cast(SearchResponseSummary, yield_item.response)
for section in search_response_summary.top_sections:
if section.center_chunk.document_id not in initial_search_results:
initial_search_results.append(section_to_llm_doc(section))
initial_search_results.append(
context_from_inference_section(section)
)
new_tool_call_chunk = AIMessageChunk(content="")
if not agent_config.behavior.skip_gen_ai_answer_generation:

View File

@@ -8,10 +8,6 @@ from langgraph.graph.state import CompiledStateGraph
from onyx.agents.agent_search.basic.graph_builder import basic_graph_builder
from onyx.agents.agent_search.basic.states import BasicInput
from onyx.agents.agent_search.dc_search_analysis.graph_builder import (
divide_and_conquer_graph_builder,
)
from onyx.agents.agent_search.dc_search_analysis.states import MainInput as DCMainInput
from onyx.agents.agent_search.deep_search.main.graph_builder import (
main_graph_builder as main_graph_builder_a,
)
@@ -86,7 +82,7 @@ def _parse_agent_event(
def manage_sync_streaming(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
graph_input: BasicInput | MainInput | DCMainInput,
graph_input: BasicInput | MainInput,
) -> Iterable[StreamEvent]:
message_id = config.persistence.message_id if config.persistence else None
for event in compiled_graph.stream(
@@ -100,7 +96,7 @@ def manage_sync_streaming(
def run_graph(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
input: BasicInput | MainInput | DCMainInput,
input: BasicInput | MainInput,
) -> AnswerStream:
config.behavior.perform_initial_search_decomposition = (
INITIAL_SEARCH_DECOMPOSITION_ENABLED
@@ -150,16 +146,6 @@ def run_basic_graph(
return run_graph(compiled_graph, config, input)
def run_dc_graph(
config: GraphConfig,
) -> AnswerStream:
graph = divide_and_conquer_graph_builder()
compiled_graph = graph.compile()
input = DCMainInput(log_messages=[])
config.inputs.search_request.query = config.inputs.search_request.query.strip()
return run_graph(compiled_graph, config, input)
if __name__ == "__main__":
for _ in range(1):
query_start_time = datetime.now()

View File

@@ -180,35 +180,3 @@ def binary_string_test_after_answer_separator(
relevant_text = text.split(f"{separator}")[-1]
return binary_string_test(relevant_text, positive_value)
def build_dc_search_prompt(
question: str,
original_question: str,
docs: list[InferenceSection],
persona_specification: str,
config: LLMConfig,
) -> list[SystemMessage | HumanMessage | AIMessage | ToolMessage]:
system_message = SystemMessage(
content=persona_specification,
)
date_str = build_date_time_string()
docs_str = format_docs(docs)
docs_str = trim_prompt_piece(
config,
docs_str,
SUB_QUESTION_RAG_PROMPT + question + original_question + date_str,
)
human_message = HumanMessage(
content=SUB_QUESTION_RAG_PROMPT.format(
question=question,
original_question=original_question,
context=docs_str,
date_prompt=date_str,
)
)
return [system_message, human_message]

View File

@@ -321,10 +321,8 @@ def dispatch_separated(
sep: str = DISPATCH_SEP_CHAR,
) -> list[BaseMessage_Content]:
num = 1
accumulated_tokens = ""
streamed_tokens: list[BaseMessage_Content] = []
for token in tokens:
accumulated_tokens += cast(str, token.content)
content = cast(str, token.content)
if sep in content:
sub_question_parts = content.split(sep)

View File

@@ -16,14 +16,13 @@ from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME
from onyx.configs.constants import ONYX_SLACK_URL
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
from onyx.db.models import User
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.file import FileWithMimeType
from onyx.utils.url import add_url_params
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import MULTI_TENANT
HTML_EMAIL_TEMPLATE = """\
<!DOCTYPE html>
<html lang="en">
@@ -63,11 +62,6 @@ HTML_EMAIL_TEMPLATE = """\
}}
.header img {{
max-width: 140px;
width: 140px;
height: auto;
filter: brightness(1.1) contrast(1.2);
border-radius: 8px;
padding: 5px;
}}
.body-content {{
padding: 20px 30px;
@@ -84,16 +78,12 @@ HTML_EMAIL_TEMPLATE = """\
}}
.cta-button {{
display: inline-block;
padding: 14px 24px;
background-color: #0055FF;
padding: 12px 20px;
background-color: #000000;
color: #ffffff !important;
text-decoration: none;
border-radius: 4px;
font-weight: 600;
font-size: 16px;
margin-top: 10px;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
text-align: center;
font-weight: 500;
}}
.footer {{
font-size: 13px;
@@ -176,7 +166,6 @@ def send_email(
if not EMAIL_CONFIGURED:
raise ValueError("Email is not configured.")
# Create a multipart/alternative message - this indicates these are alternative versions of the same content
msg = MIMEMultipart("alternative")
msg["Subject"] = subject
msg["To"] = user_email
@@ -185,30 +174,17 @@ def send_email(
msg["Date"] = formatdate(localtime=True)
msg["Message-ID"] = make_msgid(domain="onyx.app")
# Add text part first (lowest priority)
text_part = MIMEText(text_body, "plain")
msg.attach(text_part)
part_text = MIMEText(text_body, "plain")
part_html = MIMEText(html_body, "html")
msg.attach(part_text)
msg.attach(part_html)
if inline_png:
# For HTML with images, create a multipart/related container
related = MIMEMultipart("related")
# Add the HTML part to the related container
html_part = MIMEText(html_body, "html")
related.attach(html_part)
# Add image with proper Content-ID to the related container
img = MIMEImage(inline_png[1], _subtype="png")
img.add_header("Content-ID", f"<{inline_png[0]}>")
img.add_header("Content-ID", inline_png[0]) # CID reference
img.add_header("Content-Disposition", "inline", filename=inline_png[0])
related.attach(img)
# Add the related part to the message (higher priority than text)
msg.attach(related)
else:
# No images, just add HTML directly (higher priority than text)
html_part = MIMEText(html_body, "html")
msg.attach(html_part)
msg.attach(img)
try:
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
@@ -356,23 +332,17 @@ def send_forgot_password_email(
onyx_file = OnyxRuntime.get_emailable_logo()
subject = f"Reset Your {application_name} Password"
heading = "Reset Your Password"
tenant_param = f"&tenant={tenant_id}" if tenant_id and MULTI_TENANT else ""
message = "<p>Please click the button below to reset your password. This link will expire in 24 hours.</p>"
cta_text = "Reset Password"
cta_link = f"{WEB_DOMAIN}/auth/reset-password?token={token}{tenant_param}"
subject = f"{application_name} Forgot Password"
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
if MULTI_TENANT:
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
html_content = build_html_email(
application_name,
heading,
"Reset Your Password",
message,
cta_text,
cta_link,
)
text_content = (
f"Please click the following link to reset your password. This link will expire in 24 hours.\n"
f"{WEB_DOMAIN}/auth/reset-password?token={token}{tenant_param}"
)
text_content = f"Click the following link to reset your password: {link}"
send_email(
user_email,
subject,
@@ -386,7 +356,6 @@ def send_forgot_password_email(
def send_user_verification_email(
user_email: str,
token: str,
new_organization: bool = False,
mail_from: str = EMAIL_FROM,
) -> None:
# Builds a verification email
@@ -403,8 +372,6 @@ def send_user_verification_email(
subject = f"{application_name} Email Verification"
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
if new_organization:
link = add_url_params(link, {"first_user": "true"})
message = (
f"<p>Click the following link to verify your email address:</p><p>{link}</p>"
)

View File

@@ -1,211 +0,0 @@
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Optional
import httpx
from fastapi_users.manager import BaseUserManager
from sqlalchemy.ext.asyncio import AsyncSession
from onyx.configs.app_configs import OAUTH_CLIENT_ID
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
from onyx.db.models import OAuthAccount
from onyx.db.models import User
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Standard OAuth refresh token endpoints
REFRESH_ENDPOINTS = {
"google": "https://oauth2.googleapis.com/token",
}
# NOTE: Keeping this as a utility function for potential future debugging,
# but not using it in production code
async def _test_expire_oauth_token(
user: User,
oauth_account: OAuthAccount,
db_session: AsyncSession,
user_manager: BaseUserManager[User, Any],
expire_in_seconds: int = 10,
) -> bool:
"""
Utility function for testing - Sets an OAuth token to expire in a short time
to facilitate testing of the refresh flow.
Not used in production code.
"""
try:
new_expires_at = int(
(datetime.now(timezone.utc).timestamp() + expire_in_seconds)
)
updated_data: Dict[str, Any] = {"expires_at": new_expires_at}
await user_manager.user_db.update_oauth_account(
user, cast(Any, oauth_account), updated_data
)
return True
except Exception as e:
logger.exception(f"Error setting artificial expiration: {str(e)}")
return False
async def refresh_oauth_token(
user: User,
oauth_account: OAuthAccount,
db_session: AsyncSession,
user_manager: BaseUserManager[User, Any],
) -> bool:
"""
Attempt to refresh an OAuth token that's about to expire or has expired.
Returns True if successful, False otherwise.
"""
if not oauth_account.refresh_token:
logger.warning(
f"No refresh token available for {user.email}'s {oauth_account.oauth_name} account"
)
return False
provider = oauth_account.oauth_name
if provider not in REFRESH_ENDPOINTS:
logger.warning(f"Refresh endpoint not configured for provider: {provider}")
return False
try:
logger.info(f"Refreshing OAuth token for {user.email}'s {provider} account")
async with httpx.AsyncClient() as client:
response = await client.post(
REFRESH_ENDPOINTS[provider],
data={
"client_id": OAUTH_CLIENT_ID,
"client_secret": OAUTH_CLIENT_SECRET,
"refresh_token": oauth_account.refresh_token,
"grant_type": "refresh_token",
},
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
if response.status_code != 200:
logger.error(
f"Failed to refresh OAuth token: Status {response.status_code}"
)
return False
token_data = response.json()
new_access_token = token_data.get("access_token")
new_refresh_token = token_data.get(
"refresh_token", oauth_account.refresh_token
)
expires_in = token_data.get("expires_in")
# Calculate new expiry time if provided
new_expires_at: Optional[int] = None
if expires_in:
new_expires_at = int(
(datetime.now(timezone.utc).timestamp() + expires_in)
)
# Update the OAuth account
updated_data: Dict[str, Any] = {
"access_token": new_access_token,
"refresh_token": new_refresh_token,
}
if new_expires_at:
updated_data["expires_at"] = new_expires_at
# Update oidc_expiry in user model if we're tracking it
if TRACK_EXTERNAL_IDP_EXPIRY:
oidc_expiry = datetime.fromtimestamp(
new_expires_at, tz=timezone.utc
)
await user_manager.user_db.update(
user, {"oidc_expiry": oidc_expiry}
)
# Update the OAuth account
await user_manager.user_db.update_oauth_account(
user, cast(Any, oauth_account), updated_data
)
logger.info(f"Successfully refreshed OAuth token for {user.email}")
return True
except Exception as e:
logger.exception(f"Error refreshing OAuth token: {str(e)}")
return False
async def check_and_refresh_oauth_tokens(
user: User,
db_session: AsyncSession,
user_manager: BaseUserManager[User, Any],
) -> None:
"""
Check if any OAuth tokens are expired or about to expire and refresh them.
"""
if not hasattr(user, "oauth_accounts") or not user.oauth_accounts:
return
now_timestamp = datetime.now(timezone.utc).timestamp()
# Buffer time to refresh tokens before they expire (in seconds)
buffer_seconds = 300 # 5 minutes
for oauth_account in user.oauth_accounts:
# Skip accounts without refresh tokens
if not oauth_account.refresh_token:
continue
# If token is about to expire, refresh it
if (
oauth_account.expires_at
and oauth_account.expires_at - now_timestamp < buffer_seconds
):
logger.info(f"OAuth token for {user.email} is about to expire - refreshing")
success = await refresh_oauth_token(
user, oauth_account, db_session, user_manager
)
if not success:
logger.warning(
"Failed to refresh OAuth token. User may need to re-authenticate."
)
async def check_oauth_account_has_refresh_token(
user: User,
oauth_account: OAuthAccount,
) -> bool:
"""
Check if an OAuth account has a refresh token.
Returns True if a refresh token exists, False otherwise.
"""
return bool(oauth_account.refresh_token)
async def get_oauth_accounts_requiring_refresh_token(user: User) -> List[OAuthAccount]:
"""
Returns a list of OAuth accounts for a user that are missing refresh tokens.
These accounts will need re-authentication to get refresh tokens.
"""
if not hasattr(user, "oauth_accounts") or not user.oauth_accounts:
return []
accounts_needing_refresh = []
for oauth_account in user.oauth_accounts:
has_refresh_token = await check_oauth_account_has_refresh_token(
user, oauth_account
)
if not has_refresh_token:
accounts_needing_refresh.append(oauth_account)
return accounts_needing_refresh

View File

@@ -5,16 +5,12 @@ import string
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Optional
from typing import Protocol
from typing import Tuple
from typing import TypeVar
import jwt
from email_validator import EmailNotValidError
@@ -56,7 +52,6 @@ from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from onyx.auth.api_key import get_hashed_api_key_from_request
from onyx.auth.email_utils import send_forgot_password_email
from onyx.auth.email_utils import send_user_verification_email
@@ -361,6 +356,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reason="Password must contain at least one special character from the following set: "
f"{PASSWORD_SPECIAL_CHARS}."
)
return
async def oauth_callback(
@@ -514,25 +510,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
return user
async def on_after_login(
self,
user: User,
request: Optional[Request] = None,
response: Optional[Response] = None,
) -> None:
try:
if response and request and ANONYMOUS_USER_COOKIE_NAME in request.cookies:
response.delete_cookie(
ANONYMOUS_USER_COOKIE_NAME,
# Ensure cookie deletion doesn't override other cookies by setting the same path/domain
path="/",
domain=None,
secure=WEB_DOMAIN.startswith("https"),
)
logger.debug(f"Deleted anonymous user cookie for user {user.email}")
except Exception:
logger.exception("Error deleting anonymous user cookie")
async def on_after_register(
self, user: User, request: Optional[Request] = None
) -> None:
@@ -604,10 +581,8 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
logger.notice(
f"Verification requested for user {user.id}. Verification token: {token}"
)
user_count = await get_user_count()
send_user_verification_email(
user.email, token, new_organization=user_count == 1
)
send_user_verification_email(user.email, token)
async def authenticate(
self, credentials: OAuth2PasswordRequestForm
@@ -713,20 +688,16 @@ cookie_transport = CookieTransport(
)
T = TypeVar("T", covariant=True)
ID = TypeVar("ID", contravariant=True)
def get_redis_strategy() -> RedisStrategy:
return TenantAwareRedisStrategy()
# Protocol for strategies that support token refreshing without inheritance.
class RefreshableStrategy(Protocol):
"""Protocol for authentication strategies that support token refreshing."""
async def refresh_token(self, token: Optional[str], user: Any) -> str:
"""
Refresh an existing token by extending its lifetime.
Returns either the same token with extended expiration or a new token.
"""
...
def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> DatabaseStrategy:
return DatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
)
class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
@@ -785,75 +756,6 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
redis = await get_async_redis_connection()
await redis.delete(f"{self.key_prefix}{token}")
async def refresh_token(self, token: Optional[str], user: User) -> str:
"""Refresh a token by extending its expiration time in Redis."""
if token is None:
# If no token provided, create a new one
return await self.write_token(user)
redis = await get_async_redis_connection()
token_key = f"{self.key_prefix}{token}"
# Check if token exists
token_data_str = await redis.get(token_key)
if not token_data_str:
# Token not found, create new one
return await self.write_token(user)
# Token exists, extend its lifetime
token_data = json.loads(token_data_str)
await redis.set(
token_key,
json.dumps(token_data),
ex=self.lifetime_seconds,
)
return token
class RefreshableDatabaseStrategy(DatabaseStrategy[User, uuid.UUID, AccessToken]):
"""Database strategy with token refreshing capabilities."""
def __init__(
self,
access_token_db: AccessTokenDatabase[AccessToken],
lifetime_seconds: Optional[int] = None,
):
super().__init__(access_token_db, lifetime_seconds)
self._access_token_db = access_token_db
async def refresh_token(self, token: Optional[str], user: User) -> str:
"""Refresh a token by updating its expiration time in the database."""
if token is None:
return await self.write_token(user)
# Find the token in database
access_token = await self._access_token_db.get_by_token(token)
if access_token is None:
# Token not found, create new one
return await self.write_token(user)
# Update expiration time
new_expires = datetime.now(timezone.utc) + timedelta(
seconds=float(self.lifetime_seconds or SESSION_EXPIRE_TIME_SECONDS)
)
await self._access_token_db.update(access_token, {"expires": new_expires})
return token
def get_redis_strategy() -> TenantAwareRedisStrategy:
return TenantAwareRedisStrategy()
def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> RefreshableDatabaseStrategy:
return RefreshableDatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
)
if AUTH_BACKEND == AuthBackend.REDIS:
auth_backend = AuthenticationBackend(
@@ -904,88 +806,6 @@ class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
return router
def get_refresh_router(
self,
backend: AuthenticationBackend,
requires_verification: bool = REQUIRE_EMAIL_VERIFICATION,
) -> APIRouter:
"""
Provide a router for session token refreshing.
"""
# Import the oauth_refresher here to avoid circular imports
from onyx.auth.oauth_refresher import check_and_refresh_oauth_tokens
router = APIRouter()
get_current_user_token = self.authenticator.current_user_token(
active=True, verified=requires_verification
)
refresh_responses: OpenAPIResponseType = {
**{
status.HTTP_401_UNAUTHORIZED: {
"description": "Missing token or inactive user."
}
},
**backend.transport.get_openapi_login_responses_success(),
}
@router.post(
"/refresh", name=f"auth:{backend.name}.refresh", responses=refresh_responses
)
async def refresh(
user_token: Tuple[models.UP, str] = Depends(get_current_user_token),
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(
get_user_manager
),
db_session: AsyncSession = Depends(get_async_session),
) -> Response:
try:
user, token = user_token
logger.info(f"Processing token refresh request for user {user.email}")
# Check if user has OAuth accounts that need refreshing
await check_and_refresh_oauth_tokens(
user=cast(User, user),
db_session=db_session,
user_manager=cast(Any, user_manager),
)
# Check if strategy supports refreshing
supports_refresh = hasattr(strategy, "refresh_token") and callable(
getattr(strategy, "refresh_token")
)
if supports_refresh:
try:
refresh_method = getattr(strategy, "refresh_token")
new_token = await refresh_method(token, user)
logger.info(
f"Successfully refreshed session token for user {user.email}"
)
return await backend.transport.get_login_response(new_token)
except Exception as e:
logger.error(f"Error refreshing session token: {str(e)}")
# Fallback to logout and login if refresh fails
await backend.logout(strategy, user, token)
return await backend.login(strategy, user)
# Fallback: logout and login again
logger.info(
"Strategy doesn't support refresh - using logout/login flow"
)
await backend.logout(strategy, user, token)
return await backend.login(strategy, user)
except Exception as e:
logger.error(f"Unexpected error in refresh endpoint: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Token refresh failed: {str(e)}",
)
return router
fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
get_user_manager, [auth_backend]
@@ -1219,20 +1039,12 @@ def get_oauth_router(
"referral_source": referral_source or "default_referral",
}
state = generate_state_token(state_data, state_secret)
# Get the basic authorization URL
authorization_url = await oauth_client.get_authorization_url(
authorize_redirect_url,
state,
scopes,
)
# For Google OAuth, add parameters to request refresh tokens
if oauth_client.name == "google":
authorization_url = add_url_params(
authorization_url, {"access_type": "offline", "prompt": "consent"}
)
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
@router.get(
@@ -1322,7 +1134,6 @@ def get_oauth_router(
# Login user
response = await backend.login(strategy, user)
await user_manager.on_after_login(user, request, response)
# Prepare redirect response
if tenant_id is None:
# Use URL utility to add parameters
@@ -1332,14 +1143,9 @@ def get_oauth_router(
# No parameters to add
redirect_response = RedirectResponse(next_url, status_code=302)
# Copy headers from auth response to redirect response, with special handling for Set-Cookie
# Copy headers and other attributes from 'response' to 'redirect_response'
for header_name, header_value in response.headers.items():
# FastAPI can have multiple Set-Cookie headers as a list
if header_name.lower() == "set-cookie" and isinstance(header_value, list):
for cookie_value in header_value:
redirect_response.headers.append(header_name, cookie_value)
else:
redirect_response.headers[header_name] = header_value
redirect_response.headers[header_name] = header_value
if hasattr(response, "body"):
redirect_response.body = response.body

View File

@@ -1,6 +1,5 @@
import logging
import multiprocessing
import os
import time
from typing import Any
from typing import cast
@@ -35,6 +34,7 @@ from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGrou
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_shared_redis_client
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import ColoredFormatter
from onyx.utils.logger import PlainFormatter
@@ -225,7 +225,7 @@ def wait_for_redis(sender: Any, **kwargs: Any) -> None:
Will raise WorkerShutdown to kill the celery worker if the timeout
is reached."""
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
r = get_shared_redis_client()
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
@@ -306,12 +306,12 @@ def wait_for_db(sender: Any, **kwargs: Any) -> None:
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info(f"Running as a secondary celery worker: pid={os.getpid()}")
logger.info("Running as a secondary celery worker.")
# Set up variables for waiting on primary worker
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
r = get_shared_redis_client()
time_start = time.monotonic()
logger.info("Waiting for primary worker to be ready...")

View File

@@ -1,5 +1,6 @@
from datetime import timedelta
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
@@ -9,10 +10,12 @@ from celery.utils.log import get_task_logger
import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import SqlEngine
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
from shared_configs.configs import MULTI_TENANT
@@ -138,6 +141,8 @@ class DynamicTenantScheduler(PersistentScheduler):
"""Only updates the actual beat schedule on the celery app when it changes"""
do_update = False
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
task_logger.debug("_try_updating_schedule starting")
tenant_ids = get_all_tenant_ids()
@@ -147,7 +152,16 @@ class DynamicTenantScheduler(PersistentScheduler):
current_schedule = self.schedule.items()
# get potential new state
beat_multiplier = OnyxRuntime.get_beat_multiplier()
beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
beat_multiplier_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:beat_multiplier")
if beat_multiplier_raw is not None:
try:
beat_multiplier_bytes = cast(bytes, beat_multiplier_raw)
beat_multiplier = float(beat_multiplier_bytes.decode())
except ValueError:
task_logger.error(
f"Invalid beat_multiplier value: {beat_multiplier_raw}"
)
new_schedule = self._generate_schedule(tenant_ids, beat_multiplier)

View File

@@ -1,7 +0,0 @@
from celery import Celery
import onyx.background.celery.apps.app_base as app_base
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.client")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]

View File

@@ -111,7 +111,6 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.user_file_folder_sync",
"onyx.background.celery.tasks.indexing",
"onyx.background.celery.tasks.tenant_provisioning",
]

View File

@@ -1,5 +1,4 @@
import logging
import os
from typing import Any
from typing import cast
@@ -39,11 +38,10 @@ from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_connector_stop import RedisConnectorStop
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_shared_redis_client
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
logger = setup_logger()
@@ -96,7 +94,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
logger.info(f"Running as the primary celery worker: pid={os.getpid()}")
logger.info("Running as the primary celery worker.")
# Less startup checks in multi-tenant case
if MULTI_TENANT:
@@ -104,7 +102,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
# This is singleton work that should be done on startup exactly once
# by the primary worker. This is unnecessary in the multi tenant scenario
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
r = get_shared_redis_client()
# Log the role and slave count - being connected to a slave or slave count > 0 could be problematic
info: dict[str, Any] = cast(dict, r.info("replication"))
@@ -175,9 +173,6 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
f"search_settings={attempt.search_settings_id}"
)
logger.warning(failure_reason)
logger.exception(
f"Marking attempt {attempt.id} as canceled due to validation error 2"
)
mark_attempt_canceled(attempt.id, db_session, failure_reason)
@@ -240,7 +235,7 @@ class HubPeriodicTask(bootsteps.StartStopStep):
lock: RedisLock = worker.primary_worker_lock
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
r = get_shared_redis_client()
if lock.owned():
task_logger.debug("Reacquiring primary worker lock.")
@@ -289,6 +284,5 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.llm_model_update",
"onyx.background.celery.tasks.user_file_folder_sync",
]
)

View File

@@ -1,16 +0,0 @@
import onyx.background.celery.configs.base as shared_config
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late

View File

@@ -14,7 +14,7 @@ logger = setup_logger()
# Only set up memory monitoring in container environment
if is_running_in_container():
# Set up a dedicated memory monitoring logger
MEMORY_LOG_DIR = "/var/log/memory"
MEMORY_LOG_DIR = "/var/log/persisted-logs/memory"
MEMORY_LOG_FILE = os.path.join(MEMORY_LOG_DIR, "memory_usage.log")
MEMORY_LOG_MAX_BYTES = 10 * 1024 * 1024 # 10MB
MEMORY_LOG_BACKUP_COUNT = 5 # Keep 5 backup files

View File

@@ -21,7 +21,6 @@ BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
# we have a better implementation (backpressure, etc)
# Note that DynamicTenantScheduler can adjust the runtime value for this via Redis
CLOUD_BEAT_MULTIPLIER_DEFAULT = 8.0
CLOUD_DOC_PERMISSION_SYNC_MULTIPLIER_DEFAULT = 1.0
# tasks that run in either self-hosted on cloud
beat_task_templates: list[dict] = []
@@ -64,15 +63,6 @@ beat_task_templates.extend(
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-user-file-folder-sync",
"task": OnyxCeleryTask.CHECK_FOR_USER_FILE_FOLDER_SYNC,
"schedule": timedelta(seconds=30),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-pruning",
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,

View File

@@ -30,9 +30,6 @@ from onyx.db.connector_credential_pair import (
)
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.document import (
delete_all_documents_by_connector_credential_pair__no_commit,
)
from onyx.db.document import get_document_ids_for_connector_credential_pair
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from onyx.db.engine import get_session_with_current_tenant
@@ -389,8 +386,6 @@ def monitor_connector_deletion_taskset(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
credential_id_to_delete: int | None = None
connector_id_to_delete: int | None = None
if not cc_pair:
task_logger.warning(
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
@@ -445,35 +440,16 @@ def monitor_connector_deletion_taskset(
db_session=db_session,
)
# Store IDs before potentially expiring cc_pair
connector_id_to_delete = cc_pair.connector_id
credential_id_to_delete = cc_pair.credential_id
# Explicitly delete document by connector credential pair records before deleting the connector
# This is needed because connector_id is a primary key in that table and cascading deletes won't work
delete_all_documents_by_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=connector_id_to_delete,
credential_id=credential_id_to_delete,
)
# Flush to ensure document deletion happens before connector deletion
db_session.flush()
# Expire the cc_pair to ensure SQLAlchemy doesn't try to manage its state
# related to the deleted DocumentByConnectorCredentialPair during commit
db_session.expire(cc_pair)
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=connector_id_to_delete,
credential_id=credential_id_to_delete,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# if there are no credentials left, delete the connector
connector = fetch_connector_by_id(
db_session=db_session,
connector_id=connector_id_to_delete,
connector_id=cc_pair.connector_id,
)
if not connector or not len(connector.credentials):
task_logger.info(
@@ -506,15 +482,15 @@ def monitor_connector_deletion_taskset(
task_logger.exception(
f"Connector deletion exceptioned: "
f"cc_pair={cc_pair_id} connector={connector_id_to_delete} credential={credential_id_to_delete}"
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
)
raise e
task_logger.info(
f"Connector deletion succeeded: "
f"cc_pair={cc_pair_id} "
f"connector={connector_id_to_delete} "
f"credential={credential_id_to_delete} "
f"connector={cc_pair.connector_id} "
f"credential={cc_pair.credential_id} "
f"docs_deleted={fence_data.num_tasks}"
)
@@ -564,7 +540,7 @@ def validate_connector_deletion_fences(
def validate_connector_deletion_fence(
tenant_id: str,
key_bytes: bytes,
queued_upsert_tasks: set[str],
queued_tasks: set[str],
r: Redis,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
@@ -651,7 +627,7 @@ def validate_connector_deletion_fence(
member_bytes = cast(bytes, member)
member_str = member_bytes.decode("utf-8")
if member_str in queued_upsert_tasks:
if member_str in queued_tasks:
continue
tasks_not_in_celery += 1

View File

@@ -17,7 +17,6 @@ from redis.exceptions import LockError
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.onyx.db.document import upsert_document_external_perms
from ee.onyx.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS
@@ -64,14 +63,11 @@ from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSyn
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.server.utils import make_short_id
from onyx.utils.logger import doc_permission_sync_ctx
from onyx.utils.logger import format_error_for_logging
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
logger = setup_logger()
@@ -108,10 +104,9 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
source_sync_period = DOC_PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
if not source_sync_period:
source_sync_period = DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
source_sync_period *= int(OnyxRuntime.get_doc_permission_sync_multiplier())
return True
# If the last sync is greater than the full fetch period, we run the sync
next_sync = last_perm_sync + timedelta(seconds=source_sync_period)
@@ -289,7 +284,7 @@ def try_creating_permissions_sync_task(
),
queue=OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
task_id=custom_task_id,
priority=OnyxCeleryPriority.MEDIUM,
priority=OnyxCeleryPriority.HIGH,
)
# fill in the celery task id
@@ -880,18 +875,6 @@ def monitor_ccpair_permissions_taskset(
f"remaining={remaining} "
f"initial={initial}"
)
# Add telemetry for permission syncing progress
optional_telemetry(
record_type=RecordType.PERMISSION_SYNC_PROGRESS,
data={
"cc_pair_id": cc_pair_id,
"total_docs_synced": initial if initial is not None else 0,
"remaining_docs_to_sync": remaining,
},
tenant_id=tenant_id,
)
if remaining > 0:
return
@@ -903,13 +886,6 @@ def monitor_ccpair_permissions_taskset(
f"num_synced={initial}"
)
# Add telemetry for permission syncing complete
optional_telemetry(
record_type=RecordType.PERMISSION_SYNC_COMPLETE,
data={"cc_pair_id": cc_pair_id},
tenant_id=tenant_id,
)
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,

View File

@@ -271,7 +271,7 @@ def try_creating_external_group_sync_task(
),
queue=OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
task_id=custom_task_id,
priority=OnyxCeleryPriority.MEDIUM,
priority=OnyxCeleryPriority.HIGH,
)
payload.celery_task_id = result.id

View File

@@ -72,7 +72,6 @@ from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.redis.redis_utils import is_fence
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
@@ -365,7 +364,6 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
Occcasionally does some validation of existing state to clear up error conditions"""
time_start = time.monotonic()
task_logger.warning("check_for_indexing - Starting")
tasks_created = 0
locked = False
@@ -403,11 +401,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
logger.warning(f"Adding {key_bytes} to the lookup table.")
redis_client.sadd(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
redis_client.set(
OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE,
1,
ex=OnyxRuntime.get_build_fence_lookup_table_interval(),
)
redis_client.set(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE, 1, ex=300)
# 1/3: KICKOFF
@@ -434,9 +428,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
lock_beat.reacquire()
cc_pair_ids: list[int] = []
with get_session_with_current_tenant() as db_session:
cc_pairs = fetch_connector_credential_pairs(
db_session, include_user_files=True
)
cc_pairs = fetch_connector_credential_pairs(db_session)
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
@@ -455,18 +447,12 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
not search_settings_instance.status.is_current()
and not search_settings_instance.background_reindex_enabled
):
task_logger.warning("SKIPPING DUE TO NON-LIVE SEARCH SETTINGS")
continue
redis_connector_index = redis_connector.new_index(
search_settings_instance.id
)
if redis_connector_index.fenced:
task_logger.info(
f"check_for_indexing - Skipping fenced connector: "
f"cc_pair={cc_pair_id} search_settings={search_settings_instance.id}"
)
continue
cc_pair = get_connector_credential_pair_from_id(
@@ -474,9 +460,6 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
cc_pair_id=cc_pair_id,
)
if not cc_pair:
task_logger.warning(
f"check_for_indexing - CC pair not found: cc_pair={cc_pair_id}"
)
continue
last_attempt = get_last_attempt_for_cc_pair(
@@ -490,20 +473,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
secondary_index_building=len(search_settings_list) > 1,
db_session=db_session,
):
task_logger.info(
f"check_for_indexing - Not indexing cc_pair_id: {cc_pair_id} "
f"search_settings={search_settings_instance.id}, "
f"last_attempt={last_attempt.id if last_attempt else None}, "
f"secondary_index_building={len(search_settings_list) > 1}"
)
continue
else:
task_logger.info(
f"check_for_indexing - Will index cc_pair_id: {cc_pair_id} "
f"search_settings={search_settings_instance.id}, "
f"last_attempt={last_attempt.id if last_attempt else None}, "
f"secondary_index_building={len(search_settings_list) > 1}"
)
reindex = False
if search_settings_instance.status.is_current():
@@ -542,12 +512,6 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
f"search_settings={search_settings_instance.id}"
)
tasks_created += 1
else:
task_logger.info(
f"Failed to create indexing task: "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings_instance.id}"
)
lock_beat.reacquire()
@@ -1180,9 +1144,6 @@ def connector_indexing_proxy_task(
if result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL:
try:
with get_session_with_current_tenant() as db_session:
logger.exception(
f"Marking attempt {index_attempt_id} as canceled due to termination signal"
)
mark_attempt_canceled(
index_attempt_id,
db_session,

View File

@@ -371,7 +371,6 @@ def should_index(
# don't kick off indexing for `NOT_APPLICABLE` sources
if connector.source == DocumentSource.NOT_APPLICABLE:
print(f"Not indexing cc_pair={cc_pair.id}: NOT_APPLICABLE source")
return False
# User can still manually create single indexing attempts via the UI for the
@@ -381,9 +380,6 @@ def should_index(
search_settings_instance.status == IndexModelStatus.PRESENT
and secondary_index_building
):
print(
f"Not indexing cc_pair={cc_pair.id}: DISABLE_INDEX_UPDATE_ON_SWAP is True and secondary index building"
)
return False
# When switching over models, always index at least once
@@ -392,31 +388,19 @@ def should_index(
# No new index if the last index attempt succeeded
# Once is enough. The model will never be able to swap otherwise.
if last_index.status == IndexingStatus.SUCCESS:
print(
f"Not indexing cc_pair={cc_pair.id}: FUTURE model with successful last index attempt={last_index.id}"
)
return False
# No new index if the last index attempt is waiting to start
if last_index.status == IndexingStatus.NOT_STARTED:
print(
f"Not indexing cc_pair={cc_pair.id}: FUTURE model with NOT_STARTED last index attempt={last_index.id}"
)
return False
# No new index if the last index attempt is running
if last_index.status == IndexingStatus.IN_PROGRESS:
print(
f"Not indexing cc_pair={cc_pair.id}: FUTURE model with IN_PROGRESS last index attempt={last_index.id}"
)
return False
else:
if (
connector.id == 0 or connector.source == DocumentSource.INGESTION_API
): # Ingestion API
print(
f"Not indexing cc_pair={cc_pair.id}: FUTURE model with Ingestion API source"
)
return False
return True
@@ -428,9 +412,6 @@ def should_index(
or connector.id == 0
or connector.source == DocumentSource.INGESTION_API
):
print(
f"Not indexing cc_pair={cc_pair.id}: Connector is paused or is Ingestion API"
)
return False
if search_settings_instance.status.is_current():
@@ -443,16 +424,11 @@ def should_index(
return True
if connector.refresh_freq is None:
print(f"Not indexing cc_pair={cc_pair.id}: refresh_freq is None")
return False
current_db_time = get_db_current_time(db_session)
time_since_index = current_db_time - last_index.time_updated
if time_since_index.total_seconds() < connector.refresh_freq:
print(
f"Not indexing cc_pair={cc_pair.id}: Last index attempt={last_index.id} "
f"too recent ({time_since_index.total_seconds()}s < {connector.refresh_freq}s)"
)
return False
return True
@@ -532,13 +508,6 @@ def try_creating_indexing_task(
custom_task_id = redis_connector_index.generate_generator_task_id()
# Determine which queue to use based on whether this is a user file
queue = (
OnyxCeleryQueues.USER_FILES_INDEXING
if cc_pair.is_user_file
else OnyxCeleryQueues.CONNECTOR_INDEXING
)
# when the task is sent, we have yet to finish setting up the fence
# therefore, the task must contain code that blocks until the fence is ready
result = celery_app.send_task(
@@ -549,7 +518,7 @@ def try_creating_indexing_task(
search_settings_id=search_settings.id,
tenant_id=tenant_id,
),
queue=queue,
queue=OnyxCeleryQueues.CONNECTOR_INDEXING,
task_id=custom_task_id,
priority=OnyxCeleryPriority.MEDIUM,
)

View File

@@ -6,7 +6,6 @@ from tenacity import wait_random_exponential
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import VespaDocumentFields
from onyx.document_index.interfaces import VespaDocumentUserFields
class RetryDocumentIndex:
@@ -53,13 +52,11 @@ class RetryDocumentIndex:
*,
tenant_id: str,
chunk_count: int | None,
fields: VespaDocumentFields | None,
user_fields: VespaDocumentUserFields | None,
fields: VespaDocumentFields,
) -> int:
return self.index.update_single(
doc_id,
tenant_id=tenant_id,
chunk_count=chunk_count,
fields=fields,
user_fields=user_fields,
)

View File

@@ -164,7 +164,6 @@ def document_by_cc_pair_cleanup_task(
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
fields=fields,
user_fields=None,
)
# there are still other cc_pair references to the doc, so just resync to Vespa

View File

@@ -1,266 +0,0 @@
import time
from typing import List
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from tenacity import RetryError
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_FOLDER_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.connector_credential_pair import (
get_connector_credential_pairs_with_user_files,
)
from onyx.db.document import get_document
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Document
from onyx.db.models import DocumentByConnectorCredentialPair
from onyx.db.search_settings import get_active_search_settings
from onyx.db.user_documents import fetch_user_files_for_documents
from onyx.db.user_documents import fetch_user_folders_for_documents
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaDocumentUserFields
from onyx.httpx.httpx_pool import HttpxPool
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
logger = setup_logger()
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_USER_FILE_FOLDER_SYNC,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,
)
def check_for_user_file_folder_sync(self: Task, *, tenant_id: str) -> bool | None:
"""Runs periodically to check for documents that need user file folder metadata updates.
This task fetches all connector credential pairs with user files, gets the documents
associated with them, and updates the user file and folder metadata in Vespa.
"""
time_start = time.monotonic()
r = get_redis_client()
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_USER_FILE_FOLDER_SYNC_BEAT_LOCK,
timeout=CELERY_USER_FILE_FOLDER_SYNC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
try:
with get_session_with_current_tenant() as db_session:
# Get all connector credential pairs that have user files
cc_pairs = get_connector_credential_pairs_with_user_files(db_session)
if not cc_pairs:
task_logger.info("No connector credential pairs with user files found")
return True
# Get all documents associated with these cc_pairs
document_ids = get_documents_for_cc_pairs(cc_pairs, db_session)
if not document_ids:
task_logger.info(
"No documents found for connector credential pairs with user files"
)
return True
# Fetch current user file and folder IDs for these documents
doc_id_to_user_file_id = fetch_user_files_for_documents(
document_ids=document_ids, db_session=db_session
)
doc_id_to_user_folder_id = fetch_user_folders_for_documents(
document_ids=document_ids, db_session=db_session
)
# Update Vespa metadata for each document
for doc_id in document_ids:
user_file_id = doc_id_to_user_file_id.get(doc_id)
user_folder_id = doc_id_to_user_folder_id.get(doc_id)
if user_file_id is not None or user_folder_id is not None:
# Schedule a task to update the document metadata
update_user_file_folder_metadata.apply_async(
args=(doc_id,), # Use tuple instead of list for args
kwargs={
"tenant_id": tenant_id,
"user_file_id": user_file_id,
"user_folder_id": user_folder_id,
},
queue="vespa_metadata_sync",
)
task_logger.info(
f"Scheduled metadata updates for {len(document_ids)} documents. "
f"Elapsed time: {time.monotonic() - time_start:.2f}s"
)
return True
except Exception as e:
task_logger.exception(f"Error in check_for_user_file_folder_sync: {e}")
return False
finally:
lock_beat.release()
def get_documents_for_cc_pairs(
cc_pairs: List[ConnectorCredentialPair], db_session: Session
) -> List[str]:
"""Get all document IDs associated with the given connector credential pairs."""
if not cc_pairs:
return []
cc_pair_ids = [cc_pair.id for cc_pair in cc_pairs]
# Query to get document IDs from DocumentByConnectorCredentialPair
# Note: DocumentByConnectorCredentialPair uses connector_id and credential_id, not cc_pair_id
doc_cc_pairs = (
db_session.query(Document.id)
.join(
DocumentByConnectorCredentialPair,
Document.id == DocumentByConnectorCredentialPair.id,
)
.filter(
db_session.query(ConnectorCredentialPair)
.filter(
ConnectorCredentialPair.id.in_(cc_pair_ids),
ConnectorCredentialPair.connector_id
== DocumentByConnectorCredentialPair.connector_id,
ConnectorCredentialPair.credential_id
== DocumentByConnectorCredentialPair.credential_id,
)
.exists()
)
.all()
)
return [doc_id for (doc_id,) in doc_cc_pairs]
@shared_task(
name=OnyxCeleryTask.UPDATE_USER_FILE_FOLDER_METADATA,
bind=True,
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
time_limit=LIGHT_TIME_LIMIT,
max_retries=3,
)
def update_user_file_folder_metadata(
self: Task,
document_id: str,
*,
tenant_id: str,
user_file_id: int | None,
user_folder_id: int | None,
) -> bool:
"""Updates the user file and folder metadata for a document in Vespa."""
start = time.monotonic()
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
try:
with get_session_with_current_tenant() as db_session:
active_search_settings = get_active_search_settings(db_session)
doc_index = get_default_document_index(
search_settings=active_search_settings.primary,
secondary_search_settings=active_search_settings.secondary,
httpx_client=HttpxPool.get("vespa"),
)
retry_index = RetryDocumentIndex(doc_index)
doc = get_document(document_id, db_session)
if not doc:
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action=no_operation "
f"elapsed={elapsed:.2f}"
)
completion_status = OnyxCeleryTaskCompletionStatus.SKIPPED
return False
# Create user fields object with file and folder IDs
user_fields = VespaDocumentUserFields(
user_file_id=str(user_file_id) if user_file_id is not None else None,
user_folder_id=str(user_folder_id)
if user_folder_id is not None
else None,
)
# Update Vespa. OK if doc doesn't exist. Raises exception otherwise.
chunks_affected = retry_index.update_single(
document_id,
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
fields=None, # We're only updating user fields
user_fields=user_fields,
)
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action=user_file_folder_sync "
f"user_file_id={user_file_id} "
f"user_folder_id={user_folder_id} "
f"chunks={chunks_affected} "
f"elapsed={elapsed:.2f}"
)
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
return True
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
completion_status = OnyxCeleryTaskCompletionStatus.SOFT_TIME_LIMIT
except Exception as ex:
e: Exception | None = None
while True:
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
)
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()
if isinstance(e_temp, Exception):
e = e_temp
else:
e = ex
task_logger.exception(
f"update_user_file_folder_metadata exceptioned: doc={document_id}"
)
completion_status = OnyxCeleryTaskCompletionStatus.RETRYABLE_EXCEPTION
if (
self.max_retries is not None
and self.request.retries >= self.max_retries
):
completion_status = (
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
)
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown) # this will raise a celery exception
break # we won't hit this, but it looks weird not to have it
finally:
task_logger.info(
f"update_user_file_folder_metadata completed: status={completion_status.value} doc={document_id}"
)
return False

View File

@@ -80,8 +80,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
"""Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed."""
# Useful for debugging timing issues with reacquisitions.
# TODO: remove once more generalized logging is in place
# Useful for debugging timing issues with reacquisitions. TODO: remove once more generalized logging is in place
task_logger.info("check_for_vespa_sync_task started")
time_start = time.monotonic()
@@ -573,7 +572,6 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
fields=fields,
user_fields=None,
)
# update db last. Worst case = we crash right before this and

View File

@@ -1,20 +0,0 @@
"""Factory stub for running celery worker / celery beat.
This code is different from the primary/beat stubs because there is no EE version to
fetch. Port over the code in those files if we add an EE version of this worker.
This is an app stub purely for sending tasks as a client.
"""
from celery import Celery
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
def get_app() -> Celery:
from onyx.background.celery.apps.client import celery_app
return celery_app
app = get_app()

View File

@@ -59,8 +59,6 @@ from onyx.natural_language_processing.search_nlp_models import (
from onyx.utils.logger import setup_logger
from onyx.utils.logger import TaskAttemptSingleton
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import MULTI_TENANT
@@ -273,6 +271,7 @@ def _run_indexing(
"Search settings must be set for indexing. This should not be possible."
)
# search_settings = index_attempt_start.search_settings
db_connector = index_attempt_start.connector_credential_pair.connector
db_credential = index_attempt_start.connector_credential_pair.credential
ctx = RunIndexingContext(
@@ -436,7 +435,7 @@ def _run_indexing(
while checkpoint.has_more:
logger.info(
f"Running '{ctx.source.value}' connector with checkpoint: {checkpoint}"
f"Running '{ctx.source}' connector with checkpoint: {checkpoint}"
)
for document_batch, failure, next_checkpoint in connector_runner.run(
checkpoint
@@ -571,19 +570,6 @@ def _run_indexing(
if callback:
callback.progress("_run_indexing", len(doc_batch_cleaned))
# Add telemetry for indexing progress
optional_telemetry(
record_type=RecordType.INDEXING_PROGRESS,
data={
"index_attempt_id": index_attempt_id,
"cc_pair_id": ctx.cc_pair_id,
"current_docs_indexed": document_count,
"current_chunks_indexed": chunk_count,
"source": ctx.source.value,
},
tenant_id=tenant_id,
)
memory_tracer.increment_and_maybe_trace()
# `make sure the checkpoints aren't getting too large`at some regular interval
@@ -599,19 +585,6 @@ def _run_indexing(
checkpoint=checkpoint,
)
optional_telemetry(
record_type=RecordType.INDEXING_COMPLETE,
data={
"index_attempt_id": index_attempt_id,
"cc_pair_id": ctx.cc_pair_id,
"total_docs_indexed": document_count,
"total_chunks": chunk_count,
"time_elapsed_seconds": time.monotonic() - start_time,
"source": ctx.source.value,
},
tenant_id=tenant_id,
)
except Exception as e:
logger.exception(
"Connector run exceptioned after elapsed time: "
@@ -622,9 +595,6 @@ def _run_indexing(
# and mark the CCPair as invalid. This prevents the connector from being
# used in the future until the credentials are updated.
with get_session_with_current_tenant() as db_session_temp:
logger.exception(
f"Marking attempt {index_attempt_id} as canceled due to validation error."
)
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
@@ -671,9 +641,6 @@ def _run_indexing(
elif isinstance(e, ConnectorStopSignal):
with get_session_with_current_tenant() as db_session_temp:
logger.exception(
f"Marking attempt {index_attempt_id} as canceled due to stop signal."
)
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
@@ -736,7 +703,6 @@ def _run_indexing(
f"Connector succeeded: "
f"docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
)
else:
mark_attempt_partially_succeeded(index_attempt_id, db_session_temp)
logger.info(

View File

@@ -10,7 +10,6 @@ from onyx.agents.agent_search.models import GraphPersistence
from onyx.agents.agent_search.models import GraphSearchConfig
from onyx.agents.agent_search.models import GraphTooling
from onyx.agents.agent_search.run_graph import run_basic_graph
from onyx.agents.agent_search.run_graph import run_dc_graph
from onyx.agents.agent_search.run_graph import run_main_graph
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AnswerPacket
@@ -143,18 +142,11 @@ class Answer:
yield from self._processed_stream
return
if self.graph_config.behavior.use_agentic_search:
run_langgraph = run_main_graph
elif (
self.graph_config.inputs.search_request.persona
and self.graph_config.inputs.search_request.persona.description.startswith(
"DivCon Beta Agent"
)
):
run_langgraph = run_dc_graph
else:
run_langgraph = run_basic_graph
run_langgraph = (
run_main_graph
if self.graph_config.behavior.use_agentic_search
else run_basic_graph
)
stream = run_langgraph(
self.graph_config,
)

View File

@@ -127,10 +127,6 @@ class StreamStopInfo(SubQuestionIdentifier):
return data
class UserKnowledgeFilePacket(BaseModel):
user_files: list[FileDescriptor]
class LLMRelevanceFilterResponse(BaseModel):
llm_selected_doc_indices: list[int]
@@ -198,6 +194,17 @@ class StreamingError(BaseModel):
stack_trace: str | None = None
class OnyxContext(BaseModel):
content: str
document_id: str
semantic_identifier: str
blurb: str
class OnyxContexts(BaseModel):
contexts: list[OnyxContext]
class OnyxAnswer(BaseModel):
answer: str | None
@@ -263,6 +270,7 @@ class PersonaOverrideConfig(BaseModel):
AnswerQuestionPossibleReturn = (
OnyxAnswerPiece
| CitationInfo
| OnyxContexts
| FileChatDisplay
| CustomToolResponse
| StreamingError

View File

@@ -29,6 +29,7 @@ from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import MessageResponseIDInfo
from onyx.chat.models import MessageSpecificCitations
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import RefinedAnswerImprovement
@@ -36,14 +37,12 @@ from onyx.chat.models import StreamingError
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import SubQuestionKey
from onyx.chat.models import UserKnowledgeFilePacket
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.chat_configs import SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE
from onyx.configs.constants import AGENT_SEARCH_INITIAL_KEY
from onyx.configs.constants import BASIC_KEY
from onyx.configs.constants import MessageType
@@ -53,7 +52,6 @@ from onyx.context.search.enums import LLMEvaluationType
from onyx.context.search.enums import OptionalSearchSetting
from onyx.context.search.enums import QueryFlow
from onyx.context.search.enums import SearchType
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SearchRequest
@@ -67,7 +65,6 @@ from onyx.context.search.utils import relevant_sections_to_indices
from onyx.db.chat import attach_files_to_chat_message
from onyx.db.chat import create_db_search_doc
from onyx.db.chat import create_new_chat_message
from onyx.db.chat import create_search_doc_from_user_file
from onyx.db.chat import get_chat_message
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_db_search_doc_by_id
@@ -76,7 +73,6 @@ from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import reserve_message_id
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
from onyx.db.chat import update_chat_session_updated_at_timestamp
from onyx.db.engine import get_session_context_manager
from onyx.db.milestone import check_multi_assistant_milestone
from onyx.db.milestone import create_milestone_if_not_exists
@@ -84,16 +80,12 @@ from onyx.db.milestone import update_user_assistant_milestone
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.db.models import ToolCall
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.db.persona import get_persona_by_id
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import FileDescriptor
from onyx.file_store.models import InMemoryChatFile
from onyx.file_store.utils import load_all_chat_files
from onyx.file_store.utils import load_all_user_file_files
from onyx.file_store.utils import load_all_user_files
from onyx.file_store.utils import save_files
from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.factory import get_llms_for_persona
@@ -106,7 +98,6 @@ from onyx.server.query_and_chat.models import ChatMessageDetail
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.utils import get_json_line
from onyx.tools.force import ForceUseTool
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_constructor import construct_tools
@@ -139,6 +130,7 @@ from onyx.tools.tool_implementations.internet_search.internet_search_tool import
from onyx.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
@@ -184,14 +176,11 @@ def _handle_search_tool_response_summary(
db_session: Session,
selected_search_docs: list[DbSearchDoc] | None,
dedupe_docs: bool = False,
user_files: list[UserFile] | None = None,
loaded_user_files: list[InMemoryChatFile] | None = None,
) -> tuple[QADocsResponse, list[DbSearchDoc], list[int] | None]:
response_sumary = cast(SearchResponseSummary, packet.response)
is_extended = isinstance(packet, ExtendedToolResponse)
dropped_inds = None
if not selected_search_docs:
top_docs = chunks_or_sections_to_search_docs(response_sumary.top_sections)
@@ -205,31 +194,9 @@ def _handle_search_tool_response_summary(
create_db_search_doc(server_search_doc=doc, db_session=db_session)
for doc in deduped_docs
]
else:
reference_db_search_docs = selected_search_docs
doc_ids = {doc.id for doc in reference_db_search_docs}
if user_files is not None:
for user_file in user_files:
if user_file.id not in doc_ids:
associated_chat_file = None
if loaded_user_files is not None:
associated_chat_file = next(
(
file
for file in loaded_user_files
if file.file_id == str(user_file.file_id)
),
None,
)
# Use create_search_doc_from_user_file to properly add the document to the database
if associated_chat_file is not None:
db_doc = create_search_doc_from_user_file(
user_file, associated_chat_file, db_session
)
reference_db_search_docs.append(db_doc)
response_docs = [
translate_db_search_doc_to_server_search_doc(db_search_doc)
for db_search_doc in reference_db_search_docs
@@ -287,10 +254,7 @@ def _handle_internet_search_tool_response_summary(
def _get_force_search_settings(
new_msg_req: CreateChatMessageRequest,
tools: list[Tool],
user_file_ids: list[int],
user_folder_ids: list[int],
new_msg_req: CreateChatMessageRequest, tools: list[Tool]
) -> ForceUseTool:
internet_search_available = any(
isinstance(tool, InternetSearchTool) for tool in tools
@@ -298,11 +262,8 @@ def _get_force_search_settings(
search_tool_available = any(isinstance(tool, SearchTool) for tool in tools)
if not internet_search_available and not search_tool_available:
if new_msg_req.force_user_file_search:
return ForceUseTool(force_use=True, tool_name=SearchTool._NAME)
else:
# Does not matter much which tool is set here as force is false and neither tool is available
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
# Does not matter much which tool is set here as force is false and neither tool is available
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
tool_name = SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
# Currently, the internet search tool does not support query override
@@ -312,25 +273,12 @@ def _get_force_search_settings(
else None
)
# Create override_kwargs for the search tool if user_file_ids are provided
override_kwargs = None
if (user_file_ids or user_folder_ids) and tool_name == SearchTool._NAME:
override_kwargs = SearchToolOverrideKwargs(
force_no_rerank=False,
alternate_db_session=None,
retrieved_sections_callback=None,
skip_query_analysis=False,
user_file_ids=user_file_ids,
user_folder_ids=user_folder_ids,
)
if new_msg_req.file_descriptors:
# If user has uploaded files they're using, don't run any of the search tools
return ForceUseTool(force_use=False, tool_name=tool_name)
should_force_search = any(
[
new_msg_req.force_user_file_search,
new_msg_req.retrieval_options
and new_msg_req.retrieval_options.run_search
== OptionalSearchSetting.ALWAYS,
@@ -343,22 +291,15 @@ def _get_force_search_settings(
if should_force_search:
# If we are using selected docs, just put something here so the Tool doesn't need to build its own args via an LLM call
args = {"query": new_msg_req.message} if new_msg_req.search_doc_ids else args
return ForceUseTool(force_use=True, tool_name=tool_name, args=args)
return ForceUseTool(
force_use=True,
tool_name=tool_name,
args=args,
override_kwargs=override_kwargs,
)
return ForceUseTool(
force_use=False, tool_name=tool_name, args=args, override_kwargs=override_kwargs
)
return ForceUseTool(force_use=False, tool_name=tool_name, args=args)
ChatPacket = (
StreamingError
| QADocsResponse
| OnyxContexts
| LLMRelevanceFilterResponse
| FinalUsedContextDocsResponse
| ChatMessageDetail
@@ -372,7 +313,6 @@ ChatPacket = (
| AgenticMessageResponseIDInfo
| StreamStopInfo
| AgentSearchPacket
| UserKnowledgeFilePacket
)
ChatPacketStream = Iterator[ChatPacket]
@@ -418,10 +358,6 @@ def stream_chat_message_objects(
llm: LLM
try:
# Move these variables inside the try block
file_id_to_user_file = {}
ordered_user_files = None
user_id = user.id if user is not None else None
chat_session = get_chat_session_by_id(
@@ -601,70 +537,6 @@ def stream_chat_message_objects(
)
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
latest_query_files = [file for file in files if file.file_id in req_file_ids]
user_file_ids = new_msg_req.user_file_ids or []
user_folder_ids = new_msg_req.user_folder_ids or []
if persona.user_files:
for file in persona.user_files:
user_file_ids.append(file.id)
if persona.user_folders:
for folder in persona.user_folders:
user_folder_ids.append(folder.id)
# Initialize flag for user file search
use_search_for_user_files = False
user_files: list[InMemoryChatFile] | None = None
search_for_ordering_only = False
user_file_files: list[UserFile] | None = None
if user_file_ids or user_folder_ids:
# Load user files
user_files = load_all_user_files(
user_file_ids or [],
user_folder_ids or [],
db_session,
)
user_file_files = load_all_user_file_files(
user_file_ids or [],
user_folder_ids or [],
db_session,
)
# Store mapping of file_id to file for later reordering
if user_files:
file_id_to_user_file = {file.file_id: file for file in user_files}
# Calculate token count for the files
from onyx.db.user_documents import calculate_user_files_token_count
from onyx.chat.prompt_builder.citations_prompt import (
compute_max_document_tokens_for_persona,
)
total_tokens = calculate_user_files_token_count(
user_file_ids or [],
user_folder_ids or [],
db_session,
)
# Calculate available tokens for documents based on prompt, user input, etc.
available_tokens = compute_max_document_tokens_for_persona(
db_session=db_session,
persona=persona,
actual_user_input=message_text, # Use the actual user message
)
logger.debug(
f"Total file tokens: {total_tokens}, Available tokens: {available_tokens}"
)
# ALWAYS use search for user files, but track if we need it for context or just ordering
use_search_for_user_files = True
# If files are small enough for context, we'll just use search for ordering
search_for_ordering_only = total_tokens <= available_tokens
if search_for_ordering_only:
# Add original user files to context since they fit
if user_files:
latest_query_files.extend(user_files)
if user_message:
attach_files_to_chat_message(
@@ -693,13 +565,8 @@ def stream_chat_message_objects(
doc_identifiers=identifier_tuples,
document_index=document_index,
)
# Add a maximum context size in the case of user-selected docs to prevent
# slight inaccuracies in context window size pruning from causing
# the entire query to fail
document_pruning_config = DocumentPruningConfig(
is_manually_selected_docs=True,
max_window_percentage=SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE,
is_manually_selected_docs=True
)
# In case the search doc is deleted, just don't include it
@@ -812,10 +679,8 @@ def stream_chat_message_objects(
prompt_config=prompt_config,
db_session=db_session,
user=user,
user_knowledge_present=bool(user_files or user_folder_ids),
llm=llm,
fast_llm=fast_llm,
use_file_search=new_msg_req.force_user_file_search,
search_tool_config=SearchToolConfig(
answer_style_config=answer_style_config,
document_pruning_config=document_pruning_config,
@@ -845,138 +710,17 @@ def stream_chat_message_objects(
for tool_list in tool_dict.values():
tools.extend(tool_list)
force_use_tool = _get_force_search_settings(
new_msg_req, tools, user_file_ids, user_folder_ids
)
# Set force_use if user files exceed token limit
if use_search_for_user_files:
try:
# Check if search tool is available in the tools list
search_tool_available = any(
isinstance(tool, SearchTool) for tool in tools
)
# If no search tool is available, add one
if not search_tool_available:
logger.info("No search tool available, creating one for user files")
# Create a basic search tool config
search_tool_config = SearchToolConfig(
answer_style_config=answer_style_config,
document_pruning_config=document_pruning_config,
retrieval_options=retrieval_options or RetrievalDetails(),
)
# Create and add the search tool
search_tool = SearchTool(
db_session=db_session,
user=user,
persona=persona,
retrieval_options=search_tool_config.retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=search_tool_config.document_pruning_config,
answer_style_config=search_tool_config.answer_style_config,
evaluation_type=(
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
),
bypass_acl=bypass_acl,
)
# Add the search tool to the tools list
tools.append(search_tool)
logger.info(
"Added search tool for user files that exceed token limit"
)
# Now set force_use_tool.force_use to True
force_use_tool.force_use = True
force_use_tool.tool_name = SearchTool._NAME
# Set query argument if not already set
if not force_use_tool.args:
force_use_tool.args = {"query": final_msg.message}
# Pass the user file IDs to the search tool
if user_file_ids or user_folder_ids:
# Create a BaseFilters object with user_file_ids
if not retrieval_options:
retrieval_options = RetrievalDetails()
if not retrieval_options.filters:
retrieval_options.filters = BaseFilters()
# Set user file and folder IDs in the filters
retrieval_options.filters.user_file_ids = user_file_ids
retrieval_options.filters.user_folder_ids = user_folder_ids
# Create override kwargs for the search tool
override_kwargs = SearchToolOverrideKwargs(
force_no_rerank=search_for_ordering_only, # Skip reranking for ordering-only
alternate_db_session=None,
retrieved_sections_callback=None,
skip_query_analysis=search_for_ordering_only, # Skip query analysis for ordering-only
user_file_ids=user_file_ids,
user_folder_ids=user_folder_ids,
ordering_only=search_for_ordering_only, # Set ordering_only flag for fast path
)
# Set the override kwargs in the force_use_tool
force_use_tool.override_kwargs = override_kwargs
if search_for_ordering_only:
logger.info(
"Fast path: Configured search tool with optimized settings for ordering-only"
)
logger.info(
"Fast path: Skipping reranking and query analysis for ordering-only mode"
)
logger.info(
f"Using {len(user_file_ids or [])} files and {len(user_folder_ids or [])} folders"
)
else:
logger.info(
"Configured search tool to use ",
f"{len(user_file_ids or [])} files and {len(user_folder_ids or [])} folders",
)
except Exception as e:
logger.exception(
f"Error configuring search tool for user files: {str(e)}"
)
use_search_for_user_files = False
# TODO: unify message history with single message history
message_history = [
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
]
if not use_search_for_user_files and user_files:
yield UserKnowledgeFilePacket(
user_files=[
FileDescriptor(
id=str(file.file_id), type=ChatFileType.USER_KNOWLEDGE
)
for file in user_files
]
)
if search_for_ordering_only:
logger.info(
"Performance: Forcing LLMEvaluationType.SKIP to prevent chunk evaluation for ordering-only search"
)
search_request = SearchRequest(
query=final_msg.message,
evaluation_type=(
LLMEvaluationType.SKIP
if search_for_ordering_only
else (
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
)
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
),
human_selected_filters=(
retrieval_options.filters if retrieval_options else None
@@ -995,6 +739,7 @@ def stream_chat_message_objects(
),
)
force_use_tool = _get_force_search_settings(new_msg_req, tools)
prompt_builder = AnswerPromptBuilder(
user_message=default_build_user_message(
user_query=final_msg.message,
@@ -1063,22 +808,8 @@ def stream_chat_message_objects(
info = info_by_subq[
SubQuestionKey(level=level, question_num=level_question_num)
]
# Skip LLM relevance processing entirely for ordering-only mode
if search_for_ordering_only and packet.id == SECTION_RELEVANCE_LIST_ID:
logger.info(
"Fast path: Completely bypassing section relevance processing for ordering-only mode"
)
# Skip this packet entirely since it would trigger LLM processing
continue
# TODO: don't need to dedupe here when we do it in agent flow
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
if search_for_ordering_only:
logger.info(
"Fast path: Skipping document deduplication for ordering-only mode"
)
(
info.qa_docs_response,
info.reference_db_search_docs,
@@ -1088,91 +819,16 @@ def stream_chat_message_objects(
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
# Skip deduping completely for ordering-only mode to save time
dedupe_docs=(
False
if search_for_ordering_only
else (
retrieval_options.dedupe_docs
if retrieval_options
else False
)
retrieval_options.dedupe_docs
if retrieval_options
else False
),
user_files=user_file_files if search_for_ordering_only else [],
loaded_user_files=user_files
if search_for_ordering_only
else [],
)
# If we're using search just for ordering user files
if (
search_for_ordering_only
and user_files
and info.qa_docs_response
):
logger.info(
f"ORDERING: Processing search results for ordering {len(user_files)} user files"
)
import time
ordering_start = time.time()
# Extract document order from search results
doc_order = []
for doc in info.qa_docs_response.top_documents:
doc_id = doc.document_id
if str(doc_id).startswith("USER_FILE_CONNECTOR__"):
file_id = doc_id.replace("USER_FILE_CONNECTOR__", "")
if file_id in file_id_to_user_file:
doc_order.append(file_id)
logger.info(
f"ORDERING: Found {len(doc_order)} files from search results"
)
# Add any files that weren't in search results at the end
missing_files = [
f_id
for f_id in file_id_to_user_file.keys()
if f_id not in doc_order
]
missing_files.extend(doc_order)
doc_order = missing_files
logger.info(
f"ORDERING: Added {len(missing_files)} missing files to the end"
)
# Reorder user files based on search results
ordered_user_files = [
file_id_to_user_file[f_id]
for f_id in doc_order
if f_id in file_id_to_user_file
]
time.time() - ordering_start
yield UserKnowledgeFilePacket(
user_files=[
FileDescriptor(
id=str(file.file_id),
type=ChatFileType.USER_KNOWLEDGE,
)
for file in ordered_user_files
]
)
yield info.qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
relevance_sections = packet.response
if search_for_ordering_only:
logger.info(
"Performance: Skipping relevance filtering for ordering-only mode"
)
continue
if info.reference_db_search_docs is None:
logger.warning(
"No reference docs found for relevance filtering"
@@ -1262,6 +918,8 @@ def stream_chat_message_objects(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
elif packet.id == SEARCH_DOC_CONTENT_ID and include_contexts:
yield cast(OnyxContexts, packet.response)
elif isinstance(packet, StreamStopInfo):
if packet.stop_reason == StreamStopReason.FINISHED:
@@ -1282,7 +940,7 @@ def stream_chat_message_objects(
]
info.tool_result = packet
yield cast(ChatPacket, packet)
logger.debug("Reached end of stream")
except ValueError as e:
logger.exception("Failed to process chat message.")
@@ -1364,16 +1022,10 @@ def stream_chat_message_objects(
error=ERROR_TYPE_CANCELLED if answer.is_cancelled() else None,
tool_call=(
ToolCall(
tool_id=tool_name_to_tool_id.get(info.tool_result.tool_name, 0)
if info.tool_result
else None,
tool_name=info.tool_result.tool_name if info.tool_result else None,
tool_arguments=info.tool_result.tool_args
if info.tool_result
else None,
tool_result=info.tool_result.tool_result
if info.tool_result
else None,
tool_id=tool_name_to_tool_id[info.tool_result.tool_name],
tool_name=info.tool_result.tool_name,
tool_arguments=info.tool_result.tool_args,
tool_result=info.tool_result.tool_result,
)
if info.tool_result
else None
@@ -1417,8 +1069,6 @@ def stream_chat_message_objects(
prev_message = next_answer_message
logger.debug("Committing messages")
# Explicitly update the timestamp on the chat session
update_chat_session_updated_at_timestamp(chat_session_id, db_session)
db_session.commit() # actually save user / assistant message
yield AgenticMessageResponseIDInfo(agentic_message_ids=agentic_message_ids)

View File

@@ -19,7 +19,6 @@ def translate_onyx_msg_to_langchain(
# attached. Just ignore them for now.
if not isinstance(msg, ChatMessage):
files = msg.files
content = build_content_with_imgs(
msg.message, files, message_type=msg.message_type, exclude_images=exclude_images
)

View File

@@ -153,8 +153,6 @@ def _apply_pruning(
# remove docs that are explicitly marked as not for QA
sections = _remove_sections_to_ignore(sections=sections)
section_idx_token_count: dict[int, int] = {}
final_section_ind = None
total_tokens = 0
for ind, section in enumerate(sections):
@@ -204,20 +202,10 @@ def _apply_pruning(
section_token_count = DOC_EMBEDDING_CONTEXT_SIZE
total_tokens += section_token_count
section_idx_token_count[ind] = section_token_count
if total_tokens > token_limit:
final_section_ind = ind
break
try:
logger.debug(f"Number of documents after pruning: {ind + 1}")
logger.debug("Number of tokens per document (pruned):")
for x, y in section_idx_token_count.items():
logger.debug(f"{x + 1}: {y}")
except Exception as e:
logger.error(f"Error logging prune statistics: {e}")
if final_section_ind is not None:
if is_manually_selected_docs or use_sections:
if final_section_ind != len(sections) - 1:
@@ -312,14 +300,7 @@ def prune_sections(
)
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> tuple[InferenceSection, int]:
assert (
len(set([chunk.document_id for chunk in chunks])) == 1
), "One distinct document must be passed into merge_doc_chunks"
ADJACENT_CHUNK_SEP = "\n"
DISTANT_CHUNK_SEP = "\n\n...\n\n"
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection:
# Assuming there are no duplicates by this point
sorted_chunks = sorted(chunks, key=lambda x: x.chunk_id)
@@ -327,48 +308,33 @@ def _merge_doc_chunks(chunks: list[InferenceChunk]) -> tuple[InferenceSection, i
chunks, key=lambda x: x.score if x.score is not None else float("-inf")
)
added_chars = 0
merged_content = []
for i, chunk in enumerate(sorted_chunks):
if i > 0:
prev_chunk_id = sorted_chunks[i - 1].chunk_id
sep = (
ADJACENT_CHUNK_SEP
if chunk.chunk_id == prev_chunk_id + 1
else DISTANT_CHUNK_SEP
)
merged_content.append(sep)
added_chars += len(sep)
if chunk.chunk_id == prev_chunk_id + 1:
merged_content.append("\n")
else:
merged_content.append("\n\n...\n\n")
merged_content.append(chunk.content)
combined_content = "".join(merged_content)
return (
InferenceSection(
center_chunk=center_chunk,
chunks=sorted_chunks,
combined_content=combined_content,
),
added_chars,
return InferenceSection(
center_chunk=center_chunk,
chunks=sorted_chunks,
combined_content=combined_content,
)
def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
docs_map: dict[str, dict[int, InferenceChunk]] = defaultdict(dict)
doc_order: dict[str, int] = {}
combined_section_lengths: dict[str, int] = defaultdict(lambda: 0)
# chunk de-duping and doc ordering
for index, section in enumerate(sections):
if section.center_chunk.document_id not in doc_order:
doc_order[section.center_chunk.document_id] = index
combined_section_lengths[section.center_chunk.document_id] += len(
section.combined_content
)
chunks_map = docs_map[section.center_chunk.document_id]
for chunk in [section.center_chunk] + section.chunks:
chunks_map = docs_map[section.center_chunk.document_id]
existing_chunk = chunks_map.get(chunk.chunk_id)
if (
existing_chunk is None
@@ -379,22 +345,8 @@ def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
chunks_map[chunk.chunk_id] = chunk
new_sections = []
for doc_id, section_chunks in docs_map.items():
section_chunks_list = list(section_chunks.values())
merged_section, added_chars = _merge_doc_chunks(chunks=section_chunks_list)
previous_length = combined_section_lengths[doc_id] + added_chars
# After merging, ensure the content respects the pruning done earlier. Each
# combined section is restricted to the sum of the lengths of the sections
# from the pruning step. Technically the correct approach would be to prune based
# on tokens AGAIN, but this is a good approximation and worth not adding the
# tokenization overhead. This could also be fixed if we added a way of removing
# chunks from sections in the pruning step; at the moment this issue largely
# exists because we only trim the final section's combined_content.
merged_section.combined_content = merged_section.combined_content[
:previous_length
]
new_sections.append(merged_section)
for section_chunks in docs_map.values():
new_sections.append(_merge_doc_chunks(chunks=list(section_chunks.values())))
# Sort by highest score, then by original document order
# It is now 1 large section per doc, the center chunk being the one with the highest score
@@ -406,26 +358,6 @@ def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
reverse=True,
)
try:
num_original_sections = len(sections)
num_original_document_ids = len(
set([section.center_chunk.document_id for section in sections])
)
num_merged_sections = len(new_sections)
num_merged_document_ids = len(
set([section.center_chunk.document_id for section in new_sections])
)
logger.debug(
f"Merged {num_original_sections} sections from {num_original_document_ids} documents "
f"into {num_merged_sections} new sections in {num_merged_document_ids} documents"
)
logger.debug("Number of chunks per document (new ranking):")
for x, y in enumerate(new_sections):
logger.debug(f"{x + 1}: {len(y.chunks)}")
except Exception as e:
logger.error(f"Error logging merge statistics: {e}")
return new_sections

View File

@@ -3,6 +3,7 @@ from collections.abc import Sequence
from pydantic import BaseModel
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.context.search.models import InferenceChunk
@@ -11,7 +12,7 @@ class DocumentIdOrderMapping(BaseModel):
def map_document_id_order(
chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True
chunks: Sequence[InferenceChunk | LlmDoc | OnyxContext], one_indexed: bool = True
) -> DocumentIdOrderMapping:
order_mapping = {}
current = 1 if one_indexed else 0

View File

@@ -180,10 +180,6 @@ def get_tool_call_for_non_tool_calling_llm_impl(
if tool_args is None:
raise RuntimeError(f"Tool '{tool.name}' did not return args")
# If we have override_kwargs, add them to the tool_args
if force_use_tool.override_kwargs is not None:
tool_args["override_kwargs"] = force_use_tool.override_kwargs
return (tool, tool_args)
else:
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(

View File

@@ -1,8 +1,6 @@
import json
import os
import urllib.parse
from datetime import datetime
from datetime import timezone
from typing import cast
from onyx.auth.schemas import AuthBackend
@@ -159,7 +157,10 @@ VESPA_CLOUD_CERT_PATH = os.environ.get("VESPA_CLOUD_CERT_PATH")
VESPA_CLOUD_KEY_PATH = os.environ.get("VESPA_CLOUD_KEY_PATH")
# Number of documents in a batch during indexing (further batching done by chunks before passing to bi-encoder)
INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE") or 16)
try:
INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE", 16))
except ValueError:
INDEX_BATCH_SIZE = 16
MAX_DRIVE_WORKERS = int(os.environ.get("MAX_DRIVE_WORKERS", 4))
@@ -170,7 +171,7 @@ POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres"
POSTGRES_PASSWORD = urllib.parse.quote_plus(
os.environ.get("POSTGRES_PASSWORD") or "password"
)
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "127.0.0.1"
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
AWS_REGION_NAME = os.environ.get("AWS_REGION_NAME") or "us-east-2"
@@ -385,27 +386,10 @@ CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
# https://community.developer.atlassian.com/t/confluence-cloud-time-zone-get-via-rest-api/35954/16
# https://jira.atlassian.com/browse/CONFCLOUD-69670
def get_current_tz_offset() -> int:
# datetime now() gets local time, datetime.now(timezone.utc) gets UTC time.
# remove tzinfo to compare non-timezone-aware objects.
time_diff = datetime.now() - datetime.now(timezone.utc).replace(tzinfo=None)
return round(time_diff.total_seconds() / 3600)
# enter as a floating point offset from UTC in hours (-24 < val < 24)
# this will be applied globally, so it probably makes sense to transition this to per
# connector as some point.
# For the default value, we assume that the user's local timezone is more likely to be
# correct (i.e. the configured user's timezone or the default server one) than UTC.
# https://developer.atlassian.com/cloud/confluence/cql-fields/#created
CONFLUENCE_TIMEZONE_OFFSET = float(
os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", get_current_tz_offset())
)
GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD = int(
os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
)
CONFLUENCE_TIMEZONE_OFFSET = float(os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", 0.0))
JIRA_CONNECTOR_LABELS_TO_SKIP = [
ignored_tag
@@ -437,7 +421,7 @@ LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID")
LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
# Slack specific configs
SLACK_NUM_THREADS = int(os.getenv("SLACK_NUM_THREADS") or 8)
SLACK_NUM_THREADS = int(os.getenv("SLACK_NUM_THREADS") or 2)
DASK_JOB_CLIENT_ENABLED = (
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
@@ -495,11 +479,6 @@ NUM_SECONDARY_INDEXING_WORKERS = int(
ENABLE_MULTIPASS_INDEXING = (
os.environ.get("ENABLE_MULTIPASS_INDEXING", "").lower() == "true"
)
# Enable contextual retrieval
ENABLE_CONTEXTUAL_RAG = os.environ.get("ENABLE_CONTEXTUAL_RAG", "").lower() == "true"
DEFAULT_CONTEXTUAL_RAG_LLM_NAME = "gpt-4o-mini"
DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER = "DevEnvPresetOpenAI"
# Finer grained chunking for more detail retention
# Slightly larger since the sentence aware split is a max cutoff so most minichunks will be under MINI_CHUNK_SIZE
# tokens. But we need it to be at least as big as 1/4th chunk size to avoid having a tiny mini-chunk at the end
@@ -541,17 +520,6 @@ MAX_FILE_SIZE_BYTES = int(
os.environ.get("MAX_FILE_SIZE_BYTES") or 2 * 1024 * 1024 * 1024
) # 2GB in bytes
# Use document summary for contextual rag
USE_DOCUMENT_SUMMARY = os.environ.get("USE_DOCUMENT_SUMMARY", "true").lower() == "true"
# Use chunk summary for contextual rag
USE_CHUNK_SUMMARY = os.environ.get("USE_CHUNK_SUMMARY", "true").lower() == "true"
# Average summary embeddings for contextual rag (not yet implemented)
AVERAGE_SUMMARY_EMBEDDINGS = (
os.environ.get("AVERAGE_SUMMARY_EMBEDDINGS", "false").lower() == "true"
)
MAX_TOKENS_FOR_FULL_INCLUSION = 4096
#####
# Miscellaneous
#####
@@ -708,8 +676,3 @@ IMAGE_ANALYSIS_SYSTEM_PROMPT = os.environ.get(
"IMAGE_ANALYSIS_SYSTEM_PROMPT",
DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT,
)
DISABLE_AUTO_AUTH_REFRESH = (
os.environ.get("DISABLE_AUTO_AUTH_REFRESH", "").lower() == "true"
)
FAST_SEARCH_MAX_HITS = 300

View File

@@ -3,7 +3,7 @@ import os
INPUT_PROMPT_YAML = "./onyx/seeding/input_prompts.yaml"
PROMPTS_YAML = "./onyx/seeding/prompts.yaml"
PERSONAS_YAML = "./onyx/seeding/personas.yaml"
USER_FOLDERS_YAML = "./onyx/seeding/user_folders.yaml"
NUM_RETURNED_HITS = 50
# Used for LLM filtering and reranking
# We want this to be approximately the number of results we want to show on the first page
@@ -16,9 +16,6 @@ MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0)
# ~3k input, half for docs, half for chat history + prompts
CHAT_TARGET_CHUNK_PERCENTAGE = 512 * 3 / 3072
# Maximum percentage of the context window to fill with selected sections
SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE = 0.8
# 1 / (1 + DOC_TIME_DECAY * doc-age-in-years), set to 0 to have no decay
# Capped in Vespa at 0.5
DOC_TIME_DECAY = float(

View File

@@ -102,8 +102,6 @@ CELERY_GENERIC_BEAT_LOCK_TIMEOUT = 120
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 120
CELERY_USER_FILE_FOLDER_SYNC_BEAT_LOCK_TIMEOUT = 120
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
@@ -271,7 +269,6 @@ class FileOrigin(str, Enum):
CONNECTOR = "connector"
GENERATED_REPORT = "generated_report"
INDEXING_CHECKPOINT = "indexing_checkpoint"
PLAINTEXT_CACHE = "plaintext_cache"
OTHER = "other"
@@ -312,7 +309,6 @@ class OnyxCeleryQueues:
# Indexing queue
CONNECTOR_INDEXING = "connector_indexing"
USER_FILES_INDEXING = "user_files_indexing"
# Monitoring queue
MONITORING = "monitoring"
@@ -331,7 +327,6 @@ class OnyxRedisLocks:
CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK = (
"da_lock:check_connector_external_group_sync_beat"
)
CHECK_USER_FILE_FOLDER_SYNC_BEAT_LOCK = "da_lock:check_user_file_folder_sync_beat"
MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes"
CHECK_AVAILABLE_TENANTS_LOCK = "da_lock:check_available_tenants"
PRE_PROVISION_TENANT_LOCK = "da_lock:pre_provision_tenant"
@@ -387,7 +382,6 @@ ONYX_CLOUD_TENANT_ID = "cloud"
# the redis namespace for runtime variables
ONYX_CLOUD_REDIS_RUNTIME = "runtime"
CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT = 600
class OnyxCeleryTask:
@@ -402,7 +396,6 @@ class OnyxCeleryTask:
# Tenant pre-provisioning
PRE_PROVISION_TENANT = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_pre_provision_tenant"
UPDATE_USER_FILE_FOLDER_METADATA = "update_user_file_folder_metadata"
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
@@ -411,7 +404,6 @@ class OnyxCeleryTask:
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
CHECK_FOR_USER_FILE_FOLDER_SYNC = "check_for_user_file_folder_sync"
# Connector checkpoint cleanup
CHECK_FOR_CHECKPOINT_CLEANUP = "check_for_checkpoint_cleanup"

View File

@@ -87,7 +87,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
credentials.get(key)
for key in ["aws_access_key_id", "aws_secret_access_key"]
):
raise ConnectorMissingCredentialError("Amazon S3")
raise ConnectorMissingCredentialError("Google Cloud Storage")
session = boto3.Session(
aws_access_key_id=credentials["aws_access_key_id"],

View File

@@ -65,7 +65,19 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
_SLIM_DOC_BATCH_SIZE = 5000
ONE_HOUR = 3600
_ATTACHMENT_EXTENSIONS_TO_FILTER_OUT = [
"gif",
"mp4",
"mov",
"mp3",
"wav",
]
_FULL_EXTENSION_FILTER_STRING = "".join(
[
f" and title!~'*.{extension}'"
for extension in _ATTACHMENT_EXTENSIONS_TO_FILTER_OUT
]
)
class ConfluenceConnector(
@@ -195,6 +207,7 @@ class ConfluenceConnector(
def _construct_attachment_query(self, confluence_page_id: str) -> str:
attachment_query = f"type=attachment and container='{confluence_page_id}'"
attachment_query += self.cql_label_filter
attachment_query += _FULL_EXTENSION_FILTER_STRING
return attachment_query
def _get_comment_string_for_page_id(self, page_id: str) -> str:
@@ -359,13 +372,11 @@ class ConfluenceConnector(
if not validate_attachment_filetype(
attachment,
):
logger.info(f"Skipping attachment: {attachment['title']}")
continue
logger.info(f"Processing attachment: {attachment['title']}")
# Attempt to get textual content or image summarization:
try:
logger.info(f"Processing attachment: {attachment['title']}")
response = convert_attachment_to_content(
confluence_client=self.confluence_client,
attachment=attachment,
@@ -418,17 +429,7 @@ class ConfluenceConnector(
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
try:
return self._fetch_document_batches(start, end)
except Exception as e:
if "field 'updated' is invalid" in str(e) and start is not None:
logger.warning(
"Confluence says we provided an invalid 'updated' field. This may indicate"
"a real issue, but can also appear during edge cases like daylight"
f"savings time changes. Retrying with a 1 hour offset. Error: {e}"
)
return self._fetch_document_batches(start - ONE_HOUR, end)
raise
return self._fetch_document_batches(start, end)
def retrieve_all_slim_documents(
self,

View File

@@ -498,12 +498,10 @@ class OnyxConfluence:
new_start = get_start_param_from_url(url_suffix)
previous_start = get_start_param_from_url(old_url_suffix)
if new_start - previous_start > len(results):
logger.debug(
logger.warning(
f"Start was updated by more than the amount of results "
f"retrieved for `{url_suffix}`. This is a bug with Confluence, "
"but we have logic to work around it - don't worry this isn't"
f" causing an issue. Start: {new_start}, Previous Start: "
f"{previous_start}, Len Results: {len(results)}."
f"retrieved. This is a bug with Confluence. Start: {new_start}, "
f"Previous Start: {previous_start}, Len Results: {len(results)}."
)
# Update the url_suffix to use the adjusted start

View File

@@ -13,7 +13,6 @@ from typing import TYPE_CHECKING
from typing import TypeVar
from urllib.parse import parse_qs
from urllib.parse import quote
from urllib.parse import urljoin
from urllib.parse import urlparse
import requests
@@ -343,14 +342,9 @@ def build_confluence_document_id(
Returns:
str: The document id
"""
# NOTE: urljoin is tricky and will drop the last segment of the base if it doesn't
# end with "/" because it believes that makes it a file.
final_url = base_url.rstrip("/") + "/"
if is_cloud and not final_url.endswith("/wiki/"):
final_url = urljoin(final_url, "wiki") + "/"
final_url = urljoin(final_url, content_url.lstrip("/"))
return final_url
if is_cloud and not base_url.endswith("/wiki"):
base_url += "/wiki"
return f"{base_url}{content_url}"
def datetime_from_string(datetime_string: str) -> datetime:
@@ -460,19 +454,6 @@ def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
logger.warning("HTTPError with `None` as response or as headers")
raise e
# Confluence Server returns 403 when rate limited
if e.response.status_code == 403:
FORBIDDEN_MAX_RETRY_ATTEMPTS = 7
FORBIDDEN_RETRY_DELAY = 10
if attempt < FORBIDDEN_MAX_RETRY_ATTEMPTS:
logger.warning(
"403 error. This sometimes happens when we hit "
f"Confluence rate limits. Retrying in {FORBIDDEN_RETRY_DELAY} seconds..."
)
return FORBIDDEN_RETRY_DELAY
raise e
if (
e.response.status_code != 429
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()

View File

@@ -28,9 +28,8 @@ from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import detect_encoding
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext
from onyx.file_processing.extract_file_text import is_text_file_extension
from onyx.file_processing.extract_file_text import OnyxExtensionType
from onyx.file_processing.extract_file_text import is_valid_file_ext
from onyx.file_processing.extract_file_text import read_text_file
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import request_with_retries
@@ -70,9 +69,7 @@ def _process_egnyte_file(
file_name = file_metadata["name"]
extension = get_file_ext(file_name)
if not is_accepted_file_ext(
extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
):
if not is_valid_file_ext(extension):
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
return None

View File

@@ -22,9 +22,8 @@ from onyx.db.engine import get_session_with_current_tenant
from onyx.db.pg_file_store import get_pgfilestore_by_file_name
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext
from onyx.file_processing.extract_file_text import is_valid_file_ext
from onyx.file_processing.extract_file_text import load_files_from_zip
from onyx.file_processing.extract_file_text import OnyxExtensionType
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
@@ -52,7 +51,7 @@ def _read_files_and_metadata(
file_content, ignore_dirs=True
):
yield os.path.join(directory_path, file_info.filename), subfile, metadata
elif is_accepted_file_ext(extension, OnyxExtensionType.All):
elif is_valid_file_ext(extension):
yield file_name, file_content, metadata
else:
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
@@ -123,7 +122,7 @@ def _process_file(
logger.warning(f"No file record found for '{file_name}' in PG; skipping.")
return []
if not is_accepted_file_ext(extension, OnyxExtensionType.All):
if not is_valid_file_ext(extension):
logger.warning(
f"Skipping file '{file_name}' with unrecognized extension '{extension}'"
)
@@ -220,34 +219,24 @@ def _process_file(
# 2) Otherwise: text-based approach. Possibly with embedded images.
file.seek(0)
text_content = ""
embedded_images: list[tuple[bytes, str]] = []
# Extract text and images from the file
extraction_result = extract_text_and_images(
text_content, embedded_images = extract_text_and_images(
file=file,
file_name=file_name,
pdf_pass=pdf_pass,
)
# Merge file-specific metadata (from file content) with provided metadata
if extraction_result.metadata:
logger.debug(
f"Found file-specific metadata for {file_name}: {extraction_result.metadata}"
)
metadata.update(extraction_result.metadata)
# Build sections: first the text as a single Section
sections: list[TextSection | ImageSection] = []
link_in_meta = metadata.get("link")
if extraction_result.text_content.strip():
logger.debug(f"Creating TextSection for {file_name} with link: {link_in_meta}")
sections.append(
TextSection(link=link_in_meta, text=extraction_result.text_content.strip())
)
if text_content.strip():
sections.append(TextSection(link=link_in_meta, text=text_content.strip()))
# Then any extracted images from docx, etc.
for idx, (img_data, img_name) in enumerate(
extraction_result.embedded_images, start=1
):
for idx, (img_data, img_name) in enumerate(embedded_images, start=1):
# Store each embedded image as a separate file in PGFileStore
# and create a section with the image reference
try:

View File

@@ -45,8 +45,6 @@ _FIREFLIES_API_QUERY = """
}
"""
ONE_MINUTE = 60
def _create_doc_from_transcript(transcript: dict) -> Document | None:
sections: List[TextSection] = []
@@ -108,8 +106,6 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
)
# If not all transcripts are being indexed, try using a more-recently-generated
# API key.
class FirefliesConnector(PollConnector, LoadConnector):
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
@@ -195,9 +191,6 @@ class FirefliesConnector(PollConnector, LoadConnector):
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
# add some leeway to account for any timezone funkiness and/or bad handling
# of start time on the Fireflies side
start = max(0, start - ONE_MINUTE)
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
"%Y-%m-%dT%H:%M:%S.000Z"
)

View File

@@ -1,10 +1,8 @@
import copy
import time
from collections.abc import Generator
from collections.abc import Iterator
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from enum import Enum
from typing import Any
from typing import cast
@@ -15,30 +13,26 @@ from github.GithubException import GithubException
from github.Issue import Issue
from github.PaginatedList import PaginatedList
from github.PullRequest import PullRequest
from github.Requester import Requester
from pydantic import BaseModel
from typing_extensions import override
from onyx.configs.app_configs import GITHUB_CONNECTOR_BASE_URL
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import ConnectorCheckpoint
from onyx.connectors.interfaces import ConnectorFailure
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import TextSection
from onyx.utils.batching import batch_generator
from onyx.utils.logger import setup_logger
logger = setup_logger()
ITEMS_PER_PAGE = 100
_MAX_NUM_RATE_LIMIT_RETRIES = 5
@@ -54,7 +48,7 @@ def _sleep_after_rate_limit_exception(github_client: Github) -> None:
def _get_batch_rate_limited(
git_objs: PaginatedList, page_num: int, github_client: Github, attempt_num: int = 0
) -> list[PullRequest | Issue]:
) -> list[Any]:
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
raise RuntimeError(
"Re-tried fetching batch too many times. Something is going wrong with fetching objects from Github"
@@ -75,6 +69,21 @@ def _get_batch_rate_limited(
)
def _batch_github_objects(
git_objs: PaginatedList, github_client: Github, batch_size: int
) -> Iterator[list[Any]]:
page_num = 0
while True:
batch = _get_batch_rate_limited(git_objs, page_num, github_client)
page_num += 1
if not batch:
break
for mini_batch in batch_generator(batch, batch_size=batch_size):
yield mini_batch
def _convert_pr_to_document(pull_request: PullRequest) -> Document:
return Document(
id=pull_request.html_url,
@@ -86,9 +95,7 @@ def _convert_pr_to_document(pull_request: PullRequest) -> Document:
# updated_at is UTC time but is timezone unaware, explicitly add UTC
# as there is logic in indexing to prevent wrong timestamped docs
# due to local time discrepancies with UTC
doc_updated_at=pull_request.updated_at.replace(tzinfo=timezone.utc)
if pull_request.updated_at
else None,
doc_updated_at=pull_request.updated_at.replace(tzinfo=timezone.utc),
metadata={
"merged": str(pull_request.merged),
"state": pull_request.state,
@@ -115,58 +122,31 @@ def _convert_issue_to_document(issue: Issue) -> Document:
)
class SerializedRepository(BaseModel):
# id is part of the raw_data as well, just pulled out for convenience
id: int
headers: dict[str, str | int]
raw_data: dict[str, Any]
def to_Repository(self, requester: Requester) -> Repository.Repository:
return Repository.Repository(
requester, self.headers, self.raw_data, completed=True
)
class GithubConnectorStage(Enum):
START = "start"
PRS = "prs"
ISSUES = "issues"
class GithubConnectorCheckpoint(ConnectorCheckpoint):
stage: GithubConnectorStage
curr_page: int
cached_repo_ids: list[int] | None = None
cached_repo: SerializedRepository | None = None
class GithubConnector(CheckpointConnector[GithubConnectorCheckpoint]):
class GithubConnector(LoadConnector, PollConnector):
def __init__(
self,
repo_owner: str,
repositories: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
state_filter: str = "all",
include_prs: bool = True,
include_issues: bool = False,
) -> None:
self.repo_owner = repo_owner
self.repositories = repositories
self.batch_size = batch_size
self.state_filter = state_filter
self.include_prs = include_prs
self.include_issues = include_issues
self.github_client: Github | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
# defaults to 30 items per page, can be set to as high as 100
self.github_client = (
Github(
credentials["github_access_token"],
base_url=GITHUB_CONNECTOR_BASE_URL,
per_page=ITEMS_PER_PAGE,
credentials["github_access_token"], base_url=GITHUB_CONNECTOR_BASE_URL
)
if GITHUB_CONNECTOR_BASE_URL
else Github(credentials["github_access_token"], per_page=ITEMS_PER_PAGE)
else Github(credentials["github_access_token"])
)
return None
@@ -237,212 +217,85 @@ class GithubConnector(CheckpointConnector[GithubConnectorCheckpoint]):
return self._get_all_repos(github_client, attempt_num + 1)
def _fetch_from_github(
self,
checkpoint: GithubConnectorCheckpoint,
start: datetime | None = None,
end: datetime | None = None,
) -> Generator[Document | ConnectorFailure, None, GithubConnectorCheckpoint]:
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
if self.github_client is None:
raise ConnectorMissingCredentialError("GitHub")
checkpoint = copy.deepcopy(checkpoint)
# First run of the connector, fetch all repos and store in checkpoint
if checkpoint.cached_repo_ids is None:
repos = []
if self.repositories:
if "," in self.repositories:
# Multiple repositories specified
repos = self._get_github_repos(self.github_client)
else:
# Single repository (backward compatibility)
repos = [self._get_github_repo(self.github_client)]
repos = []
if self.repositories:
if "," in self.repositories:
# Multiple repositories specified
repos = self._get_github_repos(self.github_client)
else:
# All repositories
repos = self._get_all_repos(self.github_client)
if not repos:
checkpoint.has_more = False
return checkpoint
# Single repository (backward compatibility)
repos = [self._get_github_repo(self.github_client)]
else:
# All repositories
repos = self._get_all_repos(self.github_client)
checkpoint.cached_repo_ids = sorted([repo.id for repo in repos])
checkpoint.cached_repo = SerializedRepository(
id=checkpoint.cached_repo_ids[0],
headers=repos[0].raw_headers,
raw_data=repos[0].raw_data,
)
checkpoint.stage = GithubConnectorStage.PRS
checkpoint.curr_page = 0
# save checkpoint with repo ids retrieved
return checkpoint
for repo in repos:
if self.include_prs:
logger.info(f"Fetching PRs for repo: {repo.name}")
pull_requests = repo.get_pulls(
state=self.state_filter, sort="updated", direction="desc"
)
assert checkpoint.cached_repo is not None, "No repo saved in checkpoint"
# Try to access the requester - different PyGithub versions may use different attribute names
try:
# Try direct access to a known attribute name first
if hasattr(self.github_client, "_requester"):
requester = self.github_client._requester
elif hasattr(self.github_client, "_Github__requester"):
requester = self.github_client._Github__requester
else:
# If we can't find the requester attribute, we need to fall back to recreating the repo
raise AttributeError("Could not find requester attribute")
repo = checkpoint.cached_repo.to_Repository(requester)
except Exception as e:
# If all else fails, re-fetch the repo directly
logger.warning(
f"Failed to deserialize repository: {e}. Attempting to re-fetch."
)
repo_id = checkpoint.cached_repo.id
repo = self.github_client.get_repo(repo_id)
if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS:
logger.info(f"Fetching PRs for repo: {repo.name}")
pull_requests = repo.get_pulls(
state=self.state_filter, sort="updated", direction="desc"
)
doc_batch: list[Document] = []
pr_batch = _get_batch_rate_limited(
pull_requests, checkpoint.curr_page, self.github_client
)
checkpoint.curr_page += 1
done_with_prs = False
for pr in pr_batch:
# we iterate backwards in time, so at this point we stop processing prs
if (
start is not None
and pr.updated_at
and pr.updated_at.replace(tzinfo=timezone.utc) < start
for pr_batch in _batch_github_objects(
pull_requests, self.github_client, self.batch_size
):
yield from doc_batch
done_with_prs = True
break
# Skip PRs updated after the end date
if (
end is not None
and pr.updated_at
and pr.updated_at.replace(tzinfo=timezone.utc) > end
doc_batch: list[Document] = []
for pr in pr_batch:
if start is not None and pr.updated_at < start:
yield doc_batch
break
if end is not None and pr.updated_at > end:
continue
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
yield doc_batch
if self.include_issues:
logger.info(f"Fetching issues for repo: {repo.name}")
issues = repo.get_issues(
state=self.state_filter, sort="updated", direction="desc"
)
for issue_batch in _batch_github_objects(
issues, self.github_client, self.batch_size
):
continue
try:
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
except Exception as e:
error_msg = f"Error converting PR to document: {e}"
logger.exception(error_msg)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=str(pr.id), document_link=pr.html_url
),
failure_message=error_msg,
exception=e,
)
continue
doc_batch = []
for issue in issue_batch:
issue = cast(Issue, issue)
if start is not None and issue.updated_at < start:
yield doc_batch
break
if end is not None and issue.updated_at > end:
continue
if issue.pull_request is not None:
# PRs are handled separately
continue
doc_batch.append(_convert_issue_to_document(issue))
yield doc_batch
# if we found any PRs on the page, yield any associated documents and return the checkpoint
if not done_with_prs and len(pr_batch) > 0:
yield from doc_batch
return checkpoint
def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_from_github()
# if we went past the start date during the loop or there are no more
# prs to get, we move on to issues
checkpoint.stage = GithubConnectorStage.ISSUES
checkpoint.curr_page = 0
checkpoint.stage = GithubConnectorStage.ISSUES
if self.include_issues and checkpoint.stage == GithubConnectorStage.ISSUES:
logger.info(f"Fetching issues for repo: {repo.name}")
issues = repo.get_issues(
state=self.state_filter, sort="updated", direction="desc"
)
doc_batch = []
issue_batch = _get_batch_rate_limited(
issues, checkpoint.curr_page, self.github_client
)
checkpoint.curr_page += 1
done_with_issues = False
for issue in cast(list[Issue], issue_batch):
# we iterate backwards in time, so at this point we stop processing prs
if (
start is not None
and issue.updated_at.replace(tzinfo=timezone.utc) < start
):
yield from doc_batch
done_with_issues = True
break
# Skip PRs updated after the end date
if (
end is not None
and issue.updated_at.replace(tzinfo=timezone.utc) > end
):
continue
if issue.pull_request is not None:
# PRs are handled separately
continue
try:
doc_batch.append(_convert_issue_to_document(issue))
except Exception as e:
error_msg = f"Error converting issue to document: {e}"
logger.exception(error_msg)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=str(issue.id),
document_link=issue.html_url,
),
failure_message=error_msg,
exception=e,
)
continue
# if we found any issues on the page, yield them and return the checkpoint
if not done_with_issues and len(issue_batch) > 0:
yield from doc_batch
return checkpoint
# if we went past the start date during the loop or there are no more
# issues to get, we move on to the next repo
checkpoint.stage = GithubConnectorStage.PRS
checkpoint.curr_page = 0
checkpoint.has_more = len(checkpoint.cached_repo_ids) > 1
if checkpoint.cached_repo_ids:
next_id = checkpoint.cached_repo_ids.pop()
next_repo = self.github_client.get_repo(next_id)
checkpoint.cached_repo = SerializedRepository(
id=next_id,
headers=next_repo.raw_headers,
raw_data=next_repo.raw_data,
)
return checkpoint
@override
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: GithubConnectorCheckpoint,
) -> CheckpointOutput[GithubConnectorCheckpoint]:
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_datetime = datetime.utcfromtimestamp(start)
end_datetime = datetime.utcfromtimestamp(end)
# Move start time back by 3 hours, since some Issues/PRs are getting dropped
# Could be due to delayed processing on GitHub side
# The non-updated issues since last poll will be shortcut-ed and not embedded
adjusted_start_datetime = start_datetime - timedelta(hours=3)
epoch = datetime.fromtimestamp(0, tz=timezone.utc)
epoch = datetime.utcfromtimestamp(0)
if adjusted_start_datetime < epoch:
adjusted_start_datetime = epoch
return self._fetch_from_github(
checkpoint, start=adjusted_start_datetime, end=end_datetime
)
return self._fetch_from_github(adjusted_start_datetime, end_datetime)
def validate_connector_settings(self) -> None:
if self.github_client is None:
@@ -544,16 +397,6 @@ class GithubConnector(CheckpointConnector[GithubConnectorCheckpoint]):
f"Unexpected error during GitHub settings validation: {exc}"
)
def validate_checkpoint_json(
self, checkpoint_json: str
) -> GithubConnectorCheckpoint:
return GithubConnectorCheckpoint.model_validate_json(checkpoint_json)
def build_dummy_checkpoint(self) -> GithubConnectorCheckpoint:
return GithubConnectorCheckpoint(
stage=GithubConnectorStage.PRS, curr_page=0, has_more=True
)
if __name__ == "__main__":
import os
@@ -563,9 +406,7 @@ if __name__ == "__main__":
repositories=os.environ["REPOSITORIES"],
)
connector.load_credentials(
{"github_access_token": os.environ["ACCESS_TOKEN_GITHUB"]}
)
document_batches = connector.load_from_checkpoint(
0, time.time(), connector.build_dummy_checkpoint()
{"github_access_token": os.environ["GITHUB_ACCESS_TOKEN"]}
)
document_batches = connector.load_from_state()
print(next(document_batches))

View File

@@ -7,8 +7,6 @@ from typing import Any
from typing import cast
import requests
from requests.adapters import HTTPAdapter
from urllib3.util import Retry
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from onyx.configs.app_configs import GONG_CONNECTOR_START_TIME
@@ -23,12 +21,13 @@ from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.utils.logger import setup_logger
logger = setup_logger()
GONG_BASE_URL = "https://us-34014.api.gong.io"
class GongConnector(LoadConnector, PollConnector):
BASE_URL = "https://api.gong.io"
def __init__(
self,
workspaces: list[str] | None = None,
@@ -42,23 +41,15 @@ class GongConnector(LoadConnector, PollConnector):
self.auth_token_basic: str | None = None
self.hide_user_info = hide_user_info
retry_strategy = Retry(
total=5,
backoff_factor=2,
status_forcelist=[429, 500, 502, 503, 504],
)
def _get_auth_header(self) -> dict[str, str]:
if self.auth_token_basic is None:
raise ConnectorMissingCredentialError("Gong")
session = requests.Session()
session.mount(GongConnector.BASE_URL, HTTPAdapter(max_retries=retry_strategy))
self._session = session
@staticmethod
def make_url(endpoint: str) -> str:
url = f"{GongConnector.BASE_URL}{endpoint}"
return url
return {"Authorization": f"Basic {self.auth_token_basic}"}
def _get_workspace_id_map(self) -> dict[str, str]:
response = self._session.get(GongConnector.make_url("/v2/workspaces"))
url = f"{GONG_BASE_URL}/v2/workspaces"
response = requests.get(url, headers=self._get_auth_header())
response.raise_for_status()
workspaces_details = response.json().get("workspaces")
@@ -75,6 +66,7 @@ class GongConnector(LoadConnector, PollConnector):
def _get_transcript_batches(
self, start_datetime: str | None = None, end_datetime: str | None = None
) -> Generator[list[dict[str, Any]], None, None]:
url = f"{GONG_BASE_URL}/v2/calls/transcript"
body: dict[str, dict] = {"filter": {}}
if start_datetime:
body["filter"]["fromDateTime"] = start_datetime
@@ -102,8 +94,8 @@ class GongConnector(LoadConnector, PollConnector):
del body["filter"]["workspaceId"]
while True:
response = self._session.post(
GongConnector.make_url("/v2/calls/transcript"), json=body
response = requests.post(
url, headers=self._get_auth_header(), json=body
)
# If no calls in the range, just break out
if response.status_code == 404:
@@ -133,14 +125,14 @@ class GongConnector(LoadConnector, PollConnector):
yield transcripts
def _get_call_details_by_ids(self, call_ids: list[str]) -> dict:
url = f"{GONG_BASE_URL}/v2/calls/extensive"
body = {
"filter": {"callIds": call_ids},
"contentSelector": {"exposedFields": {"parties": True}},
}
response = self._session.post(
GongConnector.make_url("/v2/calls/extensive"), json=body
)
response = requests.post(url, headers=self._get_auth_header(), json=body)
response.raise_for_status()
calls = response.json().get("calls")
@@ -188,17 +180,9 @@ class GongConnector(LoadConnector, PollConnector):
call_id = transcript.get("callId")
if not call_id or call_id not in call_details_map:
# NOTE(rkuo): seeing odd behavior where call_ids from the transcript
# don't have call details. adding error debugging logs to trace.
logger.error(
f"Couldn't get call information for Call ID: {call_id}"
)
if call_id:
logger.error(
f"Call debug info: call_id={call_id} "
f"call_ids={call_ids} "
f"call_details_map={call_details_map.keys()}"
)
if not self.continue_on_fail:
raise RuntimeError(
f"Couldn't get call information for Call ID: {call_id}"
@@ -279,13 +263,6 @@ class GongConnector(LoadConnector, PollConnector):
self.auth_token_basic = base64.b64encode(combined.encode("utf-8")).decode(
"utf-8"
)
if self.auth_token_basic is None:
raise ConnectorMissingCredentialError("Gong")
self._session.headers.update(
{"Authorization": f"Basic {self.auth_token_basic}"}
)
return None
def load_from_state(self) -> GenerateDocumentsOutput:

View File

@@ -2,11 +2,11 @@ import copy
import threading
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from functools import partial
from typing import Any
from typing import cast
from typing import Protocol
from urllib.parse import urlparse
@@ -15,7 +15,6 @@ from google.oauth2.service_account import Credentials as ServiceAccountCredentia
from googleapiclient.errors import HttpError # type: ignore
from typing_extensions import override
from onyx.configs.app_configs import GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import MAX_DRIVE_WORKERS
from onyx.configs.constants import DocumentSource
@@ -28,9 +27,7 @@ from onyx.connectors.google_drive.doc_conversion import (
)
from onyx.connectors.google_drive.file_retrieval import crawl_folders_for_files
from onyx.connectors.google_drive.file_retrieval import get_all_files_for_oauth
from onyx.connectors.google_drive.file_retrieval import (
get_all_files_in_my_drive_and_shared,
)
from onyx.connectors.google_drive.file_retrieval import get_all_files_in_my_drive
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
from onyx.connectors.google_drive.models import DriveRetrievalStage
@@ -60,13 +57,13 @@ from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import EntityFailure
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.lazy import lazy_eval
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
from onyx.utils.threadpool_concurrency import parallel_yield
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.threadpool_concurrency import ThreadSafeDict
logger = setup_logger()
@@ -88,18 +85,12 @@ def _extract_ids_from_urls(urls: list[str]) -> list[str]:
def _convert_single_file(
creds: Any,
primary_admin_email: str,
allow_images: bool,
size_threshold: int,
retriever_email: str,
file: dict[str, Any],
) -> Document | ConnectorFailure | None:
# We used to always get the user email from the file owners when available,
# but this was causing issues with shared folders where the owner was not included in the service account
# now we use the email of the account that successfully listed the file. Leaving this in case we end up
# wanting to retry with file owners and/or admin email at some point.
# user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
user_email = retriever_email
# Only construct these services when needed
user_drive_service = lazy_eval(
lambda: get_drive_service(creds, user_email=user_email)
@@ -112,7 +103,6 @@ def _convert_single_file(
drive_service=user_drive_service,
docs_service=docs_service,
allow_images=allow_images,
size_threshold=size_threshold,
)
@@ -248,8 +238,6 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
self._retrieved_ids: set[str] = set()
self.allow_images = False
self.size_threshold = GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD
def set_allow_images(self, value: bool) -> None:
self.allow_images = value
@@ -445,9 +433,6 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
logger.warning(
f"User '{user_email}' does not have access to the drive APIs."
)
# mark this user as done so we don't try to retrieve anything for them
# again
curr_stage.stage = DriveRetrievalStage.DONE
return
raise
@@ -460,11 +445,10 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
logger.info(f"Getting all files in my drive as '{user_email}'")
yield from add_retrieval_info(
get_all_files_in_my_drive_and_shared(
get_all_files_in_my_drive(
service=drive_service,
update_traversed_ids_func=self._update_traversed_parent_ids,
is_slim=is_slim,
include_shared_with_me=self.include_files_shared_with_me,
start=curr_stage.completed_until if resuming else start,
end=end,
),
@@ -472,7 +456,6 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
DriveRetrievalStage.MY_DRIVE_FILES,
)
curr_stage.stage = DriveRetrievalStage.SHARED_DRIVE_FILES
resuming = False # we are starting the next stage for the first time
if curr_stage.stage == DriveRetrievalStage.SHARED_DRIVE_FILES:
@@ -508,7 +491,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
)
yield from _yield_from_drive(drive_id, start)
curr_stage.stage = DriveRetrievalStage.FOLDER_FILES
resuming = False # we are starting the next stage for the first time
if curr_stage.stage == DriveRetrievalStage.FOLDER_FILES:
def _yield_from_folder_crawl(
@@ -561,16 +544,6 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
checkpoint, is_slim, DriveRetrievalStage.MY_DRIVE_FILES
)
# Setup initial completion map on first connector run
for email in all_org_emails:
# don't overwrite existing completion map on resuming runs
if email in checkpoint.completion_map:
continue
checkpoint.completion_map[email] = StageCompletion(
stage=DriveRetrievalStage.START,
completed_until=0,
)
# we've found all users and drives, now time to actually start
# fetching stuff
logger.info(f"Found {len(all_org_emails)} users to impersonate")
@@ -584,25 +557,11 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
drive_ids_to_retrieve, checkpoint
)
# only process emails that we haven't already completed retrieval for
non_completed_org_emails = [
user_email
for user_email, stage in checkpoint.completion_map.items()
if stage != DriveRetrievalStage.DONE
]
# don't process too many emails before returning a checkpoint. This is
# to resolve the case where there are a ton of emails that don't have access
# to the drive APIs. Without this, we could loop through these emails for
# more than 3 hours, causing a timeout and stalling progress.
email_batch_takes_us_to_completion = True
MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING = 50
if len(non_completed_org_emails) > MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING:
non_completed_org_emails = non_completed_org_emails[
:MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING
]
email_batch_takes_us_to_completion = False
for email in all_org_emails:
checkpoint.completion_map[email] = StageCompletion(
stage=DriveRetrievalStage.START,
completed_until=0,
)
user_retrieval_gens = [
self._impersonate_user_for_retrieval(
email,
@@ -613,14 +572,10 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
start,
end,
)
for email in non_completed_org_emails
for email in all_org_emails
]
yield from parallel_yield(user_retrieval_gens, max_workers=MAX_DRIVE_WORKERS)
# if there are more emails to process, don't mark as complete
if not email_batch_takes_us_to_completion:
return
remaining_folders = (
drive_ids_to_retrieve | folder_ids_to_retrieve
) - self._retrieved_ids
@@ -837,12 +792,10 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
return
for file in drive_files:
if file.error is None:
if file.error is not None:
checkpoint.completion_map[file.user_email].update(
stage=file.completion_stage,
completed_until=datetime.fromisoformat(
file.drive_file[GoogleFields.MODIFIED_TIME.value]
).timestamp(),
completed_until=file.drive_file[GoogleFields.MODIFIED_TIME.value],
completed_until_parent_id=file.parent_id,
)
yield file
@@ -944,86 +897,117 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
checkpoint: GoogleDriveCheckpoint,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[Document | ConnectorFailure]:
) -> Iterator[list[Document | ConnectorFailure]]:
try:
# Prepare a partial function with the credentials and admin email
convert_func = partial(
_convert_single_file,
self.creds,
self.allow_images,
self.size_threshold,
)
# Fetch files in batches
batches_complete = 0
files_batch: list[RetrievedDriveFile] = []
def _yield_batch(
files_batch: list[RetrievedDriveFile],
) -> Iterator[Document | ConnectorFailure]:
nonlocal batches_complete
# Process the batch using run_functions_tuples_in_parallel
func_with_args = [
(
convert_func,
(
file.user_email,
file.drive_file,
),
)
for file in files_batch
]
results = cast(
list[Document | ConnectorFailure | None],
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
# Create a larger process pool for file conversion
with ThreadPoolExecutor(max_workers=8) as executor:
# Prepare a partial function with the credentials and admin email
convert_func = partial(
_convert_single_file,
self.creds,
self.primary_admin_email,
self.allow_images,
)
docs_and_failures = [result for result in results if result is not None]
# Fetch files in batches
batches_complete = 0
files_batch: list[GoogleDriveFileType] = []
for retrieved_file in self._fetch_drive_items(
is_slim=False,
checkpoint=checkpoint,
start=start,
end=end,
):
if retrieved_file.error is not None:
failure_stage = retrieved_file.completion_stage.value
failure_message = (
f"retrieval failure during stage: {failure_stage},"
)
failure_message += f"user: {retrieved_file.user_email},"
failure_message += (
f"parent drive/folder: {retrieved_file.parent_id},"
)
failure_message += f"error: {retrieved_file.error}"
logger.error(failure_message)
yield [
ConnectorFailure(
failed_entity=EntityFailure(
entity_id=failure_stage,
),
failure_message=failure_message,
exception=retrieved_file.error,
)
]
continue
files_batch.append(retrieved_file.drive_file)
if docs_and_failures:
yield from docs_and_failures
batches_complete += 1
if len(files_batch) < self.batch_size:
continue
for retrieved_file in self._fetch_drive_items(
is_slim=False,
checkpoint=checkpoint,
start=start,
end=end,
):
if retrieved_file.error is not None:
failure_stage = retrieved_file.completion_stage.value
failure_message = (
f"retrieval failure during stage: {failure_stage},"
)
failure_message += f"user: {retrieved_file.user_email},"
failure_message += (
f"parent drive/folder: {retrieved_file.parent_id},"
)
failure_message += f"error: {retrieved_file.error}"
logger.error(failure_message)
yield ConnectorFailure(
failed_entity=EntityFailure(
entity_id=failure_stage,
),
failure_message=failure_message,
exception=retrieved_file.error,
)
# Process the batch
futures = [
executor.submit(convert_func, file) for file in files_batch
]
documents = []
for future in as_completed(futures):
try:
doc = future.result()
if doc is not None:
documents.append(doc)
except Exception as e:
error_str = f"Error converting file: {e}"
logger.error(error_str)
yield [
ConnectorFailure(
failed_document=DocumentFailure(
document_id=retrieved_file.drive_file["id"],
document_link=retrieved_file.drive_file[
"webViewLink"
],
),
failure_message=error_str,
exception=e,
)
]
continue
files_batch.append(retrieved_file)
if documents:
yield documents
batches_complete += 1
files_batch = []
if len(files_batch) < self.batch_size:
continue
if batches_complete > BATCHES_PER_CHECKPOINT:
checkpoint.retrieved_folder_and_drive_ids = self._retrieved_ids
return # create a new checkpoint
yield from _yield_batch(files_batch)
files_batch = []
# Process any remaining files
if files_batch:
futures = [
executor.submit(convert_func, file) for file in files_batch
]
documents = []
for future in as_completed(futures):
try:
doc = future.result()
if doc is not None:
documents.append(doc)
except Exception as e:
error_str = f"Error converting file: {e}"
logger.error(error_str)
yield [
ConnectorFailure(
failed_document=DocumentFailure(
document_id=retrieved_file.drive_file["id"],
document_link=retrieved_file.drive_file[
"webViewLink"
],
),
failure_message=error_str,
exception=e,
)
]
if batches_complete > BATCHES_PER_CHECKPOINT:
checkpoint.retrieved_folder_and_drive_ids = self._retrieved_ids
return # create a new checkpoint
# Process any remaining files
if files_batch:
yield from _yield_batch(files_batch)
if documents:
yield documents
except Exception as e:
logger.exception(f"Error extracting documents from Google Drive: {e}")
raise e
@@ -1045,7 +1029,10 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
checkpoint = copy.deepcopy(checkpoint)
self._retrieved_ids = checkpoint.retrieved_folder_and_drive_ids
try:
yield from self._extract_docs_from_google_drive(checkpoint, start, end)
for doc_list in self._extract_docs_from_google_drive(
checkpoint, start, end
):
yield from doc_list
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
@@ -1080,7 +1067,9 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
raise RuntimeError(
"_extract_slim_docs_from_google_drive: Stop signal detected"
)
callback.progress("_extract_slim_docs_from_google_drive", 1)
yield slim_batch
def retrieve_all_slim_documents(

View File

@@ -30,7 +30,6 @@ from onyx.file_processing.file_validation import is_valid_image_type
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.llm.interfaces import LLM
from onyx.utils.lazy import lazy_eval
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -77,27 +76,7 @@ def is_gdrive_image_mime_type(mime_type: str) -> bool:
return is_valid_image_type(mime_type)
def download_request(service: GoogleDriveService, file_id: str) -> bytes:
"""
Download the file from Google Drive.
"""
# For other file types, download the file
# Use the correct API call for downloading files
request = service.files().get_media(fileId=file_id)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
while not done:
_, done = downloader.next_chunk()
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to download {file_id}")
return bytes()
return response
def _download_and_extract_sections_basic(
def _extract_sections_basic(
file: dict[str, str],
service: GoogleDriveService,
allow_images: bool,
@@ -108,17 +87,35 @@ def _download_and_extract_sections_basic(
mime_type = file["mimeType"]
link = file.get("webViewLink", "")
# skip images if not explicitly enabled
if not allow_images and is_gdrive_image_mime_type(mime_type):
return []
try:
# skip images if not explicitly enabled
if not allow_images and is_gdrive_image_mime_type(mime_type):
return []
# For Google Docs, Sheets, and Slides, export as plain text
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
# Use the correct API call for exporting files
request = service.files().export_media(
fileId=file_id, mimeType=export_mime_type
)
# For Google Docs, Sheets, and Slides, export as plain text
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
# Use the correct API call for exporting files
request = service.files().export_media(
fileId=file_id, mimeType=export_mime_type
)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
while not done:
_, done = downloader.next_chunk()
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
return []
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
# For other file types, download the file
# Use the correct API call for downloading files
request = service.files().get_media(fileId=file_id)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
@@ -127,97 +124,88 @@ def _download_and_extract_sections_basic(
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
logger.warning(f"Failed to download {file_name}")
return []
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
# Process based on mime type
if mime_type == "text/plain":
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
# For other file types, download the file
# Use the correct API call for downloading files
response_call = lazy_eval(lambda: download_request(service, file_id))
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
text, _ = docx_to_text_and_images(io.BytesIO(response))
return [TextSection(link=link, text=text)]
# Process based on mime type
if mime_type == "text/plain":
text = response_call().decode("utf-8")
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
):
text = xlsx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
text, _ = docx_to_text_and_images(io.BytesIO(response_call()))
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
):
text = pptx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif (
mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
):
text = xlsx_to_text(io.BytesIO(response_call()))
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
):
text = pptx_to_text(io.BytesIO(response_call()))
return [TextSection(link=link, text=text)]
elif is_gdrive_image_mime_type(mime_type):
# For images, store them for later processing
sections: list[TextSection | ImageSection] = []
try:
with get_session_with_current_tenant() as db_session:
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=response_call(),
file_name=file_id,
display_name=file_name,
media_type=mime_type,
file_origin=FileOrigin.CONNECTOR,
link=link,
)
sections.append(section)
except Exception as e:
logger.error(f"Failed to process image {file_name}: {e}")
return sections
elif mime_type == "application/pdf":
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_call()))
pdf_sections: list[TextSection | ImageSection] = [
TextSection(link=link, text=text)
]
# Process embedded images in the PDF
try:
with get_session_with_current_tenant() as db_session:
for idx, (img_data, img_name) in enumerate(images):
elif is_gdrive_image_mime_type(mime_type):
# For images, store them for later processing
sections: list[TextSection | ImageSection] = []
try:
with get_session_with_current_tenant() as db_session:
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=img_data,
file_name=f"{file_id}_img_{idx}",
display_name=img_name or f"{file_name} - image {idx}",
image_data=response,
file_name=file_id,
display_name=file_name,
media_type=mime_type,
file_origin=FileOrigin.CONNECTOR,
link=link,
)
pdf_sections.append(section)
except Exception as e:
logger.error(f"Failed to process PDF images in {file_name}: {e}")
return pdf_sections
sections.append(section)
except Exception as e:
logger.error(f"Failed to process image {file_name}: {e}")
return sections
else:
# For unsupported file types, try to extract text
if mime_type in [
"application/vnd.google-apps.video",
"application/vnd.google-apps.audio",
"application/zip",
]:
return []
# For unsupported file types, try to extract text
try:
text = extract_file_text(io.BytesIO(response_call()), file_name)
return [TextSection(link=link, text=text)]
except Exception as e:
logger.warning(f"Failed to extract text from {file_name}: {e}")
return []
elif mime_type == "application/pdf":
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response))
pdf_sections: list[TextSection | ImageSection] = [
TextSection(link=link, text=text)
]
# Process embedded images in the PDF
try:
with get_session_with_current_tenant() as db_session:
for idx, (img_data, img_name) in enumerate(images):
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=img_data,
file_name=f"{file_id}_img_{idx}",
display_name=img_name or f"{file_name} - image {idx}",
file_origin=FileOrigin.CONNECTOR,
)
pdf_sections.append(section)
except Exception as e:
logger.error(f"Failed to process PDF images in {file_name}: {e}")
return pdf_sections
else:
# For unsupported file types, try to extract text
try:
text = extract_file_text(io.BytesIO(response), file_name)
return [TextSection(link=link, text=text)]
except Exception as e:
logger.warning(f"Failed to extract text from {file_name}: {e}")
return []
except Exception as e:
logger.error(f"Error processing file {file_name}: {e}")
return []
def convert_drive_item_to_document(
@@ -225,7 +213,6 @@ def convert_drive_item_to_document(
drive_service: Callable[[], GoogleDriveService],
docs_service: Callable[[], GoogleDocsService],
allow_images: bool,
size_threshold: int,
) -> Document | ConnectorFailure | None:
"""
Main entry point for converting a Google Drive file => Document object.
@@ -253,24 +240,9 @@ def convert_drive_item_to_document(
f"Error in advanced parsing: {e}. Falling back to basic extraction."
)
size_str = file.get("size")
if size_str:
try:
size_int = int(size_str)
except ValueError:
logger.warning(f"Parsing string to int failed: size_str={size_str}")
else:
if size_int > size_threshold:
logger.warning(
f"{file.get('name')} exceeds size threshold of {size_threshold}. Skipping."
)
return None
# If we don't have sections yet, use the basic extraction method
if not sections:
sections = _download_and_extract_sections_basic(
file, drive_service(), allow_images
)
sections = _extract_sections_basic(file, drive_service(), allow_images)
# If we still don't have any sections, skip this file
if not sections:

View File

@@ -123,7 +123,7 @@ def crawl_folders_for_files(
end=end,
):
found_files = True
logger.info(f"Found file: {file['name']}, user email: {user_email}")
logger.info(f"Found file: {file['name']}")
yield RetrievedDriveFile(
drive_file=file,
user_email=user_email,
@@ -214,11 +214,10 @@ def get_files_in_shared_drive(
yield file
def get_all_files_in_my_drive_and_shared(
def get_all_files_in_my_drive(
service: GoogleDriveService,
update_traversed_ids_func: Callable,
is_slim: bool,
include_shared_with_me: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
@@ -230,8 +229,7 @@ def get_all_files_in_my_drive_and_shared(
# Get all folders being queried and add them to the traversed set
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
folder_query += " and trashed = false"
if not include_shared_with_me:
folder_query += " and 'me' in owners"
folder_query += " and 'me' in owners"
found_folders = False
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
@@ -248,8 +246,7 @@ def get_all_files_in_my_drive_and_shared(
# Then get the files
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
file_query += " and trashed = false"
if not include_shared_with_me:
file_query += " and 'me' in owners"
file_query += " and 'me' in owners"
file_query += _generate_time_range_filter(start, end)
yield from execute_paginated_retrieval(
retrieval_function=service.files().list,

View File

@@ -75,7 +75,7 @@ class HighspotClient:
self.key = key
self.secret = secret
self.base_url = base_url.rstrip("/") + "/"
self.base_url = base_url
self.timeout = timeout
# Set up session with retry logic

View File

@@ -20,9 +20,8 @@ from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import ACCEPTED_DOCUMENT_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import VALID_FILE_EXTENSIONS
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -85,21 +84,14 @@ class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
Populate the spot ID map with all available spots.
Keys are stored as lowercase for case-insensitive lookups.
"""
try:
spots = self.client.get_spots()
for spot in spots:
if "title" in spot and "id" in spot:
spot_name = spot["title"]
self._spot_id_map[spot_name.lower()] = spot["id"]
spots = self.client.get_spots()
for spot in spots:
if "title" in spot and "id" in spot:
spot_name = spot["title"]
self._spot_id_map[spot_name.lower()] = spot["id"]
self._all_spots_fetched = True
logger.info(f"Retrieved {len(self._spot_id_map)} spots from Highspot")
except HighspotClientError as e:
logger.error(f"Error retrieving spots from Highspot: {str(e)}")
raise
except Exception as e:
logger.error(f"Unexpected error retrieving spots from Highspot: {str(e)}")
raise
self._all_spots_fetched = True
logger.info(f"Retrieved {len(self._spot_id_map)} spots from Highspot")
def _get_all_spot_names(self) -> List[str]:
"""
@@ -159,142 +151,116 @@ class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
Batches of Document objects
"""
doc_batch: list[Document] = []
try:
# If no spots specified, get all spots
spot_names_to_process = self.spot_names
if not spot_names_to_process:
spot_names_to_process = self._get_all_spot_names()
if not spot_names_to_process:
logger.warning("No spots found in Highspot")
raise ValueError("No spots found in Highspot")
logger.info(
f"No spots specified, using all {len(spot_names_to_process)} available spots"
)
for spot_name in spot_names_to_process:
try:
spot_id = self._get_spot_id_from_name(spot_name)
if spot_id is None:
logger.warning(f"Spot ID not found for spot {spot_name}")
continue
offset = 0
has_more = True
# If no spots specified, get all spots
spot_names_to_process = self.spot_names
if not spot_names_to_process:
spot_names_to_process = self._get_all_spot_names()
logger.info(
f"No spots specified, using all {len(spot_names_to_process)} available spots"
)
while has_more:
logger.info(
f"Retrieving items from spot {spot_name}, offset {offset}"
)
response = self.client.get_spot_items(
spot_id=spot_id, offset=offset, page_size=self.batch_size
)
items = response.get("collection", [])
logger.info(f"Received Items: {items}")
if not items:
has_more = False
continue
for spot_name in spot_names_to_process:
try:
spot_id = self._get_spot_id_from_name(spot_name)
if spot_id is None:
logger.warning(f"Spot ID not found for spot {spot_name}")
continue
offset = 0
has_more = True
for item in items:
try:
item_id = item.get("id")
if not item_id:
logger.warning("Item without ID found, skipping")
continue
item_details = self.client.get_item(item_id)
if not item_details:
logger.warning(
f"Item {item_id} details not found, skipping"
)
continue
# Apply time filter if specified
if start or end:
updated_at = item_details.get("date_updated")
if updated_at:
# Convert to datetime for comparison
try:
updated_time = datetime.fromisoformat(
updated_at.replace("Z", "+00:00")
)
if (
start
and updated_time.timestamp() < start
) or (
end and updated_time.timestamp() > end
):
continue
except (ValueError, TypeError):
# Skip if date cannot be parsed
logger.warning(
f"Invalid date format for item {item_id}: {updated_at}"
)
continue
content = self._get_item_content(item_details)
title = item_details.get("title", "")
doc_batch.append(
Document(
id=f"HIGHSPOT_{item_id}",
sections=[
TextSection(
link=item_details.get(
"url",
f"https://www.highspot.com/items/{item_id}",
),
text=content,
)
],
source=DocumentSource.HIGHSPOT,
semantic_identifier=title,
metadata={
"spot_name": spot_name,
"type": item_details.get(
"content_type", ""
),
"created_at": item_details.get(
"date_added", ""
),
"author": item_details.get("author", ""),
"language": item_details.get(
"language", ""
),
"can_download": str(
item_details.get("can_download", False)
),
},
doc_updated_at=item_details.get("date_updated"),
)
)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
except HighspotClientError as e:
item_id = "ID" if not item_id else item_id
logger.error(
f"Error retrieving item {item_id}: {str(e)}"
)
except Exception as e:
item_id = "ID" if not item_id else item_id
logger.error(
f"Unexpected error for item {item_id}: {str(e)}"
)
has_more = len(items) >= self.batch_size
offset += self.batch_size
except (HighspotClientError, ValueError) as e:
logger.error(f"Error processing spot {spot_name}: {str(e)}")
except Exception as e:
logger.error(
f"Unexpected error processing spot {spot_name}: {str(e)}"
while has_more:
logger.info(
f"Retrieving items from spot {spot_name}, offset {offset}"
)
response = self.client.get_spot_items(
spot_id=spot_id, offset=offset, page_size=self.batch_size
)
items = response.get("collection", [])
logger.info(f"Received Items: {items}")
if not items:
has_more = False
continue
except Exception as e:
logger.error(f"Error in Highspot connector: {str(e)}")
raise
for item in items:
try:
item_id = item.get("id")
if not item_id:
logger.warning("Item without ID found, skipping")
continue
item_details = self.client.get_item(item_id)
if not item_details:
logger.warning(
f"Item {item_id} details not found, skipping"
)
continue
# Apply time filter if specified
if start or end:
updated_at = item_details.get("date_updated")
if updated_at:
# Convert to datetime for comparison
try:
updated_time = datetime.fromisoformat(
updated_at.replace("Z", "+00:00")
)
if (
start and updated_time.timestamp() < start
) or (end and updated_time.timestamp() > end):
continue
except (ValueError, TypeError):
# Skip if date cannot be parsed
logger.warning(
f"Invalid date format for item {item_id}: {updated_at}"
)
continue
content = self._get_item_content(item_details)
title = item_details.get("title", "")
doc_batch.append(
Document(
id=f"HIGHSPOT_{item_id}",
sections=[
TextSection(
link=item_details.get(
"url",
f"https://www.highspot.com/items/{item_id}",
),
text=content,
)
],
source=DocumentSource.HIGHSPOT,
semantic_identifier=title,
metadata={
"spot_name": spot_name,
"type": item_details.get("content_type", ""),
"created_at": item_details.get(
"date_added", ""
),
"author": item_details.get("author", ""),
"language": item_details.get("language", ""),
"can_download": str(
item_details.get("can_download", False)
),
},
doc_updated_at=item_details.get("date_updated"),
)
)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
except HighspotClientError as e:
item_id = "ID" if not item_id else item_id
logger.error(f"Error retrieving item {item_id}: {str(e)}")
has_more = len(items) >= self.batch_size
offset += self.batch_size
except (HighspotClientError, ValueError) as e:
logger.error(f"Error processing spot {spot_name}: {str(e)}")
if doc_batch:
yield doc_batch
@@ -320,9 +286,7 @@ class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
# Extract title and description once at the beginning
title, description = self._extract_title_and_description(item_details)
default_content = f"{title}\n{description}"
logger.info(
f"Processing item {item_id} with extension {file_extension} and file name {content_name}"
)
logger.info(f"Processing item {item_id} with extension {file_extension}")
try:
if content_type == "WebLink":
@@ -334,39 +298,30 @@ class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
elif (
is_valid_format
and (
file_extension in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS
or file_extension in ACCEPTED_DOCUMENT_FILE_EXTENSIONS
)
and file_extension in VALID_FILE_EXTENSIONS
and can_download
):
# For documents, try to get the text content
if not item_id: # Ensure item_id is defined
return default_content
content_response = self.client.get_item_content(item_id)
# Process and extract text from binary content based on type
if content_response:
text_content = extract_file_text(
BytesIO(content_response), content_name, False
BytesIO(content_response), content_name
)
return text_content if text_content else default_content
return text_content
return default_content
else:
return default_content
except HighspotClientError as e:
error_context = f"item {item_id}" if item_id else "(item id not found)"
# Use item_id safely in the warning message
error_context = f"item {item_id}" if item_id else "item"
logger.warning(f"Could not retrieve content for {error_context}: {str(e)}")
return default_content
except ValueError as e:
error_context = f"item {item_id}" if item_id else "(item id not found)"
logger.error(f"Value error for {error_context}: {str(e)}")
return default_content
except Exception as e:
error_context = f"item {item_id}" if item_id else "(item id not found)"
logger.error(
f"Unexpected error retrieving content for {error_context}: {str(e)}"
)
return default_content
return ""
def _extract_title_and_description(
self, item_details: Dict[str, Any]
@@ -403,63 +358,55 @@ class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
Batches of SlimDocument objects
"""
slim_doc_batch: list[SlimDocument] = []
try:
# If no spots specified, get all spots
spot_names_to_process = self.spot_names
if not spot_names_to_process:
spot_names_to_process = self._get_all_spot_names()
if not spot_names_to_process:
logger.warning("No spots found in Highspot")
raise ValueError("No spots found in Highspot")
logger.info(
f"No spots specified, using all {len(spot_names_to_process)} available spots for slim documents"
)
for spot_name in spot_names_to_process:
try:
spot_id = self._get_spot_id_from_name(spot_name)
offset = 0
has_more = True
# If no spots specified, get all spots
spot_names_to_process = self.spot_names
if not spot_names_to_process:
spot_names_to_process = self._get_all_spot_names()
logger.info(
f"No spots specified, using all {len(spot_names_to_process)} available spots for slim documents"
)
while has_more:
logger.info(
f"Retrieving slim documents from spot {spot_name}, offset {offset}"
)
response = self.client.get_spot_items(
spot_id=spot_id, offset=offset, page_size=self.batch_size
)
for spot_name in spot_names_to_process:
try:
spot_id = self._get_spot_id_from_name(spot_name)
offset = 0
has_more = True
items = response.get("collection", [])
if not items:
has_more = False
continue
for item in items:
item_id = item.get("id")
if not item_id:
continue
slim_doc_batch.append(
SlimDocument(id=f"HIGHSPOT_{item_id}")
)
if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
yield slim_doc_batch
slim_doc_batch = []
has_more = len(items) >= self.batch_size
offset += self.batch_size
except (HighspotClientError, ValueError) as e:
logger.error(
f"Error retrieving slim documents from spot {spot_name}: {str(e)}"
while has_more:
logger.info(
f"Retrieving slim documents from spot {spot_name}, offset {offset}"
)
response = self.client.get_spot_items(
spot_id=spot_id, offset=offset, page_size=self.batch_size
)
if slim_doc_batch:
yield slim_doc_batch
except Exception as e:
logger.error(f"Error in Highspot Slim Connector: {str(e)}")
raise
items = response.get("collection", [])
if not items:
has_more = False
continue
for item in items:
item_id = item.get("id")
if not item_id:
continue
slim_doc_batch.append(SlimDocument(id=f"HIGHSPOT_{item_id}"))
if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
yield slim_doc_batch
slim_doc_batch = []
has_more = len(items) >= self.batch_size
offset += self.batch_size
except (HighspotClientError, ValueError) as e:
logger.error(
f"Error retrieving slim documents from spot {spot_name}: {str(e)}"
)
if slim_doc_batch:
yield slim_doc_batch
def validate_credentials(self) -> bool:
"""

View File

@@ -8,6 +8,7 @@ from typing import TypeAlias
from typing import TypeVar
from pydantic import BaseModel
from typing_extensions import override
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import ConnectorCheckpoint
@@ -230,7 +231,7 @@ class CheckpointConnector(BaseConnector[CT]):
"""
raise NotImplementedError
@abc.abstractmethod
@override
def build_dummy_checkpoint(self) -> CT:
raise NotImplementedError

View File

@@ -1,4 +1,3 @@
import sys
from datetime import datetime
from enum import Enum
from typing import Any
@@ -41,9 +40,6 @@ class TextSection(Section):
text: str
link: str | None = None
def __sizeof__(self) -> int:
return sys.getsizeof(self.text) + sys.getsizeof(self.link)
class ImageSection(Section):
"""Section containing an image reference"""
@@ -51,9 +47,6 @@ class ImageSection(Section):
image_file_name: str
link: str | None = None
def __sizeof__(self) -> int:
return sys.getsizeof(self.image_file_name) + sys.getsizeof(self.link)
class BasicExpertInfo(BaseModel):
"""Basic Information for the owner of a document, any of the fields can be left as None
@@ -117,14 +110,6 @@ class BasicExpertInfo(BaseModel):
)
)
def __sizeof__(self) -> int:
size = sys.getsizeof(self.display_name)
size += sys.getsizeof(self.first_name)
size += sys.getsizeof(self.middle_initial)
size += sys.getsizeof(self.last_name)
size += sys.getsizeof(self.email)
return size
class DocumentBase(BaseModel):
"""Used for Onyx ingestion api, the ID is inferred before use if not provided"""
@@ -178,35 +163,6 @@ class DocumentBase(BaseModel):
attributes.append(k + INDEX_SEPARATOR + v)
return attributes
def __sizeof__(self) -> int:
size = sys.getsizeof(self.id)
for section in self.sections:
size += sys.getsizeof(section)
size += sys.getsizeof(self.source)
size += sys.getsizeof(self.semantic_identifier)
size += sys.getsizeof(self.doc_updated_at)
size += sys.getsizeof(self.chunk_count)
if self.primary_owners is not None:
for primary_owner in self.primary_owners:
size += sys.getsizeof(primary_owner)
else:
size += sys.getsizeof(self.primary_owners)
if self.secondary_owners is not None:
for secondary_owner in self.secondary_owners:
size += sys.getsizeof(secondary_owner)
else:
size += sys.getsizeof(self.secondary_owners)
size += sys.getsizeof(self.title)
size += sys.getsizeof(self.from_ingestion_api)
size += sys.getsizeof(self.additional_info)
return size
def get_text_content(self) -> str:
return " ".join([section.text for section in self.sections if section.text])
class Document(DocumentBase):
"""Used for Onyx ingestion api, the ID is required"""
@@ -235,12 +191,6 @@ class Document(DocumentBase):
from_ingestion_api=base.from_ingestion_api,
)
def __sizeof__(self) -> int:
size = super().__sizeof__()
size += sys.getsizeof(self.id)
size += sys.getsizeof(self.source)
return size
class IndexingDocument(Document):
"""Document with processed sections for indexing"""

View File

@@ -1,9 +1,4 @@
import gc
import os
import sys
import tempfile
from collections import defaultdict
from pathlib import Path
from typing import Any
from simple_salesforce import Salesforce
@@ -26,13 +21,9 @@ from onyx.connectors.salesforce.salesforce_calls import get_all_children_of_sf_t
from onyx.connectors.salesforce.sqlite_functions import get_affected_parent_ids_by_type
from onyx.connectors.salesforce.sqlite_functions import get_record
from onyx.connectors.salesforce.sqlite_functions import init_db
from onyx.connectors.salesforce.sqlite_functions import sqlite_log_stats
from onyx.connectors.salesforce.sqlite_functions import update_sf_db_with_csv
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
from onyx.connectors.salesforce.utils import get_sqlite_db_path
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -41,8 +32,6 @@ _DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
MAX_BATCH_BYTES = 1024 * 1024
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
@@ -75,45 +64,22 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
raise ConnectorMissingCredentialError("Salesforce")
return self._sf_client
@staticmethod
def reconstruct_object_types(directory: str) -> dict[str, list[str] | None]:
"""
Scans the given directory for all CSV files and reconstructs the available object types.
Assumes filenames are formatted as "ObjectType.filename.csv" or "ObjectType.csv".
Args:
directory (str): The path to the directory containing CSV files.
Returns:
dict[str, list[str]]: A dictionary mapping object types to lists of file paths.
"""
object_types = defaultdict(list)
for filename in os.listdir(directory):
if filename.endswith(".csv"):
parts = filename.split(".", 1) # Split on the first period
object_type = parts[0] # Take the first part as the object type
object_types[object_type].append(os.path.join(directory, filename))
return dict(object_types)
@staticmethod
def _download_object_csvs(
directory: str,
parent_object_list: list[str],
sf_client: Salesforce,
def _fetch_from_salesforce(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> None:
all_object_types: set[str] = set(parent_object_list)
) -> GenerateDocumentsOutput:
init_db()
all_object_types: set[str] = set(self.parent_object_list)
logger.info(
f"Parent object types: num={len(parent_object_list)} list={parent_object_list}"
)
logger.info(f"Starting with {len(self.parent_object_list)} parent object types")
logger.debug(f"Parent object types: {self.parent_object_list}")
# This takes like 20 seconds
for parent_object_type in parent_object_list:
child_types = get_all_children_of_sf_type(sf_client, parent_object_type)
for parent_object_type in self.parent_object_list:
child_types = get_all_children_of_sf_type(
self.sf_client, parent_object_type
)
all_object_types.update(child_types)
logger.debug(
f"Found {len(child_types)} child types for {parent_object_type}"
@@ -122,53 +88,20 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
# Always want to make sure user is grabbed for permissioning purposes
all_object_types.add("User")
logger.info(
f"All object types: num={len(all_object_types)} list={all_object_types}"
)
# gc.collect()
logger.info(f"Found total of {len(all_object_types)} object types to fetch")
logger.debug(f"All object types: {all_object_types}")
# checkpoint - we've found all object types, now time to fetch the data
logger.info("Fetching CSVs for all object types")
logger.info("Starting to fetch CSVs for all object types")
# This takes like 30 minutes first time and <2 minutes for updates
object_type_to_csv_path = fetch_all_csvs_in_parallel(
sf_client=sf_client,
sf_client=self.sf_client,
object_types=all_object_types,
start=start,
end=end,
target_dir=directory,
)
# print useful information
num_csvs = 0
num_bytes = 0
for object_type, csv_paths in object_type_to_csv_path.items():
if not csv_paths:
continue
for csv_path in csv_paths:
if not csv_path:
continue
file_path = Path(csv_path)
file_size = file_path.stat().st_size
num_csvs += 1
num_bytes += file_size
logger.info(
f"CSV info: object_type={object_type} path={csv_path} bytes={file_size}"
)
logger.info(f"CSV info total: total_csvs={num_csvs} total_bytes={num_bytes}")
@staticmethod
def _load_csvs_to_db(csv_directory: str, db_directory: str) -> set[str]:
updated_ids: set[str] = set()
object_type_to_csv_path = SalesforceConnector.reconstruct_object_types(
csv_directory
)
# This takes like 10 seconds
# This is for testing the rest of the functionality if data has
# already been fetched and put in sqlite
@@ -187,16 +120,10 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
# If path is None, it means it failed to fetch the csv
if csv_paths is None:
continue
# Go through each csv path and use it to update the db
for csv_path in csv_paths:
logger.debug(
f"Processing CSV: object_type={object_type} "
f"csv={csv_path} "
f"len={Path(csv_path).stat().st_size}"
)
logger.debug(f"Updating {object_type} with {csv_path}")
new_ids = update_sf_db_with_csv(
db_directory,
object_type=object_type,
csv_download_path=csv_path,
)
@@ -205,127 +132,49 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
f"Added {len(new_ids)} new/updated records for {object_type}"
)
os.remove(csv_path)
return updated_ids
def _fetch_from_salesforce(
self,
temp_dir: str,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
logger.info("_fetch_from_salesforce starting.")
if not self._sf_client:
raise RuntimeError("self._sf_client is None!")
init_db(temp_dir)
sqlite_log_stats(temp_dir)
# Step 1 - download
SalesforceConnector._download_object_csvs(
temp_dir, self.parent_object_list, self._sf_client, start, end
)
gc.collect()
# Step 2 - load CSV's to sqlite
updated_ids = SalesforceConnector._load_csvs_to_db(temp_dir, temp_dir)
gc.collect()
logger.info(f"Found {len(updated_ids)} total updated records")
logger.info(
f"Starting to process parent objects of types: {self.parent_object_list}"
)
# Step 3 - extract and index docs
batches_processed = 0
docs_processed = 0
docs_to_yield: list[Document] = []
docs_to_yield_bytes = 0
docs_processed = 0
# Takes 15-20 seconds per batch
for parent_type, parent_id_batch in get_affected_parent_ids_by_type(
temp_dir,
updated_ids=list(updated_ids),
parent_types=self.parent_object_list,
):
batches_processed += 1
logger.info(
f"Processing batch: index={batches_processed} "
f"object_type={parent_type} "
f"len={len(parent_id_batch)} "
f"processed={docs_processed} "
f"remaining={len(updated_ids) - docs_processed}"
f"Processing batch of {len(parent_id_batch)} {parent_type} objects"
)
for parent_id in parent_id_batch:
if not (parent_object := get_record(temp_dir, parent_id, parent_type)):
if not (parent_object := get_record(parent_id, parent_type)):
logger.warning(
f"Failed to get parent object {parent_id} for {parent_type}"
)
continue
doc = convert_sf_object_to_doc(
temp_dir,
sf_object=parent_object,
sf_instance=self.sf_client.sf_instance,
docs_to_yield.append(
convert_sf_object_to_doc(
sf_object=parent_object,
sf_instance=self.sf_client.sf_instance,
)
)
doc_sizeof = sys.getsizeof(doc)
docs_to_yield_bytes += doc_sizeof
docs_to_yield.append(doc)
docs_processed += 1
# memory usage is sensitive to the input length, so we're yielding immediately
# if the batch exceeds a certain byte length
if (
len(docs_to_yield) >= self.batch_size
or docs_to_yield_bytes > SalesforceConnector.MAX_BATCH_BYTES
):
if len(docs_to_yield) >= self.batch_size:
yield docs_to_yield
docs_to_yield = []
docs_to_yield_bytes = 0
# observed a memory leak / size issue with the account table if we don't gc.collect here.
gc.collect()
yield docs_to_yield
logger.info(
f"Final processing stats: "
f"processed={docs_processed} "
f"remaining={len(updated_ids) - docs_processed}"
)
def load_from_state(self) -> GenerateDocumentsOutput:
if MULTI_TENANT:
# if multi tenant, we cannot expect the sqlite db to be cached/present
with tempfile.TemporaryDirectory() as temp_dir:
return self._fetch_from_salesforce(temp_dir)
# nuke the db since we're starting from scratch
sqlite_db_path = get_sqlite_db_path(BASE_DATA_PATH)
if os.path.exists(sqlite_db_path):
logger.info(f"load_from_state: Removing db at {sqlite_db_path}.")
os.remove(sqlite_db_path)
return self._fetch_from_salesforce(BASE_DATA_PATH)
return self._fetch_from_salesforce()
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
if MULTI_TENANT:
# if multi tenant, we cannot expect the sqlite db to be cached/present
with tempfile.TemporaryDirectory() as temp_dir:
return self._fetch_from_salesforce(temp_dir, start=start, end=end)
if start == 0:
# nuke the db if we're starting from scratch
sqlite_db_path = get_sqlite_db_path(BASE_DATA_PATH)
if os.path.exists(sqlite_db_path):
logger.info(
f"poll_source: Starting at time 0, removing db at {sqlite_db_path}."
)
os.remove(sqlite_db_path)
return self._fetch_from_salesforce(BASE_DATA_PATH)
return self._fetch_from_salesforce(start=start, end=end)
def retrieve_all_slim_documents(
self,
@@ -360,7 +209,7 @@ if __name__ == "__main__":
"sf_security_token": os.environ["SF_SECURITY_TOKEN"],
}
)
start_time = time.monotonic()
start_time = time.time()
doc_count = 0
section_count = 0
text_count = 0
@@ -372,7 +221,7 @@ if __name__ == "__main__":
for section in doc.sections:
if isinstance(section, TextSection) and section.text is not None:
text_count += len(section.text)
end_time = time.monotonic()
end_time = time.time()
print(f"Doc count: {doc_count}")
print(f"Section count: {section_count}")

Some files were not shown because too many files have changed in this diff Show More