Compare commits

..

56 Commits

Author SHA1 Message Date
Weves
4d1d5bdcfe tweak worker count 2025-02-16 14:51:10 -08:00
Weves
0d468b49a1 Final fixes 2025-02-16 14:50:51 -08:00
Weves
67b87ced39 fix paths 2025-02-16 14:23:51 -08:00
Weves
8b4e4a6c80 Fix paths 2025-02-16 14:22:12 -08:00
Weves
e26bcf5a05 Misc fixes 2025-02-16 14:02:16 -08:00
Weves
435959cf90 test 2025-02-16 14:02:16 -08:00
Weves
fcbe305dc0 test 2025-02-16 14:02:16 -08:00
Weves
6f13d44564 move 2025-02-16 14:02:16 -08:00
Weves
c1810a35cd Fix redis port 2025-02-16 14:02:16 -08:00
Weves
4003e7346a test 2025-02-16 14:02:16 -08:00
Weves
8057f1eb0d Add logging 2025-02-16 14:02:16 -08:00
Weves
7eebd3cff1 test not removing files 2025-02-16 14:02:16 -08:00
Weves
bac2aeb8b7 Fix 2025-02-16 14:02:16 -08:00
Weves
9831697acc Make migrations work 2025-02-16 14:02:15 -08:00
Weves
5da766dd3b testing 2025-02-16 14:01:50 -08:00
Weves
180608694a improvements 2025-02-16 14:01:49 -08:00
Weves
96b92edfdb Parallelize IT
Parallelization

Full draft of first pass

Adjsut test name

test

test

Fix

Update cmd

test

Fix

test

Test with all tests

Resource bump + limit num parallel runs

Add retries
2025-02-16 14:01:12 -08:00
pablonyx
ec0e55fd39 Seeding count issue (#4009)
* k

* k

* quick nit

* nit
2025-02-16 20:49:25 +00:00
pablonyx
e441c899af Playwright + Chromatic update (#4015) 2025-02-16 13:03:45 -08:00
Chris Weaver
f1fc8ac19b Connector checkpointing (#3876)
* wip checkpointing/continue on failure

more stuff for checkpointing

Basic implementation

FE stuff

More checkpointing/failure handling

rebase

rebase

initial scaffolding for IT

IT to test checkpointing

Cleanup

cleanup

Fix it

Rebase

Add todo

Fix actions IT

Test more

Pagination + fixes + cleanup

Fix IT networking

fix it

* rebase

* Address misc comments

* Address comments

* Remove unused router

* rebase

* Fix mypy

* Fixes

* fix it

* Fix tests

* Add drop index

* Add retries

* reset lock timeout

* Try hard drop of schema

* Add timeout/retries to downgrade

* rebase

* test

* test

* test

* Close all connections

* test closing idle only

* Fix it

* fix

* try using null pool

* Test

* fix

* rebase

* log

* Fix

* apply null pool

* Fix other test

* Fix quality checks

* Test not using the fixture

* Fix ordering

* fix test

* Change pooling behavior
2025-02-16 02:34:39 +00:00
Weves
bc087fc20e Fix ruff 2025-02-15 16:35:15 -08:00
Yuhong Sun
ab8081c36b k 2025-02-15 13:42:43 -08:00
Adam Siemiginowski
f371efc916 Fix Zulip connector schema + links and enable temporal metadata (#4005) 2025-02-15 11:49:41 -08:00
pablonyx
7fd5d31dbe Minor background process log cleanup (#4010) 2025-02-15 11:03:10 -08:00
rkuo-danswer
2829e6715e Feature/propagate exceptions (#3974)
* better propagation of exceptions up the stack

* remove debug testing

* refactor the watchdog more to emit data consistently at the end of the function

* enumerate a lot more terminal statuses

* handle more codes

* improve logging

* handle "-9"

* single line exception logging

* typo/grammar

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-15 04:53:01 +00:00
Weves
bc7b4ec396 Fix typing for metadata 2025-02-14 18:19:37 -08:00
pablonyx
697f8bc1c6 Reduce background errors (#4004) 2025-02-14 17:35:26 -08:00
evan-danswer
3ba65214b8 bump version and fix related issues (#3996) 2025-02-14 19:57:12 +00:00
joachim-danswer
6687d5d499 major Agent Search Updates (#3994) 2025-02-14 19:40:21 +00:00
pablonyx
ec78f78f3c k (#3999) 2025-02-14 02:33:42 +00:00
rkuo-danswer
ed253e469a add nano and vim to base image (#3995)
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-14 02:27:24 +00:00
pablodanswer
e3aafd95af k 2025-02-13 18:34:05 -08:00
Weves
3a704f1950 Add new vars to github action 2025-02-13 18:33:17 -08:00
Weves
2bf8a7aee5 Misc improvements 2025-02-13 18:33:17 -08:00
Weves
c2f3302aa0 Fix mypy 2025-02-13 18:33:17 -08:00
neo773
7f4d1f27a0 Gitbook connector (#3991)
* add parser

* add tests
2025-02-13 17:58:05 -08:00
pablonyx
b70db15622 Bugfix Vespa Deletion Script (#3998) 2025-02-13 17:26:04 -08:00
pablonyx
e9492ce9ec minor read replica fix (#3997) 2025-02-13 17:11:45 -08:00
pablodanswer
35574369ed update cloud build to use public stripe key 2025-02-13 16:55:56 -08:00
pablonyx
eff433bdc5 Reduce errors in workers (#3962) 2025-02-13 15:59:44 -08:00
pablonyx
3260d793d1 Billing fixes (#3976) 2025-02-13 15:59:10 -08:00
Yuhong Sun
1a7aca06b9 Fix Agent Slowness (#3979) 2025-02-13 15:54:34 -08:00
pablonyx
c6434db7eb Add delete all for tenants in Vespa (#3970) 2025-02-13 14:33:49 -08:00
joachim-danswer
667b9e04c5 updated rerank function arguments (#3988) 2025-02-13 14:13:14 -08:00
rkuo-danswer
29c84d7707 xfail this test (#3992)
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-13 14:09:15 -08:00
pablonyx
17c915b11b Improved email formatting (#3985)
* prettier emails

* k

* remove mislieading comment

* minor typing
2025-02-13 21:11:57 +00:00
rkuo-danswer
95ca592d6d fix title check (#3993)
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-13 13:14:55 -08:00
Yuhong Sun
e39a27fd6b Hope this actually skips the model server builds now (#3987) 2025-02-13 11:48:25 -08:00
rkuo-danswer
26d3c952c6 Bugfix/jira connector test 2 (#3986)
* fix jira connector test

* typo fix

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-13 10:21:54 -08:00
rkuo-danswer
53683e2f3c fix jira connector test (#3983)
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-13 09:41:45 -08:00
rkuo-danswer
0c0113a481 ignore result when using send_task on lightweight tasks (#3978)
* ignore result when using send_task on lightweight tasks

* fix ignore_result

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-02-13 03:22:13 -08:00
Chris Weaver
c0f381e471 Add background errors ability (#3982) 2025-02-13 00:44:55 -08:00
rkuo-danswer
5ed83f1148 no thread local locks in callbacks and raise permission sync timeout … (#3977)
* no thread local locks in callbacks and raise permission sync timeout by a lot based on empirical log observations

* more fixes

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-12 22:31:01 -08:00
pablonyx
9db7b67a6c Minor misc ux improvements (#3966)
* minor misc ux

* nit

* k

* quick nit

* k
2025-02-13 04:43:11 +00:00
Yuhong Sun
2850048c6b Jira add key to semantic id (#3981) 2025-02-12 20:04:47 -08:00
rkuo-danswer
61058e5fcd merge monitoring with kickoff tasks (#3953)
* move indexing

* all monitor work moved

* reacquire lock more

* remove monitor task completely

* fix import

* fix pruning finalization

* no multiplier on system/cloud tasks

* monitor queues every 30 seconds in the cloud

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-13 02:35:41 +00:00
219 changed files with 10063 additions and 3038 deletions

View File

@@ -65,6 +65,7 @@ jobs:
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
NEXT_PUBLIC_GTM_ENABLED=true
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true

View File

@@ -4,9 +4,6 @@ on:
push:
tags:
- "*"
paths:
- 'backend/model_server/**'
- 'backend/Dockerfile.model_server'
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
@@ -15,7 +12,32 @@ env:
BUILDKIT_PROGRESS: plain
jobs:
# 1) Preliminary job to check if the changed files are relevant
check_model_server_changes:
runs-on: ubuntu-latest
outputs:
changed: ${{ steps.check.outputs.changed }}
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Check if relevant files changed
id: check
run: |
# Default to "false"
echo "changed=false" >> $GITHUB_OUTPUT
# Compare the previous commit (github.event.before) to the current one (github.sha)
# If any file in backend/model_server/** or backend/Dockerfile.model_server is changed,
# set changed=true
if git diff --name-only ${{ github.event.before }} ${{ github.sha }} \
| grep -E '^backend/model_server/|^backend/Dockerfile.model_server'; then
echo "changed=true" >> $GITHUB_OUTPUT
fi
build-amd64:
needs: [check_model_server_changes]
if: needs.check_model_server_changes.outputs.changed == 'true'
runs-on:
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-amd64"]
steps:
@@ -55,6 +77,8 @@ jobs:
provenance: false
build-arm64:
needs: [check_model_server_changes]
if: needs.check_model_server_changes.outputs.changed == 'true'
runs-on:
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-arm64"]
steps:
@@ -94,7 +118,8 @@ jobs:
provenance: false
merge-and-scan:
needs: [build-amd64, build-arm64]
needs: [build-amd64, build-arm64, check_model_server_changes]
if: needs.check_model_server_changes.outputs.changed == 'true'
runs-on: ubuntu-latest
steps:
- name: Login to Docker Hub

View File

@@ -0,0 +1,153 @@
name: Run Integration Tests v3
concurrency:
group: Run-Integration-Tests-Parallel-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
merge_group:
pull_request:
branches:
- main
- "release/**"
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
jobs:
integration-tests:
# See https://runs-on.com/runners/linux/
runs-on:
[runs-on, runner=32cpu-linux-x64, ram=64, "run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build integration test Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/tests/integration/Dockerfile
platforms: linux/amd64
tags: danswer/danswer-integration:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration-parallel/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration-parallel/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Run Standard Integration Tests
run: |
# Print a message indicating that tests are starting
echo "Running integration tests..."
# Create a directory for test logs that will be mounted into the container
mkdir -p ${{ github.workspace }}/test_logs
chmod 777 ${{ github.workspace }}/test_logs
# Run the integration tests in a Docker container
# Mount the Docker socket to allow Docker-in-Docker (DinD)
# Mount the test_logs directory to capture logs
# Use host network for easier communication with other services
docker run \
-v /var/run/docker.sock:/var/run/docker.sock \
-v ${{ github.workspace }}/test_logs:/tmp \
--network host \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
danswer/danswer-integration:test \
python /app/tests/integration/run.py
continue-on-error: true
id: run_tests
- name: Check test results
run: |
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
echo "Integration tests failed. Exiting with error."
exit 1
else
echo "All integration tests passed successfully."
fi
- name: Collect log files
if: success() || failure()
run: |
# Create a directory for logs
mkdir -p ${{ github.workspace }}/logs
mkdir -p ${{ github.workspace }}/logs/shared_services
# Copy all relevant log files from the mounted directory
cp ${{ github.workspace }}/test_logs/api_server_*.txt ${{ github.workspace }}/logs/ || true
cp ${{ github.workspace }}/test_logs/background_*.txt ${{ github.workspace }}/logs/ || true
cp ${{ github.workspace }}/test_logs/shared_model_server.txt ${{ github.workspace }}/logs/ || true
# Collect logs from shared services (Docker containers)
# Note: using a wildcard for the UUID part of the stack name
docker ps -a --filter "name=base-onyx-" --format "{{.Names}}" | while read container; do
echo "Collecting logs from $container"
docker logs $container > "${{ github.workspace }}/logs/shared_services/${container}.log" 2>&1 || true
done
# Also collect Redis container logs
docker ps -a --filter "name=redis-onyx-" --format "{{.Names}}" | while read container; do
echo "Collecting logs from $container"
docker logs $container > "${{ github.workspace }}/logs/shared_services/${container}.log" 2>&1 || true
done
# List collected logs
echo "Collected log files:"
ls -l ${{ github.workspace }}/logs/
echo "Collected shared services logs:"
ls -l ${{ github.workspace }}/logs/shared_services/
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@v4
with:
name: integration-test-logs
path: |
${{ github.workspace }}/logs/
${{ github.workspace }}/logs/shared_services/
retention-days: 5
# save before stopping the containers so the logs can be captured
# - name: Save Docker logs
# if: success() || failure()
# run: |
# cd deployment/docker_compose
# docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
# mv docker-compose.log ${{ github.workspace }}/docker-compose.log
# - name: Stop Docker containers
# run: |
# cd deployment/docker_compose
# docker compose -f docker-compose.dev.yml -p danswer-stack down -v
# - name: Upload logs
# if: success() || failure()
# uses: actions/upload-artifact@v4
# with:
# name: docker-logs
# path: ${{ github.workspace }}/docker-compose.log
# - name: Stop Docker containers
# run: |
# cd deployment/docker_compose
# docker compose -f docker-compose.dev.yml -p danswer-stack down -v

View File

@@ -5,10 +5,10 @@ concurrency:
on:
merge_group:
pull_request:
branches:
- main
- "release/**"
# pull_request:
# branches:
# - main
# - "release/**"
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@@ -99,7 +99,7 @@ jobs:
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
DEV_MODE=true \
docker compose -f docker-compose.multitenant-dev.yml -p danswer-stack up -d
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack up -d
id: start_docker_multi_tenant
# In practice, `cloud` Auth type would require OAUTH credentials to be set.
@@ -108,12 +108,13 @@ jobs:
echo "Waiting for 3 minutes to ensure API server is ready..."
sleep 180
echo "Running integration tests..."
docker run --rm --network danswer-stack_default \
docker run --rm --network onyx-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
@@ -143,24 +144,27 @@ jobs:
- name: Stop multi-tenant Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.multitenant-dev.yml -p danswer-stack down -v
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack down -v
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
- name: Start Docker containers
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
AUTH_TYPE=basic \
POSTGRES_POOL_PRE_PING=true \
POSTGRES_USE_NULL_POOL=true \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
docker logs -f danswer-stack-api_server-1 &
docker logs -f onyx-stack-api_server-1 &
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
@@ -190,15 +194,24 @@ jobs:
done
echo "Finished waiting for service."
- name: Start Mock Services
run: |
cd backend/tests/integration/mock_services
docker compose -f docker-compose.mock-it-services.yml \
-p mock-it-services-stack up -d
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
- name: Run Standard Integration Tests
run: |
echo "Running integration tests..."
docker run --rm --network danswer-stack_default \
docker run --rm --network onyx-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
@@ -208,6 +221,8 @@ jobs:
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
onyxdotapp/onyx-integration:test \
/app/tests/integration/tests \
/app/tests/integration/connector_job_tests
@@ -229,13 +244,13 @@ jobs:
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
- name: Dump all-container logs (optional)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
@@ -249,4 +264,4 @@ jobs:
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
docker compose -f docker-compose.dev.yml -p onyx-stack down -v

View File

@@ -1,6 +1,6 @@
name: Run Chromatic Tests
name: Run Playwright Tests
concurrency:
group: Run-Chromatic-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
group: Run-Playwright-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on: push
@@ -198,43 +198,47 @@ jobs:
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
chromatic-tests:
name: Chromatic Tests
# NOTE: Chromatic UI diff testing is currently disabled.
# We are using Playwright for local and CI testing without visual regression checks.
# Chromatic may be reintroduced in the future for UI diff testing if needed.
needs: playwright-tests
runs-on:
[
runs-on,
runner=32cpu-linux-x64,
disk=large,
"run-id=${{ github.run_id }}",
]
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
# chromatic-tests:
# name: Chromatic Tests
- name: Setup node
uses: actions/setup-node@v4
with:
node-version: 22
# needs: playwright-tests
# runs-on:
# [
# runs-on,
# runner=32cpu-linux-x64,
# disk=large,
# "run-id=${{ github.run_id }}",
# ]
# steps:
# - name: Checkout code
# uses: actions/checkout@v4
# with:
# fetch-depth: 0
- name: Install node dependencies
working-directory: ./web
run: npm ci
# - name: Setup node
# uses: actions/setup-node@v4
# with:
# node-version: 22
- name: Download Playwright test results
uses: actions/download-artifact@v4
with:
name: test-results
path: ./web/test-results
# - name: Install node dependencies
# working-directory: ./web
# run: npm ci
- name: Run Chromatic
uses: chromaui/action@latest
with:
playwright: true
projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
workingDir: ./web
env:
CHROMATIC_ARCHIVE_LOCATION: ./test-results
# - name: Download Playwright test results
# uses: actions/download-artifact@v4
# with:
# name: test-results
# path: ./web/test-results
# - name: Run Chromatic
# uses: chromaui/action@latest
# with:
# playwright: true
# projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
# workingDir: ./web
# env:
# CHROMATIC_ARCHIVE_LOCATION: ./test-results

View File

@@ -44,6 +44,9 @@ env:
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
# Gitbook
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}
jobs:
connectors-check:

View File

@@ -205,7 +205,7 @@
"--loglevel=INFO",
"--hostname=light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup",
],
"presentation": {
"group": "2",

View File

@@ -35,7 +35,9 @@ RUN apt-get update && \
libuuid1=2.38.1-5+deb12u1 \
libxmlsec1-dev \
pkg-config \
gcc && \
gcc \
nano \
vim && \
rm -rf /var/lib/apt/lists/* && \
apt-get clean

View File

@@ -1,6 +1,6 @@
from typing import Any, Literal
from onyx.db.engine import get_iam_auth_token
from onyx.configs.app_configs import USE_IAM_AUTH
from onyx.db.engine import SYNC_DB_API, get_iam_auth_token
from onyx.configs.app_configs import POSTGRES_DB, USE_IAM_AUTH
from onyx.configs.app_configs import POSTGRES_HOST
from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USER
@@ -13,12 +13,11 @@ from sqlalchemy import text
from sqlalchemy.engine.base import Connection
import os
import ssl
import asyncio
import logging
from logging.config import fileConfig
from alembic import context
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy import create_engine
from sqlalchemy.sql.schema import SchemaItem
from onyx.configs.constants import SSL_CERT_FILE
from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
@@ -133,17 +132,32 @@ def provide_iam_token_for_alembic(
cparams["ssl"] = ssl_context
async def run_async_migrations() -> None:
def run_migrations() -> None:
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
engine = create_async_engine(
build_connection_string(),
# Get any environment variables passed through alembic config
env_vars = context.config.attributes.get("env_vars", {})
# Use env vars if provided, otherwise fall back to defaults
postgres_host = env_vars.get("POSTGRES_HOST", POSTGRES_HOST)
postgres_port = env_vars.get("POSTGRES_PORT", POSTGRES_PORT)
postgres_user = env_vars.get("POSTGRES_USER", POSTGRES_USER)
postgres_db = env_vars.get("POSTGRES_DB", POSTGRES_DB)
engine = create_engine(
build_connection_string(
db=postgres_db,
user=postgres_user,
host=postgres_host,
port=postgres_port,
db_api=SYNC_DB_API,
),
poolclass=pool.NullPool,
)
if USE_IAM_AUTH:
@event.listens_for(engine.sync_engine, "do_connect")
@event.listens_for(engine, "do_connect")
def event_provide_iam_token_for_alembic(
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
) -> None:
@@ -152,31 +166,26 @@ async def run_async_migrations() -> None:
if upgrade_all_tenants:
tenant_schemas = get_all_tenant_ids()
for schema in tenant_schemas:
if schema is None:
continue
try:
logger.info(f"Migrating schema: {schema}")
async with engine.connect() as connection:
await connection.run_sync(
do_run_migrations,
schema_name=schema,
create_schema=create_schema,
)
with engine.connect() as connection:
do_run_migrations(connection, schema, create_schema)
except Exception as e:
logger.error(f"Error migrating schema {schema}: {e}")
raise
else:
try:
logger.info(f"Migrating schema: {schema_name}")
async with engine.connect() as connection:
await connection.run_sync(
do_run_migrations,
schema_name=schema_name,
create_schema=create_schema,
)
with engine.connect() as connection:
do_run_migrations(connection, schema_name, create_schema)
except Exception as e:
logger.error(f"Error migrating schema {schema_name}: {e}")
raise
await engine.dispose()
engine.dispose()
def run_migrations_offline() -> None:
@@ -184,18 +193,18 @@ def run_migrations_offline() -> None:
url = build_connection_string()
if upgrade_all_tenants:
engine = create_async_engine(url)
engine = create_engine(url)
if USE_IAM_AUTH:
@event.listens_for(engine.sync_engine, "do_connect")
@event.listens_for(engine, "do_connect")
def event_provide_iam_token_for_alembic_offline(
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
) -> None:
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)
tenant_schemas = get_all_tenant_ids()
engine.sync_engine.dispose()
engine.dispose()
for schema in tenant_schemas:
logger.info(f"Migrating schema: {schema}")
@@ -230,7 +239,7 @@ def run_migrations_offline() -> None:
def run_migrations_online() -> None:
asyncio.run(run_async_migrations())
run_migrations()
if context.is_offline_mode():

View File

@@ -0,0 +1,124 @@
"""Add checkpointing/failure handling
Revision ID: b7a7eee5aa15
Revises: f39c5794c10a
Create Date: 2025-01-24 15:17:36.763172
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "b7a7eee5aa15"
down_revision = "f39c5794c10a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"index_attempt",
sa.Column("checkpoint_pointer", sa.String(), nullable=True),
)
op.add_column(
"index_attempt",
sa.Column("poll_range_start", sa.DateTime(timezone=True), nullable=True),
)
op.add_column(
"index_attempt",
sa.Column("poll_range_end", sa.DateTime(timezone=True), nullable=True),
)
op.create_index(
"ix_index_attempt_cc_pair_settings_poll",
"index_attempt",
[
"connector_credential_pair_id",
"search_settings_id",
"status",
sa.text("time_updated DESC"),
],
)
# Drop the old IndexAttemptError table
op.drop_index("index_attempt_id", table_name="index_attempt_errors")
op.drop_table("index_attempt_errors")
# Create the new version of the table
op.create_table(
"index_attempt_errors",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("index_attempt_id", sa.Integer(), nullable=False),
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=False),
sa.Column("document_id", sa.String(), nullable=True),
sa.Column("document_link", sa.String(), nullable=True),
sa.Column("entity_id", sa.String(), nullable=True),
sa.Column("failed_time_range_start", sa.DateTime(timezone=True), nullable=True),
sa.Column("failed_time_range_end", sa.DateTime(timezone=True), nullable=True),
sa.Column("failure_message", sa.Text(), nullable=False),
sa.Column("is_resolved", sa.Boolean(), nullable=False, default=False),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["index_attempt_id"],
["index_attempt.id"],
),
sa.ForeignKeyConstraint(
["connector_credential_pair_id"],
["connector_credential_pair.id"],
),
)
def downgrade() -> None:
op.execute("SET lock_timeout = '5s'")
# try a few times to drop the table, this has been observed to fail due to other locks
# blocking the drop
NUM_TRIES = 10
for i in range(NUM_TRIES):
try:
op.drop_table("index_attempt_errors")
break
except Exception as e:
if i == NUM_TRIES - 1:
raise e
print(f"Error dropping table: {e}. Retrying...")
op.execute("SET lock_timeout = DEFAULT")
# Recreate the old IndexAttemptError table
op.create_table(
"index_attempt_errors",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("index_attempt_id", sa.Integer(), nullable=True),
sa.Column("batch", sa.Integer(), nullable=True),
sa.Column("doc_summaries", postgresql.JSONB(), nullable=False),
sa.Column("error_msg", sa.Text(), nullable=True),
sa.Column("traceback", sa.Text(), nullable=True),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
),
sa.ForeignKeyConstraint(
["index_attempt_id"],
["index_attempt.id"],
),
)
op.create_index(
"index_attempt_id",
"index_attempt_errors",
["time_created"],
)
op.drop_index("ix_index_attempt_cc_pair_settings_poll")
op.drop_column("index_attempt", "checkpoint_pointer")
op.drop_column("index_attempt", "poll_range_start")
op.drop_column("index_attempt", "poll_range_end")

View File

@@ -0,0 +1,40 @@
"""Add background errors table
Revision ID: f39c5794c10a
Revises: 2cdeff6d8c93
Create Date: 2025-02-12 17:11:14.527876
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "f39c5794c10a"
down_revision = "2cdeff6d8c93"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"background_error",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("message", sa.String(), nullable=False),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("cc_pair_id", sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
["cc_pair_id"],
["connector_credential_pair.id"],
ondelete="CASCADE",
),
)
def downgrade() -> None:
op.drop_table("background_error")

View File

@@ -1,10 +1,10 @@
from datetime import timedelta
from typing import Any
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
from onyx.background.celery.tasks.beat_schedule import (
beat_system_tasks as base_beat_system_tasks,
beat_cloud_tasks as base_beat_system_tasks,
)
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
from onyx.background.celery.tasks.beat_schedule import (
beat_task_templates as base_beat_task_templates,
)

View File

@@ -77,3 +77,5 @@ POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")
ANONYMOUS_USER_COOKIE_NAME = "onyx_anonymous_user"
GATED_TENANTS_KEY = "gated_tenants"

View File

@@ -1,5 +1,6 @@
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
from onyx.background.error_logging import emit_background_error
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.confluence.utils import get_user_email_from_username__server
@@ -10,7 +11,7 @@ logger = setup_logger()
def _build_group_member_email_map(
confluence_client: OnyxConfluence,
confluence_client: OnyxConfluence, cc_pair_id: int
) -> dict[str, set[str]]:
group_member_emails: dict[str, set[str]] = {}
for user_result in confluence_client.paginated_cql_user_retrieval():
@@ -18,8 +19,11 @@ def _build_group_member_email_map(
user = user_result.get("user", {})
if not user:
logger.warning(f"user result missing user field: {user_result}")
msg = f"user result missing user field: {user_result}"
emit_background_error(msg, cc_pair_id=cc_pair_id)
logger.error(msg)
continue
email = user.get("email")
if not email:
# This field is only present in Confluence Server
@@ -32,7 +36,12 @@ def _build_group_member_email_map(
)
if not email:
# If we still don't have an email, skip this user
logger.warning(f"user result missing email field: {user_result}")
msg = f"user result missing email field: {user_result}"
if user.get("type") == "app":
logger.warning(msg)
else:
emit_background_error(msg, cc_pair_id=cc_pair_id)
logger.error(msg)
continue
all_users_groups: set[str] = set()
@@ -42,11 +51,18 @@ def _build_group_member_email_map(
group_member_emails.setdefault(group_id, set()).add(email)
all_users_groups.add(group_id)
if not group_member_emails:
logger.warning(f"No groups found for user with email: {email}")
if not all_users_groups:
msg = f"No groups found for user with email: {email}"
emit_background_error(msg, cc_pair_id=cc_pair_id)
logger.error(msg)
else:
logger.debug(f"Found groups {all_users_groups} for user with email {email}")
if not group_member_emails:
msg = "No groups found for any users."
emit_background_error(msg, cc_pair_id=cc_pair_id)
logger.error(msg)
return group_member_emails
@@ -61,6 +77,7 @@ def confluence_group_sync(
group_member_email_map = _build_group_member_email_map(
confluence_client=confluence_client,
cc_pair_id=cc_pair.id,
)
onyx_groups: list[ExternalUserGroup] = []
all_found_emails = set()

View File

@@ -5,7 +5,7 @@ from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.slack.connector import get_channels
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
from onyx.connectors.slack.connector import SlackPollConnector
from onyx.connectors.slack.connector import SlackConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -17,7 +17,7 @@ logger = setup_logger()
def _get_slack_document_ids_and_channels(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
) -> dict[str, list[str]]:
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
slack_connector.load_credentials(cc_pair.credential.credential_json)
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)

View File

@@ -83,6 +83,7 @@ def handle_search_request(
user=user,
llm=llm,
fast_llm=fast_llm,
skip_query_analysis=False,
db_session=db_session,
bypass_acl=False,
)

View File

@@ -18,11 +18,16 @@ from ee.onyx.server.tenants.anonymous_user_path import (
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import AnonymousUserPath
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ImpersonateRequest
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import store_product_gating
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
@@ -39,12 +44,9 @@ from onyx.db.auth import get_user_count
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.engine import get_session_with_tenant
from onyx.db.notification import create_notification
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.server.settings.store import load_settings
from onyx.server.settings.store import store_settings
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
@@ -126,37 +128,29 @@ async def login_as_anonymous_user(
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
) -> None:
) -> ProductGatingResponse:
"""
Gating the product means that the product is not available to the tenant.
They will be directed to the billing page.
We gate the product when
1) User has ended free trial without adding payment method
2) User's card has declined
We gate the product when their subscription has ended.
"""
tenant_id = product_gating_request.tenant_id
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
store_product_gating(
product_gating_request.tenant_id, product_gating_request.application_status
)
return ProductGatingResponse(updated=True, error=None)
settings = load_settings()
settings.product_gating = product_gating_request.product_gating
store_settings(settings)
if product_gating_request.notification:
with get_session_with_tenant(tenant_id) as db_session:
create_notification(None, product_gating_request.notification, db_session)
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
except Exception as e:
logger.exception("Failed to gate product")
return ProductGatingResponse(updated=False, error=str(e))
@router.get("/billing-information", response_model=BillingInformation)
@router.get("/billing-information")
async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation:
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
return BillingInformation(
**fetch_billing_information(CURRENT_TENANT_ID_CONTEXTVAR.get())
)
return fetch_billing_information(CURRENT_TENANT_ID_CONTEXTVAR.get())
@router.post("/create-customer-portal-session")
@@ -169,9 +163,10 @@ async def create_customer_portal_session(_: User = Depends(current_admin_user))
if not stripe_customer_id:
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
logger.info(stripe_customer_id)
portal_session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=f"{WEB_DOMAIN}/admin/cloud-settings",
return_url=f"{WEB_DOMAIN}/admin/billing",
)
logger.info(portal_session)
return {"url": portal_session.url}
@@ -180,6 +175,20 @@ async def create_customer_portal_session(_: User = Depends(current_admin_user))
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-subscription-session")
async def create_subscription_session(
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
logger.exception("Failed to create resubscription session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/impersonate")
async def impersonate_user(
impersonate_request: ImpersonateRequest,

View File

@@ -6,6 +6,7 @@ import stripe
from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import BillingInformation
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.utils.logger import setup_logger
@@ -14,6 +15,19 @@ stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
def fetch_stripe_checkout_session(tenant_id: str) -> str:
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
url = f"{CONTROL_PLANE_API_BASE_URL}/create-checkout-session"
params = {"tenant_id": tenant_id}
response = requests.post(url, headers=headers, params=params)
response.raise_for_status()
return response.json()["sessionId"]
def fetch_tenant_stripe_information(tenant_id: str) -> dict:
token = generate_data_plane_token()
headers = {
@@ -27,7 +41,7 @@ def fetch_tenant_stripe_information(tenant_id: str) -> dict:
return response.json()
def fetch_billing_information(tenant_id: str) -> dict:
def fetch_billing_information(tenant_id: str) -> BillingInformation:
logger.info("Fetching billing information")
token = generate_data_plane_token()
headers = {
@@ -38,7 +52,7 @@ def fetch_billing_information(tenant_id: str) -> dict:
params = {"tenant_id": tenant_id}
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
billing_info = response.json()
billing_info = BillingInformation(**response.json())
return billing_info

View File

@@ -1,7 +1,8 @@
from datetime import datetime
from pydantic import BaseModel
from onyx.configs.constants import NotificationType
from onyx.server.settings.models import GatingType
from onyx.server.settings.models import ApplicationStatus
class CheckoutSessionCreationRequest(BaseModel):
@@ -15,15 +16,24 @@ class CreateTenantRequest(BaseModel):
class ProductGatingRequest(BaseModel):
tenant_id: str
product_gating: GatingType
notification: NotificationType | None = None
application_status: ApplicationStatus
class SubscriptionStatusResponse(BaseModel):
subscribed: bool
class BillingInformation(BaseModel):
stripe_subscription_id: str
status: str
current_period_start: datetime
current_period_end: datetime
number_of_seats: int
cancel_at_period_end: bool
canceled_at: datetime | None
trial_start: datetime | None
trial_end: datetime | None
seats: int
subscription_status: str
billing_start: str
billing_end: str
payment_method_enabled: bool
@@ -48,3 +58,12 @@ class TenantDeletionPayload(BaseModel):
class AnonymousUserPath(BaseModel):
anonymous_user_path: str | None
class ProductGatingResponse(BaseModel):
updated: bool
error: str | None
class SubscriptionSessionResponse(BaseModel):
sessionId: str

View File

@@ -0,0 +1,51 @@
from typing import cast
from ee.onyx.configs.app_configs import GATED_TENANTS_KEY
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.settings.models import ApplicationStatus
from onyx.server.settings.store import load_settings
from onyx.server.settings.store import store_settings
from onyx.setup import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
def update_tenant_gating(tenant_id: str, status: ApplicationStatus) -> None:
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
# Store the full status
status_key = f"tenant:{tenant_id}:status"
redis_client.set(status_key, status.value)
# Maintain the GATED_ACCESS set
if status == ApplicationStatus.GATED_ACCESS:
redis_client.sadd(GATED_TENANTS_KEY, tenant_id)
else:
redis_client.srem(GATED_TENANTS_KEY, tenant_id)
def store_product_gating(tenant_id: str, application_status: ApplicationStatus) -> None:
try:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
settings = load_settings()
settings.application_status = application_status
store_settings(settings)
# Store gated tenant information in Redis
update_tenant_gating(tenant_id, application_status)
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
except Exception:
logger.exception("Failed to gate product")
raise
def get_gated_tenants() -> set[str]:
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
return cast(set[str], redis_client.smembers(GATED_TENANTS_KEY))

View File

@@ -85,7 +85,7 @@ if __name__ == "__main__":
graph = basic_graph_builder()
compiled_graph = graph.compile()
input = BasicInput(_unused=True)
input = BasicInput(unused=True)
primary_llm, fast_llm = get_default_llms()
with get_session_context_manager() as db_session:
config, _ = get_test_config(

View File

@@ -17,7 +17,7 @@ from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
class BasicInput(BaseModel):
# Langgraph needs a nonempty input, but we pass in all static
# data through a RunnableConfig.
_unused: bool = True
unused: bool = True
## Graph Output State

View File

@@ -9,7 +9,6 @@ class CoreState(BaseModel):
This is the core state that is shared across all subgraphs.
"""
base_question: str = ""
log_messages: Annotated[list[str], add] = []
@@ -18,4 +17,4 @@ class SubgraphCoreState(BaseModel):
This is the core state that is shared across all subgraphs.
"""
log_messages: Annotated[list[str], add]
log_messages: Annotated[list[str], add] = []

View File

@@ -1,8 +1,8 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
@@ -12,14 +12,43 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer
SubQuestionAnswerCheckUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
binary_string_test,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_POSITIVE_VALUE_STR,
)
from onyx.agents.agent_search.shared_graph_utils.constants import AgentLLMErrorType
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import SUB_ANSWER_CHECK_PROMPT
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="LLM Timeout Error. The sub-answer will be treated as 'relevant'",
rate_limit="LLM Rate Limit Error. The sub-answer will be treated as 'relevant'",
general_error="General LLM Error. The sub-answer will be treated as 'relevant'",
)
@log_function_time(print_only=True)
def check_sub_answer(
state: AnswerQuestionState, config: RunnableConfig
) -> SubQuestionAnswerCheckUpdate:
@@ -53,14 +82,40 @@ def check_sub_answer(
graph_config = cast(GraphConfig, config["metadata"]["config"])
fast_llm = graph_config.tooling.fast_llm
response = list(
fast_llm.stream(
agent_error: AgentErrorLog | None = None
response: BaseMessage | None = None
try:
response = fast_llm.invoke(
prompt=msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK,
)
)
quality_str: str = merge_message_runs(response, chunk_separator="")[0].content
answer_quality = "yes" in quality_str.lower()
quality_str: str = cast(str, response.content)
answer_quality = binary_string_test(
text=quality_str, positive_value=AGENT_POSITIVE_VALUE_STR
)
log_result = f"Answer quality: {quality_str}"
except LLMTimeoutError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
answer_quality = True
log_result = agent_error.error_result
logger.error("LLM Timeout Error - check sub answer")
except LLMRateLimitError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
answer_quality = True
log_result = agent_error.error_result
logger.error("LLM Rate Limit Error - check sub answer")
return SubQuestionAnswerCheckUpdate(
answer_quality=answer_quality,
@@ -69,7 +124,7 @@ def check_sub_answer(
graph_component="initial - generate individual sub answer",
node_name="check sub answer",
node_start_time=node_start_time,
result=f"Answer quality: {quality_str}",
result=log_result,
)
],
)

View File

@@ -16,6 +16,23 @@ from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_sub_question_answer_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.calculations import (
dedup_sort_inference_section_list,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
LLM_ANSWER_ERROR_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import get_answer_citation_ids
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
@@ -30,12 +47,23 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import NO_RECOVERED_DOCS
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="LLM Timeout Error. A sub-answer could not be constructed and the sub-question will be ignored.",
rate_limit="LLM Rate Limit Error. A sub-answer could not be constructed and the sub-question will be ignored.",
general_error="General LLM Error. A sub-answer could not be constructed and the sub-question will be ignored.",
)
@log_function_time(print_only=True)
def generate_sub_answer(
state: AnswerQuestionState,
config: RunnableConfig,
@@ -51,12 +79,17 @@ def generate_sub_answer(
state.verified_reranked_documents
level, question_num = parse_question_id(state.question_id)
context_docs = state.context_documents[:AGENT_MAX_ANSWER_CONTEXT_DOCS]
context_docs = dedup_sort_inference_section_list(context_docs)
persona_contextualized_prompt = get_persona_agent_prompt_expressions(
graph_config.inputs.search_request.persona
).contextualized_prompt
if len(context_docs) == 0:
answer_str = NO_RECOVERED_DOCS
cited_documents: list = []
log_results = "No documents retrieved"
write_custom_event(
"sub_answers",
AgentAnswerPiece(
@@ -79,41 +112,67 @@ def generate_sub_answer(
response: list[str | list[str | dict[str, Any]]] = []
dispatch_timings: list[float] = []
for message in fast_llm.stream(
prompt=msg,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
agent_error: AgentErrorLog | None = None
try:
for message in fast_llm.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=content,
level=level,
level_question_num=question_num,
answer_type="agent_sub_answer",
),
writer,
)
start_stream_token = datetime.now()
write_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=content,
level=level,
level_question_num=question_num,
answer_type="agent_sub_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
response.append(content)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
response.append(content)
answer_str = merge_message_runs(response, chunk_separator="")[0].content
logger.debug(
f"Average dispatch time: {sum(dispatch_timings) / len(dispatch_timings)}"
)
except LLMTimeoutError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - generate sub answer")
except LLMRateLimitError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - generate sub answer")
answer_citation_ids = get_answer_citation_ids(answer_str)
cited_documents = [
context_docs[id] for id in answer_citation_ids if id < len(context_docs)
]
if agent_error:
answer_str = LLM_ANSWER_ERROR_MESSAGE
cited_documents = []
log_results = (
agent_error.error_result
or "Sub-answer generation failed due to LLM error"
)
else:
answer_str = merge_message_runs(response, chunk_separator="")[0].content
answer_citation_ids = get_answer_citation_ids(answer_str)
cited_documents = [
context_docs[id] for id in answer_citation_ids if id < len(context_docs)
]
log_results = None
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
@@ -131,7 +190,7 @@ def generate_sub_answer(
graph_component="initial - generate individual sub answer",
node_name="generate sub answer",
node_start_time=node_start_time,
result="",
result=log_results or "",
)
],
)

View File

@@ -42,10 +42,8 @@ class SubQuestionRetrievalIngestionUpdate(LoggerUpdate, BaseModel):
class SubQuestionAnsweringInput(SubgraphCoreState):
question: str = ""
question_id: str = (
"" # 0_0 is original question, everything else is <level>_<question_num>.
)
question: str
question_id: str
# level 0 is original question and first decomposition, level 1 is follow up, etc
# question_num is a unique number per original question per level.

View File

@@ -26,14 +26,31 @@ from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.calculations import (
get_answer_generation_documents,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
dedup_inference_section_list,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_deduplicated_structured_subquestion_documents,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
@@ -42,12 +59,16 @@ from onyx.agents.agent_search.shared_graph_utils.utils import remove_document_ci
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.context.search.models import InferenceSection
from onyx.prompts.agent_search import (
INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS,
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS
from onyx.prompts.agent_search import (
INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS,
)
@@ -56,8 +77,16 @@ from onyx.prompts.agent_search import (
)
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
from onyx.utils.timing import log_function_time
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="LLM Timeout Error. The initial answer could not be generated.",
rate_limit="LLM Rate Limit Error. The initial answer could not be generated.",
general_error="General LLM Error. The initial answer could not be generated.",
)
@log_function_time(print_only=True)
def generate_initial_answer(
state: SubQuestionRetrievalState,
config: RunnableConfig,
@@ -73,15 +102,19 @@ def generate_initial_answer(
question = graph_config.inputs.search_request.query
prompt_enrichment_components = get_prompt_enrichment_components(graph_config)
sub_questions_cited_documents = state.cited_documents
# get all documents cited in sub-questions
structured_subquestion_docs = get_deduplicated_structured_subquestion_documents(
state.sub_question_results
)
orig_question_retrieval_documents = state.orig_question_retrieved_documents
consolidated_context_docs: list[InferenceSection] = sub_questions_cited_documents
consolidated_context_docs = structured_subquestion_docs.cited_documents
counter = 0
for original_doc_number, original_doc in enumerate(
orig_question_retrieval_documents
):
if original_doc_number not in sub_questions_cited_documents:
if original_doc_number not in structured_subquestion_docs.cited_documents:
if (
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
or len(consolidated_context_docs) < AGENT_MAX_ANSWER_CONTEXT_DOCS
@@ -90,15 +123,18 @@ def generate_initial_answer(
counter += 1
# sort docs by their scores - though the scores refer to different questions
relevant_docs = dedup_inference_sections(
consolidated_context_docs, consolidated_context_docs
)
relevant_docs = dedup_inference_section_list(consolidated_context_docs)
sub_questions: list[str] = []
streamed_documents = (
relevant_docs
if len(relevant_docs) > 0
else state.orig_question_retrieved_documents[:15]
# Create the list of documents to stream out. Start with the
# ones that wil be in the context (or, if len == 0, use docs
# that were retrieved for the original question)
answer_generation_documents = get_answer_generation_documents(
relevant_docs=relevant_docs,
context_documents=structured_subquestion_docs.context_documents,
original_question_docs=orig_question_retrieval_documents,
max_docs=AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER,
)
# Use the query info from the base document retrieval
@@ -108,11 +144,13 @@ def generate_initial_answer(
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
relevance_list = relevance_from_docs(relevant_docs)
relevance_list = relevance_from_docs(
answer_generation_documents.streaming_documents
)
for tool_response in yield_search_responses(
query=question,
reranked_sections=streamed_documents,
final_context_sections=streamed_documents,
reranked_sections=answer_generation_documents.streaming_documents,
final_context_sections=answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
@@ -128,7 +166,7 @@ def generate_initial_answer(
writer,
)
if len(relevant_docs) == 0:
if len(answer_generation_documents.context_documents) == 0:
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
@@ -194,7 +232,7 @@ def generate_initial_answer(
model = graph_config.tooling.fast_llm
doc_context = format_docs(relevant_docs)
doc_context = format_docs(answer_generation_documents.context_documents)
doc_context = trim_prompt_piece(
config=model.config,
prompt_piece=doc_context,
@@ -224,30 +262,82 @@ def generate_initial_answer(
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
dispatch_timings: list[float] = []
for message in model.stream(msg):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
agent_error: AgentErrorLog | None = None
try:
for message in model.stream(
msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
streamed_tokens.append(content)
except LLMTimeoutError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - generate initial answer")
except LLMRateLimitError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - generate initial answer")
if agent_error:
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=0,
level_question_num=0,
answer_type="agent_level_answer",
StreamingError(
error=AGENT_LLM_TIMEOUT_MESSAGE,
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
return InitialAnswerUpdate(
initial_answer=None,
answer_error=AgentErrorLog(
error_message=agent_error.error_message or "An LLM error occurred",
error_type=agent_error.error_type,
error_result=agent_error.error_result,
),
initial_agent_stats=None,
generated_sub_questions=sub_questions,
agent_base_end_time=None,
agent_base_metrics=None,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate initial answer",
node_name="generate initial answer",
node_start_time=node_start_time,
result=agent_error.error_result or "An LLM error occurred",
)
],
)
streamed_tokens.append(content)
logger.debug(
f"Average dispatch time for initial answer: {sum(dispatch_timings) / len(dispatch_timings)}"

View File

@@ -10,8 +10,10 @@ from onyx.agents.agent_search.deep_search.main.states import (
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.timing import log_function_time
@log_function_time(print_only=True)
def validate_initial_answer(
state: SubQuestionRetrievalState,
) -> InitialAnswerQualityUpdate:
@@ -25,7 +27,7 @@ def validate_initial_answer(
f"--------{node_start_time}--------Checking for base answer validity - for not set True/False manually"
)
verdict = True
verdict = True # not actually required as already streamed out. Refinement will do similar
return InitialAnswerQualityUpdate(
initial_answer_quality_eval=verdict,

View File

@@ -23,6 +23,8 @@ from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
@@ -33,17 +35,30 @@ from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH,
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH_ASSUMING_REFINEMENT,
)
from onyx.prompts.agent_search import (
INITIAL_QUESTION_DECOMPOSITION_PROMPT,
INITIAL_QUESTION_DECOMPOSITION_PROMPT_ASSUMING_REFINEMENT,
)
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="LLM Timeout Error. Sub-questions could not be generated.",
rate_limit="LLM Rate Limit Error. Sub-questions could not be generated.",
general_error="General LLM Error. Sub-questions could not be generated.",
)
@log_function_time(print_only=True)
def decompose_orig_question(
state: SubQuestionRetrievalState,
config: RunnableConfig,
@@ -85,15 +100,15 @@ def decompose_orig_question(
]
)
decomposition_prompt = (
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH.format(
question=question, sample_doc_str=sample_doc_str, history=history
)
decomposition_prompt = INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH_ASSUMING_REFINEMENT.format(
question=question, sample_doc_str=sample_doc_str, history=history
)
else:
decomposition_prompt = INITIAL_QUESTION_DECOMPOSITION_PROMPT.format(
question=question, history=history
decomposition_prompt = (
INITIAL_QUESTION_DECOMPOSITION_PROMPT_ASSUMING_REFINEMENT.format(
question=question, history=history
)
)
# Start decomposition
@@ -112,32 +127,42 @@ def decompose_orig_question(
)
# dispatches custom events for subquestion tokens, adding in subquestion ids.
streamed_tokens = dispatch_separated(
model.stream(msg),
dispatch_subquestion(0, writer),
sep_callback=dispatch_subquestion_sep(0, writer),
)
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.SUB_QUESTIONS,
level=0,
)
write_custom_event("stream_finished", stop_event, writer)
streamed_tokens: list[BaseMessage_Content] = []
deomposition_response = merge_content(*streamed_tokens)
try:
streamed_tokens = dispatch_separated(
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
),
dispatch_subquestion(0, writer),
sep_callback=dispatch_subquestion_sep(0, writer),
)
# this call should only return strings. Commenting out for efficiency
# assert [type(tok) == str for tok in streamed_tokens]
decomposition_response = merge_content(*streamed_tokens)
# use no-op cast() instead of str() which runs code
# list_of_subquestions = clean_and_parse_list_string(cast(str, response))
list_of_subqs = cast(str, deomposition_response).split("\n")
list_of_subqs = cast(str, decomposition_response).split("\n")
decomp_list: list[str] = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
initial_sub_questions = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
log_result = f"decomposed original question into {len(initial_sub_questions)} subquestions"
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.SUB_QUESTIONS,
level=0,
)
write_custom_event("stream_finished", stop_event, writer)
except LLMTimeoutError as e:
logger.error("LLM Timeout Error - decompose orig question")
raise e # fail loudly on this critical step
except LLMRateLimitError as e:
logger.error("LLM Rate Limit Error - decompose orig question")
raise e
return InitialQuestionDecompositionUpdate(
initial_sub_questions=decomp_list,
initial_sub_questions=initial_sub_questions,
agent_start_time=agent_start_time,
agent_refined_start_time=None,
agent_refined_end_time=None,
@@ -151,7 +176,7 @@ def decompose_orig_question(
graph_component="initial - generate sub answers",
node_name="decompose original question",
node_start_time=node_start_time,
result=f"decomposed original question into {len(decomp_list)} subquestions",
result=log_result,
)
],
)

View File

@@ -26,8 +26,8 @@ from onyx.agents.agent_search.deep_search.main.nodes.decide_refinement_need impo
from onyx.agents.agent_search.deep_search.main.nodes.extract_entities_terms import (
extract_entities_terms,
)
from onyx.agents.agent_search.deep_search.main.nodes.generate_refined_answer import (
generate_refined_answer,
from onyx.agents.agent_search.deep_search.main.nodes.generate_validate_refined_answer import (
generate_validate_refined_answer,
)
from onyx.agents.agent_search.deep_search.main.nodes.ingest_refined_sub_answers import (
ingest_refined_sub_answers,
@@ -126,8 +126,8 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
# Node to generate the refined answer
graph.add_node(
node="generate_refined_answer",
action=generate_refined_answer,
node="generate_validate_refined_answer",
action=generate_validate_refined_answer,
)
# Early node to extract the entities and terms from the initial answer,
@@ -215,11 +215,11 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
graph.add_edge(
start_key="ingest_refined_sub_answers",
end_key="generate_refined_answer",
end_key="generate_validate_refined_answer",
)
graph.add_edge(
start_key="generate_refined_answer",
start_key="generate_validate_refined_answer",
end_key="compare_answers",
)
graph.add_edge(
@@ -252,9 +252,7 @@ if __name__ == "__main__":
db_session, primary_llm, fast_llm, search_request
)
inputs = MainInput(
base_question=graph_config.inputs.search_request.query, log_messages=[]
)
inputs = MainInput(log_messages=[])
for thing in compiled_graph.stream(
input=inputs,

View File

@@ -1,6 +1,7 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
@@ -10,16 +11,51 @@ from onyx.agents.agent_search.deep_search.main.states import (
)
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
binary_string_test,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_POSITIVE_VALUE_STR,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import RefinedAnswerImprovement
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
INITIAL_REFINED_ANSWER_COMPARISON_PROMPT,
)
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="The LLM timed out, and the answers could not be compared.",
rate_limit="The LLM encountered a rate limit, and the answers could not be compared.",
general_error="The LLM encountered an error, and the answers could not be compared.",
)
_ANSWER_QUALITY_NOT_SUFFICIENT_MESSAGE = (
"Answer quality is not sufficient, so stay with the initial answer."
)
@log_function_time(print_only=True)
def compare_answers(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> InitialRefinedAnswerComparisonUpdate:
@@ -34,21 +70,75 @@ def compare_answers(
initial_answer = state.initial_answer
refined_answer = state.refined_answer
# if answer quality is not sufficient, then stay with the initial answer
if not state.refined_answer_quality:
write_custom_event(
"refined_answer_improvement",
RefinedAnswerImprovement(
refined_answer_improvement=False,
),
writer,
)
return InitialRefinedAnswerComparisonUpdate(
refined_answer_improvement_eval=False,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="compare answers",
node_start_time=node_start_time,
result=_ANSWER_QUALITY_NOT_SUFFICIENT_MESSAGE,
)
],
)
compare_answers_prompt = INITIAL_REFINED_ANSWER_COMPARISON_PROMPT.format(
question=question, initial_answer=initial_answer, refined_answer=refined_answer
)
msg = [HumanMessage(content=compare_answers_prompt)]
agent_error: AgentErrorLog | None = None
# Get the rewritten queries in a defined format
model = graph_config.tooling.fast_llm
resp: BaseMessage | None = None
refined_answer_improvement: bool | None = None
# no need to stream this
resp = model.invoke(msg)
try:
resp = model.invoke(
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
)
refined_answer_improvement = (
isinstance(resp.content, str) and "yes" in resp.content.lower()
)
except LLMTimeoutError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - compare answers")
# continue as True in this support step
except LLMRateLimitError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - compare answers")
# continue as True in this support step
if agent_error or resp is None:
refined_answer_improvement = True
if agent_error:
log_result = agent_error.error_result
else:
log_result = "An answer could not be generated."
else:
refined_answer_improvement = binary_string_test(
text=cast(str, resp.content),
positive_value=AGENT_POSITIVE_VALUE_STR,
)
log_result = f"Answer comparison: {refined_answer_improvement}"
write_custom_event(
"refined_answer_improvement",
@@ -65,7 +155,7 @@ def compare_answers(
graph_component="main",
node_name="compare answers",
node_start_time=node_start_time,
result=f"Answer comparison: {refined_answer_improvement}",
result=log_result,
)
],
)

View File

@@ -21,6 +21,18 @@ from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import (
format_entity_term_extraction,
@@ -30,12 +42,31 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
)
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
REFINEMENT_QUESTION_DECOMPOSITION_PROMPT,
REFINEMENT_QUESTION_DECOMPOSITION_PROMPT_W_INITIAL_SUBQUESTION_ANSWERS,
)
from onyx.tools.models import ToolCallKickoff
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time
logger = setup_logger()
_ANSWERED_SUBQUESTIONS_DIVIDER = "\n\n---\n\n"
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="The LLM timed out. The sub-questions could not be generated.",
rate_limit="The LLM encountered a rate limit. The sub-questions could not be generated.",
general_error="The LLM encountered an error. The sub-questions could not be generated.",
)
@log_function_time(print_only=True)
def create_refined_sub_questions(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> RefinedQuestionDecompositionUpdate:
@@ -72,8 +103,10 @@ def create_refined_sub_questions(
initial_question_answers = state.sub_question_results
addressed_question_list = [
x.question for x in initial_question_answers if x.verified_high_quality
addressed_subquestions_with_answers = [
f"Subquestion: {x.question}\nSubanswer:\n{x.answer}"
for x in initial_question_answers
if x.verified_high_quality and x.answer
]
failed_question_list = [
@@ -82,12 +115,14 @@ def create_refined_sub_questions(
msg = [
HumanMessage(
content=REFINEMENT_QUESTION_DECOMPOSITION_PROMPT.format(
content=REFINEMENT_QUESTION_DECOMPOSITION_PROMPT_W_INITIAL_SUBQUESTION_ANSWERS.format(
question=question,
history=history,
entity_term_extraction_str=entity_term_extraction_str,
base_answer=base_answer,
answered_sub_questions="\n - ".join(addressed_question_list),
answered_subquestions_with_answers=_ANSWERED_SUBQUESTIONS_DIVIDER.join(
addressed_subquestions_with_answers
),
failed_sub_questions="\n - ".join(failed_question_list),
),
)
@@ -96,29 +131,65 @@ def create_refined_sub_questions(
# Grader
model = graph_config.tooling.fast_llm
streamed_tokens = dispatch_separated(
model.stream(msg),
dispatch_subquestion(1, writer),
sep_callback=dispatch_subquestion_sep(1, writer),
)
response = merge_content(*streamed_tokens)
agent_error: AgentErrorLog | None = None
streamed_tokens: list[BaseMessage_Content] = []
try:
streamed_tokens = dispatch_separated(
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
),
dispatch_subquestion(1, writer),
sep_callback=dispatch_subquestion_sep(1, writer),
)
except LLMTimeoutError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - create refined sub questions")
if isinstance(response, str):
parsed_response = [q for q in response.split("\n") if q.strip() != ""]
else:
raise ValueError("LLM response is not a string")
except LLMRateLimitError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - create refined sub questions")
refined_sub_question_dict = {}
for sub_question_num, sub_question in enumerate(parsed_response):
refined_sub_question = RefinementSubQuestion(
sub_question=sub_question,
sub_question_id=make_question_id(1, sub_question_num + 1),
verified=False,
answered=False,
answer="",
if agent_error:
refined_sub_question_dict: dict[int, RefinementSubQuestion] = {}
log_result = agent_error.error_result
write_custom_event(
"refined_sub_question_creation_error",
StreamingError(
error="Your LLM was not able to create refined sub questions in time and timed out. Please try again.",
),
writer,
)
refined_sub_question_dict[sub_question_num + 1] = refined_sub_question
else:
response = merge_content(*streamed_tokens)
if isinstance(response, str):
parsed_response = [q for q in response.split("\n") if q.strip() != ""]
else:
raise ValueError("LLM response is not a string")
refined_sub_question_dict = {}
for sub_question_num, sub_question in enumerate(parsed_response):
refined_sub_question = RefinementSubQuestion(
sub_question=sub_question,
sub_question_id=make_question_id(1, sub_question_num + 1),
verified=False,
answered=False,
answer="",
)
refined_sub_question_dict[sub_question_num + 1] = refined_sub_question
log_result = f"Created {len(refined_sub_question_dict)} refined sub questions"
return RefinedQuestionDecompositionUpdate(
refined_sub_questions=refined_sub_question_dict,
@@ -128,7 +199,7 @@ def create_refined_sub_questions(
graph_component="main",
node_name="create refined sub questions",
node_start_time=node_start_time,
result=f"Created {len(refined_sub_question_dict)} refined sub questions",
result=log_result,
)
],
)

View File

@@ -11,8 +11,10 @@ from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.utils.timing import log_function_time
@log_function_time(print_only=True)
def decide_refinement_need(
state: MainState, config: RunnableConfig
) -> RequireRefinemenEvalUpdate:
@@ -26,6 +28,19 @@ def decide_refinement_need(
decision = True # TODO: just for current testing purposes
if state.answer_error:
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=False,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="decide refinement need",
node_start_time=node_start_time,
result="Timeout Error",
)
],
)
log_messages = [
get_langgraph_node_log_string(
graph_component="main",

View File

@@ -21,11 +21,16 @@ from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
)
from onyx.configs.constants import NUM_EXPLORATORY_DOCS
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE
from onyx.utils.timing import log_function_time
@log_function_time(print_only=True)
def extract_entities_terms(
state: MainState, config: RunnableConfig
) -> EntityTermExtractionUpdate:
@@ -81,6 +86,7 @@ def extract_entities_terms(
# Grader
llm_response = fast_llm.invoke(
prompt=msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
)
cleaned_response = (

View File

@@ -11,27 +11,49 @@ from onyx.agents.agent_search.deep_search.main.models import (
AgentRefinedMetrics,
)
from onyx.agents.agent_search.deep_search.main.operations import get_query_info
from onyx.agents.agent_search.deep_search.main.operations import logger
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.main.states import (
RefinedAnswerUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
binary_string_test_after_answer_separator,
)
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
get_prompt_enrichment_components,
)
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.models import InferenceSection
from onyx.agents.agent_search.shared_graph_utils.calculations import (
get_answer_generation_documents,
)
from onyx.agents.agent_search.shared_graph_utils.constants import AGENT_ANSWER_SEPARATOR
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_POSITIVE_VALUE_STR,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
dedup_inference_section_list,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_deduplicated_structured_subquestion_documents,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
@@ -43,26 +65,50 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
REFINED_ANSWER_PROMPT_W_SUB_QUESTIONS,
)
from onyx.prompts.agent_search import (
REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS,
)
from onyx.prompts.agent_search import (
REFINED_ANSWER_VALIDATION_PROMPT,
)
from onyx.prompts.agent_search import (
SUB_QUESTION_ANSWER_TEMPLATE_REFINED,
)
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="The LLM timed out. The refined answer could not be generated.",
rate_limit="The LLM encountered a rate limit. The refined answer could not be generated.",
general_error="The LLM encountered an error. The refined answer could not be generated.",
)
def generate_refined_answer(
@log_function_time(print_only=True)
def generate_validate_refined_answer(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> RefinedAnswerUpdate:
"""
LangGraph node to generate the refined answer.
LangGraph node to generate the refined answer and validate it.
"""
node_start_time = datetime.now()
@@ -76,19 +122,24 @@ def generate_refined_answer(
)
verified_reranked_documents = state.verified_reranked_documents
sub_questions_cited_documents = state.cited_documents
# get all documents cited in sub-questions
structured_subquestion_docs = get_deduplicated_structured_subquestion_documents(
state.sub_question_results
)
original_question_verified_documents = (
state.orig_question_verified_reranked_documents
)
original_question_retrieved_documents = state.orig_question_retrieved_documents
consolidated_context_docs: list[InferenceSection] = sub_questions_cited_documents
consolidated_context_docs = structured_subquestion_docs.cited_documents
counter = 0
for original_doc_number, original_doc in enumerate(
original_question_verified_documents
):
if original_doc_number not in sub_questions_cited_documents:
if original_doc_number not in structured_subquestion_docs.cited_documents:
if (
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
or len(consolidated_context_docs)
@@ -99,14 +150,16 @@ def generate_refined_answer(
counter += 1
# sort docs by their scores - though the scores refer to different questions
relevant_docs = dedup_inference_sections(
consolidated_context_docs, consolidated_context_docs
)
relevant_docs = dedup_inference_section_list(consolidated_context_docs)
streaming_docs = (
relevant_docs
if len(relevant_docs) > 0
else original_question_retrieved_documents[:15]
# Create the list of documents to stream out. Start with the
# ones that wil be in the context (or, if len == 0, use docs
# that were retrieved for the original question)
answer_generation_documents = get_answer_generation_documents(
relevant_docs=relevant_docs,
context_documents=structured_subquestion_docs.context_documents,
original_question_docs=original_question_retrieved_documents,
max_docs=AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER,
)
query_info = get_query_info(state.orig_question_sub_query_retrieval_results)
@@ -114,11 +167,13 @@ def generate_refined_answer(
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
# stream refined answer docs, or original question docs if no relevant docs are found
relevance_list = relevance_from_docs(relevant_docs)
relevance_list = relevance_from_docs(
answer_generation_documents.streaming_documents
)
for tool_response in yield_search_responses(
query=question,
reranked_sections=streaming_docs,
final_context_sections=streaming_docs,
reranked_sections=answer_generation_documents.streaming_documents,
final_context_sections=answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
@@ -199,7 +254,7 @@ def generate_refined_answer(
)
model = graph_config.tooling.fast_llm
relevant_docs_str = format_docs(relevant_docs)
relevant_docs_str = format_docs(answer_generation_documents.context_documents)
relevant_docs_str = trim_prompt_piece(
model.config,
relevant_docs_str,
@@ -231,28 +286,80 @@ def generate_refined_answer(
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
dispatch_timings: list[float] = []
for message in model.stream(msg):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
agent_error: AgentErrorLog | None = None
start_stream_token = datetime.now()
try:
for message in model.stream(
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"refined_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=1,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
streamed_tokens.append(content)
except LLMTimeoutError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - generate refined answer")
except LLMRateLimitError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - generate refined answer")
if agent_error:
write_custom_event(
"refined_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=1,
level_question_num=0,
answer_type="agent_level_answer",
"initial_agent_answer",
StreamingError(
error=AGENT_LLM_TIMEOUT_MESSAGE,
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
streamed_tokens.append(content)
return RefinedAnswerUpdate(
refined_answer=None,
refined_answer_quality=False, # TODO: replace this with the actual check value
refined_agent_stats=None,
agent_refined_end_time=None,
agent_refined_metrics=AgentRefinedMetrics(
refined_doc_boost_factor=0.0,
refined_question_boost_factor=0.0,
duration_s=None,
),
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="generate refined answer",
node_start_time=node_start_time,
result=agent_error.error_result or "An LLM error occurred",
)
],
)
logger.debug(
f"Average dispatch time for refined answer: {sum(dispatch_timings) / len(dispatch_timings)}"
@@ -261,54 +368,43 @@ def generate_refined_answer(
response = merge_content(*streamed_tokens)
answer = cast(str, response)
# run a validation step for the refined answer only
msg = [
HumanMessage(
content=REFINED_ANSWER_VALIDATION_PROMPT.format(
question=question,
history=prompt_enrichment_components.history,
answered_sub_questions=sub_question_answer_str,
relevant_docs=relevant_docs_str,
proposed_answer=answer,
persona_specification=persona_contextualized_prompt,
)
)
]
try:
validation_response = model.invoke(
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION
)
refined_answer_quality = binary_string_test_after_answer_separator(
text=cast(str, validation_response.content),
positive_value=AGENT_POSITIVE_VALUE_STR,
separator=AGENT_ANSWER_SEPARATOR,
)
except LLMTimeoutError:
refined_answer_quality = True
logger.error("LLM Timeout Error - validate refined answer")
except LLMRateLimitError:
refined_answer_quality = True
logger.error("LLM Rate Limit Error - validate refined answer")
refined_agent_stats = RefinedAgentStats(
revision_doc_efficiency=refined_doc_effectiveness,
revision_question_efficiency=revision_question_efficiency,
)
logger.debug(f"\n\n---INITIAL ANSWER ---\n\n Answer:\n Agent: {initial_answer}")
logger.debug("-" * 10)
logger.debug(f"\n\n---REVISED AGENT ANSWER ---\n\n Answer:\n Agent: {answer}")
logger.debug("-" * 100)
if state.initial_agent_stats:
initial_doc_boost_factor = state.initial_agent_stats.agent_effectiveness.get(
"utilized_chunk_ratio", "--"
)
initial_support_boost_factor = (
state.initial_agent_stats.agent_effectiveness.get("support_ratio", "--")
)
num_initial_verified_docs = state.initial_agent_stats.original_question.get(
"num_verified_documents", "--"
)
initial_verified_docs_avg_score = (
state.initial_agent_stats.original_question.get("verified_avg_score", "--")
)
initial_sub_questions_verified_docs = (
state.initial_agent_stats.sub_questions.get("num_verified_documents", "--")
)
logger.debug("INITIAL AGENT STATS")
logger.debug(f"Document Boost Factor: {initial_doc_boost_factor}")
logger.debug(f"Support Boost Factor: {initial_support_boost_factor}")
logger.debug(f"Originally Verified Docs: {num_initial_verified_docs}")
logger.debug(
f"Originally Verified Docs Avg Score: {initial_verified_docs_avg_score}"
)
logger.debug(
f"Sub-Questions Verified Docs: {initial_sub_questions_verified_docs}"
)
if refined_agent_stats:
logger.debug("-" * 10)
logger.debug("REFINED AGENT STATS")
logger.debug(
f"Revision Doc Factor: {refined_agent_stats.revision_doc_efficiency}"
)
logger.debug(
f"Revision Question Factor: {refined_agent_stats.revision_question_efficiency}"
)
agent_refined_end_time = datetime.now()
if state.agent_refined_start_time:
agent_refined_duration = (
@@ -325,7 +421,7 @@ def generate_refined_answer(
return RefinedAnswerUpdate(
refined_answer=answer,
refined_answer_quality=True, # TODO: replace this with the actual check value
refined_answer_quality=refined_answer_quality,
refined_agent_stats=refined_agent_stats,
agent_refined_end_time=agent_refined_end_time,
agent_refined_metrics=agent_refined_metrics,

View File

@@ -17,6 +17,7 @@ from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
)
@@ -76,6 +77,7 @@ class InitialAnswerUpdate(LoggerUpdate):
"""
initial_answer: str | None = None
answer_error: AgentErrorLog | None = None
initial_agent_stats: InitialAgentResultStats | None = None
generated_sub_questions: list[str] = []
agent_base_end_time: datetime | None = None
@@ -88,6 +90,7 @@ class RefinedAnswerUpdate(RefinedAgentEndStats, LoggerUpdate):
"""
refined_answer: str | None = None
answer_error: AgentErrorLog | None = None
refined_agent_stats: RefinedAgentStats | None = None
refined_answer_quality: bool = False

View File

@@ -16,16 +16,44 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
QueryExpansionUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
QUERY_REWRITING_PROMPT,
)
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="Query rewriting failed due to LLM timeout - the original question will be used.",
rate_limit="Query rewriting failed due to LLM rate limit - the original question will be used.",
general_error="Query rewriting failed due to LLM error - the original question will be used.",
)
@log_function_time(print_only=True)
def expand_queries(
state: ExpandedRetrievalInput,
config: RunnableConfig,
@@ -54,13 +82,43 @@ def expand_queries(
)
]
llm_response_list = dispatch_separated(
llm.stream(prompt=msg), dispatch_subquery(level, question_num, writer)
)
agent_error: AgentErrorLog | None = None
llm_response_list: list[BaseMessage_Content] = []
llm_response = ""
rewritten_queries = []
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
try:
llm_response_list = dispatch_separated(
llm.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
),
dispatch_subquery(level, question_num, writer),
)
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[
0
].content
rewritten_queries = llm_response.split("\n")
log_result = f"Number of expanded queries: {len(rewritten_queries)}"
rewritten_queries = llm_response.split("\n")
except LLMTimeoutError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
logger.error("LLM Timeout Error - expand queries")
log_result = agent_error.error_result
except LLMRateLimitError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
logger.error("LLM Rate Limit Error - expand queries")
log_result = agent_error.error_result
# use subquestion as query if query generation fails
return QueryExpansionUpdate(
expanded_queries=rewritten_queries,
@@ -69,7 +127,7 @@ def expand_queries(
graph_component="shared - expanded retrieval",
node_name="expand queries",
node_start_time=node_start_time,
result=f"Number of expanded queries: {len(rewritten_queries)}",
result=log_result,
)
],
)

View File

@@ -21,12 +21,15 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
from onyx.configs.agent_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS
from onyx.configs.agent_configs import AGENT_RERANKING_STATS
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import SearchRequest
from onyx.context.search.pipeline import retrieval_preprocessing
from onyx.context.search.models import RerankingDetails
from onyx.context.search.postprocessing.postprocessing import rerank_sections
from onyx.context.search.postprocessing.postprocessing import should_rerank
from onyx.db.engine import get_session_context_manager
from onyx.db.search_settings import get_current_search_settings
from onyx.utils.timing import log_function_time
@log_function_time(print_only=True)
def rerank_documents(
state: ExpandedRetrievalState, config: RunnableConfig
) -> DocRerankingUpdate:
@@ -39,6 +42,8 @@ def rerank_documents(
# Rerank post retrieval and verification. First, create a search query
# then create the list of reranked sections
# If no question defined/question is None in the state, use the original
# question from the search request as query
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = (
@@ -47,39 +52,28 @@ def rerank_documents(
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
with get_session_context_manager() as db_session:
# we ignore some of the user specified fields since this search is
# internal to agentic search, but we still want to pass through
# persona (for stuff like document sets) and rerank settings
# (to not make an unnecessary db call).
search_request = SearchRequest(
query=question,
persona=graph_config.inputs.search_request.persona,
rerank_settings=graph_config.inputs.search_request.rerank_settings,
)
_search_query = retrieval_preprocessing(
search_request=search_request,
user=graph_config.tooling.search_tool.user, # bit of a hack
llm=graph_config.tooling.fast_llm,
db_session=db_session,
)
# skip section filtering
# Note that these are passed in values from the API and are overrides which are typically None
rerank_settings = graph_config.inputs.search_request.rerank_settings
if (
_search_query.rerank_settings
and _search_query.rerank_settings.rerank_model_name
and _search_query.rerank_settings.num_rerank > 0
and len(verified_documents) > 0
):
if rerank_settings is None:
with get_session_context_manager() as db_session:
search_settings = get_current_search_settings(db_session)
if not search_settings.disable_rerank_for_streaming:
rerank_settings = RerankingDetails.from_db_model(search_settings)
if should_rerank(rerank_settings) and len(verified_documents) > 0:
if len(verified_documents) > 1:
reranked_documents = rerank_sections(
_search_query,
verified_documents,
query_str=question,
# if runnable, then rerank_settings is not None
rerank_settings=cast(RerankingDetails, rerank_settings),
sections_to_rerank=verified_documents,
)
else:
num = "No" if len(verified_documents) == 0 else "One"
logger.warning(f"{num} verified document(s) found, skipping reranking")
logger.warning(
f"{len(verified_documents)} verified document(s) found, skipping reranking"
)
reranked_documents = verified_documents
else:
logger.warning("No reranking settings found, using unranked documents")

View File

@@ -23,12 +23,15 @@ from onyx.configs.agent_configs import AGENT_RETRIEVAL_STATS
from onyx.context.search.models import InferenceSection
from onyx.db.engine import get_session_context_manager
from onyx.tools.models import SearchQueryInfo
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.utils.timing import log_function_time
@log_function_time(print_only=True)
def retrieve_documents(
state: RetrievalInput, config: RunnableConfig
) -> DocRetrievalUpdate:
@@ -67,9 +70,12 @@ def retrieve_documents(
with get_session_context_manager() as db_session:
for tool_response in search_tool.run(
query=query_to_retrieve,
force_no_rerank=True,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=True,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=not state.base_search,
),
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:

View File

@@ -1,5 +1,7 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.runnables.config import RunnableConfig
@@ -10,14 +12,38 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
DocVerificationUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
binary_string_test,
)
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_POSITIVE_VALUE_STR,
)
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
DOCUMENT_VERIFICATION_PROMPT,
)
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time
logger = setup_logger()
_llm_node_error_strings = LLMNodeErrorStrings(
timeout="The LLM timed out. The document could not be verified. The document will be treated as 'relevant'",
rate_limit="The LLM encountered a rate limit. The document could not be verified. The document will be treated as 'relevant'",
general_error="The LLM encountered an error. The document could not be verified. The document will be treated as 'relevant'",
)
@log_function_time(print_only=True)
def verify_documents(
state: DocVerificationInput, config: RunnableConfig
) -> DocVerificationUpdate:
@@ -26,12 +52,14 @@ def verify_documents(
Args:
state (DocVerificationInput): The current state
config (RunnableConfig): Configuration containing ProSearchConfig
config (RunnableConfig): Configuration containing AgentSearchConfig
Updates:
verified_documents: list[InferenceSection]
"""
node_start_time = datetime.now()
question = state.question
retrieved_document_to_verify = state.retrieved_document_to_verify
document_content = retrieved_document_to_verify.combined_content
@@ -51,12 +79,40 @@ def verify_documents(
)
]
response = fast_llm.invoke(msg)
response: BaseMessage | None = None
verified_documents = []
if isinstance(response.content, str) and "yes" in response.content.lower():
verified_documents.append(retrieved_document_to_verify)
verified_documents = [
retrieved_document_to_verify
] # default is to treat document as relevant
try:
response = fast_llm.invoke(
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
)
assert isinstance(response.content, str)
if not binary_string_test(
text=response.content, positive_value=AGENT_POSITIVE_VALUE_STR
):
verified_documents = []
except LLMTimeoutError:
# In this case, we decide to continue and don't raise an error, as
# little harm in letting some docs through that are less relevant.
logger.error("LLM Timeout Error - verify documents")
except LLMRateLimitError:
# In this case, we decide to continue and don't raise an error, as
# little harm in letting some docs through that are less relevant.
logger.error("LLM Rate Limit Error - verify documents")
return DocVerificationUpdate(
verified_documents=verified_documents,
log_messages=[
get_langgraph_node_log_string(
graph_component="shared - expanded retrieval",
node_name="verify documents",
node_start_time=node_start_time,
)
],
)

View File

@@ -21,9 +21,13 @@ from onyx.context.search.models import InferenceSection
class ExpandedRetrievalInput(SubgraphCoreState):
question: str = ""
base_search: bool = False
# exception from 'no default value'for LangGraph input states
# Here, sub_question_id default None implies usage for the
# original question. This is sometimes needed for nested sub-graphs
sub_question_id: str | None = None
question: str
base_search: bool
## Update/Return States
@@ -34,7 +38,7 @@ class QueryExpansionUpdate(LoggerUpdate, BaseModel):
log_messages: list[str] = []
class DocVerificationUpdate(BaseModel):
class DocVerificationUpdate(LoggerUpdate, BaseModel):
verified_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
@@ -88,4 +92,4 @@ class DocVerificationInput(ExpandedRetrievalInput):
class RetrievalInput(ExpandedRetrievalInput):
query_to_retrieve: str = ""
query_to_retrieve: str

View File

@@ -12,7 +12,7 @@ from onyx.agents.agent_search.deep_search.main.graph_builder import (
main_graph_builder as main_graph_builder_a,
)
from onyx.agents.agent_search.deep_search.main.states import (
MainInput as MainInput_a,
MainInput as MainInput,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
@@ -21,6 +21,7 @@ from onyx.chat.models import AnswerPacket
from onyx.chat.models import AnswerStream
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamingError
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionPiece
@@ -33,6 +34,7 @@ from onyx.llm.factory import get_default_llms
from onyx.tools.tool_runner import ToolCallKickoff
from onyx.utils.logger import setup_logger
logger = setup_logger()
_COMPILED_GRAPH: CompiledStateGraph | None = None
@@ -72,13 +74,15 @@ def _parse_agent_event(
return cast(AnswerPacket, event["data"])
elif event["name"] == "refined_answer_improvement":
return cast(RefinedAnswerImprovement, event["data"])
elif event["name"] == "refined_sub_question_creation_error":
return cast(StreamingError, event["data"])
return None
def manage_sync_streaming(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
graph_input: BasicInput | MainInput_a,
graph_input: BasicInput | MainInput,
) -> Iterable[StreamEvent]:
message_id = config.persistence.message_id if config.persistence else None
for event in compiled_graph.stream(
@@ -92,7 +96,7 @@ def manage_sync_streaming(
def run_graph(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
input: BasicInput | MainInput_a,
input: BasicInput | MainInput,
) -> AnswerStream:
config.behavior.perform_initial_search_decomposition = (
INITIAL_SEARCH_DECOMPOSITION_ENABLED
@@ -123,9 +127,7 @@ def run_main_graph(
) -> AnswerStream:
compiled_graph = load_compiled_graph()
input = MainInput_a(
base_question=config.inputs.search_request.query, log_messages=[]
)
input = MainInput(log_messages=[])
# Agent search is not a Tool per se, but this is helpful for the frontend
yield ToolCallKickoff(
@@ -140,7 +142,7 @@ def run_basic_graph(
) -> AnswerStream:
graph = basic_graph_builder()
compiled_graph = graph.compile()
input = BasicInput()
input = BasicInput(unused=True)
return run_graph(compiled_graph, config, input)
@@ -172,9 +174,7 @@ if __name__ == "__main__":
# search_request.persona = get_persona_by_id(1, None, db_session)
# config.perform_initial_search_path_decision = False
config.behavior.perform_initial_search_decomposition = True
input = MainInput_a(
base_question=config.inputs.search_request.query, log_messages=[]
)
input = MainInput(log_messages=[])
tool_responses: list = []
for output in run_graph(compiled_graph, config, input):

View File

@@ -7,6 +7,7 @@ from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.models import (
AgentPromptEnrichmentComponents,
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_persona_agent_prompt_expressions,
)
@@ -40,13 +41,7 @@ def build_sub_question_answer_prompt(
date_str = build_date_time_string()
# TODO: This should include document metadata and title
docs_format_list = [
f"Document Number: [D{doc_num + 1}]\nContent: {doc.combined_content}\n\n"
for doc_num, doc in enumerate(docs)
]
docs_str = "\n\n".join(docs_format_list)
docs_str = format_docs(docs)
docs_str = trim_prompt_piece(
config,
@@ -150,3 +145,38 @@ def get_prompt_enrichment_components(
history=history,
date_str=date_str,
)
def binary_string_test(text: str, positive_value: str = "yes") -> bool:
"""
Tests if a string contains a positive value (case-insensitive).
Args:
text: The string to test
positive_value: The value to look for (defaults to "yes")
Returns:
True if the positive value is found in the text
"""
return positive_value.lower() in text.lower()
def binary_string_test_after_answer_separator(
text: str, positive_value: str = "yes", separator: str = "Answer:"
) -> bool:
"""
Tests if a string contains a positive value (case-insensitive).
Args:
text: The string to test
positive_value: The value to look for (defaults to "yes")
Returns:
True if the positive value is found in the text
"""
if separator not in text:
return False
relevant_text = text.split(f"{separator}")[-1]
return binary_string_test(relevant_text, positive_value)

View File

@@ -1,7 +1,11 @@
import numpy as np
from onyx.agents.agent_search.shared_graph_utils.models import AnswerGenerationDocuments
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitScoreMetrics
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_section_list,
)
from onyx.chat.models import SectionRelevancePiece
from onyx.context.search.models import InferenceSection
from onyx.utils.logger import setup_logger
@@ -96,3 +100,106 @@ def get_fit_scores(
)
return fit_eval
def get_answer_generation_documents(
relevant_docs: list[InferenceSection],
context_documents: list[InferenceSection],
original_question_docs: list[InferenceSection],
max_docs: int,
) -> AnswerGenerationDocuments:
"""
Create a deduplicated list of documents to stream, prioritizing relevant docs.
Args:
relevant_docs: Primary documents to include
context_documents: Additional context documents to append
original_question_docs: Original question documents to append
max_docs: Maximum number of documents to return
Returns:
List of deduplicated documents, limited to max_docs
"""
# get relevant_doc ids
relevant_doc_ids = [doc.center_chunk.document_id for doc in relevant_docs]
# Start with relevant docs or fallback to original question docs
streaming_documents = relevant_docs.copy()
# Use a set for O(1) lookups of document IDs
seen_doc_ids = {doc.center_chunk.document_id for doc in streaming_documents}
# Combine additional documents to check in one iteration
additional_docs = context_documents + original_question_docs
for doc_idx, doc in enumerate(additional_docs):
doc_id = doc.center_chunk.document_id
if doc_id not in seen_doc_ids:
streaming_documents.append(doc)
seen_doc_ids.add(doc_id)
streaming_documents = dedup_inference_section_list(streaming_documents)
relevant_streaming_docs = [
doc
for doc in streaming_documents
if doc.center_chunk.document_id in relevant_doc_ids
]
relevant_streaming_docs = dedup_sort_inference_section_list(relevant_streaming_docs)
additional_streaming_docs = [
doc
for doc in streaming_documents
if doc.center_chunk.document_id not in relevant_doc_ids
]
additional_streaming_docs = dedup_sort_inference_section_list(
additional_streaming_docs
)
for doc in additional_streaming_docs:
if doc.center_chunk.score:
doc.center_chunk.score += -2.0
else:
doc.center_chunk.score = -2.0
sorted_streaming_documents = relevant_streaming_docs + additional_streaming_docs
return AnswerGenerationDocuments(
streaming_documents=sorted_streaming_documents[:max_docs],
context_documents=relevant_streaming_docs[:max_docs],
)
def dedup_sort_inference_section_list(
sections: list[InferenceSection],
) -> list[InferenceSection]:
"""Deduplicates InferenceSections by document_id and sorts by score.
Args:
sections: List of InferenceSections to deduplicate and sort
Returns:
Deduplicated list of InferenceSections sorted by score in descending order
"""
# dedupe/merge with existing framework
sections = dedup_inference_section_list(sections)
# Use dict to deduplicate by document_id, keeping highest scored version
unique_sections: dict[str, InferenceSection] = {}
for section in sections:
doc_id = section.center_chunk.document_id
if doc_id not in unique_sections:
unique_sections[doc_id] = section
continue
# Keep version with higher score
existing_score = unique_sections[doc_id].center_chunk.score or 0
new_score = section.center_chunk.score or 0
if new_score > existing_score:
unique_sections[doc_id] = section
# Sort by score in descending order, handling None scores
sorted_sections = sorted(
unique_sections.values(), key=lambda x: x.center_chunk.score or 0, reverse=True
)
return sorted_sections

View File

@@ -0,0 +1,19 @@
from enum import Enum
AGENT_LLM_TIMEOUT_MESSAGE = "The agent timed out. Please try again."
AGENT_LLM_ERROR_MESSAGE = "The agent encountered an error. Please try again."
AGENT_LLM_RATELIMIT_MESSAGE = (
"The agent encountered a rate limit error. Please try again."
)
LLM_ANSWER_ERROR_MESSAGE = "The question was not answered due to an LLM error."
AGENT_POSITIVE_VALUE_STR = "yes"
AGENT_NEGATIVE_VALUE_STR = "no"
AGENT_ANSWER_SEPARATOR = "Answer:"
class AgentLLMErrorType(str, Enum):
TIMEOUT = "timeout"
RATE_LIMIT = "rate_limit"
GENERAL_ERROR = "general_error"

View File

@@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel
from onyx.agents.agent_search.deep_search.main.models import (
@@ -56,6 +58,12 @@ class InitialAgentResultStats(BaseModel):
agent_effectiveness: dict[str, float | int | None]
class AgentErrorLog(BaseModel):
error_message: str
error_type: str
error_result: str
class RefinedAgentStats(BaseModel):
revision_doc_efficiency: float | None
revision_question_efficiency: float | None
@@ -110,6 +118,11 @@ class SubQuestionAnswerResults(BaseModel):
sub_question_retrieval_stats: AgentChunkRetrievalStats
class StructuredSubquestionDocuments(BaseModel):
cited_documents: list[InferenceSection]
context_documents: list[InferenceSection]
class CombinedAgentMetrics(BaseModel):
timings: AgentTimings
base_metrics: AgentBaseMetrics | None
@@ -126,3 +139,17 @@ class AgentPromptEnrichmentComponents(BaseModel):
persona_prompts: PersonaPromptExpressions
history: str
date_str: str
class LLMNodeErrorStrings(BaseModel):
timeout: str = "LLM Timeout Error"
rate_limit: str = "LLM Rate Limit Error"
general_error: str = "General LLM Error"
class AnswerGenerationDocuments(BaseModel):
streaming_documents: list[InferenceSection]
context_documents: list[InferenceSection]
BaseMessage_Content = str | list[str | dict[str, Any]]

View File

@@ -12,6 +12,13 @@ def dedup_inference_sections(
return deduped
def dedup_inference_section_list(
list: list[InferenceSection],
) -> list[InferenceSection]:
deduped = _merge_sections(list)
return deduped
def dedup_question_answer_results(
question_answer_results_1: list[SubQuestionAnswerResults],
question_answer_results_2: list[SubQuestionAnswerResults],

View File

@@ -20,10 +20,18 @@ from onyx.agents.agent_search.models import GraphInputs
from onyx.agents.agent_search.models import GraphPersistence
from onyx.agents.agent_search.models import GraphSearchConfig
from onyx.agents.agent_search.models import GraphTooling
from onyx.agents.agent_search.shared_graph_utils.models import BaseMessage_Content
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
)
from onyx.agents.agent_search.shared_graph_utils.models import PersonaPromptExpressions
from onyx.agents.agent_search.shared_graph_utils.models import (
StructuredSubquestionDocuments,
)
from onyx.agents.agent_search.shared_graph_utils.models import SubQuestionAnswerResults
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_section_list,
)
from onyx.chat.models import AnswerPacket
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationConfig
@@ -34,6 +42,9 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
)
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.constants import DEFAULT_PERSONA_ID
@@ -46,6 +57,8 @@ from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_context_manager
from onyx.db.persona import get_persona_by_id
from onyx.db.persona import Persona
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.llm.interfaces import LLM
from onyx.prompts.agent_search import (
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
@@ -58,6 +71,7 @@ from onyx.prompts.agent_search import (
)
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
from onyx.tools.force import ForceUseTool
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_constructor import SearchToolConfig
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
@@ -65,8 +79,9 @@ from onyx.tools.tool_implementations.search.search_tool import (
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.logger import setup_logger
BaseMessage_Content = str | list[str | dict[str, Any]]
logger = setup_logger()
# Post-processing
@@ -218,7 +233,10 @@ def get_test_config(
using_tool_calling_llm=using_tool_calling_llm,
)
chat_session_id = os.environ.get("ONYX_AS_CHAT_SESSION_ID")
chat_session_id = (
os.environ.get("ONYX_AS_CHAT_SESSION_ID")
or "00000000-0000-0000-0000-000000000000"
)
assert (
chat_session_id is not None
), "ONYX_AS_CHAT_SESSION_ID must be set for backend tests"
@@ -341,8 +359,12 @@ def retrieve_search_docs(
with get_session_context_manager() as db_session:
for tool_response in search_tool.run(
query=question,
force_no_rerank=True,
alternate_db_session=db_session,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=True,
alternate_db_session=db_session,
retrieved_sections_callback=None,
skip_query_analysis=False,
),
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
@@ -372,8 +394,24 @@ def summarize_history(
)
)
history_response = llm.invoke(history_context_prompt)
try:
history_response = llm.invoke(
history_context_prompt,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
)
except LLMTimeoutError:
logger.error("LLM Timeout Error - summarize history")
return (
history # this is what is done at this point anyway, so we default to this
)
except LLMRateLimitError:
logger.error("LLM Rate Limit Error - summarize history")
return (
history # this is what is done at this point anyway, so we default to this
)
assert isinstance(history_response.content, str)
return history_response.content
@@ -439,3 +477,27 @@ def remove_document_citations(text: str) -> str:
# \d+ - one or more digits
# \] - literal ] character
return re.sub(r"\[(?:D|Q)?\d+\]", "", text)
def get_deduplicated_structured_subquestion_documents(
sub_question_results: list[SubQuestionAnswerResults],
) -> StructuredSubquestionDocuments:
"""
Extract and deduplicate all cited documents from sub-question results.
Args:
sub_question_results: List of sub-question results containing cited documents
Returns:
Deduplicated list of cited documents
"""
cited_docs = [
doc for result in sub_question_results for doc in result.cited_documents
]
context_docs = [
doc for result in sub_question_results for doc in result.context_documents
]
return StructuredSubquestionDocuments(
cited_documents=dedup_inference_section_list(cited_docs),
context_documents=dedup_inference_section_list(context_docs),
)

View File

@@ -1,7 +1,7 @@
import smtplib
from datetime import datetime
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from textwrap import dedent
from onyx.configs.app_configs import EMAIL_CONFIGURED
from onyx.configs.app_configs import EMAIL_FROM
@@ -13,23 +13,150 @@ from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
from onyx.db.models import User
HTML_EMAIL_TEMPLATE = """\
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width" />
<title>{title}</title>
<style>
body, table, td, a {{
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
text-size-adjust: 100%;
margin: 0;
padding: 0;
-webkit-font-smoothing: antialiased;
-webkit-text-size-adjust: none;
}}
body {{
background-color: #f7f7f7;
color: #333;
}}
.body-content {{
color: #333;
}}
.email-container {{
width: 100%;
max-width: 600px;
margin: 0 auto;
background-color: #ffffff;
border-radius: 6px;
overflow: hidden;
border: 1px solid #eaeaea;
}}
.header {{
background-color: #000000;
padding: 20px;
text-align: center;
}}
.header img {{
max-width: 140px;
}}
.body-content {{
padding: 20px 30px;
}}
.title {{
font-size: 20px;
font-weight: bold;
margin: 0 0 10px;
}}
.message {{
font-size: 16px;
line-height: 1.5;
margin: 0 0 20px;
}}
.cta-button {{
display: inline-block;
padding: 12px 20px;
background-color: #000000;
color: #ffffff !important;
text-decoration: none;
border-radius: 4px;
font-weight: 500;
}}
.footer {{
font-size: 13px;
color: #6A7280;
text-align: center;
padding: 20px;
}}
.footer a {{
color: #6b7280;
text-decoration: underline;
}}
</style>
</head>
<body>
<table role="presentation" class="email-container" cellpadding="0" cellspacing="0">
<tr>
<td class="header">
<img
style="background-color: #ffffff; border-radius: 8px;"
src="https://www.onyx.app/logos/customer/onyx.png"
alt="Onyx Logo"
>
</td>
</tr>
<tr>
<td class="body-content">
<h1 class="title">{heading}</h1>
<div class="message">
{message}
</div>
{cta_block}
</td>
</tr>
<tr>
<td class="footer">
© {year} Onyx. All rights reserved.
<br>
Have questions? Join our Slack community <a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA">here</a>.
</td>
</tr>
</table>
</body>
</html>
"""
def build_html_email(
heading: str, message: str, cta_text: str | None = None, cta_link: str | None = None
) -> str:
if cta_text and cta_link:
cta_block = f'<a class="cta-button" href="{cta_link}">{cta_text}</a>'
else:
cta_block = ""
return HTML_EMAIL_TEMPLATE.format(
title=heading,
heading=heading,
message=message,
cta_block=cta_block,
year=datetime.now().year,
)
def send_email(
user_email: str,
subject: str,
body: str,
html_body: str,
text_body: str,
mail_from: str = EMAIL_FROM,
) -> None:
if not EMAIL_CONFIGURED:
raise ValueError("Email is not configured.")
msg = MIMEMultipart()
msg = MIMEMultipart("alternative")
msg["Subject"] = subject
msg["To"] = user_email
if mail_from:
msg["From"] = mail_from
msg.attach(MIMEText(body))
part_text = MIMEText(text_body, "plain")
part_html = MIMEText(html_body, "html")
msg.attach(part_text)
msg.attach(part_html)
try:
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
@@ -40,26 +167,44 @@ def send_email(
raise e
def send_subscription_cancellation_email(user_email: str) -> None:
# Example usage of the reusable HTML
subject = "Your Onyx Subscription Has Been Canceled"
heading = "Subscription Canceled"
message = (
"<p>Were sorry to see you go.</p>"
"<p>Your subscription has been canceled and will end on your next billing date.</p>"
"<p>If you change your mind, you can always come back!</p>"
)
cta_text = "Renew Subscription"
cta_link = "https://www.onyx.app/pricing"
html_content = build_html_email(heading, message, cta_text, cta_link)
text_content = (
"We're sorry to see you go.\n"
"Your subscription has been canceled and will end on your next billing date.\n"
"If you change your mind, visit https://www.onyx.app/pricing"
)
send_email(user_email, subject, html_content, text_content)
def send_user_email_invite(user_email: str, current_user: User) -> None:
subject = "Invitation to Join Onyx Organization"
body = dedent(
f"""\
Hello,
You have been invited to join an organization on Onyx.
To join the organization, please visit the following link:
{WEB_DOMAIN}/auth/signup?email={user_email}
You'll be asked to set a password or login with Google to complete your registration.
Best regards,
The Onyx Team
"""
heading = "You've Been Invited!"
message = (
f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
"<p>To join the organization, please click the button below to set a password "
"or login with Google and complete your registration.</p>"
)
send_email(user_email, subject, body, current_user.email)
cta_text = "Join Organization"
cta_link = f"{WEB_DOMAIN}/auth/signup?email={user_email}"
html_content = build_html_email(heading, message, cta_text, cta_link)
text_content = (
f"You have been invited by {current_user.email} to join an organization on Onyx.\n"
"To join the organization, please visit the following link:\n"
f"{WEB_DOMAIN}/auth/signup?email={user_email}\n"
"You'll be asked to set a password or login with Google to complete your registration."
)
send_email(user_email, subject, html_content, text_content)
def send_forgot_password_email(
@@ -68,13 +213,15 @@ def send_forgot_password_email(
mail_from: str = EMAIL_FROM,
tenant_id: str | None = None,
) -> None:
# Builds a forgot password email with or without fancy HTML
subject = "Onyx Forgot Password"
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
if tenant_id:
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
# Keep search param same name as cookie for simplicity
body = f"Click the following link to reset your password: {link}"
send_email(user_email, subject, body, mail_from)
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
html_content = build_html_email("Reset Your Password", message)
text_content = f"Click the following link to reset your password: {link}"
send_email(user_email, subject, html_content, text_content, mail_from)
def send_user_verification_email(
@@ -82,7 +229,12 @@ def send_user_verification_email(
token: str,
mail_from: str = EMAIL_FROM,
) -> None:
# Builds a verification email
subject = "Onyx Email Verification"
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
body = f"Click the following link to verify your email address: {link}"
send_email(user_email, subject, body, mail_from)
message = (
f"<p>Click the following link to verify your email address:</p><p>{link}</p>"
)
html_content = build_html_email("Verify Your Email", message)
text_content = f"Click the following link to verify your email address: {link}"
send_email(user_email, subject, html_content, text_content, mail_from)

View File

@@ -144,7 +144,6 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)
r.delete(OnyxRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
r.delete(OnyxRedisConstants.ACTIVE_FENCES)

View File

@@ -19,6 +19,7 @@ BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
# hack to slow down task dispatch in the cloud until
# we have a better implementation (backpressure, etc)
# Note that DynamicTenantScheduler can adjust the runtime value for this via Redis
CLOUD_BEAT_MULTIPLIER_DEFAULT = 8.0
# tasks that run in either self-hosted on cloud
@@ -35,6 +36,15 @@ beat_task_templates.extend(
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-checkpoint-cleanup",
"task": OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-connector-deletion",
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
@@ -56,16 +66,7 @@ beat_task_templates.extend(
{
"name": "check-for-pruning",
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "monitor-vespa-sync",
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
"schedule": timedelta(seconds=5),
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
@@ -141,14 +142,14 @@ def make_cloud_generator_task(task: dict[str, Any]) -> dict[str, Any]:
return cloud_task
# tasks that only run in the cloud
# tasks that only run in the cloud and are system wide
# the name attribute must start with ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud" to be seen
# by the DynamicTenantScheduler as system wide task and not a per tenant task
beat_system_tasks: list[dict] = [
beat_cloud_tasks: list[dict] = [
# cloud specific tasks
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-alembic",
"task": OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-alembic",
"task": OnyxCeleryTask.CLOUD_MONITOR_ALEMBIC,
"schedule": timedelta(hours=1),
"options": {
"queue": OnyxCeleryQueues.MONITORING,
@@ -156,11 +157,37 @@ beat_system_tasks: list[dict] = [
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-celery-queues",
"task": OnyxCeleryTask.CLOUD_MONITOR_CELERY_QUEUES,
"schedule": timedelta(seconds=30),
"options": {
"queue": OnyxCeleryQueues.MONITORING,
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
]
# tasks that only run self hosted
tasks_to_schedule: list[dict] = []
if not MULTI_TENANT:
tasks_to_schedule = beat_task_templates
tasks_to_schedule.extend(
[
{
"name": "monitor-celery-queues",
"task": OnyxCeleryTask.MONITOR_CELERY_QUEUES,
"schedule": timedelta(seconds=10),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.MONITORING,
},
},
]
)
tasks_to_schedule.extend(beat_task_templates)
def generate_cloud_tasks(
@@ -180,23 +207,24 @@ def generate_cloud_tasks(
if beat_multiplier <= 0:
raise ValueError("beat_multiplier must be positive!")
# start with the incoming beat tasks
cloud_tasks: list[dict] = copy.deepcopy(beat_tasks)
cloud_tasks: list[dict] = []
# generate our cloud tasks from the templates
# generate our tenant aware cloud tasks from the templates
for beat_template in beat_templates:
cloud_task = make_cloud_generator_task(beat_template)
cloud_tasks.append(cloud_task)
# factor in the cloud multiplier
# factor in the cloud multiplier for the above
for cloud_task in cloud_tasks:
cloud_task["schedule"] = cloud_task["schedule"] * beat_multiplier
# add the fixed cloud/system beat tasks. No multiplier for these.
cloud_tasks.extend(copy.deepcopy(beat_tasks))
return cloud_tasks
def get_cloud_tasks_to_schedule(beat_multiplier: float) -> list[dict[str, Any]]:
return generate_cloud_tasks(beat_system_tasks, beat_task_templates, beat_multiplier)
return generate_cloud_tasks(beat_cloud_tasks, beat_task_templates, beat_multiplier)
def get_tasks_to_schedule() -> list[dict[str, Any]]:

View File

@@ -1,10 +1,14 @@
import traceback
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
@@ -12,18 +16,35 @@ from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.connector import fetch_connector_by_id
from onyx.db.connector_credential_pair import add_deletion_failure_message
from onyx.db.connector_credential_pair import (
delete_connector_credential_pair__no_commit,
)
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.document import get_document_ids_for_connector_credential_pair
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.index_attempt import delete_index_attempts
from onyx.db.search_settings import get_all_search_settings
from onyx.db.sync_record import cleanup_sync_records
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_delete import RedisConnectorDeletePayload
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from onyx.utils.variable_functionality import noop_fallback
class TaskDependencyError(RuntimeError):
@@ -42,6 +63,7 @@ def check_for_connector_deletion_task(
self: Task, *, tenant_id: str | None
) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
r_replica = get_redis_replica_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
@@ -77,6 +99,18 @@ def check_for_connector_deletion_task(
# clear the stop signal if it exists ... no longer needed
redis_connector.stop.set_fence(False)
lock_beat.reacquire()
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
for key in keys:
key_bytes = cast(bytes, key)
if not r.exists(key_bytes):
r.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
continue
key_str = key_bytes.decode("utf-8")
if key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -212,3 +246,158 @@ def try_generate_document_cc_pair_cleanup_tasks(
redis_connector.delete.set_fence(fence_payload)
return tasks_generated
def monitor_connector_deletion_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(f"could not parse cc_pair_id from {fence_key}")
return
cc_pair_id = int(cc_pair_id_str)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
fence_data = redis_connector.delete.payload
if not fence_data:
task_logger.warning(
f"Connector deletion - fence payload invalid: cc_pair={cc_pair_id}"
)
return
if fence_data.num_tasks is None:
# the fence is setting up but isn't ready yet
return
remaining = redis_connector.delete.get_remaining()
task_logger.info(
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
)
if remaining > 0:
with get_session_with_tenant(tenant_id) as db_session:
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
sync_status=SyncStatus.IN_PROGRESS,
num_docs_synced=remaining,
)
return
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
task_logger.warning(
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
)
return
try:
doc_ids = get_document_ids_for_connector_credential_pair(
db_session, cc_pair.connector_id, cc_pair.credential_id
)
if len(doc_ids) > 0:
# NOTE(rkuo): if this happens, documents somehow got added while
# deletion was in progress. Likely a bug gating off pruning and indexing
# work before deletion starts.
task_logger.warning(
"Connector deletion - documents still found after taskset completion. "
"Clearing the current deletion attempt and allowing deletion to restart: "
f"cc_pair={cc_pair_id} "
f"docs_deleted={fence_data.num_tasks} "
f"docs_remaining={len(doc_ids)}"
)
# We don't want to waive off why we get into this state, but resetting
# our attempt and letting the deletion restart is a good way to recover
redis_connector.delete.reset()
raise RuntimeError(
"Connector deletion - documents still found after taskset completion"
)
# clean up the rest of the related Postgres entities
# index attempts
delete_index_attempts(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
# document sets
delete_document_set_cc_pair_relationship__no_commit(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# user groups
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
"onyx.db.user_group",
"delete_user_group_cc_pair_relationship__no_commit",
noop_fallback,
)
cleanup_user_groups(
cc_pair_id=cc_pair_id,
db_session=db_session,
)
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# if there are no credentials left, delete the connector
connector = fetch_connector_by_id(
db_session=db_session,
connector_id=cc_pair.connector_id,
)
if not connector or not len(connector.credentials):
task_logger.info(
"Connector deletion - Found no credentials left for connector, deleting connector"
)
db_session.delete(connector)
db_session.commit()
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
sync_status=SyncStatus.SUCCESS,
num_docs_synced=fence_data.num_tasks,
)
except Exception as e:
db_session.rollback()
stack_trace = traceback.format_exc()
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
add_deletion_failure_message(db_session, cc_pair_id, error_message)
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
sync_status=SyncStatus.FAILED,
num_docs_synced=fence_data.num_tasks,
)
task_logger.exception(
f"Connector deletion exceptioned: "
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
)
raise e
task_logger.info(
f"Connector deletion succeeded: "
f"cc_pair={cc_pair_id} "
f"connector={cc_pair.connector_id} "
f"credential={cc_pair.credential_id} "
f"docs_deleted={fence_data.num_tasks}"
)
redis_connector.delete.reset()

View File

@@ -175,6 +175,24 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool
)
r.set(OnyxRedisSignals.BLOCK_VALIDATE_PERMISSION_SYNC_FENCES, 1, ex=300)
# use a lookup table to find active fences. We still have to verify the fence
# exists since it is an optimization and not the source of truth.
lock_beat.reacquire()
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
for key in keys:
key_bytes = cast(bytes, key)
if not r.exists(key_bytes):
r.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
continue
key_str = key_bytes.decode("utf-8")
if key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_permissions_taskset(
tenant_id, key_bytes, r, db_session
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -349,6 +367,7 @@ def connector_permission_sync_generator_task(
OnyxRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT,
thread_local=False,
)
acquired = lock.acquire(blocking=False)
@@ -756,7 +775,7 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
raise
"""Monitoring CCPair permissions utils, called in monitor_vespa_sync"""
"""Monitoring CCPair permissions utils"""
def monitor_ccpair_permissions_taskset(

View File

@@ -26,11 +26,11 @@ from ee.onyx.external_permissions.sync_params import (
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.error_logging import emit_background_error
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
@@ -72,18 +72,26 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if external group sync is due."""
if cc_pair.access_type != AccessType.SYNC:
return False
# skip external group sync if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
task_logger.error(
f"Recieved non-sync CC Pair {cc_pair.id} for external "
f"group sync. Actual access type: {cc_pair.access_type}"
)
return False
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
task_logger.debug(
f"Skipping group sync for CC Pair {cc_pair.id} - "
f"CC Pair is being deleted"
)
return False
# If there is not group sync function for the connector, we don't run the sync
# This is fine because all sources dont necessarily have a concept of groups
if not GROUP_PERMISSIONS_FUNC_MAP.get(cc_pair.connector.source):
task_logger.debug(
f"Skipping group sync for CC Pair {cc_pair.id} - "
f"no group sync function for {cc_pair.connector.source}"
)
return False
# If the last sync is None, it has never been run so we run the sync
@@ -125,6 +133,9 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
task_logger.warning(
f"Failed to acquire beat lock for external group sync: {tenant_id}"
)
return None
try:
@@ -205,20 +216,12 @@ def try_creating_external_group_sync_task(
redis_connector = RedisConnector(tenant_id, cc_pair_id)
LOCK_TIMEOUT = 30
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_external_group_sync_tasks",
timeout=LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
return None
try:
# Dont kick off a new sync if the previous one is still running
if redis_connector.external_group_sync.fenced:
logger.warning(
f"Skipping external group sync for CC Pair {cc_pair_id} - already running."
)
return None
redis_connector.external_group_sync.generator_clear()
@@ -269,9 +272,6 @@ def try_creating_external_group_sync_task(
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
)
return None
finally:
if lock.owned():
lock.release()
return payload_id
@@ -304,22 +304,26 @@ def connector_external_group_sync_generator_task(
start = time.monotonic()
while True:
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
raise ValueError(
msg = (
f"connector_external_group_sync_generator_task - timed out waiting for fence to be ready: "
f"fence={redis_connector.external_group_sync.fence_key}"
)
emit_background_error(msg, cc_pair_id=cc_pair_id)
raise ValueError(msg)
if not redis_connector.external_group_sync.fenced: # The fence must exist
raise ValueError(
msg = (
f"connector_external_group_sync_generator_task - fence not found: "
f"fence={redis_connector.external_group_sync.fence_key}"
)
emit_background_error(msg, cc_pair_id=cc_pair_id)
raise ValueError(msg)
payload = redis_connector.external_group_sync.payload # The payload must exist
if not payload:
raise ValueError(
"connector_external_group_sync_generator_task: payload invalid or not found"
)
msg = "connector_external_group_sync_generator_task: payload invalid or not found"
emit_background_error(msg, cc_pair_id=cc_pair_id)
raise ValueError(msg)
if payload.celery_task_id is None:
logger.info(
@@ -344,9 +348,9 @@ def connector_external_group_sync_generator_task(
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
)
msg = f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
emit_background_error(msg, cc_pair_id=cc_pair_id)
task_logger.error(msg)
return None
try:
@@ -367,9 +371,9 @@ def connector_external_group_sync_generator_task(
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
if ext_group_sync_func is None:
raise ValueError(
f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
)
msg = f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
emit_background_error(msg, cc_pair_id=cc_pair_id)
raise ValueError(msg)
logger.info(
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
@@ -400,9 +404,9 @@ def connector_external_group_sync_generator_task(
sync_status=SyncStatus.SUCCESS,
)
except Exception as e:
task_logger.exception(
f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id}"
)
msg = f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id}"
task_logger.exception(msg)
emit_background_error(msg + f"\n\n{e}", cc_pair_id=cc_pair_id)
with get_session_with_tenant(tenant_id) as db_session:
update_sync_record_status(
@@ -492,9 +496,11 @@ def validate_external_group_sync_fence(
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
msg = (
f"validate_external_group_sync_fence - could not parse id from {fence_key}"
)
emit_background_error(msg)
task_logger.error(msg)
return
cc_pair_id = int(cc_pair_id_str)
@@ -509,12 +515,14 @@ def validate_external_group_sync_fence(
try:
payload = redis_connector.external_group_sync.payload
except ValidationError:
task_logger.exception(
msg = (
"validate_external_group_sync_fence - "
"Resetting fence because fence schema is out of date: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
task_logger.exception(msg)
emit_background_error(msg, cc_pair_id=cc_pair_id)
redis_connector.external_group_sync.reset()
return
@@ -551,12 +559,15 @@ def validate_external_group_sync_fence(
# return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
logger.warning(
"validate_external_group_sync_fence - "
"Resetting fence because no associated celery tasks were found: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key} "
f"payload_id={payload.id}"
emit_background_error(
message=(
"validate_external_group_sync_fence - "
"Resetting fence because no associated celery tasks were found: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key} "
f"payload_id={payload.id}"
),
cc_pair_id=cc_pair_id,
)
redis_connector.external_group_sync.reset()

File diff suppressed because it is too large Load Diff

View File

@@ -240,7 +240,8 @@ def validate_indexing_fence(
# it would be odd to get here as there isn't that much that can go wrong during
# initial fence setup, but it's still worth making sure we can recover
logger.info(
f"validate_indexing_fence - Resetting fence in basic state without any activity: fence={fence_key}"
f"validate_indexing_fence - "
f"Resetting fence in basic state without any activity: fence={fence_key}"
)
redis_connector_index.reset()
return

View File

@@ -17,7 +17,8 @@ from sqlalchemy import text
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.vespa.tasks import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import OnyxCeleryQueues
@@ -189,9 +190,9 @@ def _build_connector_start_latency_metric(
desired_start_time = cc_pair.connector.time_created
else:
if not cc_pair.connector.refresh_freq:
task_logger.error(
"Found non-initial index attempt for connector "
"without refresh_freq. This should never happen."
task_logger.debug(
"Connector has no refresh_freq and this is a non-initial index attempt. "
"Assuming user manually triggered indexing, so we'll skip start latency metric."
)
return None
@@ -722,7 +723,7 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
@shared_task(
name=OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
name=OnyxCeleryTask.CLOUD_MONITOR_ALEMBIC,
)
def cloud_check_alembic() -> bool | None:
"""A task to verify that all tenants are on the same alembic revision.
@@ -852,3 +853,55 @@ def cloud_check_alembic() -> bool | None:
f"cloud_check_alembic finished: num_tenants={len(tenant_ids)} elapsed={time_elapsed:.2f}"
)
return True
@shared_task(
name=OnyxCeleryTask.CLOUD_MONITOR_CELERY_QUEUES, ignore_result=True, bind=True
)
def cloud_monitor_celery_queues(
self: Task,
) -> None:
return monitor_celery_queues_helper(self)
@shared_task(name=OnyxCeleryTask.MONITOR_CELERY_QUEUES, ignore_result=True, bind=True)
def monitor_celery_queues(self: Task, *, tenant_id: str | None) -> None:
return monitor_celery_queues_helper(self)
def monitor_celery_queues_helper(
task: Task,
) -> None:
"""A task to monitor all celery queue lengths."""
r_celery = task.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r_celery)
n_indexing = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery)
n_sync = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
n_deletion = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
n_pruning = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery)
n_permissions_sync = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
)
n_external_group_sync = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
)
n_permissions_upsert = celery_get_queue_length(
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
)
n_indexing_prefetched = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
task_logger.info(
f"Queue lengths: celery={n_celery} "
f"indexing={n_indexing} "
f"indexing_prefetched={len(n_indexing_prefetched)} "
f"sync={n_sync} "
f"deletion={n_deletion} "
f"pruning={n_pruning} "
f"permissions_sync={n_permissions_sync} "
f"external_group_sync={n_external_group_sync} "
f"permissions_upsert={n_permissions_upsert} "
)

View File

@@ -122,34 +122,39 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
return None
try:
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
# the entire task needs to run frequently in order to finalize pruning
for cc_pair_id in cc_pair_ids:
lock_beat.reacquire()
# but pruning only kicks off once per hour
if not r.exists(OnyxRedisSignals.BLOCK_PRUNING):
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
continue
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
if not _is_pruning_due(cc_pair):
continue
for cc_pair_id in cc_pair_ids:
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
continue
payload_id = try_creating_prune_generator_task(
self.app, cc_pair, db_session, r, tenant_id
)
if not payload_id:
continue
if not _is_pruning_due(cc_pair):
continue
task_logger.info(
f"Pruning queued: cc_pair={cc_pair.id} id={payload_id}"
)
payload_id = try_creating_prune_generator_task(
self.app, cc_pair, db_session, r, tenant_id
)
if not payload_id:
continue
task_logger.info(
f"Pruning queued: cc_pair={cc_pair.id} id={payload_id}"
)
r.set(OnyxRedisSignals.BLOCK_PRUNING, 1, ex=3600)
# we want to run this less frequently than the overall task
lock_beat.reacquire()
@@ -163,6 +168,22 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
task_logger.exception("Exception while validating pruning fences")
r.set(OnyxRedisSignals.BLOCK_VALIDATE_PRUNING_FENCES, 1, ex=300)
# use a lookup table to find active fences. We still have to verify the fence
# exists since it is an optimization and not the source of truth.
lock_beat.reacquire()
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
for key in keys:
key_bytes = cast(bytes, key)
if not r.exists(key_bytes):
r.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
continue
key_str = key_bytes.decode("utf-8")
if key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -481,7 +502,7 @@ def connector_pruning_generator_task(
)
"""Monitoring pruning utils, called in monitor_vespa_sync"""
"""Monitoring pruning utils"""
def monitor_ccpair_pruning_taskset(

View File

@@ -8,6 +8,7 @@ from celery.exceptions import SoftTimeLimitExceeded
from redis.lock import Lock as RedisLock
from tenacity import RetryError
from ee.onyx.server.tenants.product_gating import get_gated_tenants
from onyx.access.access import get_access_for_document
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
@@ -104,6 +105,7 @@ def document_by_cc_pair_cleanup_task(
tenant_id=tenant_id,
chunk_count=chunk_count,
)
delete_documents_complete__no_commit(
db_session=db_session,
document_ids=[document_id],
@@ -252,7 +254,11 @@ def cloud_beat_task_generator(
try:
tenant_ids = get_all_tenant_ids()
gated_tenants = get_gated_tenants()
for tenant_id in tenant_ids:
if tenant_id in gated_tenants:
continue
current_time = time.monotonic()
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
lock_beat.reacquire()
@@ -270,6 +276,7 @@ def cloud_beat_task_generator(
queue=queue,
priority=priority,
expires=expires,
ignore_result=True,
)
except SoftTimeLimitExceeded:
task_logger.info(

View File

@@ -1,9 +1,5 @@
import random
import time
import traceback
from collections.abc import Callable
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from typing import Any
from typing import cast
@@ -13,8 +9,6 @@ from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from celery.result import AsyncResult
from celery.states import READY_STATES
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
@@ -22,47 +16,27 @@ from tenacity import RetryError
from onyx.access.access import get_access_for_document
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.tasks.doc_permission_syncing.tasks import (
monitor_ccpair_permissions_taskset,
)
from onyx.background.celery.tasks.pruning.tasks import monitor_ccpair_pruning_taskset
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import VESPA_SYNC_MAX_TASKS
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.db.connector import fetch_connector_by_id
from onyx.db.connector_credential_pair import add_deletion_failure_message
from onyx.db.connector_credential_pair import (
delete_connector_credential_pair__no_commit,
)
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.document import count_documents_by_needs_sync
from onyx.db.document import get_document
from onyx.db.document import get_document_ids_for_connector_credential_pair
from onyx.db.document import mark_document_as_synced
from onyx.db.document_set import delete_document_set
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from onyx.db.document_set import fetch_document_sets
from onyx.db.document_set import fetch_document_sets_for_document
from onyx.db.document_set import get_document_set_by_id
from onyx.db.document_set import mark_document_set_as_synced
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import IndexingStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.index_attempt import delete_index_attempts
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.models import DocumentSet
from onyx.db.models import UserGroup
from onyx.db.search_settings import get_active_search_settings
@@ -72,20 +46,14 @@ from onyx.db.sync_record import update_sync_record_status
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaDocumentFields
from onyx.httpx.httpx_pool import HttpxPool
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from onyx.redis.redis_connector_credential_pair import (
RedisGlobalConnectorCredentialPair,
)
from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation
@@ -94,7 +62,6 @@ from onyx.utils.variable_functionality import (
)
from onyx.utils.variable_functionality import global_version
from onyx.utils.variable_functionality import noop_fallback
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -111,9 +78,14 @@ logger = setup_logger()
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | None:
"""Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed."""
# Useful for debugging timing issues with reacquisitions. TODO: remove once more generalized logging is in place
task_logger.info("check_for_vespa_sync_task started")
time_start = time.monotonic()
r = get_redis_client(tenant_id=tenant_id)
r_replica = get_redis_replica_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
@@ -125,6 +97,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
return None
try:
# 1/3: KICKOFF
with get_session_with_tenant(tenant_id) as db_session:
try_generate_stale_document_sync_tasks(
self.app, VESPA_SYNC_MAX_TASKS, db_session, r, lock_beat, tenant_id
@@ -151,9 +124,8 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
# endregion
# check if any user groups are not synced
lock_beat.reacquire()
if global_version.is_ee_version():
lock_beat.reacquire()
try:
fetch_user_groups = fetch_versioned_implementation(
"onyx.db.user_group", "fetch_user_groups"
@@ -179,6 +151,35 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
self.app, usergroup_id, db_session, r, lock_beat, tenant_id
)
# 2/3: VALIDATE: TODO
# 3/3: FINALIZE
lock_beat.reacquire()
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
for key in keys:
key_bytes = cast(bytes, key)
if not r.exists(key_bytes):
r.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
continue
key_str = key_bytes.decode("utf-8")
if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY:
monitor_connector_taskset(r)
elif key_str.startswith(RedisDocumentSet.FENCE_PREFIX):
with get_session_with_tenant(tenant_id) as db_session:
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
elif key_str.startswith(RedisUserGroup.FENCE_PREFIX):
monitor_usergroup_taskset = (
fetch_versioned_implementation_with_fallback(
"onyx.background.celery.tasks.vespa.tasks",
"monitor_usergroup_taskset",
noop_fallback,
)
)
with get_session_with_tenant(tenant_id) as db_session:
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -495,484 +496,23 @@ def monitor_document_set_taskset(
task_logger.info(
f"Successfully synced document set: document_set={document_set_id}"
)
update_sync_record_status(
db_session=db_session,
entity_id=document_set_id,
sync_type=SyncType.DOCUMENT_SET,
sync_status=SyncStatus.SUCCESS,
num_docs_synced=initial_count,
)
rds.reset()
def monitor_connector_deletion_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(f"could not parse cc_pair_id from {fence_key}")
return
cc_pair_id = int(cc_pair_id_str)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
fence_data = redis_connector.delete.payload
if not fence_data:
task_logger.warning(
f"Connector deletion - fence payload invalid: cc_pair={cc_pair_id}"
)
return
if fence_data.num_tasks is None:
# the fence is setting up but isn't ready yet
return
remaining = redis_connector.delete.get_remaining()
task_logger.info(
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
)
if remaining > 0:
with get_session_with_tenant(tenant_id) as db_session:
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
sync_status=SyncStatus.IN_PROGRESS,
num_docs_synced=remaining,
)
return
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
task_logger.warning(
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
)
return
try:
doc_ids = get_document_ids_for_connector_credential_pair(
db_session, cc_pair.connector_id, cc_pair.credential_id
)
if len(doc_ids) > 0:
# NOTE(rkuo): if this happens, documents somehow got added while
# deletion was in progress. Likely a bug gating off pruning and indexing
# work before deletion starts.
task_logger.warning(
"Connector deletion - documents still found after taskset completion. "
"Clearing the current deletion attempt and allowing deletion to restart: "
f"cc_pair={cc_pair_id} "
f"docs_deleted={fence_data.num_tasks} "
f"docs_remaining={len(doc_ids)}"
)
# We don't want to waive off why we get into this state, but resetting
# our attempt and letting the deletion restart is a good way to recover
redis_connector.delete.reset()
raise RuntimeError(
"Connector deletion - documents still found after taskset completion"
)
# clean up the rest of the related Postgres entities
# index attempts
delete_index_attempts(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
# document sets
delete_document_set_cc_pair_relationship__no_commit(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# user groups
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
"onyx.db.user_group",
"delete_user_group_cc_pair_relationship__no_commit",
noop_fallback,
)
cleanup_user_groups(
cc_pair_id=cc_pair_id,
db_session=db_session,
)
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# if there are no credentials left, delete the connector
connector = fetch_connector_by_id(
db_session=db_session,
connector_id=cc_pair.connector_id,
)
if not connector or not len(connector.credentials):
task_logger.info(
"Connector deletion - Found no credentials left for connector, deleting connector"
)
db_session.delete(connector)
db_session.commit()
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
entity_id=document_set_id,
sync_type=SyncType.DOCUMENT_SET,
sync_status=SyncStatus.SUCCESS,
num_docs_synced=fence_data.num_tasks,
num_docs_synced=initial_count,
)
except Exception as e:
db_session.rollback()
stack_trace = traceback.format_exc()
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
add_deletion_failure_message(db_session, cc_pair_id, error_message)
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
sync_status=SyncStatus.FAILED,
num_docs_synced=fence_data.num_tasks,
)
except Exception:
task_logger.exception(
f"Connector deletion exceptioned: "
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
)
raise e
task_logger.info(
f"Connector deletion succeeded: "
f"cc_pair={cc_pair_id} "
f"connector={cc_pair.connector_id} "
f"credential={cc_pair.credential_id} "
f"docs_deleted={fence_data.num_tasks}"
)
redis_connector.delete.reset()
def monitor_ccpair_indexing_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
if composite_id is None:
task_logger.warning(
f"Connector indexing: could not parse composite_id from {fence_key}"
)
return
# parse out metadata and initialize the helper class with it
parts = composite_id.split("/")
if len(parts) != 2:
return
cc_pair_id = int(parts[0])
search_settings_id = int(parts[1])
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
if not redis_connector_index.fenced:
return
payload = redis_connector_index.payload
if not payload:
return
elapsed_started_str = None
if payload.started:
elapsed_started = datetime.now(timezone.utc) - payload.started
elapsed_started_str = f"{elapsed_started.total_seconds():.2f}"
elapsed_submitted = datetime.now(timezone.utc) - payload.submitted
progress = redis_connector_index.get_progress()
if progress is not None:
task_logger.info(
f"Connector indexing progress: "
f"attempt={payload.index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
f"elapsed_started={elapsed_started_str}"
)
if payload.index_attempt_id is None or payload.celery_task_id is None:
# the task is still setting up
return
# never use any blocking methods on the result from inside a task!
result: AsyncResult = AsyncResult(payload.celery_task_id)
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# Verify: if the generator isn't complete, the task must not be in READY state
# inner = get_completion / generator_complete not signaled
# outer = result.state in READY state
status_int = redis_connector_index.get_completion()
if status_int is None: # inner signal not set ... possible error
task_state = result.state
if (
task_state in READY_STATES
): # outer signal in terminal state ... possible error
# Now double check!
if redis_connector_index.get_completion() is None:
# inner signal still not set (and cannot change when outer result_state is READY)
# Task is finished but generator complete isn't set.
# We have a problem! Worker may have crashed.
task_result = str(result.result)
task_traceback = str(result.traceback)
msg = (
f"Connector indexing aborted or exceptioned: "
f"attempt={payload.index_attempt_id} "
f"celery_task={payload.celery_task_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
f"result.state={task_state} "
f"result.result={task_result} "
f"result.traceback={task_traceback}"
)
task_logger.warning(msg)
try:
index_attempt = get_index_attempt(
db_session, payload.index_attempt_id
)
if index_attempt:
if (
index_attempt.status != IndexingStatus.CANCELED
and index_attempt.status != IndexingStatus.FAILED
):
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason=msg,
)
except Exception:
task_logger.exception(
"Connector indexing - Transient exception marking index attempt as failed: "
f"attempt={payload.index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
redis_connector_index.reset()
return
if redis_connector_index.watchdog_signaled():
# if the generator is complete, don't clean up until the watchdog has exited
task_logger.info(
f"Connector indexing - Delaying finalization until watchdog has exited: "
f"attempt={payload.index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
f"elapsed_started={elapsed_started_str}"
)
return
status_enum = HTTPStatus(status_int)
task_logger.info(
f"Connector indexing finished: "
f"attempt={payload.index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"status={status_enum.name} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
f"elapsed_started={elapsed_started_str}"
)
redis_connector_index.reset()
@shared_task(
name=OnyxCeleryTask.MONITOR_VESPA_SYNC,
ignore_result=True,
soft_time_limit=300,
bind=True,
)
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
"""This is a celery beat task that monitors and finalizes various long running tasks.
The name monitor_vespa_sync is a bit of a misnomer since it checks many different tasks
now. Should change that at some point.
It scans for fence values and then gets the counts of any associated tasksets.
For many tasks, the count is 0, that means all tasks finished and we should clean up.
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
do anything too expensive in this function!
Returns True if the task actually did work, False if it exited early to prevent overlap
"""
task_logger.info(f"monitor_vespa_sync starting: tenant={tenant_id}")
time_start = time.monotonic()
r = get_redis_client(tenant_id=tenant_id)
# Replica usage notes
#
# False negatives are OK. (aka fail to to see a key that exists on the master).
# We simply skip the monitoring work and it will be caught on the next pass.
#
# False positives are not OK, and are possible if we clear a fence on the master and
# then read from the replica. In this case, monitoring work could be done on a fence
# that no longer exists. To avoid this, we scan from the replica, but double check
# the result on the master.
r_replica = get_redis_replica_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
# prevent overlapping tasks
if not lock_beat.acquire(blocking=False):
return None
try:
# print current queue lengths
time.monotonic()
# we don't need every tenant polling redis for this info.
if not MULTI_TENANT or random.randint(1, 10) == 10:
r_celery = self.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r_celery)
n_indexing = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
n_sync = celery_get_queue_length(
OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery
)
n_deletion = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
)
n_pruning = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery
)
n_permissions_sync = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
)
n_external_group_sync = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
)
n_permissions_upsert = celery_get_queue_length(
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
"update_sync_record_status exceptioned. "
f"document_set_id={document_set_id} "
"Resetting document set regardless."
)
prefetched = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
task_logger.info(
f"Queue lengths: celery={n_celery} "
f"indexing={n_indexing} "
f"indexing_prefetched={len(prefetched)} "
f"sync={n_sync} "
f"deletion={n_deletion} "
f"pruning={n_pruning} "
f"permissions_sync={n_permissions_sync} "
f"external_group_sync={n_external_group_sync} "
f"permissions_upsert={n_permissions_upsert} "
)
# we want to run this less frequently than the overall task
if not r.exists(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE):
# build a lookup table of existing fences
# this is just a migration concern and should be unnecessary once
# lookup tables are rolled out
for key_bytes in r_replica.scan_iter(count=SCAN_ITER_COUNT_DEFAULT):
if is_fence(key_bytes) and not r.sismember(
OnyxRedisConstants.ACTIVE_FENCES, key_bytes
):
logger.warning(f"Adding {key_bytes} to the lookup table.")
r.sadd(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
r.set(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE, 1, ex=300)
# use a lookup table to find active fences. We still have to verify the fence
# exists since it is an optimization and not the source of truth.
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
for key in keys:
key_bytes = cast(bytes, key)
if not r.exists(key_bytes):
r.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
continue
key_str = key_bytes.decode("utf-8")
if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY:
monitor_connector_taskset(r)
elif key_str.startswith(RedisDocumentSet.FENCE_PREFIX):
with get_session_with_tenant(tenant_id) as db_session:
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
elif key_str.startswith(RedisUserGroup.FENCE_PREFIX):
monitor_usergroup_taskset = (
fetch_versioned_implementation_with_fallback(
"onyx.background.celery.tasks.vespa.tasks",
"monitor_usergroup_taskset",
noop_fallback,
)
)
with get_session_with_tenant(tenant_id) as db_session:
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
elif key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
elif key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
elif key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
elif key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_permissions_taskset(
tenant_id, key_bytes, r, db_session
)
else:
pass
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
return False
except Exception:
task_logger.exception("monitor_vespa_sync exceptioned.")
return False
finally:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error(
"monitor_vespa_sync - Lock not owned on completion: "
f"tenant={tenant_id}"
# f"timings={timings}"
)
redis_lock_dump(lock_beat, r)
time_elapsed = time.monotonic() - time_start
task_logger.info(f"monitor_vespa_sync finished: elapsed={time_elapsed:.2f}")
return True
rds.reset()
@shared_task(
@@ -1072,23 +612,3 @@ def vespa_metadata_sync_task(
self.retry(exc=e, countdown=countdown)
return True
def is_fence(key_bytes: bytes) -> bool:
key_str = key_bytes.decode("utf-8")
if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY:
return True
if key_str.startswith(RedisDocumentSet.FENCE_PREFIX):
return True
if key_str.startswith(RedisUserGroup.FENCE_PREFIX):
return True
if key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
return True
if key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
return True
if key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
return True
if key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
return True
return False

View File

@@ -0,0 +1,13 @@
from onyx.db.background_error import create_background_error
from onyx.db.engine import get_session_with_tenant
def emit_background_error(
message: str,
cc_pair_id: int | None = None,
) -> None:
"""Currently just saves a row in the background_errors table.
In the future, could create notifications based on the severity."""
with get_session_with_tenant() as db_session:
create_background_error(db_session, message, cc_pair_id)

View File

@@ -1,80 +0,0 @@
"""Experimental functionality related to splitting up indexing
into a series of checkpoints to better handle intermittent failures
/ jobs being killed by cloud providers."""
import datetime
from onyx.configs.app_configs import EXPERIMENTAL_CHECKPOINTING_ENABLED
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import datetime_to_utc
def _2010_dt() -> datetime.datetime:
return datetime.datetime(year=2010, month=1, day=1, tzinfo=datetime.timezone.utc)
def _2020_dt() -> datetime.datetime:
return datetime.datetime(year=2020, month=1, day=1, tzinfo=datetime.timezone.utc)
def _default_end_time(
last_successful_run: datetime.datetime | None,
) -> datetime.datetime:
"""If year is before 2010, go to the beginning of 2010.
If year is 2010-2020, go in 5 year increments.
If year > 2020, then go in 180 day increments.
For connectors that don't support a `filter_by` and instead rely on `sort_by`
for polling, then this will cause a massive duplication of fetches. For these
connectors, you may want to override this function to return a more reasonable
plan (e.g. extending the 2020+ windows to 6 months, 1 year, or higher)."""
last_successful_run = (
datetime_to_utc(last_successful_run) if last_successful_run else None
)
if last_successful_run is None or last_successful_run < _2010_dt():
return _2010_dt()
if last_successful_run < _2020_dt():
return min(last_successful_run + datetime.timedelta(days=365 * 5), _2020_dt())
return last_successful_run + datetime.timedelta(days=180)
def find_end_time_for_indexing_attempt(
last_successful_run: datetime.datetime | None,
# source_type can be used to override the default for certain connectors, currently unused
source_type: DocumentSource,
) -> datetime.datetime | None:
"""Is the current time unless the connector is run over a large period, in which case it is
split up into large time segments that become smaller as it approaches the present
"""
# NOTE: source_type can be used to override the default for certain connectors
end_of_window = _default_end_time(last_successful_run)
now = datetime.datetime.now(tz=datetime.timezone.utc)
if end_of_window < now:
return end_of_window
# None signals that we should index up to current time
return None
def get_time_windows_for_index_attempt(
last_successful_run: datetime.datetime, source_type: DocumentSource
) -> list[tuple[datetime.datetime, datetime.datetime]]:
if not EXPERIMENTAL_CHECKPOINTING_ENABLED:
return [(last_successful_run, datetime.datetime.now(tz=datetime.timezone.utc))]
time_windows: list[tuple[datetime.datetime, datetime.datetime]] = []
start_of_window: datetime.datetime | None = last_successful_run
while start_of_window:
end_of_window = find_end_time_for_indexing_attempt(
last_successful_run=start_of_window, source_type=source_type
)
time_windows.append(
(
start_of_window,
end_of_window or datetime.datetime.now(tz=datetime.timezone.utc),
)
)
start_of_window = end_of_window
return time_windows

View File

@@ -0,0 +1,200 @@
from datetime import datetime
from datetime import timedelta
from io import BytesIO
from sqlalchemy import and_
from sqlalchemy.orm import Session
from onyx.configs.constants import FileOrigin
from onyx.connectors.models import ConnectorCheckpoint
from onyx.db.engine import get_db_current_time
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import get_recent_completed_attempts_for_cc_pair
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
from onyx.utils.object_size_check import deep_getsizeof
logger = setup_logger()
_NUM_RECENT_ATTEMPTS_TO_CONSIDER = 20
_NUM_DOCS_INDEXED_TO_BE_VALID_CHECKPOINT = 100
def _build_checkpoint_pointer(index_attempt_id: int) -> str:
return f"checkpoint_{index_attempt_id}.json"
def save_checkpoint(
db_session: Session, index_attempt_id: int, checkpoint: ConnectorCheckpoint
) -> str:
"""Save a checkpoint for a given index attempt to the file store"""
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
file_store = get_default_file_store(db_session)
file_store.save_file(
file_name=checkpoint_pointer,
content=BytesIO(checkpoint.model_dump_json().encode()),
display_name=checkpoint_pointer,
file_origin=FileOrigin.INDEXING_CHECKPOINT,
file_type="application/json",
)
index_attempt = get_index_attempt(db_session, index_attempt_id)
if not index_attempt:
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
index_attempt.checkpoint_pointer = checkpoint_pointer
db_session.add(index_attempt)
db_session.commit()
return checkpoint_pointer
def load_checkpoint(
db_session: Session, index_attempt_id: int
) -> ConnectorCheckpoint | None:
"""Load a checkpoint for a given index attempt from the file store"""
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
file_store = get_default_file_store(db_session)
try:
checkpoint_io = file_store.read_file(checkpoint_pointer, mode="rb")
checkpoint_data = checkpoint_io.read().decode("utf-8")
return ConnectorCheckpoint.model_validate_json(checkpoint_data)
except RuntimeError:
return None
def get_latest_valid_checkpoint(
db_session: Session,
cc_pair_id: int,
search_settings_id: int,
window_start: datetime,
window_end: datetime,
) -> ConnectorCheckpoint:
"""Get the latest valid checkpoint for a given connector credential pair"""
checkpoint_candidates = get_recent_completed_attempts_for_cc_pair(
cc_pair_id=cc_pair_id,
search_settings_id=search_settings_id,
db_session=db_session,
limit=_NUM_RECENT_ATTEMPTS_TO_CONSIDER,
)
checkpoint_candidates = [
candidate
for candidate in checkpoint_candidates
if (
candidate.poll_range_start == window_start
and candidate.poll_range_end == window_end
and candidate.status == IndexingStatus.FAILED
and candidate.checkpoint_pointer is not None
# we want to make sure that the checkpoint is actually useful
# if it's only gone through a few docs, it's probably not worth
# using. This also avoids weird cases where a connector is basically
# non-functional but still "makes progress" by slowly moving the
# checkpoint forward run after run
and candidate.total_docs_indexed
and candidate.total_docs_indexed > _NUM_DOCS_INDEXED_TO_BE_VALID_CHECKPOINT
)
]
# don't keep using checkpoints if we've had a bunch of failed attempts in a row
# for now, capped at 10
if len(checkpoint_candidates) == _NUM_RECENT_ATTEMPTS_TO_CONSIDER:
logger.warning(
f"{_NUM_RECENT_ATTEMPTS_TO_CONSIDER} consecutive failed attempts found "
f"for cc_pair={cc_pair_id}. Ignoring checkpoint to let the run start "
"from scratch."
)
return ConnectorCheckpoint.build_dummy_checkpoint()
# assumes latest checkpoint is the furthest along. This only isn't true
# if something else has gone wrong.
latest_valid_checkpoint_candidate = (
checkpoint_candidates[0] if checkpoint_candidates else None
)
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
if latest_valid_checkpoint_candidate:
try:
previous_checkpoint = load_checkpoint(
db_session=db_session,
index_attempt_id=latest_valid_checkpoint_candidate.id,
)
except Exception:
logger.exception(
f"Failed to load checkpoint from previous failed attempt with ID "
f"{latest_valid_checkpoint_candidate.id}."
)
previous_checkpoint = None
if previous_checkpoint is not None:
logger.info(
f"Using checkpoint from previous failed attempt with ID "
f"{latest_valid_checkpoint_candidate.id}. Previous checkpoint: "
f"{previous_checkpoint}"
)
save_checkpoint(
db_session=db_session,
index_attempt_id=latest_valid_checkpoint_candidate.id,
checkpoint=previous_checkpoint,
)
checkpoint = previous_checkpoint
return checkpoint
def get_index_attempts_with_old_checkpoints(
db_session: Session, days_to_keep: int = 7
) -> list[IndexAttempt]:
"""Get all index attempts with checkpoints older than the specified number of days.
Args:
db_session: The database session
days_to_keep: Number of days to keep checkpoints for (default: 7)
Returns:
Number of checkpoints deleted
"""
cutoff_date = get_db_current_time(db_session) - timedelta(days=days_to_keep)
# Find all index attempts with checkpoints older than cutoff_date
old_attempts = (
db_session.query(IndexAttempt)
.filter(
and_(
IndexAttempt.checkpoint_pointer.isnot(None),
IndexAttempt.time_created < cutoff_date,
)
)
.all()
)
return old_attempts
def cleanup_checkpoint(db_session: Session, index_attempt_id: int) -> None:
"""Clean up a checkpoint for a given index attempt"""
index_attempt = get_index_attempt(db_session, index_attempt_id)
if not index_attempt:
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
if not index_attempt.checkpoint_pointer:
return None
file_store = get_default_file_store(db_session)
file_store.delete_file(index_attempt.checkpoint_pointer)
index_attempt.checkpoint_pointer = None
db_session.add(index_attempt)
db_session.commit()
return None
def check_checkpoint_size(checkpoint: ConnectorCheckpoint) -> None:
"""Check if the checkpoint content size exceeds the limit (200MB)"""
content_size = deep_getsizeof(checkpoint.checkpoint_content)
if content_size > 200_000_000: # 200MB in bytes
raise ValueError(
f"Checkpoint content size ({content_size} bytes) exceeds 200MB limit"
)

View File

@@ -5,6 +5,8 @@ not follow the expected behavior, etc.
NOTE: cannot use Celery directly due to
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
import multiprocessing as mp
import sys
import traceback
from collections.abc import Callable
from dataclasses import dataclass
from multiprocessing.context import SpawnProcess
@@ -18,6 +20,16 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
class SimpleJobException(Exception):
"""lets us raise an exception that will return a specific error code"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
code: int | None = kwargs.pop("code", None)
self.code = code
super().__init__(*args, **kwargs)
JobStatusType = (
Literal["error"]
| Literal["finished"]
@@ -28,7 +40,10 @@ JobStatusType = (
def _initializer(
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
func: Callable,
queue: mp.Queue,
args: list | tuple,
kwargs: dict[str, Any] | None = None,
) -> Any:
"""Initialize the child process with a fresh SQLAlchemy Engine.
@@ -52,13 +67,29 @@ def _initializer(
)
# Proceed with executing the target function
return func(*args, **kwargs)
try:
return func(*args, **kwargs)
except SimpleJobException as e:
logger.exception("SimpleJob raised a SimpleJobException")
error_msg = traceback.format_exc()
queue.put(error_msg) # Send the exception to the parent process
sys.exit(e.code) # use the given exit code
except Exception:
logger.exception("SimpleJob raised an exception")
error_msg = traceback.format_exc()
queue.put(error_msg) # Send the exception to the parent process
sys.exit(255) # use 255 to indicate a generic exception
def _run_in_process(
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
func: Callable,
queue: mp.Queue,
args: list | tuple,
kwargs: dict[str, Any] | None = None,
) -> None:
_initializer(func, args, kwargs)
_initializer(func, queue, args, kwargs)
@dataclass
@@ -67,6 +98,8 @@ class SimpleJob:
id: int
process: Optional["SpawnProcess"] = None
queue: Optional[mp.Queue] = None
_exception: Optional[str] = None
def cancel(self) -> bool:
return self.release()
@@ -100,9 +133,15 @@ class SimpleJob:
def exception(self) -> str:
"""Needed to match the Dask API, but not implemented since we don't currently
have a way to get back the exception information from the child process."""
return (
f"Job with ID '{self.id}' was killed or encountered an unhandled exception."
)
"""Retrieve exception from the multiprocessing queue if available."""
if self._exception is None and self.queue and not self.queue.empty():
self._exception = self.queue.get() # Get exception from queue
if self._exception:
return self._exception
return f"Job with ID '{self.id}' did not report an exception."
class SimpleJobClient:
@@ -137,8 +176,11 @@ class SimpleJobClient:
# this approach allows us to always "spawn" a new process regardless of
# get_start_method's current setting
ctx = mp.get_context("spawn")
process = ctx.Process(target=_run_in_process, args=(func, args), daemon=True)
job = SimpleJob(id=job_id, process=process)
queue = ctx.Queue()
process = ctx.Process(
target=_run_in_process, args=(func, queue, args), daemon=True
)
job = SimpleJob(id=job_id, process=process, queue=queue)
process.start()
self.jobs[job_id] = job

View File

@@ -0,0 +1,87 @@
import tracemalloc
from onyx.utils.logger import setup_logger
logger = setup_logger()
DANSWER_TRACEMALLOC_FRAMES = 10
class MemoryTracer:
def __init__(self, interval: int = 0, num_print_entries: int = 5):
self.interval = interval
self.num_print_entries = num_print_entries
self.snapshot_first: tracemalloc.Snapshot | None = None
self.snapshot_prev: tracemalloc.Snapshot | None = None
self.snapshot: tracemalloc.Snapshot | None = None
self.counter = 0
def start(self) -> None:
"""Start the memory tracer if interval is greater than 0."""
if self.interval > 0:
logger.debug(f"Memory tracer starting: interval={self.interval}")
tracemalloc.start(DANSWER_TRACEMALLOC_FRAMES)
self._take_snapshot()
def stop(self) -> None:
"""Stop the memory tracer if it's running."""
if self.interval > 0:
self.log_final_diff()
tracemalloc.stop()
logger.debug("Memory tracer stopped.")
def _take_snapshot(self) -> None:
"""Take a snapshot and update internal snapshot states."""
snapshot = tracemalloc.take_snapshot()
# Filter out irrelevant frames
snapshot = snapshot.filter_traces(
(
tracemalloc.Filter(False, tracemalloc.__file__),
tracemalloc.Filter(False, "<frozen importlib._bootstrap>"),
tracemalloc.Filter(False, "<frozen importlib._bootstrap_external>"),
)
)
if not self.snapshot_first:
self.snapshot_first = snapshot
if self.snapshot:
self.snapshot_prev = self.snapshot
self.snapshot = snapshot
def _log_diff(
self, current: tracemalloc.Snapshot, previous: tracemalloc.Snapshot
) -> None:
"""Log the memory difference between two snapshots."""
stats = current.compare_to(previous, "traceback")
for s in stats[: self.num_print_entries]:
logger.debug(f"Tracer diff: {s}")
for line in s.traceback.format():
logger.debug(f"* {line}")
def increment_and_maybe_trace(self) -> None:
"""Increment counter and perform trace if interval is hit."""
if self.interval <= 0:
return
self.counter += 1
if self.counter % self.interval == 0:
logger.debug(
f"Running trace comparison for batch {self.counter}. interval={self.interval}"
)
self._take_snapshot()
if self.snapshot and self.snapshot_prev:
self._log_diff(self.snapshot, self.snapshot_prev)
def log_final_diff(self) -> None:
"""Log the final memory diff between start and end of indexing."""
if self.interval <= 0:
return
logger.debug(
f"Running trace comparison between start and end of indexing. {self.counter} batches processed."
)
self._take_snapshot()
if self.snapshot and self.snapshot_first:
self._log_diff(self.snapshot, self.snapshot_first)

View File

@@ -0,0 +1,40 @@
from datetime import datetime
from pydantic import BaseModel
from onyx.db.models import IndexAttemptError
class IndexAttemptErrorPydantic(BaseModel):
id: int
connector_credential_pair_id: int
document_id: str | None
document_link: str | None
entity_id: str | None
failed_time_range_start: datetime | None
failed_time_range_end: datetime | None
failure_message: str
is_resolved: bool = False
time_created: datetime
index_attempt_id: int
@classmethod
def from_model(cls, model: IndexAttemptError) -> "IndexAttemptErrorPydantic":
return cls(
id=model.id,
connector_credential_pair_id=model.connector_credential_pair_id,
document_id=model.document_id,
document_link=model.document_link,
entity_id=model.entity_id,
failed_time_range_start=model.failed_time_range_start,
failed_time_range_end=model.failed_time_range_end,
failure_message=model.failure_message,
is_resolved=model.is_resolved,
time_created=model.time_created,
index_attempt_id=model.index_attempt_id,
)

View File

@@ -1,5 +1,6 @@
import time
import traceback
from collections import defaultdict
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -7,8 +8,11 @@ from datetime import timezone
from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.background.indexing.checkpointing import get_time_windows_for_index_attempt
from onyx.background.indexing.tracer import OnyxTracer
from onyx.background.indexing.checkpointing_utils import check_checkpoint_size
from onyx.background.indexing.checkpointing_utils import get_latest_valid_checkpoint
from onyx.background.indexing.checkpointing_utils import save_checkpoint
from onyx.background.indexing.memory_tracer import MemoryTracer
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE
@@ -17,6 +21,8 @@ from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MilestoneRecordType
from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import IndexAttemptMetadata
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
@@ -24,15 +30,18 @@ from onyx.db.connector_credential_pair import get_last_successful_attempt_time
from onyx.db.connector_credential_pair import update_connector_credential_pair
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.index_attempt import create_index_attempt_error
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair
from onyx.db.index_attempt import get_recent_completed_attempts_for_cc_pair
from onyx.db.index_attempt import mark_attempt_canceled
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.index_attempt import mark_attempt_partially_succeeded
from onyx.db.index_attempt import mark_attempt_succeeded
from onyx.db.index_attempt import transition_attempt_to_in_progress
from onyx.db.index_attempt import update_docs_indexed
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
from onyx.db.models import IndexingStatus
from onyx.db.models import IndexModelStatus
from onyx.document_index.factory import get_default_document_index
@@ -53,6 +62,7 @@ INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
def _get_connector_runner(
db_session: Session,
attempt: IndexAttempt,
batch_size: int,
start_time: datetime,
end_time: datetime,
tenant_id: str | None,
@@ -100,7 +110,9 @@ def _get_connector_runner(
raise e
return ConnectorRunner(
connector=runnable_connector, time_range=(start_time, end_time)
connector=runnable_connector,
batch_size=batch_size,
time_range=(start_time, end_time),
)
@@ -159,6 +171,66 @@ class RunIndexingContext(BaseModel):
search_settings_status: IndexModelStatus
def _check_connector_and_attempt_status(
db_session_temp: Session, ctx: RunIndexingContext, index_attempt_id: int
) -> None:
"""
Checks the status of the connector credential pair and index attempt.
Raises a RuntimeError if any conditions are not met.
"""
cc_pair_loop = get_connector_credential_pair_from_id(
db_session_temp,
ctx.cc_pair_id,
)
if not cc_pair_loop:
raise RuntimeError(f"CC pair {ctx.cc_pair_id} not found in DB.")
if (
cc_pair_loop.status == ConnectorCredentialPairStatus.PAUSED
and ctx.search_settings_status != IndexModelStatus.FUTURE
) or cc_pair_loop.status == ConnectorCredentialPairStatus.DELETING:
raise RuntimeError("Connector was disabled mid run")
index_attempt_loop = get_index_attempt(db_session_temp, index_attempt_id)
if not index_attempt_loop:
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
if index_attempt_loop.status != IndexingStatus.IN_PROGRESS:
raise RuntimeError(
f"Index Attempt was canceled, status is {index_attempt_loop.status}"
)
def _check_failure_threshold(
total_failures: int,
document_count: int,
batch_num: int,
last_failure: ConnectorFailure | None,
) -> None:
"""Check if we've hit the failure threshold and raise an appropriate exception if so.
We consider the threshold hit if:
1. We have more than 3 failures AND
2. Failures account for more than 10% of processed documents
"""
failure_ratio = total_failures / (document_count or 1)
FAILURE_THRESHOLD = 3
FAILURE_RATIO_THRESHOLD = 0.1
if total_failures > FAILURE_THRESHOLD and failure_ratio > FAILURE_RATIO_THRESHOLD:
logger.error(
f"Connector run failed with '{total_failures}' errors "
f"after '{batch_num}' batches."
)
if last_failure and last_failure.exception:
raise last_failure.exception from last_failure.exception
raise RuntimeError(
f"Connector run encountered too many errors, aborting. "
f"Last error: {last_failure}"
)
def _run_indexing(
db_session: Session,
index_attempt_id: int,
@@ -169,11 +241,8 @@ def _run_indexing(
1. Get documents which are either new or updated from specified application
2. Embed and index these documents into the chosen datastore (vespa)
3. Updates Postgres to record the indexed documents + the outcome of this run
TODO: do not change index attempt statuses here ... instead, set signals in redis
and allow the monitor function to clean them up
"""
start_time = time.time()
start_time = time.monotonic() # jsut used for logging
with get_session_with_tenant(tenant_id) as db_session_temp:
index_attempt_start = get_index_attempt(db_session_temp, index_attempt_id)
@@ -221,6 +290,46 @@ def _run_indexing(
db_session=db_session_temp,
)
)
if last_successful_index_time > POLL_CONNECTOR_OFFSET:
window_start = datetime.fromtimestamp(
last_successful_index_time, tz=timezone.utc
) - timedelta(minutes=POLL_CONNECTOR_OFFSET)
else:
# don't go into "negative" time if we've never indexed before
window_start = datetime.fromtimestamp(0, tz=timezone.utc)
most_recent_attempt = next(
iter(
get_recent_completed_attempts_for_cc_pair(
cc_pair_id=ctx.cc_pair_id,
search_settings_id=index_attempt_start.search_settings_id,
db_session=db_session_temp,
limit=1,
)
),
None,
)
# if the last attempt failed, try and use the same window. This is necessary
# to ensure correctness with checkpointing. If we don't do this, things like
# new slack channels could be missed (since existing slack channels are
# cached as part of the checkpoint).
if (
most_recent_attempt
and most_recent_attempt.poll_range_end
and (
most_recent_attempt.status == IndexingStatus.FAILED
or most_recent_attempt.status == IndexingStatus.CANCELED
)
):
window_end = most_recent_attempt.poll_range_end
else:
window_end = datetime.now(tz=timezone.utc)
# add start/end now that they have been set
index_attempt_start.poll_range_start = window_start
index_attempt_start.poll_range_end = window_end
db_session_temp.add(index_attempt_start)
db_session_temp.commit()
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=index_attempt_start.search_settings,
@@ -234,7 +343,6 @@ def _run_indexing(
)
indexing_pipeline = build_indexing_pipeline(
attempt_id=index_attempt_id,
embedder=embedding_model,
document_index=document_index,
ignore_time_skip=(
@@ -246,63 +354,73 @@ def _run_indexing(
callback=callback,
)
tracer: OnyxTracer
if INDEXING_TRACER_INTERVAL > 0:
logger.debug(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}")
tracer = OnyxTracer()
tracer.start()
tracer.snap()
# Initialize memory tracer. NOTE: won't actually do anything if
# `INDEXING_TRACER_INTERVAL` is 0.
memory_tracer = MemoryTracer(interval=INDEXING_TRACER_INTERVAL)
memory_tracer.start()
index_attempt_md = IndexAttemptMetadata(
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
)
total_failures = 0
batch_num = 0
net_doc_change = 0
document_count = 0
chunk_count = 0
run_end_dt = None
tracer_counter: int
try:
with get_session_with_tenant(tenant_id) as db_session_temp:
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
if not index_attempt:
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
for ind, (window_start, window_end) in enumerate(
get_time_windows_for_index_attempt(
last_successful_run=datetime.fromtimestamp(
last_successful_index_time, tz=timezone.utc
),
source_type=db_connector.source,
)
):
cc_pair_loop: ConnectorCredentialPair | None = None
index_attempt_loop: IndexAttempt | None = None
tracer_counter = 0
try:
window_start = max(
window_start - timedelta(minutes=POLL_CONNECTOR_OFFSET),
datetime(1970, 1, 1, tzinfo=timezone.utc),
connector_runner = _get_connector_runner(
db_session=db_session_temp,
attempt=index_attempt,
batch_size=INDEX_BATCH_SIZE,
start_time=window_start,
end_time=window_end,
tenant_id=tenant_id,
)
with get_session_with_tenant(tenant_id) as db_session_temp:
index_attempt_loop_start = get_index_attempt(
db_session_temp, index_attempt_id
)
if not index_attempt_loop_start:
raise RuntimeError(
f"Index attempt {index_attempt_id} not found in DB."
)
connector_runner = _get_connector_runner(
# don't use a checkpoint if we're explicitly indexing from
# the beginning in order to avoid weird interactions between
# checkpointing / failure handling.
if index_attempt.from_beginning:
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
else:
checkpoint = get_latest_valid_checkpoint(
db_session=db_session_temp,
attempt=index_attempt_loop_start,
start_time=window_start,
end_time=window_end,
tenant_id=tenant_id,
cc_pair_id=ctx.cc_pair_id,
search_settings_id=index_attempt.search_settings_id,
window_start=window_start,
window_end=window_end,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.snap()
for doc_batch in connector_runner.run():
unresolved_errors = get_index_attempt_errors_for_cc_pair(
cc_pair_id=ctx.cc_pair_id,
unresolved_only=True,
db_session=db_session_temp,
)
doc_id_to_unresolved_errors: dict[
str, list[IndexAttemptError]
] = defaultdict(list)
for error in unresolved_errors:
if error.document_id:
doc_id_to_unresolved_errors[error.document_id].append(error)
entity_based_unresolved_errors = [
error for error in unresolved_errors if error.entity_id
]
while checkpoint.has_more:
logger.info(
f"Running '{ctx.source}' connector with checkpoint: {checkpoint}"
)
for document_batch, failure, next_checkpoint in connector_runner.run(
checkpoint
):
# Check if connector is disabled mid run and stop if so unless it's the secondary
# index being built. We want to populate it even for paused connectors
# Often paused connectors are sources that aren't updated frequently but the
@@ -313,41 +431,37 @@ def _run_indexing(
# TODO: should we move this into the above callback instead?
with get_session_with_tenant(tenant_id) as db_session_temp:
cc_pair_loop = get_connector_credential_pair_from_id(
db_session_temp,
ctx.cc_pair_id,
# will exception if the connector/index attempt is marked as paused/failed
_check_connector_and_attempt_status(
db_session_temp, ctx, index_attempt_id
)
if not cc_pair_loop:
raise RuntimeError(f"CC pair {ctx.cc_pair_id} not found in DB.")
if (
(
cc_pair_loop.status == ConnectorCredentialPairStatus.PAUSED
and ctx.search_settings_status != IndexModelStatus.FUTURE
# save record of any failures at the connector level
if failure is not None:
total_failures += 1
with get_session_with_tenant(tenant_id) as db_session_temp:
create_index_attempt_error(
index_attempt_id,
ctx.cc_pair_id,
failure,
db_session_temp,
)
# if it's deleting, we don't care if this is a secondary index
or cc_pair_loop.status == ConnectorCredentialPairStatus.DELETING
):
# let the `except` block handle this
raise RuntimeError("Connector was disabled mid run")
index_attempt_loop = get_index_attempt(
db_session_temp, index_attempt_id
_check_failure_threshold(
total_failures, document_count, batch_num, failure
)
if not index_attempt_loop:
raise RuntimeError(
f"Index attempt {index_attempt_id} not found in DB."
)
if index_attempt_loop.status != IndexingStatus.IN_PROGRESS:
# Likely due to user manually disabling it or model swap
raise RuntimeError(
f"Index Attempt was canceled, status is {index_attempt_loop.status}"
)
# save the new checkpoint (if one is provided)
if next_checkpoint:
checkpoint = next_checkpoint
# below is all document processing logic, so if no batch we can just continue
if document_batch is None:
continue
batch_description = []
doc_batch_cleaned = strip_null_characters(doc_batch)
doc_batch_cleaned = strip_null_characters(document_batch)
for doc in doc_batch_cleaned:
batch_description.append(doc.to_short_descriptor())
@@ -377,15 +491,51 @@ def _run_indexing(
chunk_count += index_pipeline_result.total_chunks
document_count += index_pipeline_result.total_docs
# commit transaction so that the `update` below begins
# with a brand new transaction. Postgres uses the start
# of the transactions when computing `NOW()`, so if we have
# a long running transaction, the `time_updated` field will
# be inaccurate
db_session.commit()
# resolve errors for documents that were successfully indexed
failed_document_ids = [
failure.failed_document.document_id
for failure in index_pipeline_result.failures
if failure.failed_document
]
successful_document_ids = [
document.id
for document in document_batch
if document.id not in failed_document_ids
]
for document_id in successful_document_ids:
with get_session_with_tenant(tenant_id) as db_session_temp:
if document_id in doc_id_to_unresolved_errors:
logger.info(
f"Resolving IndexAttemptError for document '{document_id}'"
)
for error in doc_id_to_unresolved_errors[document_id]:
error.is_resolved = True
db_session_temp.add(error)
db_session_temp.commit()
# add brand new failures
if index_pipeline_result.failures:
total_failures += len(index_pipeline_result.failures)
with get_session_with_tenant(tenant_id) as db_session_temp:
for failure in index_pipeline_result.failures:
create_index_attempt_error(
index_attempt_id,
ctx.cc_pair_id,
failure,
db_session_temp,
)
_check_failure_threshold(
total_failures,
document_count,
batch_num,
index_pipeline_result.failures[-1],
)
# This new value is updated every batch, so UI can refresh per batch update
with get_session_with_tenant(tenant_id) as db_session_temp:
# NOTE: Postgres uses the start of the transactions when computing `NOW()`
# so we need either to commit() or to use a new session
update_docs_indexed(
db_session=db_session_temp,
index_attempt_id=index_attempt_id,
@@ -397,126 +547,77 @@ def _run_indexing(
if callback:
callback.progress("_run_indexing", len(doc_batch_cleaned))
tracer_counter += 1
if (
INDEXING_TRACER_INTERVAL > 0
and tracer_counter % INDEXING_TRACER_INTERVAL == 0
):
logger.debug(
f"Running trace comparison for batch {tracer_counter}. interval={INDEXING_TRACER_INTERVAL}"
)
tracer.snap()
tracer.log_previous_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
memory_tracer.increment_and_maybe_trace()
run_end_dt = window_end
if ctx.is_primary:
with get_session_with_tenant(tenant_id) as db_session_temp:
# `make sure the checkpoints aren't getting too large`at some regular interval
CHECKPOINT_SIZE_CHECK_INTERVAL = 100
if batch_num % CHECKPOINT_SIZE_CHECK_INTERVAL == 0:
check_checkpoint_size(checkpoint)
# save latest checkpoint
with get_session_with_tenant(tenant_id) as db_session_temp:
save_checkpoint(
db_session=db_session_temp,
index_attempt_id=index_attempt_id,
checkpoint=checkpoint,
)
except Exception as e:
logger.exception(
"Connector run exceptioned after elapsed time: "
f"{time.monotonic() - start_time} seconds"
)
if isinstance(e, ConnectorStopSignal):
with get_session_with_tenant(tenant_id) as db_session_temp:
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
reason=str(e),
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
net_docs=net_doc_change,
run_dt=run_end_dt,
)
except Exception as e:
logger.exception(
f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds"
)
if isinstance(e, ConnectorStopSignal):
with get_session_with_tenant(tenant_id) as db_session_temp:
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
reason=str(e),
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
net_docs=net_doc_change,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
else:
# Only mark the attempt as a complete failure if this is the first indexing window.
# Otherwise, some progress was made - the next run will not start from the beginning.
# In this case, it is not accurate to mark it as a failure. When the next run begins,
# if that fails immediately, it will be marked as a failure.
#
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
# to give better clarity in the UI, as the next run will never happen.
if (
ind == 0
or (
cc_pair_loop is not None and not cc_pair_loop.status.is_active()
)
or (
index_attempt_loop is not None
and index_attempt_loop.status != IndexingStatus.IN_PROGRESS
)
):
with get_session_with_tenant(tenant_id) as db_session_temp:
mark_attempt_failed(
index_attempt_id,
db_session_temp,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
net_docs=net_doc_change,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
# break => similar to success case. As mentioned above, if the next run fails for the same
# reason it will then be marked as a failure
break
if INDEXING_TRACER_INTERVAL > 0:
logger.debug(
f"Running trace comparison between start and end of indexing. {tracer_counter} batches processed."
)
tracer.snap()
tracer.log_first_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
tracer.stop()
logger.debug("Memory tracer stopped.")
if (
index_attempt_md.num_exceptions > 0
and index_attempt_md.num_exceptions >= batch_num
):
with get_session_with_tenant(tenant_id) as db_session_temp:
mark_attempt_failed(
index_attempt_id,
db_session_temp,
failure_reason="All batches exceptioned.",
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
memory_tracer.stop()
raise e
else:
with get_session_with_tenant(tenant_id) as db_session_temp:
mark_attempt_failed(
index_attempt_id,
db_session_temp,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
raise Exception(
f"Connector failed - All batches exceptioned: batches={batch_num}"
)
elapsed_time = time.time() - start_time
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
net_docs=net_doc_change,
)
memory_tracer.stop()
raise e
memory_tracer.stop()
elapsed_time = time.monotonic() - start_time
with get_session_with_tenant(tenant_id) as db_session_temp:
if index_attempt_md.num_exceptions == 0:
# resolve entity-based errors
for error in entity_based_unresolved_errors:
logger.info(f"Resolving IndexAttemptError for entity '{error.entity_id}'")
error.is_resolved = True
db_session_temp.add(error)
db_session_temp.commit()
if total_failures == 0:
mark_attempt_succeeded(index_attempt_id, db_session_temp)
create_milestone_and_report(
@@ -535,7 +636,7 @@ def _run_indexing(
mark_attempt_partially_succeeded(index_attempt_id, db_session_temp)
logger.info(
f"Connector completed with some errors: "
f"exceptions={index_attempt_md.num_exceptions} "
f"failures={total_failures} "
f"batches={batch_num} "
f"docs={document_count} "
f"chunks={chunk_count} "
@@ -547,7 +648,7 @@ def _run_indexing(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
run_dt=run_end_dt,
run_dt=window_end,
)
@@ -558,46 +659,43 @@ def run_indexing_entrypoint(
is_ee: bool = False,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
try:
if is_ee:
global_version.set_ee()
"""Don't swallow exceptions here ... propagate them up."""
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
TaskAttemptSingleton.set_cc_and_index_id(
index_attempt_id, connector_credential_pair_id
if is_ee:
global_version.set_ee()
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
TaskAttemptSingleton.set_cc_and_index_id(
index_attempt_id, connector_credential_pair_id
)
with get_session_with_tenant(tenant_id) as db_session:
# TODO: remove long running session entirely
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
tenant_str = ""
if tenant_id is not None:
tenant_str = f" for tenant {tenant_id}"
connector_name = attempt.connector_credential_pair.connector.name
connector_config = (
attempt.connector_credential_pair.connector.connector_specific_config
)
with get_session_with_tenant(tenant_id) as db_session:
# TODO: remove long running session entirely
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
credential_id = attempt.connector_credential_pair.credential_id
tenant_str = ""
if tenant_id is not None:
tenant_str = f" for tenant {tenant_id}"
logger.info(
f"Indexing starting{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
connector_name = attempt.connector_credential_pair.connector.name
connector_config = (
attempt.connector_credential_pair.connector.connector_specific_config
)
credential_id = attempt.connector_credential_pair.credential_id
with get_session_with_tenant(tenant_id) as db_session:
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
logger.info(
f"Indexing starting{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
with get_session_with_tenant(tenant_id) as db_session:
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
logger.info(
f"Indexing finished{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
except Exception as e:
logger.exception(
f"Indexing job with ID '{index_attempt_id}' for tenant {tenant_id} failed due to {e}"
)
logger.info(
f"Indexing finished{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)

View File

@@ -1,77 +0,0 @@
import tracemalloc
from onyx.utils.logger import setup_logger
logger = setup_logger()
DANSWER_TRACEMALLOC_FRAMES = 10
class OnyxTracer:
def __init__(self) -> None:
self.snapshot_first: tracemalloc.Snapshot | None = None
self.snapshot_prev: tracemalloc.Snapshot | None = None
self.snapshot: tracemalloc.Snapshot | None = None
def start(self) -> None:
tracemalloc.start(DANSWER_TRACEMALLOC_FRAMES)
def stop(self) -> None:
tracemalloc.stop()
def snap(self) -> None:
snapshot = tracemalloc.take_snapshot()
# Filter out irrelevant frames (e.g., from tracemalloc itself or importlib)
snapshot = snapshot.filter_traces(
(
tracemalloc.Filter(False, tracemalloc.__file__), # Exclude tracemalloc
tracemalloc.Filter(
False, "<frozen importlib._bootstrap>"
), # Exclude importlib
tracemalloc.Filter(
False, "<frozen importlib._bootstrap_external>"
), # Exclude external importlib
)
)
if not self.snapshot_first:
self.snapshot_first = snapshot
if self.snapshot:
self.snapshot_prev = self.snapshot
self.snapshot = snapshot
def log_snapshot(self, numEntries: int) -> None:
if not self.snapshot:
return
stats = self.snapshot.statistics("traceback")
for s in stats[:numEntries]:
logger.debug(f"Tracer snap: {s}")
for line in s.traceback:
logger.debug(f"* {line}")
@staticmethod
def log_diff(
snap_current: tracemalloc.Snapshot,
snap_previous: tracemalloc.Snapshot,
numEntries: int,
) -> None:
stats = snap_current.compare_to(snap_previous, "traceback")
for s in stats[:numEntries]:
logger.debug(f"Tracer diff: {s}")
for line in s.traceback.format():
logger.debug(f"* {line}")
def log_previous_diff(self, numEntries: int) -> None:
if not self.snapshot or not self.snapshot_prev:
return
OnyxTracer.log_diff(self.snapshot, self.snapshot_prev, numEntries)
def log_first_diff(self, numEntries: int) -> None:
if not self.snapshot or not self.snapshot_first:
return
OnyxTracer.log_diff(self.snapshot, self.snapshot_first, numEntries)

View File

@@ -3,6 +3,24 @@ import os
INITIAL_SEARCH_DECOMPOSITION_ENABLED = True
ALLOW_REFINEMENT = True
AGENT_DEFAULT_RETRIEVAL_HITS = 15
AGENT_DEFAULT_RERANKING_HITS = 10
AGENT_DEFAULT_SUB_QUESTION_MAX_CONTEXT_HITS = 8
AGENT_DEFAULT_NUM_DOCS_FOR_INITIAL_DECOMPOSITION = 3
AGENT_DEFAULT_NUM_DOCS_FOR_REFINED_DECOMPOSITION = 5
AGENT_DEFAULT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER = 25
AGENT_DEFAULT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER = 35
AGENT_DEFAULT_EXPLORATORY_SEARCH_RESULTS = 5
AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS = 3
AGENT_DEFAULT_MAX_ANSWER_CONTEXT_DOCS = 10
AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH = 2000
INITIAL_SEARCH_DECOMPOSITION_ENABLED = True
ALLOW_REFINEMENT = True
AGENT_DEFAULT_RETRIEVAL_HITS = 15
AGENT_DEFAULT_RERANKING_HITS = 10
AGENT_DEFAULT_SUB_QUESTION_MAX_CONTEXT_HITS = 8
@@ -13,9 +31,21 @@ AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS = 3
AGENT_DEFAULT_MAX_ANSWER_CONTEXT_DOCS = 10
AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH = 2000
#####
# Agent Configs
#####
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION = 30 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION = 10 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION = 25 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION = 4 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION = 1 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION = 3 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION = 12 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK = 8 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION = 25 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION = 6 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION = 25 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION = 8 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS = 8 # in seconds
AGENT_RETRIEVAL_STATS = (
@@ -77,4 +107,151 @@ AGENT_MAX_STATIC_HISTORY_WORD_LENGTH = int(
or AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH
) # 2000
AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER = int(
os.environ.get("AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER")
or AGENT_DEFAULT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER
) # 25
AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER = int(
os.environ.get("AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER")
or AGENT_DEFAULT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER
) # 35
AGENT_RETRIEVAL_STATS = (
not os.environ.get("AGENT_RETRIEVAL_STATS") == "False"
) or True # default True
AGENT_MAX_QUERY_RETRIEVAL_RESULTS = int(
os.environ.get("AGENT_MAX_QUERY_RETRIEVAL_RESULTS") or AGENT_DEFAULT_RETRIEVAL_HITS
) # 15
AGENT_MAX_QUERY_RETRIEVAL_RESULTS = int(
os.environ.get("AGENT_MAX_QUERY_RETRIEVAL_RESULTS") or AGENT_DEFAULT_RETRIEVAL_HITS
) # 15
# Reranking agent configs
# Reranking stats - no influence on flow outside of stats collection
AGENT_RERANKING_STATS = (
not os.environ.get("AGENT_RERANKING_STATS") == "True"
) or False # default False
AGENT_MAX_QUERY_RETRIEVAL_RESULTS = int(
os.environ.get("AGENT_MAX_QUERY_RETRIEVAL_RESULTS") or AGENT_DEFAULT_RETRIEVAL_HITS
) # 15
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS = int(
os.environ.get("AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS")
or AGENT_DEFAULT_RERANKING_HITS
) # 10
AGENT_NUM_DOCS_FOR_DECOMPOSITION = int(
os.environ.get("AGENT_NUM_DOCS_FOR_DECOMPOSITION")
or AGENT_DEFAULT_NUM_DOCS_FOR_INITIAL_DECOMPOSITION
) # 3
AGENT_NUM_DOCS_FOR_REFINED_DECOMPOSITION = int(
os.environ.get("AGENT_NUM_DOCS_FOR_REFINED_DECOMPOSITION")
or AGENT_DEFAULT_NUM_DOCS_FOR_REFINED_DECOMPOSITION
) # 5
AGENT_EXPLORATORY_SEARCH_RESULTS = int(
os.environ.get("AGENT_EXPLORATORY_SEARCH_RESULTS")
or AGENT_DEFAULT_EXPLORATORY_SEARCH_RESULTS
) # 5
AGENT_MIN_ORIG_QUESTION_DOCS = int(
os.environ.get("AGENT_MIN_ORIG_QUESTION_DOCS")
or AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS
) # 3
AGENT_MAX_ANSWER_CONTEXT_DOCS = int(
os.environ.get("AGENT_MAX_ANSWER_CONTEXT_DOCS")
or AGENT_DEFAULT_SUB_QUESTION_MAX_CONTEXT_HITS
) # 8
AGENT_MAX_STATIC_HISTORY_WORD_LENGTH = int(
os.environ.get("AGENT_MAX_STATIC_HISTORY_WORD_LENGTH")
or AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH
) # 2000
AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION
) # 25
AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
) # 3
AGENT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION
) # 30
AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION
) # 8
AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
) # 12
AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION
) # 25
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION
) # 25
AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
) # 8
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION
) # 6
AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION
) # 1
AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION
) # 4
AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
) # 8
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION
) # 8
GRAPH_VERSION_NAME: str = "a"

View File

@@ -169,6 +169,11 @@ POSTGRES_API_SERVER_POOL_SIZE = int(
POSTGRES_API_SERVER_POOL_OVERFLOW = int(
os.environ.get("POSTGRES_API_SERVER_POOL_OVERFLOW") or 10
)
# defaults to False
# generally should only be used for
POSTGRES_USE_NULL_POOL = os.environ.get("POSTGRES_USE_NULL_POOL", "").lower() == "true"
# defaults to False
POSTGRES_POOL_PRE_PING = os.environ.get("POSTGRES_POOL_PRE_PING", "").lower() == "true"
@@ -621,6 +626,8 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
MOCK_CONNECTOR_FILE_PATH = os.environ.get("MOCK_CONNECTOR_FILE_PATH")
TEST_ENV = os.environ.get("TEST_ENV", "").lower() == "true"
# Set to true to mock LLM responses for testing purposes

View File

@@ -107,9 +107,9 @@ CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT = 5 * 60 # 5 min
# needs to be long enough to cover the maximum time it takes to download an object
# if we can get callbacks as object bytes download, we could lower this a lot.
CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min
CELERY_PRUNING_LOCK_TIMEOUT = 3600 # 1 hour (in seconds)
CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT = 300 # 5 min
CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT = 3600 # 1 hour (in seconds)
CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
@@ -125,6 +125,7 @@ class DocumentSource(str, Enum):
GMAIL = "gmail"
REQUESTTRACKER = "requesttracker"
GITHUB = "github"
GITBOOK = "gitbook"
GITLAB = "gitlab"
GURU = "guru"
BOOKSTACK = "bookstack"
@@ -164,6 +165,9 @@ class DocumentSource(str, Enum):
EGNYTE = "egnyte"
AIRTABLE = "airtable"
# Special case just for integration tests
MOCK_CONNECTOR = "mock_connector"
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
@@ -242,6 +246,7 @@ class FileOrigin(str, Enum):
CHAT_IMAGE_GEN = "chat_image_gen"
CONNECTOR = "connector"
GENERATED_REPORT = "generated_report"
INDEXING_CHECKPOINT = "indexing_checkpoint"
OTHER = "other"
@@ -273,6 +278,7 @@ class OnyxCeleryQueues:
DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert"
CONNECTOR_DELETION = "connector_deletion"
LLM_MODEL_UPDATE = "llm_model_update"
CHECKPOINT_CLEANUP = "checkpoint_cleanup"
# Heavy queue
CONNECTOR_PRUNING = "connector_pruning"
@@ -292,13 +298,13 @@ class OnyxRedisLocks:
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat"
CHECK_INDEXING_BEAT_LOCK = "da_lock:check_indexing_beat"
CHECK_CHECKPOINT_CLEANUP_BEAT_LOCK = "da_lock:check_checkpoint_cleanup_beat"
CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK = (
"da_lock:check_connector_doc_permissions_sync_beat"
)
CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK = (
"da_lock:check_connector_external_group_sync_beat"
)
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes"
CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX = (
@@ -324,6 +330,7 @@ class OnyxRedisSignals:
BLOCK_VALIDATE_PERMISSION_SYNC_FENCES = (
"signal:block_validate_permission_sync_fences"
)
BLOCK_PRUNING = "signal:block_pruning"
BLOCK_VALIDATE_PRUNING_FENCES = "signal:block_validate_pruning_fences"
BLOCK_BUILD_FENCE_LOOKUP_TABLE = "signal:block_build_fence_lookup_table"
@@ -354,7 +361,10 @@ class OnyxCeleryTask:
DEFAULT = "celery"
CLOUD_BEAT_TASK_GENERATOR = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_generate_beat_tasks"
CLOUD_CHECK_ALEMBIC = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_alembic"
CLOUD_MONITOR_ALEMBIC = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor_alembic"
CLOUD_MONITOR_CELERY_QUEUES = (
f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor_celery_queues"
)
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
@@ -364,8 +374,12 @@ class OnyxCeleryTask:
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
# Connector checkpoint cleanup
CHECK_FOR_CHECKPOINT_CLEANUP = "check_for_checkpoint_cleanup"
CLEANUP_CHECKPOINT = "cleanup_checkpoint"
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
MONITOR_CELERY_QUEUES = "monitor_celery_queues"
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (

View File

@@ -0,0 +1,6 @@
import os
SKIP_CONNECTION_POOL_WARM_UP = (
os.environ.get("SKIP_CONNECTION_POOL_WARM_UP", "").lower() == "true"
)

View File

@@ -245,7 +245,7 @@ class AirtableConnector(LoadConnector):
return [(" ".join(combined) if combined else str(field_info), default_link)]
if isinstance(field_info, list):
return [(item, default_link) for item in field_info]
return [(str(item), default_link) for item in field_info]
return [(str(field_info), default_link)]
@@ -268,7 +268,7 @@ class AirtableConnector(LoadConnector):
table_id: str,
view_id: str | None,
record_id: str,
) -> tuple[list[Section], dict[str, Any]]:
) -> tuple[list[Section], dict[str, str | list[str]]]:
"""
Process a single Airtable field and return sections or metadata.
@@ -342,7 +342,7 @@ class AirtableConnector(LoadConnector):
record_id = record["id"]
fields = record["fields"]
sections: list[Section] = []
metadata: dict[str, Any] = {}
metadata: dict[str, str | list[str]] = {}
# Get primary field value if it exists
primary_field_value = (

View File

@@ -1,11 +1,16 @@
import sys
import time
from collections.abc import Generator
from datetime import datetime
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.utils.logger import setup_logger
@@ -15,48 +20,139 @@ logger = setup_logger()
TimeRange = tuple[datetime, datetime]
class CheckpointOutputWrapper:
"""
Wraps a CheckpointOutput generator to give things back in a more digestible format.
The connector format is easier for the connector implementor (e.g. it enforces exactly
one new checkpoint is returned AND that the checkpoint is at the end), thus the different
formats.
"""
def __init__(self) -> None:
self.next_checkpoint: ConnectorCheckpoint | None = None
def __call__(
self,
checkpoint_connector_generator: CheckpointOutput,
) -> Generator[
tuple[Document | None, ConnectorFailure | None, ConnectorCheckpoint | None],
None,
None,
]:
# grabs the final return value and stores it in the `next_checkpoint` variable
def _inner_wrapper(
checkpoint_connector_generator: CheckpointOutput,
) -> CheckpointOutput:
self.next_checkpoint = yield from checkpoint_connector_generator
return self.next_checkpoint # not used
for document_or_failure in _inner_wrapper(checkpoint_connector_generator):
if isinstance(document_or_failure, Document):
yield document_or_failure, None, None
elif isinstance(document_or_failure, ConnectorFailure):
yield None, document_or_failure, None
else:
raise ValueError(
f"Invalid document_or_failure type: {type(document_or_failure)}"
)
if self.next_checkpoint is None:
raise RuntimeError(
"Checkpoint is None. This should never happen - the connector should always return a checkpoint."
)
yield None, None, self.next_checkpoint
class ConnectorRunner:
"""
Handles:
- Batching
- Additional exception logging
- Combining different connector types to a single interface
"""
def __init__(
self,
connector: BaseConnector,
batch_size: int,
time_range: TimeRange | None = None,
fail_loudly: bool = False,
):
self.connector = connector
self.time_range = time_range
self.batch_size = batch_size
if isinstance(self.connector, PollConnector):
if time_range is None:
raise ValueError("time_range is required for PollConnector")
self.doc_batch: list[Document] = []
self.doc_batch_generator = self.connector.poll_source(
time_range[0].timestamp(), time_range[1].timestamp()
)
elif isinstance(self.connector, LoadConnector):
if time_range and fail_loudly:
raise ValueError(
"time_range specified, but passed in connector is not a PollConnector"
)
self.doc_batch_generator = self.connector.load_from_state()
else:
raise ValueError(f"Invalid connector. type: {type(self.connector)}")
def run(self) -> GenerateDocumentsOutput:
def run(
self, checkpoint: ConnectorCheckpoint
) -> Generator[
tuple[
list[Document] | None, ConnectorFailure | None, ConnectorCheckpoint | None
],
None,
None,
]:
"""Adds additional exception logging to the connector."""
try:
start = time.monotonic()
for batch in self.doc_batch_generator:
# to know how long connector is taking
logger.debug(
f"Connector took {time.monotonic() - start} seconds to build a batch."
)
yield batch
if isinstance(self.connector, CheckpointConnector):
if self.time_range is None:
raise ValueError("time_range is required for CheckpointConnector")
start = time.monotonic()
checkpoint_connector_generator = self.connector.load_from_checkpoint(
start=self.time_range[0].timestamp(),
end=self.time_range[1].timestamp(),
checkpoint=checkpoint,
)
next_checkpoint: ConnectorCheckpoint | None = None
# this is guaranteed to always run at least once with next_checkpoint being non-None
for document, failure, next_checkpoint in CheckpointOutputWrapper()(
checkpoint_connector_generator
):
if document is not None:
self.doc_batch.append(document)
if failure is not None:
yield None, failure, None
if len(self.doc_batch) >= self.batch_size:
yield self.doc_batch, None, None
self.doc_batch = []
# yield remaining documents
if len(self.doc_batch) > 0:
yield self.doc_batch, None, None
self.doc_batch = []
yield None, None, next_checkpoint
logger.debug(
f"Connector took {time.monotonic() - start} seconds to get to the next checkpoint."
)
else:
finished_checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
finished_checkpoint.has_more = False
if isinstance(self.connector, PollConnector):
if self.time_range is None:
raise ValueError("time_range is required for PollConnector")
for document_batch in self.connector.poll_source(
start=self.time_range[0].timestamp(),
end=self.time_range[1].timestamp(),
):
yield document_batch, None, None
yield None, None, finished_checkpoint
elif isinstance(self.connector, LoadConnector):
for document_batch in self.connector.load_from_state():
yield document_batch, None, None
yield None, None, finished_checkpoint
else:
raise ValueError(f"Invalid connector. type: {type(self.connector)}")
except Exception:
exc_type, _, exc_traceback = sys.exc_info()
@@ -76,6 +172,6 @@ class ConnectorRunner:
)
logger.error(
f"Error in connector. type: {exc_type};\n"
f"local_vars below -> \n{local_vars_str}"
f"local_vars below -> \n{local_vars_str[:1024]}"
)
raise

View File

@@ -20,6 +20,7 @@ from onyx.connectors.egnyte.connector import EgnyteConnector
from onyx.connectors.file.connector import LocalFileConnector
from onyx.connectors.fireflies.connector import FirefliesConnector
from onyx.connectors.freshdesk.connector import FreshdeskConnector
from onyx.connectors.gitbook.connector import GitbookConnector
from onyx.connectors.github.connector import GithubConnector
from onyx.connectors.gitlab.connector import GitlabConnector
from onyx.connectors.gmail.connector import GmailConnector
@@ -29,12 +30,14 @@ from onyx.connectors.google_site.connector import GoogleSitesConnector
from onyx.connectors.guru.connector import GuruConnector
from onyx.connectors.hubspot.connector import HubSpotConnector
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import EventConnector
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.linear.connector import LinearConnector
from onyx.connectors.loopio.connector import LoopioConnector
from onyx.connectors.mediawiki.wiki import MediaWikiConnector
from onyx.connectors.mock_connector.connector import MockConnector
from onyx.connectors.models import InputType
from onyx.connectors.notion.connector import NotionConnector
from onyx.connectors.onyx_jira.connector import JiraConnector
@@ -42,7 +45,7 @@ from onyx.connectors.productboard.connector import ProductboardConnector
from onyx.connectors.salesforce.connector import SalesforceConnector
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.connectors.slab.connector import SlabConnector
from onyx.connectors.slack.connector import SlackPollConnector
from onyx.connectors.slack.connector import SlackConnector
from onyx.connectors.teams.connector import TeamsConnector
from onyx.connectors.web.connector import WebConnector
from onyx.connectors.wikipedia.connector import WikipediaConnector
@@ -65,12 +68,13 @@ def identify_connector_class(
DocumentSource.WEB: WebConnector,
DocumentSource.FILE: LocalFileConnector,
DocumentSource.SLACK: {
InputType.POLL: SlackPollConnector,
InputType.SLIM_RETRIEVAL: SlackPollConnector,
InputType.POLL: SlackConnector,
InputType.SLIM_RETRIEVAL: SlackConnector,
},
DocumentSource.GITHUB: GithubConnector,
DocumentSource.GMAIL: GmailConnector,
DocumentSource.GITLAB: GitlabConnector,
DocumentSource.GITBOOK: GitbookConnector,
DocumentSource.GOOGLE_DRIVE: GoogleDriveConnector,
DocumentSource.BOOKSTACK: BookstackConnector,
DocumentSource.CONFLUENCE: ConfluenceConnector,
@@ -107,6 +111,8 @@ def identify_connector_class(
DocumentSource.FIREFLIES: FirefliesConnector,
DocumentSource.EGNYTE: EgnyteConnector,
DocumentSource.AIRTABLE: AirtableConnector,
# just for integration tests
DocumentSource.MOCK_CONNECTOR: MockConnector,
}
connector_by_source = connector_map.get(source, {})
@@ -123,10 +129,23 @@ def identify_connector_class(
if any(
[
input_type == InputType.LOAD_STATE
and not issubclass(connector, LoadConnector),
input_type == InputType.POLL and not issubclass(connector, PollConnector),
input_type == InputType.EVENT and not issubclass(connector, EventConnector),
(
input_type == InputType.LOAD_STATE
and not issubclass(connector, LoadConnector)
),
(
input_type == InputType.POLL
# either poll or checkpoint works for this, in the future
# all connectors should be checkpoint connectors
and (
not issubclass(connector, PollConnector)
and not issubclass(connector, CheckpointConnector)
)
),
(
input_type == InputType.EVENT
and not issubclass(connector, EventConnector)
),
]
):
raise ConnectorMissingException(

View File

@@ -0,0 +1,279 @@
from datetime import datetime
from datetime import timezone
from typing import Any
from urllib.parse import urljoin
import requests
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
logger = setup_logger()
GITBOOK_API_BASE = "https://api.gitbook.com/v1/"
class GitbookApiClient:
def __init__(self, access_token: str) -> None:
self.access_token = access_token
def get(self, endpoint: str, params: dict[str, Any] | None = None) -> Any:
headers = {
"Authorization": f"Bearer {self.access_token}",
"Content-Type": "application/json",
}
url = urljoin(GITBOOK_API_BASE, endpoint.lstrip("/"))
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
return response.json()
def get_page_content(self, space_id: str, page_id: str) -> dict[str, Any]:
return self.get(f"/spaces/{space_id}/content/page/{page_id}")
def _extract_text_from_document(document: dict[str, Any]) -> str:
"""Extract text content from GitBook document structure by parsing the document nodes
into markdown format."""
def parse_leaf(leaf: dict[str, Any]) -> str:
text = leaf.get("text", "")
leaf.get("marks", [])
return text
def parse_text_node(node: dict[str, Any]) -> str:
text = ""
for leaf in node.get("leaves", []):
text += parse_leaf(leaf)
return text
def parse_block_node(node: dict[str, Any]) -> str:
block_type = node.get("type", "")
result = ""
if block_type == "heading-1":
text = "".join(parse_text_node(n) for n in node.get("nodes", []))
result = f"# {text}\n\n"
elif block_type == "heading-2":
text = "".join(parse_text_node(n) for n in node.get("nodes", []))
result = f"## {text}\n\n"
elif block_type == "heading-3":
text = "".join(parse_text_node(n) for n in node.get("nodes", []))
result = f"### {text}\n\n"
elif block_type == "heading-4":
text = "".join(parse_text_node(n) for n in node.get("nodes", []))
result = f"#### {text}\n\n"
elif block_type == "heading-5":
text = "".join(parse_text_node(n) for n in node.get("nodes", []))
result = f"##### {text}\n\n"
elif block_type == "heading-6":
text = "".join(parse_text_node(n) for n in node.get("nodes", []))
result = f"###### {text}\n\n"
elif block_type == "list-unordered":
for list_item in node.get("nodes", []):
paragraph = list_item.get("nodes", [])[0]
text = "".join(parse_text_node(n) for n in paragraph.get("nodes", []))
result += f"* {text}\n"
result += "\n"
elif block_type == "paragraph":
text = "".join(parse_text_node(n) for n in node.get("nodes", []))
result = f"{text}\n\n"
elif block_type == "list-tasks":
for task_item in node.get("nodes", []):
checked = task_item.get("data", {}).get("checked", False)
paragraph = task_item.get("nodes", [])[0]
text = "".join(parse_text_node(n) for n in paragraph.get("nodes", []))
checkbox = "[x]" if checked else "[ ]"
result += f"- {checkbox} {text}\n"
result += "\n"
elif block_type == "code":
for code_line in node.get("nodes", []):
if code_line.get("type") == "code-line":
text = "".join(
parse_text_node(n) for n in code_line.get("nodes", [])
)
result += f"{text}\n"
result += "\n"
elif block_type == "blockquote":
for quote_node in node.get("nodes", []):
if quote_node.get("type") == "paragraph":
text = "".join(
parse_text_node(n) for n in quote_node.get("nodes", [])
)
result += f"> {text}\n"
result += "\n"
elif block_type == "table":
records = node.get("data", {}).get("records", {})
definition = node.get("data", {}).get("definition", {})
view = node.get("data", {}).get("view", {})
columns = view.get("columns", [])
header_cells = []
for col_id in columns:
col_def = definition.get(col_id, {})
header_cells.append(col_def.get("title", ""))
result = "| " + " | ".join(header_cells) + " |\n"
result += "|" + "---|" * len(header_cells) + "\n"
sorted_records = sorted(
records.items(), key=lambda x: x[1].get("orderIndex", "")
)
for record_id, record_data in sorted_records:
values = record_data.get("values", {})
row_cells = []
for col_id in columns:
fragment_id = values.get(col_id, "")
fragment_text = ""
for fragment in node.get("fragments", []):
if fragment.get("fragment") == fragment_id:
for frag_node in fragment.get("nodes", []):
if frag_node.get("type") == "paragraph":
fragment_text = "".join(
parse_text_node(n)
for n in frag_node.get("nodes", [])
)
break
row_cells.append(fragment_text)
result += "| " + " | ".join(row_cells) + " |\n"
result += "\n"
return result
if not document or "document" not in document:
return ""
markdown = ""
nodes = document["document"].get("nodes", [])
for node in nodes:
markdown += parse_block_node(node)
return markdown
def _convert_page_to_document(
client: GitbookApiClient, space_id: str, page: dict[str, Any]
) -> Document:
page_id = page["id"]
page_content = client.get_page_content(space_id, page_id)
return Document(
id=f"gitbook-{space_id}-{page_id}",
sections=[
Section(
link=page.get("urls", {}).get("app", ""),
text=_extract_text_from_document(page_content),
)
],
source=DocumentSource.GITBOOK,
semantic_identifier=page.get("title", ""),
doc_updated_at=datetime.fromisoformat(page["updatedAt"]).replace(
tzinfo=timezone.utc
),
metadata={
"path": page.get("path", ""),
"type": page.get("type", ""),
"kind": page.get("kind", ""),
},
)
class GitbookConnector(LoadConnector, PollConnector):
def __init__(
self,
space_id: str,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.space_id = space_id
self.batch_size = batch_size
self.access_token: str | None = None
self.client: GitbookApiClient | None = None
def load_credentials(self, credentials: dict[str, Any]) -> None:
access_token = credentials.get("gitbook_api_key")
if not access_token:
raise ConnectorMissingCredentialError("GitBook access token")
self.access_token = access_token
self.client = GitbookApiClient(access_token)
def _fetch_all_pages(
self,
start: datetime | None = None,
end: datetime | None = None,
) -> GenerateDocumentsOutput:
if not self.client:
raise ConnectorMissingCredentialError("GitBook")
try:
content = self.client.get(f"/spaces/{self.space_id}/content")
pages = content.get("pages", [])
current_batch: list[Document] = []
for page in pages:
updated_at = datetime.fromisoformat(page["updatedAt"])
if start and updated_at < start:
if current_batch:
yield current_batch
return
if end and updated_at > end:
continue
current_batch.append(
_convert_page_to_document(self.client, self.space_id, page)
)
if len(current_batch) >= self.batch_size:
yield current_batch
current_batch = []
if current_batch:
yield current_batch
except requests.RequestException as e:
logger.error(f"Error fetching GitBook content: {str(e)}")
raise
def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_all_pages()
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
return self._fetch_all_pages(start_datetime, end_datetime)
if __name__ == "__main__":
import os
connector = GitbookConnector(
space_id=os.environ["GITBOOK_SPACE_ID"],
)
connector.load_credentials({"gitbook_api_key": os.environ["GITBOOK_API_KEY"]})
document_batches = connector.load_from_state()
print(next(document_batches))

View File

@@ -302,7 +302,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
if e.status_code == 401:
# fail gracefully, let the other impersonations continue
# one user without access shouldn't block the entire connector
logger.exception(
logger.warning(
f"User '{user_email}' does not have access to the drive APIs."
)
return

View File

@@ -1,10 +1,13 @@
import abc
from collections.abc import Generator
from collections.abc import Iterator
from typing import Any
from pydantic import BaseModel
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
@@ -14,6 +17,7 @@ SecondsSinceUnixEpoch = float
GenerateDocumentsOutput = Iterator[list[Document]]
GenerateSlimDocumentOutput = Iterator[list[SlimDocument]]
CheckpointOutput = Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]
class BaseConnector(abc.ABC):
@@ -105,3 +109,33 @@ class EventConnector(BaseConnector):
@abc.abstractmethod
def handle_event(self, event: Any) -> GenerateDocumentsOutput:
raise NotImplementedError
class CheckpointConnector(BaseConnector):
@abc.abstractmethod
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> CheckpointOutput:
"""Yields back documents or failures. Final return is the new checkpoint.
Final return can be access via either:
```
try:
for document_or_failure in connector.load_from_checkpoint(start, end, checkpoint):
print(document_or_failure)
except StopIteration as e:
checkpoint = e.value # Extracting the return value
print(checkpoint)
```
OR
```
checkpoint = yield from connector.load_from_checkpoint(start, end, checkpoint)
```
"""
raise NotImplementedError

View File

@@ -0,0 +1,86 @@
from typing import Any
import httpx
from pydantic import BaseModel
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.utils.logger import setup_logger
logger = setup_logger()
class SingleConnectorYield(BaseModel):
documents: list[Document]
checkpoint: ConnectorCheckpoint
failures: list[ConnectorFailure]
unhandled_exception: str | None = None
class MockConnector(CheckpointConnector):
def __init__(
self,
mock_server_host: str,
mock_server_port: int,
) -> None:
self.mock_server_host = mock_server_host
self.mock_server_port = mock_server_port
self.client = httpx.Client(timeout=30.0)
self.connector_yields: list[SingleConnectorYield] | None = None
self.current_yield_index: int = 0
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
response = self.client.get(self._get_mock_server_url("get-documents"))
response.raise_for_status()
data = response.json()
self.connector_yields = [
SingleConnectorYield(**yield_data) for yield_data in data
]
return None
def _get_mock_server_url(self, endpoint: str) -> str:
return f"http://{self.mock_server_host}:{self.mock_server_port}/{endpoint}"
def _save_checkpoint(self, checkpoint: ConnectorCheckpoint) -> None:
response = self.client.post(
self._get_mock_server_url("add-checkpoint"),
json=checkpoint.model_dump(mode="json"),
)
response.raise_for_status()
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> CheckpointOutput:
if self.connector_yields is None:
raise ValueError("No connector yields configured")
# Save the checkpoint to the mock server
self._save_checkpoint(checkpoint)
yield_index = self.current_yield_index
self.current_yield_index += 1
current_yield = self.connector_yields[yield_index]
# If the current yield has an unhandled exception, raise it
# This is used to simulate an unhandled failure in the connector.
if current_yield.unhandled_exception:
raise RuntimeError(current_yield.unhandled_exception)
# yield all documents
for document in current_yield.documents:
yield document
for failure in current_yield.failures:
yield failure
return current_yield.checkpoint

View File

@@ -3,6 +3,7 @@ from enum import Enum
from typing import Any
from pydantic import BaseModel
from pydantic import model_validator
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import INDEX_SEPARATOR
@@ -187,36 +188,48 @@ class SlimDocument(BaseModel):
perm_sync_data: Any | None = None
class DocumentErrorSummary(BaseModel):
id: str
semantic_id: str
section_link: str | None
@classmethod
def from_document(cls, doc: Document) -> "DocumentErrorSummary":
section_link = doc.sections[0].link if len(doc.sections) > 0 else None
return cls(
id=doc.id, semantic_id=doc.semantic_identifier, section_link=section_link
)
@classmethod
def from_dict(cls, data: dict) -> "DocumentErrorSummary":
return cls(
id=str(data.get("id")),
semantic_id=str(data.get("semantic_id")),
section_link=str(data.get("section_link")),
)
def to_dict(self) -> dict[str, str | None]:
return {
"id": self.id,
"semantic_id": self.semantic_id,
"section_link": self.section_link,
}
class IndexAttemptMetadata(BaseModel):
batch_num: int | None = None
num_exceptions: int = 0
connector_id: int
credential_id: int
class ConnectorCheckpoint(BaseModel):
# TODO: maybe move this to something disk-based to handle extremely large checkpoints?
checkpoint_content: dict
has_more: bool
@classmethod
def build_dummy_checkpoint(cls) -> "ConnectorCheckpoint":
return ConnectorCheckpoint(checkpoint_content={}, has_more=True)
class DocumentFailure(BaseModel):
document_id: str
document_link: str | None = None
class EntityFailure(BaseModel):
entity_id: str
missed_time_range: tuple[datetime, datetime] | None = None
class ConnectorFailure(BaseModel):
failed_document: DocumentFailure | None = None
failed_entity: EntityFailure | None = None
failure_message: str
exception: Exception | None = None
model_config = {"arbitrary_types_allowed": True}
@model_validator(mode="before")
def check_failed_fields(cls, values: dict) -> dict:
failed_document = values.get("failed_document")
failed_entity = values.get("failed_entity")
if (failed_document is None and failed_entity is None) or (
failed_document is not None and failed_entity is not None
):
raise ValueError(
"Exactly one of 'failed_document' or 'failed_entity' must be specified."
)
return values

View File

@@ -145,7 +145,8 @@ def fetch_jira_issues_batch(
id=page_url,
sections=[Section(link=page_url, text=ticket_content)],
source=DocumentSource.JIRA,
semantic_identifier=issue.fields.summary,
semantic_identifier=f"{issue.key}: {issue.fields.summary}",
title=f"{issue.key} {issue.fields.summary}",
doc_updated_at=time_str_to_utc(issue.fields.updated),
primary_owners=list(people) or None,
# TODO add secondary_owners (commenters) if needed

View File

@@ -1,10 +1,16 @@
import contextvars
import copy
import re
from collections.abc import Callable
from collections.abc import Generator
from concurrent.futures import as_completed
from concurrent.futures import Future
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from typing import TypedDict
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
@@ -12,14 +18,18 @@ from slack_sdk.errors import SlackApiError
from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import EntityFailure
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.slack.utils import expert_info_from_slack_id
@@ -33,6 +43,8 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_SLACK_LIMIT = 900
ChannelType = dict[str, Any]
MessageType = dict[str, Any]
@@ -40,6 +52,13 @@ MessageType = dict[str, Any]
ThreadType = list[MessageType]
class SlackCheckpointContent(TypedDict):
channel_ids: list[str]
channel_completion_map: dict[str, str]
current_channel: ChannelType | None
seen_thread_ts: list[str]
def _collect_paginated_channels(
client: WebClient,
exclude_archived: bool,
@@ -140,6 +159,10 @@ def get_latest_message_time(thread: ThreadType) -> datetime:
return datetime.fromtimestamp(max_ts, tz=timezone.utc)
def _build_doc_id(channel_id: str, thread_ts: str) -> str:
return f"{channel_id}__{thread_ts}"
def thread_to_doc(
channel: ChannelType,
thread: ThreadType,
@@ -182,7 +205,7 @@ def thread_to_doc(
)
return Document(
id=f"{channel_id}__{thread[0]['ts']}",
id=_build_doc_id(channel_id=channel_id, thread_ts=thread[0]["ts"]),
sections=[
Section(
link=get_message_link(event=m, client=client, channel_id=channel_id),
@@ -267,64 +290,97 @@ def filter_channels(
]
def _get_all_docs(
def _get_channel_by_id(client: WebClient, channel_id: str) -> ChannelType:
"""Get a channel by its ID.
Args:
client: The Slack WebClient instance
channel_id: The ID of the channel to fetch
Returns:
The channel information
Raises:
SlackApiError: If the channel cannot be fetched
"""
response = make_slack_api_call_w_retries(
client.conversations_info,
channel=channel_id,
)
return cast(ChannelType, response["channel"])
def _get_messages(
channel: ChannelType,
client: WebClient,
channels: list[str] | None = None,
channel_name_regex_enabled: bool = False,
oldest: str | None = None,
latest: str | None = None,
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
) -> Generator[Document, None, None]:
"""Get all documents in the workspace, channel by channel"""
slack_cleaner = SlackTextCleaner(client=client)
) -> tuple[list[MessageType], bool]:
"""Slack goes from newest to oldest."""
# Cache to prevent refetching via API since users
user_cache: dict[str, BasicExpertInfo | None] = {}
# have to be in the channel in order to read messages
if not channel["is_member"]:
make_slack_api_call_w_retries(
client.conversations_join,
channel=channel["id"],
is_private=channel["is_private"],
)
logger.info(f"Successfully joined '{channel['name']}'")
all_channels = get_channels(client)
filtered_channels = filter_channels(
all_channels, channels, channel_name_regex_enabled
response = make_slack_api_call_w_retries(
client.conversations_history,
channel=channel["id"],
oldest=oldest,
latest=latest,
limit=_SLACK_LIMIT,
)
response.validate()
for channel in filtered_channels:
channel_docs = 0
channel_message_batches = get_channel_messages(
client=client, channel=channel, oldest=oldest, latest=latest
messages = cast(list[MessageType], response.get("messages", []))
cursor = cast(dict[str, Any], response.get("response_metadata", {})).get(
"next_cursor", ""
)
has_more = bool(cursor)
return messages, has_more
def _message_to_doc(
message: MessageType,
client: WebClient,
channel: ChannelType,
slack_cleaner: SlackTextCleaner,
user_cache: dict[str, BasicExpertInfo | None],
seen_thread_ts: set[str],
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
) -> Document | None:
filtered_thread: ThreadType | None = None
thread_ts = message.get("thread_ts")
if thread_ts:
# skip threads we've already seen, since we've already processed all
# messages in that thread
if thread_ts in seen_thread_ts:
return None
thread = get_thread(
client=client, channel_id=channel["id"], thread_id=thread_ts
)
filtered_thread = [
message for message in thread if not msg_filter_func(message)
]
elif not msg_filter_func(message):
filtered_thread = [message]
if filtered_thread:
return thread_to_doc(
channel=channel,
thread=filtered_thread,
slack_cleaner=slack_cleaner,
client=client,
user_cache=user_cache,
)
seen_thread_ts: set[str] = set()
for message_batch in channel_message_batches:
for message in message_batch:
filtered_thread: ThreadType | None = None
thread_ts = message.get("thread_ts")
if thread_ts:
# skip threads we've already seen, since we've already processed all
# messages in that thread
if thread_ts in seen_thread_ts:
continue
seen_thread_ts.add(thread_ts)
thread = get_thread(
client=client, channel_id=channel["id"], thread_id=thread_ts
)
filtered_thread = [
message for message in thread if not msg_filter_func(message)
]
elif not msg_filter_func(message):
filtered_thread = [message]
if filtered_thread:
channel_docs += 1
yield thread_to_doc(
channel=channel,
thread=filtered_thread,
slack_cleaner=slack_cleaner,
client=client,
user_cache=user_cache,
)
logger.info(
f"Pulled {channel_docs} documents from slack channel {channel['name']}"
)
return None
def _get_all_doc_ids(
@@ -368,7 +424,7 @@ def _get_all_doc_ids(
for message_ts in message_ts_set:
channel_metadata_list.append(
SlimDocument(
id=f"{channel_id}__{message_ts}",
id=_build_doc_id(channel_id=channel_id, thread_ts=message_ts),
perm_sync_data={"channel_id": channel_id},
)
)
@@ -376,7 +432,51 @@ def _get_all_doc_ids(
yield channel_metadata_list
class SlackPollConnector(PollConnector, SlimConnector):
def _process_message(
message: MessageType,
client: WebClient,
channel: ChannelType,
slack_cleaner: SlackTextCleaner,
user_cache: dict[str, BasicExpertInfo | None],
seen_thread_ts: set[str],
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
) -> tuple[Document | None, str | None, ConnectorFailure | None]:
thread_ts = message.get("thread_ts")
try:
# causes random failures for testing checkpointing / continue on failure
# import random
# if random.random() > 0.95:
# raise RuntimeError("Random failure :P")
doc = _message_to_doc(
message=message,
client=client,
channel=channel,
slack_cleaner=slack_cleaner,
user_cache=user_cache,
seen_thread_ts=seen_thread_ts,
msg_filter_func=msg_filter_func,
)
return (doc, thread_ts, None)
except Exception as e:
logger.exception(f"Error processing message {message['ts']}")
return (
None,
thread_ts,
ConnectorFailure(
failed_document=DocumentFailure(
document_id=_build_doc_id(
channel_id=channel["id"], thread_ts=(thread_ts or message["ts"])
),
document_link=get_message_link(message, client, channel["id"]),
),
failure_message=str(e),
exception=e,
),
)
class SlackConnector(SlimConnector, CheckpointConnector):
def __init__(
self,
channels: list[str] | None = None,
@@ -390,9 +490,14 @@ class SlackPollConnector(PollConnector, SlimConnector):
self.batch_size = batch_size
self.client: WebClient | None = None
# just used for efficiency
self.text_cleaner: SlackTextCleaner | None = None
self.user_cache: dict[str, BasicExpertInfo | None] = {}
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
bot_token = credentials["slack_bot_token"]
self.client = WebClient(token=bot_token)
self.text_cleaner = SlackTextCleaner(client=self.client)
return None
def retrieve_all_slim_documents(
@@ -411,30 +516,155 @@ class SlackPollConnector(PollConnector, SlimConnector):
callback=callback,
)
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
if self.client is None:
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> CheckpointOutput:
"""Rough outline:
Step 1: Get all channels, yield back Checkpoint.
Step 2: Loop through each channel. For each channel:
Step 2.1: Get messages within the time range.
Step 2.2: Process messages in parallel, yield back docs.
Step 2.3: Update checkpoint with new_latest, seen_thread_ts, and current_channel.
Slack returns messages from newest to oldest, so we need to keep track of
the latest message we've seen in each channel.
Step 2.4: If there are no more messages in the channel, switch the current
channel to the next channel.
"""
if self.client is None or self.text_cleaner is None:
raise ConnectorMissingCredentialError("Slack")
documents: list[Document] = []
for document in _get_all_docs(
client=self.client,
channels=self.channels,
channel_name_regex_enabled=self.channel_regex_enabled,
# NOTE: need to impute to `None` instead of using 0.0, since Slack will
# throw an error if we use 0.0 on an account without infinite data
# retention
oldest=str(start) if start else None,
latest=str(end),
):
documents.append(document)
if len(documents) >= self.batch_size:
yield documents
documents = []
checkpoint_content = cast(
SlackCheckpointContent,
(
copy.deepcopy(checkpoint.checkpoint_content)
or {
"channel_ids": None,
"channel_completion_map": {},
"current_channel": None,
"seen_thread_ts": [],
}
),
)
if documents:
yield documents
# if this is the very first time we've called this, need to
# get all relevant channels and save them into the checkpoint
if checkpoint_content["channel_ids"] is None:
raw_channels = get_channels(self.client)
filtered_channels = filter_channels(
raw_channels, self.channels, self.channel_regex_enabled
)
if len(filtered_channels) == 0:
return checkpoint
checkpoint_content["channel_ids"] = [c["id"] for c in filtered_channels]
checkpoint_content["current_channel"] = filtered_channels[0]
checkpoint = ConnectorCheckpoint(
checkpoint_content=checkpoint_content, # type: ignore
has_more=True,
)
return checkpoint
final_channel_ids = checkpoint_content["channel_ids"]
channel = checkpoint_content["current_channel"]
if channel is None:
raise ValueError("current_channel key not found in checkpoint")
channel_id = channel["id"]
if channel_id not in final_channel_ids:
raise ValueError(f"Channel {channel_id} not found in checkpoint")
oldest = str(start) if start else None
latest = checkpoint_content["channel_completion_map"].get(channel_id, str(end))
seen_thread_ts = set(checkpoint_content["seen_thread_ts"])
try:
logger.debug(
f"Getting messages for channel {channel} within range {oldest} - {latest}"
)
message_batch, has_more_in_channel = _get_messages(
channel, self.client, oldest, latest
)
new_latest = message_batch[-1]["ts"] if message_batch else latest
# Process messages in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=8) as executor:
futures: list[Future] = []
for message in message_batch:
# Capture the current context so that the thread gets the current tenant ID
current_context = contextvars.copy_context()
futures.append(
executor.submit(
current_context.run,
_process_message,
message=message,
client=self.client,
channel=channel,
slack_cleaner=self.text_cleaner,
user_cache=self.user_cache,
seen_thread_ts=seen_thread_ts,
)
)
for future in as_completed(futures):
doc, thread_ts, failures = future.result()
if doc:
# handle race conditions here since this is single
# threaded. Multi-threaded _process_message reads from this
# but since this is single threaded, we won't run into simul
# writes. At worst, we can duplicate a thread, which will be
# deduped later on.
if thread_ts not in seen_thread_ts:
yield doc
if thread_ts:
seen_thread_ts.add(thread_ts)
elif failures:
for failure in failures:
yield failure
checkpoint_content["seen_thread_ts"] = list(seen_thread_ts)
checkpoint_content["channel_completion_map"][channel["id"]] = new_latest
if has_more_in_channel:
checkpoint_content["current_channel"] = channel
else:
new_channel_id = next(
(
channel_id
for channel_id in final_channel_ids
if channel_id
not in checkpoint_content["channel_completion_map"]
),
None,
)
if new_channel_id:
new_channel = _get_channel_by_id(self.client, new_channel_id)
checkpoint_content["current_channel"] = new_channel
else:
checkpoint_content["current_channel"] = None
checkpoint = ConnectorCheckpoint(
checkpoint_content=checkpoint_content, # type: ignore
has_more=checkpoint_content["current_channel"] is not None,
)
return checkpoint
except Exception as e:
logger.exception(f"Error processing channel {channel['name']}")
yield ConnectorFailure(
failed_entity=EntityFailure(
entity_id=channel["id"],
missed_time_range=(
datetime.fromtimestamp(start, tz=timezone.utc),
datetime.fromtimestamp(end, tz=timezone.utc),
),
),
failure_message=str(e),
exception=e,
)
return checkpoint
if __name__ == "__main__":
@@ -442,7 +672,7 @@ if __name__ == "__main__":
import time
slack_channel = os.environ.get("SLACK_CHANNEL")
connector = SlackPollConnector(
connector = SlackConnector(
channels=[slack_channel] if slack_channel else None,
)
connector.load_credentials({"slack_bot_token": os.environ["SLACK_BOT_TOKEN"]})
@@ -450,6 +680,17 @@ if __name__ == "__main__":
current = time.time()
one_day_ago = current - 24 * 60 * 60 # 1 day
document_batches = connector.poll_source(one_day_ago, current)
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
print(next(document_batches))
gen = connector.load_from_checkpoint(one_day_ago, current, checkpoint)
try:
for document_or_failure in gen:
if isinstance(document_or_failure, Document):
print(document_or_failure)
elif isinstance(document_or_failure, ConnectorFailure):
print(document_or_failure)
except StopIteration as e:
checkpoint = e.value
print("Next checkpoint:", checkpoint)
print("Next checkpoint:", checkpoint)

View File

@@ -34,9 +34,14 @@ def get_message_link(
) -> str:
channel_id = channel_id or event["channel"]
message_ts = event["ts"]
response = client.chat_getPermalink(channel=channel_id, message_ts=message_ts)
permalink = response["permalink"]
return permalink
message_ts_without_dot = message_ts.replace(".", "")
thread_ts = event.get("thread_ts")
base_url = get_base_url(client.token)
link = f"{base_url.rstrip('/')}/archives/{channel_id}/p{message_ts_without_dot}" + (
f"?thread_ts={thread_ts}" if thread_ts else ""
)
return link
def _make_slack_api_call_paginated(

View File

@@ -1,9 +1,14 @@
import os
import tempfile
import urllib.parse
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import Dict
from typing import List
from typing import Tuple
from typing import Union
from zulip import Client
@@ -36,8 +41,39 @@ class ZulipConnector(LoadConnector, PollConnector):
) -> None:
self.batch_size = batch_size
self.realm_name = realm_name
self.realm_url = realm_url if realm_url.endswith("/") else realm_url + "/"
self.client: Client | None = None
# Clean and normalize the URL
realm_url = realm_url.strip().lower()
# Remove any trailing slashes
realm_url = realm_url.rstrip("/")
# Ensure the URL has a scheme
if not realm_url.startswith(("http://", "https://")):
realm_url = f"https://{realm_url}"
try:
parsed = urllib.parse.urlparse(realm_url)
# Extract the base domain without any paths or ports
netloc = parsed.netloc.split(":")[0] # Remove port if present
if not netloc:
raise ValueError(
f"Invalid realm URL format: {realm_url}. "
f"URL must include a valid domain name."
)
# Always use HTTPS for security
self.base_url = f"https://{netloc}"
self.client: Client | None = None
except Exception as e:
raise ValueError(
f"Failed to parse Zulip realm URL: {realm_url}. "
f"Please provide a URL in the format: domain.com or https://domain.com. "
f"Error: {str(e)}"
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
contents = credentials["zuliprc_content"]
@@ -55,12 +91,17 @@ class ZulipConnector(LoadConnector, PollConnector):
return None
def _message_to_narrow_link(self, m: Message) -> str:
stream_name = m.display_recipient # assume str
stream_operand = encode_zulip_narrow_operand(f"{m.stream_id}-{stream_name}")
topic_operand = encode_zulip_narrow_operand(m.subject)
try:
stream_name = m.display_recipient # assume str
stream_operand = encode_zulip_narrow_operand(f"{m.stream_id}-{stream_name}")
topic_operand = encode_zulip_narrow_operand(m.subject)
narrow_link = f"{self.realm_url}#narrow/stream/{stream_operand}/topic/{topic_operand}/near/{m.id}"
return narrow_link
narrow_link = f"{self.base_url}#narrow/stream/{stream_operand}/topic/{topic_operand}/near/{m.id}"
return narrow_link
except Exception as e:
logger.error(f"Error generating Zulip message link: {e}")
# Fallback to a basic link that at least includes the base URL
return f"{self.base_url}#narrow/id/{m.id}"
def _get_message_batch(self, anchor: str) -> Tuple[bool, List[Message]]:
if self.client is None:
@@ -83,6 +124,40 @@ class ZulipConnector(LoadConnector, PollConnector):
def _message_to_doc(self, message: Message) -> Document:
text = f"{message.sender_full_name}: {message.content}"
try:
# Convert timestamps to UTC datetime objects
post_time = datetime.fromtimestamp(message.timestamp, tz=timezone.utc)
edit_time = (
datetime.fromtimestamp(message.last_edit_timestamp, tz=timezone.utc)
if message.last_edit_timestamp is not None
else None
)
# Use the most recent edit time if available, otherwise use post time
doc_time = edit_time if edit_time is not None else post_time
except (ValueError, TypeError) as e:
logger.warning(f"Failed to parse timestamp for message {message.id}: {e}")
post_time = None
edit_time = None
doc_time = None
metadata: Dict[str, Union[str, List[str]]] = {
"stream_name": str(message.display_recipient),
"topic": str(message.subject),
"sender_name": str(message.sender_full_name),
"sender_email": str(message.sender_email),
"message_timestamp": str(message.timestamp),
"message_id": str(message.id),
"stream_id": str(message.stream_id),
"has_reactions": str(len(message.reactions) > 0),
"content_type": str(message.content_type or "text"),
}
# Always include edit timestamp in metadata when available
if edit_time is not None:
metadata["edit_timestamp"] = str(message.last_edit_timestamp)
return Document(
id=f"{message.stream_id}__{message.id}",
sections=[
@@ -92,8 +167,9 @@ class ZulipConnector(LoadConnector, PollConnector):
)
],
source=DocumentSource.ZULIP,
semantic_identifier=message.display_recipient or message.subject,
metadata={},
semantic_identifier=f"{message.display_recipient} > {message.subject}",
metadata=metadata,
doc_updated_at=doc_time, # Use most recent edit time or post time
)
def _get_docs(

View File

@@ -1,6 +1,7 @@
from typing import Any
from typing import List
from typing import Optional
from typing import Union
from pydantic import BaseModel
from pydantic import Field
@@ -19,7 +20,7 @@ class Message(BaseModel):
sender_realm_str: str
subject: str
topic_links: Optional[List[Any]] = None
last_edit_timestamp: Optional[int]
last_edit_timestamp: Optional[int] = None
edit_history: Any = None
reactions: List[Any]
submessages: List[Any]
@@ -39,5 +40,5 @@ class GetMessagesResponse(BaseModel):
found_oldest: Optional[bool] = None
found_newest: Optional[bool] = None
history_limited: Optional[bool] = None
anchor: Optional[str] = None
anchor: Optional[Union[str, int]] = None
messages: List[Message] = Field(default_factory=list)

View File

@@ -51,6 +51,7 @@ class SearchPipeline:
user: User | None,
llm: LLM,
fast_llm: LLM,
skip_query_analysis: bool,
db_session: Session,
bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION
retrieval_metrics_callback: (
@@ -61,10 +62,13 @@ class SearchPipeline:
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
prompt_config: PromptConfig | None = None,
):
# NOTE: The Search Request contains a lot of fields that are overrides, many of them can be None
# and typically are None. The preprocessing will fetch default values to replace these empty overrides.
self.search_request = search_request
self.user = user
self.llm = llm
self.fast_llm = fast_llm
self.skip_query_analysis = skip_query_analysis
self.db_session = db_session
self.bypass_acl = bypass_acl
self.retrieval_metrics_callback = retrieval_metrics_callback
@@ -106,6 +110,7 @@ class SearchPipeline:
search_request=self.search_request,
user=self.user,
llm=self.llm,
skip_query_analysis=self.skip_query_analysis,
db_session=self.db_session,
bypass_acl=self.bypass_acl,
)
@@ -160,6 +165,12 @@ class SearchPipeline:
that have a corresponding chunk.
This step should be fast for any document index implementation.
Current implementation timing is approximately broken down in timing as:
- 200 ms to get the embedding of the query
- 15 ms to get chunks from the document index
- possibly more to get additional surrounding chunks
- possibly more for query expansion (multilingual)
"""
if self._retrieved_sections is not None:
return self._retrieved_sections

View File

@@ -15,6 +15,7 @@ from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceChunkUncleaned
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import MAX_METRICS_CONTENT
from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import RerankMetricsContainer
from onyx.context.search.models import SearchQuery
from onyx.document_index.document_index_utils import (
@@ -77,7 +78,8 @@ def cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk
@log_function_time(print_only=True)
def semantic_reranking(
query: SearchQuery,
query_str: str,
rerank_settings: RerankingDetails,
chunks: list[InferenceChunk],
model_min: int = CROSS_ENCODER_RANGE_MIN,
model_max: int = CROSS_ENCODER_RANGE_MAX,
@@ -88,11 +90,9 @@ def semantic_reranking(
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
"""
rerank_settings = query.rerank_settings
if not rerank_settings or not rerank_settings.rerank_model_name:
# Should never reach this part of the flow without reranking settings
raise RuntimeError("Reranking flow should not be running")
assert (
rerank_settings.rerank_model_name
), "Reranking flow cannot run without a specific model"
chunks_to_rerank = chunks[: rerank_settings.num_rerank]
@@ -107,7 +107,7 @@ def semantic_reranking(
f"{chunk.semantic_identifier or chunk.title or ''}\n{chunk.content}"
for chunk in chunks_to_rerank
]
sim_scores_floats = cross_encoder.predict(query=query.query, passages=passages)
sim_scores_floats = cross_encoder.predict(query=query_str, passages=passages)
# Old logic to handle multiple cross-encoders preserved but not used
sim_scores = [numpy.array(sim_scores_floats)]
@@ -165,8 +165,20 @@ def semantic_reranking(
return list(ranked_chunks), list(ranked_indices)
def should_rerank(rerank_settings: RerankingDetails | None) -> bool:
"""Based on the RerankingDetails model, only run rerank if the following conditions are met:
- rerank_model_name is not None
- num_rerank is greater than 0
"""
if not rerank_settings:
return False
return bool(rerank_settings.rerank_model_name and rerank_settings.num_rerank > 0)
def rerank_sections(
query: SearchQuery,
query_str: str,
rerank_settings: RerankingDetails,
sections_to_rerank: list[InferenceSection],
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> list[InferenceSection]:
@@ -181,16 +193,13 @@ def rerank_sections(
"""
chunks_to_rerank = [section.center_chunk for section in sections_to_rerank]
if not query.rerank_settings:
# Should never reach this part of the flow without reranking settings
raise RuntimeError("Reranking settings not found")
ranked_chunks, _ = semantic_reranking(
query=query,
query_str=query_str,
rerank_settings=rerank_settings,
chunks=chunks_to_rerank,
rerank_metrics_callback=rerank_metrics_callback,
)
lower_chunks = chunks_to_rerank[query.rerank_settings.num_rerank :]
lower_chunks = chunks_to_rerank[rerank_settings.num_rerank :]
# Scores from rerank cannot be meaningfully combined with scores without rerank
# However the ordering is still important
@@ -260,16 +269,13 @@ def search_postprocessing(
rerank_task_id = None
sections_yielded = False
if (
search_query.rerank_settings
and search_query.rerank_settings.rerank_model_name
and search_query.rerank_settings.num_rerank > 0
):
if should_rerank(search_query.rerank_settings):
post_processing_tasks.append(
FunctionCall(
rerank_sections,
(
search_query,
search_query.query,
search_query.rerank_settings, # Cannot be None here
retrieved_sections,
rerank_metrics_callback,
),

View File

@@ -50,11 +50,11 @@ def retrieval_preprocessing(
search_request: SearchRequest,
user: User | None,
llm: LLM,
skip_query_analysis: bool,
db_session: Session,
bypass_acl: bool = False,
skip_query_analysis: bool = False,
base_recency_decay: float = BASE_RECENCY_DECAY,
favor_recent_decay_multiplier: float = FAVOR_RECENT_DECAY_MULTIPLIER,
base_recency_decay: float = BASE_RECENCY_DECAY,
bypass_acl: bool = False,
) -> SearchQuery:
"""Logic is as follows:
Any global disables apply first
@@ -146,7 +146,7 @@ def retrieval_preprocessing(
is_keyword, extracted_keywords = (
parallel_results[run_query_analysis.result_id]
if run_query_analysis
else (None, None)
else (False, None)
)
all_query_terms = query.split()

View File

@@ -0,0 +1,10 @@
from sqlalchemy.orm import Session
from onyx.db.models import BackgroundError
def create_background_error(
db_session: Session, message: str, cc_pair_id: int | None
) -> None:
db_session.add(BackgroundError(message=message, cc_pair_id=cc_pair_id))
db_session.commit()

View File

@@ -350,13 +350,17 @@ def delete_chat_session(
user_id: UUID | None,
chat_session_id: UUID,
db_session: Session,
include_deleted: bool = False,
hard_delete: bool = HARD_DELETE_CHATS,
) -> None:
chat_session = get_chat_session_by_id(
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
chat_session_id=chat_session_id,
user_id=user_id,
db_session=db_session,
include_deleted=include_deleted,
)
if chat_session.deleted:
if chat_session.deleted and not include_deleted:
raise ValueError("Cannot delete an already deleted chat session")
if hard_delete:
@@ -380,7 +384,15 @@ def delete_chat_sessions_older_than(days_old: int, db_session: Session) -> None:
).fetchall()
for user_id, session_id in old_sessions:
delete_chat_session(user_id, session_id, db_session, hard_delete=True)
try:
delete_chat_session(
user_id, session_id, db_session, include_deleted=True, hard_delete=True
)
except Exception:
logger.exception(
"delete_chat_session exceptioned. "
f"user_id={user_id} session_id={session_id}"
)
def get_chat_message(
@@ -893,14 +905,18 @@ def translate_db_sub_questions_to_server_objects(
question=sub_question.sub_question,
answer=sub_question.sub_answer,
sub_queries=sub_queries,
context_docs=get_retrieval_docs_from_search_docs(verified_docs),
context_docs=get_retrieval_docs_from_search_docs(
verified_docs, sort_by_score=False
),
)
)
return sub_questions
def get_retrieval_docs_from_search_docs(
search_docs: list[SearchDoc], remove_doc_content: bool = False
search_docs: list[SearchDoc],
remove_doc_content: bool = False,
sort_by_score: bool = True,
) -> RetrievalDocs:
top_documents = [
translate_db_search_doc_to_server_search_doc(
@@ -908,7 +924,8 @@ def get_retrieval_docs_from_search_docs(
)
for db_doc in search_docs
]
top_documents = sorted(top_documents, key=lambda doc: doc.score, reverse=True) # type: ignore
if sort_by_score:
top_documents = sorted(top_documents, key=lambda doc: doc.score, reverse=True) # type: ignore
return RetrievalDocs(top_documents=top_documents)
@@ -1018,7 +1035,7 @@ def log_agent_sub_question_results(
sub_question = sub_question_answer_result.question
sub_answer = sub_question_answer_result.answer
sub_document_results = _create_citation_format_list(
sub_question_answer_result.verified_reranked_documents
sub_question_answer_result.context_documents
)
sub_question_object = AgentSubQuestion(

View File

@@ -18,6 +18,7 @@ import boto3
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy import event
from sqlalchemy import pool
from sqlalchemy import text
from sqlalchemy.engine import create_engine
from sqlalchemy.engine import Engine
@@ -39,6 +40,7 @@ from onyx.configs.app_configs import POSTGRES_PASSWORD
from onyx.configs.app_configs import POSTGRES_POOL_PRE_PING
from onyx.configs.app_configs import POSTGRES_POOL_RECYCLE
from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USE_NULL_POOL
from onyx.configs.app_configs import POSTGRES_USER
from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from onyx.configs.constants import SSL_CERT_FILE
@@ -187,20 +189,45 @@ class SqlEngine:
_engine: Engine | None = None
_lock: threading.Lock = threading.Lock()
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
DEFAULT_ENGINE_KWARGS = {
"pool_size": 20,
"max_overflow": 5,
"pool_pre_ping": POSTGRES_POOL_PRE_PING,
"pool_recycle": POSTGRES_POOL_RECYCLE,
}
@classmethod
def _init_engine(cls, **engine_kwargs: Any) -> Engine:
def _init_engine(
cls, host: str, port: str, db: str, **engine_kwargs: Any
) -> Engine:
connection_string = build_connection_string(
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH
db_api=SYNC_DB_API,
host=host,
port=port,
db=db,
app_name=cls._app_name + "_sync",
use_iam=USE_IAM_AUTH,
)
merged_kwargs = {**cls.DEFAULT_ENGINE_KWARGS, **engine_kwargs}
engine = create_engine(connection_string, **merged_kwargs)
# Start with base kwargs that are valid for all pool types
final_engine_kwargs: dict[str, Any] = {}
if POSTGRES_USE_NULL_POOL:
# if null pool is specified, then we need to make sure that
# we remove any passed in kwargs related to pool size that would
# cause the initialization to fail
final_engine_kwargs.update(engine_kwargs)
final_engine_kwargs["poolclass"] = pool.NullPool
if "pool_size" in final_engine_kwargs:
del final_engine_kwargs["pool_size"]
if "max_overflow" in final_engine_kwargs:
del final_engine_kwargs["max_overflow"]
else:
final_engine_kwargs["pool_size"] = 20
final_engine_kwargs["max_overflow"] = 5
final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING
final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE
# any passed in kwargs override the defaults
final_engine_kwargs.update(engine_kwargs)
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
engine = create_engine(connection_string, **final_engine_kwargs)
if USE_IAM_AUTH:
event.listen(engine, "do_connect", provide_iam_token)
@@ -211,15 +238,19 @@ class SqlEngine:
def init_engine(cls, **engine_kwargs: Any) -> None:
with cls._lock:
if not cls._engine:
cls._engine = cls._init_engine(**engine_kwargs)
cls._engine = cls._init_engine(
host=engine_kwargs.get("host", POSTGRES_HOST),
port=engine_kwargs.get("port", POSTGRES_PORT),
db=engine_kwargs.get("db", POSTGRES_DB),
**engine_kwargs,
)
@classmethod
def get_engine(cls) -> Engine:
if not cls._engine:
with cls._lock:
if not cls._engine:
cls._engine = cls._init_engine()
return cls._engine
cls.init_engine()
return cls._engine # type: ignore
@classmethod
def set_app_name(cls, app_name: str) -> None:
@@ -299,13 +330,21 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
connect_args["ssl"] = ssl_context
engine_kwargs = {
"connect_args": connect_args,
"pool_pre_ping": POSTGRES_POOL_PRE_PING,
"pool_recycle": POSTGRES_POOL_RECYCLE,
}
if POSTGRES_USE_NULL_POOL:
engine_kwargs["poolclass"] = pool.NullPool
else:
engine_kwargs["pool_size"] = POSTGRES_API_SERVER_POOL_SIZE
engine_kwargs["max_overflow"] = POSTGRES_API_SERVER_POOL_OVERFLOW
_ASYNC_ENGINE = create_async_engine(
connection_string,
connect_args=connect_args,
pool_size=POSTGRES_API_SERVER_POOL_SIZE,
max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW,
pool_pre_ping=POSTGRES_POOL_PRE_PING,
pool_recycle=POSTGRES_POOL_RECYCLE,
**engine_kwargs,
)
if USE_IAM_AUTH:

View File

@@ -11,8 +11,7 @@ from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import Session
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentErrorSummary
from onyx.connectors.models import ConnectorFailure
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
from onyx.db.models import IndexingStatus
@@ -41,6 +40,27 @@ def get_last_attempt_for_cc_pair(
)
def get_recent_completed_attempts_for_cc_pair(
cc_pair_id: int,
search_settings_id: int,
limit: int,
db_session: Session,
) -> list[IndexAttempt]:
return (
db_session.query(IndexAttempt)
.filter(
IndexAttempt.connector_credential_pair_id == cc_pair_id,
IndexAttempt.search_settings_id == search_settings_id,
IndexAttempt.status.notin_(
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
),
)
.order_by(IndexAttempt.time_updated.desc())
.limit(limit)
.all()
)
def get_index_attempt(
db_session: Session, index_attempt_id: int
) -> IndexAttempt | None:
@@ -615,23 +635,32 @@ def count_unique_cc_pairs_with_successful_index_attempts(
def create_index_attempt_error(
index_attempt_id: int | None,
batch: int | None,
docs: list[Document],
exception_msg: str,
exception_traceback: str,
connector_credential_pair_id: int,
failure: ConnectorFailure,
db_session: Session,
) -> int:
doc_summaries = []
for doc in docs:
doc_summary = DocumentErrorSummary.from_document(doc)
doc_summaries.append(doc_summary.to_dict())
new_error = IndexAttemptError(
index_attempt_id=index_attempt_id,
batch=batch,
doc_summaries=doc_summaries,
error_msg=exception_msg,
traceback=exception_traceback,
connector_credential_pair_id=connector_credential_pair_id,
document_id=(
failure.failed_document.document_id if failure.failed_document else None
),
document_link=(
failure.failed_document.document_link if failure.failed_document else None
),
entity_id=(failure.failed_entity.entity_id if failure.failed_entity else None),
failed_time_range_start=(
failure.failed_entity.missed_time_range[0]
if failure.failed_entity and failure.failed_entity.missed_time_range
else None
),
failed_time_range_end=(
failure.failed_entity.missed_time_range[1]
if failure.failed_entity and failure.failed_entity.missed_time_range
else None
),
failure_message=failure.failure_message,
is_resolved=False,
)
db_session.add(new_error)
db_session.commit()
@@ -649,3 +678,42 @@ def get_index_attempt_errors(
errors = db_session.scalars(stmt)
return list(errors.all())
def count_index_attempt_errors_for_cc_pair(
cc_pair_id: int,
unresolved_only: bool,
db_session: Session,
) -> int:
stmt = (
select(func.count())
.select_from(IndexAttemptError)
.where(IndexAttemptError.connector_credential_pair_id == cc_pair_id)
)
if unresolved_only:
stmt = stmt.where(IndexAttemptError.is_resolved.is_(False))
result = db_session.scalar(stmt)
return 0 if result is None else result
def get_index_attempt_errors_for_cc_pair(
cc_pair_id: int,
unresolved_only: bool,
db_session: Session,
page: int | None = None,
page_size: int | None = None,
) -> list[IndexAttemptError]:
stmt = select(IndexAttemptError).where(
IndexAttemptError.connector_credential_pair_id == cc_pair_id
)
if unresolved_only:
stmt = stmt.where(IndexAttemptError.is_resolved.is_(False))
# Order by most recent first
stmt = stmt.order_by(desc(IndexAttemptError.time_created))
if page is not None and page_size is not None:
stmt = stmt.offset(page * page_size).limit(page_size)
return list(db_session.scalars(stmt).all())

View File

@@ -483,6 +483,10 @@ class ConnectorCredentialPair(Base):
primaryjoin="foreign(ConnectorCredentialPair.creator_id) == remote(User.id)",
)
background_errors: Mapped[list["BackgroundError"]] = relationship(
"BackgroundError", back_populates="cc_pair", cascade="all, delete-orphan"
)
class Document(Base):
__tablename__ = "document"
@@ -823,6 +827,19 @@ class IndexAttempt(Base):
nullable=True,
)
# for polling connectors, the start and end time of the poll window
# will be set when the index attempt starts
poll_range_start: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True, default=None
)
poll_range_end: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True, default=None
)
# Points to the last checkpoint that was saved for this run. The pointer here
# can be taken to the FileStore to grab the actual checkpoint value
checkpoint_pointer: Mapped[str | None] = mapped_column(String, nullable=True)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
@@ -866,6 +883,13 @@ class IndexAttempt(Base):
desc("time_updated"),
unique=False,
),
Index(
"ix_index_attempt_cc_pair_settings_poll",
"connector_credential_pair_id",
"search_settings_id",
"status",
desc("time_updated"),
),
)
def __repr__(self) -> str:
@@ -882,25 +906,33 @@ class IndexAttempt(Base):
class IndexAttemptError(Base):
"""
Represents an error that was encountered during an IndexAttempt.
"""
__tablename__ = "index_attempt_errors"
id: Mapped[int] = mapped_column(primary_key=True)
index_attempt_id: Mapped[int] = mapped_column(
ForeignKey("index_attempt.id"),
nullable=True,
nullable=False,
)
connector_credential_pair_id: Mapped[int] = mapped_column(
ForeignKey("connector_credential_pair.id"),
nullable=False,
)
# The index of the batch where the error occurred (if looping thru batches)
# Just informational.
batch: Mapped[int | None] = mapped_column(Integer, default=None)
doc_summaries: Mapped[list[Any]] = mapped_column(postgresql.JSONB())
error_msg: Mapped[str | None] = mapped_column(Text, default=None)
traceback: Mapped[str | None] = mapped_column(Text, default=None)
document_id: Mapped[str | None] = mapped_column(String, nullable=True)
document_link: Mapped[str | None] = mapped_column(String, nullable=True)
entity_id: Mapped[str | None] = mapped_column(String, nullable=True)
failed_time_range_start: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
failed_time_range_end: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
failure_message: Mapped[str] = mapped_column(Text)
is_resolved: Mapped[bool] = mapped_column(Boolean, default=False)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
@@ -909,21 +941,6 @@ class IndexAttemptError(Base):
# This is the reverse side of the relationship
index_attempt = relationship("IndexAttempt", back_populates="error_rows")
__table_args__ = (
Index(
"index_attempt_id",
"time_created",
),
)
def __repr__(self) -> str:
return (
f"<IndexAttempt(id={self.id!r}, "
f"index_attempt_id={self.index_attempt_id!r}, "
f"error_msg={self.error_msg!r})>"
f"time_created={self.time_created!r}, "
)
class SyncRecord(Base):
"""
@@ -2115,6 +2132,31 @@ class StandardAnswer(Base):
)
class BackgroundError(Base):
"""Important background errors. Serves to:
1. Ensure that important logs are kept around and not lost on rotation/container restarts
2. A trail for high-signal events so that the debugger doesn't need to remember/know every
possible relevant log line.
"""
__tablename__ = "background_error"
id: Mapped[int] = mapped_column(primary_key=True)
message: Mapped[str] = mapped_column(String)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
# option to link the error to a specific CC Pair
cc_pair_id: Mapped[int | None] = mapped_column(
ForeignKey("connector_credential_pair.id", ondelete="CASCADE"), nullable=True
)
cc_pair: Mapped["ConnectorCredentialPair | None"] = relationship(
"ConnectorCredentialPair", back_populates="background_errors"
)
"""Tables related to Permission Sync"""

View File

@@ -6,6 +6,7 @@ from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.vespa.index import VespaIndex
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY
def get_default_document_index(
@@ -23,14 +24,27 @@ def get_default_document_index(
secondary_index_name = secondary_search_settings.index_name
secondary_large_chunks_enabled = secondary_search_settings.large_chunks_enabled
# modify index names for integration tests so that we can run many tests
# using the same Vespa instance w/o having them collide
primary_index_name = search_settings.index_name
if VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY:
primary_index_name = (
f"{VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY}_{primary_index_name}"
)
if secondary_index_name:
secondary_index_name = f"{VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY}_{secondary_index_name}"
# Currently only supporting Vespa
return VespaIndex(
index_name=search_settings.index_name,
index_name=primary_index_name,
secondary_index_name=secondary_index_name,
large_chunks_enabled=search_settings.large_chunks_enabled,
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
multitenant=MULTI_TENANT,
httpx_client=httpx_client,
preserve_existing_indices=bool(
VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY
),
)

View File

@@ -136,6 +136,7 @@ class VespaIndex(DocumentIndex):
secondary_large_chunks_enabled: bool | None,
multitenant: bool = False,
httpx_client: httpx.Client | None = None,
preserve_existing_indices: bool = False,
) -> None:
self.index_name = index_name
self.secondary_index_name = secondary_index_name
@@ -161,18 +162,18 @@ class VespaIndex(DocumentIndex):
secondary_index_name
] = secondary_large_chunks_enabled
def ensure_indices_exist(
self,
index_embedding_dim: int,
secondary_index_embedding_dim: int | None,
) -> None:
if MULTI_TENANT:
logger.info(
"Skipping Vespa index seup for multitenant (would wipe all indices)"
)
return None
self.preserve_existing_indices = preserve_existing_indices
deploy_url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/prepareandactivate"
@classmethod
def create_indices(
cls,
indices: list[tuple[str, int, bool]],
application_endpoint: str = VESPA_APPLICATION_ENDPOINT,
) -> None:
"""
Create indices in Vespa based on the passed in configuration(s).
"""
deploy_url = f"{application_endpoint}/tenant/default/prepareandactivate"
logger.notice(f"Deploying Vespa application package to {deploy_url}")
vespa_schema_path = os.path.join(
@@ -185,7 +186,7 @@ class VespaIndex(DocumentIndex):
with open(services_file, "r") as services_f:
services_template = services_f.read()
schema_names = [self.index_name, self.secondary_index_name]
schema_names = [index_name for (index_name, _, _) in indices]
doc_lines = _create_document_xml_lines(schema_names)
services = services_template.replace(DOCUMENT_REPLACEMENT_PAT, doc_lines)
@@ -193,14 +194,6 @@ class VespaIndex(DocumentIndex):
SEARCH_THREAD_NUMBER_PAT, str(VESPA_SEARCHER_THREADS)
)
kv_store = get_kv_store()
needs_reindexing = False
try:
needs_reindexing = cast(bool, kv_store.load(KV_REINDEX_KEY))
except Exception:
logger.debug("Could not load the reindexing flag. Using ngrams")
with open(overrides_file, "r") as overrides_f:
overrides_template = overrides_f.read()
@@ -221,29 +214,63 @@ class VespaIndex(DocumentIndex):
schema_template = schema_f.read()
schema_template = schema_template.replace(TENANT_ID_PAT, "")
schema = schema_template.replace(
DANSWER_CHUNK_REPLACEMENT_PAT, self.index_name
).replace(VESPA_DIM_REPLACEMENT_PAT, str(index_embedding_dim))
for index_name, index_embedding_dim, needs_reindexing in indices:
schema = schema_template.replace(
DANSWER_CHUNK_REPLACEMENT_PAT, index_name
).replace(VESPA_DIM_REPLACEMENT_PAT, str(index_embedding_dim))
schema = add_ngrams_to_schema(schema) if needs_reindexing else schema
schema = schema.replace(TENANT_ID_PAT, "")
zip_dict[f"schemas/{schema_names[0]}.sd"] = schema.encode("utf-8")
if self.secondary_index_name:
upcoming_schema = schema_template.replace(
DANSWER_CHUNK_REPLACEMENT_PAT, self.secondary_index_name
).replace(VESPA_DIM_REPLACEMENT_PAT, str(secondary_index_embedding_dim))
zip_dict[f"schemas/{schema_names[1]}.sd"] = upcoming_schema.encode("utf-8")
schema = add_ngrams_to_schema(schema) if needs_reindexing else schema
schema = schema.replace(TENANT_ID_PAT, "")
logger.info(
f"Creating index: {index_name} with embedding "
f"dimension: {index_embedding_dim}. Schema:\n\n {schema}"
)
zip_dict[f"schemas/{index_name}.sd"] = schema.encode("utf-8")
zip_file = in_memory_zip_from_file_bytes(zip_dict)
headers = {"Content-Type": "application/zip"}
response = requests.post(deploy_url, headers=headers, data=zip_file)
if response.status_code != 200:
logger.error(f"Failed to create Vespa indices: {response.text}")
raise RuntimeError(
f"Failed to prepare Vespa Onyx Index. Response: {response.text}"
)
def ensure_indices_exist(
self,
index_embedding_dim: int,
secondary_index_embedding_dim: int | None,
) -> None:
if self.multitenant or MULTI_TENANT: # be extra safe here
logger.info(
"Skipping Vespa index setup for multitenant (would wipe all indices)"
)
return None
# Used in IT
# NOTE: this means that we can't switch embedding models
if self.preserve_existing_indices:
logger.info("Preserving existing indices")
return None
kv_store = get_kv_store()
primary_needs_reindexing = False
try:
primary_needs_reindexing = cast(bool, kv_store.load(KV_REINDEX_KEY))
except Exception:
logger.debug("Could not load the reindexing flag. Using ngrams")
indices = [
(self.index_name, index_embedding_dim, primary_needs_reindexing),
]
if self.secondary_index_name and secondary_index_embedding_dim:
indices.append(
(self.secondary_index_name, secondary_index_embedding_dim, False)
)
self.create_indices(indices)
@staticmethod
def register_multitenant_indices(
indices: list[str],

View File

@@ -1,6 +1,10 @@
import time
from abc import ABC
from abc import abstractmethod
from collections import defaultdict
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import DocumentFailure
from onyx.db.models import SearchSettings
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.models import ChunkEmbedding
@@ -217,3 +221,49 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
deployment_name=search_settings.deployment_name,
callback=callback,
)
def embed_chunks_with_failure_handling(
chunks: list[DocAwareChunk],
embedder: IndexingEmbedder,
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
"""Tries to embed all chunks in one large batch. If that batch fails for any reason,
goes document by document to isolate the failure(s).
"""
# First try to embed all chunks in one batch
try:
return embedder.embed_chunks(chunks=chunks), []
except Exception:
logger.exception("Failed to embed chunk batch. Trying individual docs.")
# wait a couple seconds to let any rate limits or temporary issues resolve
time.sleep(2)
# Try embedding each document's chunks individually
chunks_by_doc: dict[str, list[DocAwareChunk]] = defaultdict(list)
for chunk in chunks:
chunks_by_doc[chunk.source_document.id].append(chunk)
embedded_chunks: list[IndexChunk] = []
failures: list[ConnectorFailure] = []
for doc_id, chunks_for_doc in chunks_by_doc.items():
try:
doc_embedded_chunks = embedder.embed_chunks(chunks=chunks_for_doc)
embedded_chunks.extend(doc_embedded_chunks)
except Exception as e:
logger.exception(f"Failed to embed chunks for document '{doc_id}'")
failures.append(
ConnectorFailure(
failed_document=DocumentFailure(
document_id=doc_id,
document_link=(
chunks_for_doc[0].get_link() if chunks_for_doc else None
),
),
failure_message=str(e),
exception=e,
)
)
return embedded_chunks, failures

View File

@@ -1,23 +1,21 @@
import traceback
from collections.abc import Callable
from functools import partial
from http import HTTPStatus
from typing import Protocol
import httpx
from pydantic import BaseModel
from pydantic import ConfigDict
from sqlalchemy.orm import Session
from onyx.access.access import get_access_for_documents
from onyx.access.models import DocumentAccess
from onyx.configs.app_configs import INDEXING_EXCEPTION_LIMIT
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
from onyx.configs.constants import DEFAULT_BOOST
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
get_experts_stores_representations,
)
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import IndexAttemptMetadata
from onyx.db.document import fetch_chunk_counts_for_documents
from onyx.db.document import get_documents_by_ids
@@ -29,7 +27,6 @@ from onyx.db.document import update_docs_updated_at__no_commit
from onyx.db.document import upsert_document_by_connector_credential_pair
from onyx.db.document import upsert_documents
from onyx.db.document_set import fetch_document_sets_for_documents
from onyx.db.index_attempt import create_index_attempt_error
from onyx.db.models import Document as DBDocument
from onyx.db.search_settings import get_current_search_settings
from onyx.db.tag import create_or_add_document_tag
@@ -41,10 +38,12 @@ from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import DocumentMetadata
from onyx.document_index.interfaces import IndexBatchParams
from onyx.indexing.chunker import Chunker
from onyx.indexing.embedder import embed_chunks_with_failure_handling
from onyx.indexing.embedder import IndexingEmbedder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time
@@ -67,6 +66,8 @@ class IndexingPipelineResult(BaseModel):
# number of chunks that were inserted into Vespa
total_chunks: int
failures: list[ConnectorFailure]
class IndexingPipelineProtocol(Protocol):
def __call__(
@@ -156,14 +157,10 @@ def index_doc_batch_with_handler(
document_index: DocumentIndex,
document_batch: list[Document],
index_attempt_metadata: IndexAttemptMetadata,
attempt_id: int | None,
db_session: Session,
ignore_time_skip: bool = False,
tenant_id: str | None = None,
) -> IndexingPipelineResult:
index_pipeline_result = IndexingPipelineResult(
new_docs=0, total_docs=len(document_batch), total_chunks=0
)
try:
index_pipeline_result = index_doc_batch(
chunker=chunker,
@@ -176,47 +173,25 @@ def index_doc_batch_with_handler(
tenant_id=tenant_id,
)
except Exception as e:
if isinstance(e, httpx.HTTPStatusError):
if e.response.status_code == HTTPStatus.INSUFFICIENT_STORAGE:
logger.error(
"NOTE: HTTP Status 507 Insufficient Storage indicates "
"you need to allocate more memory or disk space to the "
"Vespa/index container."
logger.exception(f"Failed to index document batch: {document_batch}")
index_pipeline_result = IndexingPipelineResult(
new_docs=0,
total_docs=len(document_batch),
total_chunks=0,
failures=[
ConnectorFailure(
failed_document=DocumentFailure(
document_id=document.id,
document_link=(
document.sections[0].link if document.sections else None
),
),
failure_message=str(e),
exception=e,
)
if INDEXING_EXCEPTION_LIMIT == 0:
raise
trace = traceback.format_exc()
create_index_attempt_error(
attempt_id,
batch=index_attempt_metadata.batch_num,
docs=document_batch,
exception_msg=str(e),
exception_traceback=trace,
db_session=db_session,
for document in document_batch
],
)
logger.exception(
f"Indexing batch {index_attempt_metadata.batch_num} failed. msg='{e}' trace='{trace}'"
)
index_attempt_metadata.num_exceptions += 1
if index_attempt_metadata.num_exceptions == INDEXING_EXCEPTION_LIMIT:
logger.warning(
f"Maximum number of exceptions for this index attempt "
f"({INDEXING_EXCEPTION_LIMIT}) has been reached. "
f"The next exception will abort the indexing attempt."
)
elif index_attempt_metadata.num_exceptions > INDEXING_EXCEPTION_LIMIT:
logger.warning(
f"Maximum number of exceptions for this index attempt "
f"({INDEXING_EXCEPTION_LIMIT}) has been exceeded."
)
raise RuntimeError(
f"Maximum exception limit of {INDEXING_EXCEPTION_LIMIT} exceeded."
)
else:
pass
return index_pipeline_result
@@ -376,8 +351,12 @@ def index_doc_batch(
document_ids=[doc.id for doc in filtered_documents],
db_session=db_session,
)
db_session.commit()
return IndexingPipelineResult(
new_docs=0, total_docs=len(filtered_documents), total_chunks=0
new_docs=0,
total_docs=len(filtered_documents),
total_chunks=0,
failures=[],
)
doc_descriptors = [
@@ -390,10 +369,19 @@ def index_doc_batch(
logger.debug(f"Starting indexing process for documents: {doc_descriptors}")
logger.debug("Starting chunking")
# NOTE: no special handling for failures here, since the chunker is not
# a common source of failure for the indexing pipeline
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)
logger.debug("Starting embedding")
chunks_with_embeddings = embedder.embed_chunks(chunks) if chunks else []
chunks_with_embeddings, embedding_failures = (
embed_chunks_with_failure_handling(
chunks=chunks,
embedder=embedder,
)
if chunks
else ([], [])
)
updatable_ids = [doc.id for doc in ctx.updatable_docs]
@@ -459,7 +447,11 @@ def index_doc_batch(
# A document will not be spread across different batches, so all the
# documents with chunks in this set, are fully represented by the chunks
# in this set
insertion_records = document_index.index(
(
insertion_records,
vector_db_write_failures,
) = write_chunks_to_vector_db_with_backoff(
document_index=document_index,
chunks=access_aware_chunks,
index_batch_params=IndexBatchParams(
doc_id_to_previous_chunk_cnt=doc_id_to_previous_chunk_cnt,
@@ -519,6 +511,7 @@ def index_doc_batch(
new_docs=len([r for r in insertion_records if r.already_existed is False]),
total_docs=len(filtered_documents),
total_chunks=len(access_aware_chunks),
failures=vector_db_write_failures + embedding_failures,
)
return result
@@ -531,7 +524,6 @@ def build_indexing_pipeline(
db_session: Session,
chunker: Chunker | None = None,
ignore_time_skip: bool = False,
attempt_id: int | None = None,
tenant_id: str | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> IndexingPipelineProtocol:
@@ -553,7 +545,6 @@ def build_indexing_pipeline(
embedder=embedder,
document_index=document_index,
ignore_time_skip=ignore_time_skip,
attempt_id=attempt_id,
db_session=db_session,
tenant_id=tenant_id,
)

View File

@@ -57,6 +57,13 @@ class DocAwareChunk(BaseChunk):
"""Used when logging the identity of a chunk"""
return f"{self.source_document.to_short_descriptor()} Chunk ID: {self.chunk_id}"
def get_link(self) -> str | None:
return (
self.source_document.sections[0].link
if self.source_document.sections
else None
)
class IndexChunk(DocAwareChunk):
embeddings: ChunkEmbedding

View File

@@ -0,0 +1,99 @@
import time
from collections import defaultdict
from http import HTTPStatus
import httpx
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import DocumentFailure
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import DocumentInsertionRecord
from onyx.document_index.interfaces import IndexBatchParams
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _log_insufficient_storage_error(e: Exception) -> None:
if isinstance(e, httpx.HTTPStatusError):
if e.response.status_code == HTTPStatus.INSUFFICIENT_STORAGE:
logger.error(
"NOTE: HTTP Status 507 Insufficient Storage indicates "
"you need to allocate more memory or disk space to the "
"Vespa/index container."
)
def write_chunks_to_vector_db_with_backoff(
document_index: DocumentIndex,
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> tuple[list[DocumentInsertionRecord], list[ConnectorFailure]]:
"""Tries to insert all chunks in one large batch. If that batch fails for any reason,
goes document by document to isolate the failure(s).
IMPORTANT: must pass in whole documents at a time not individual chunks, since the
vector DB interface assumes that all chunks for a single document are present.
"""
# first try to write the chunks to the vector db
try:
return (
list(
document_index.index(
chunks=chunks,
index_batch_params=index_batch_params,
)
),
[],
)
except Exception as e:
logger.exception(
"Failed to write chunk batch to vector db. Trying individual docs."
)
# give some specific logging on this common failure case.
_log_insufficient_storage_error(e)
# wait a couple seconds just to give the vector db a chance to recover
time.sleep(2)
# try writing each doc one by one
chunks_for_docs: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(list)
for chunk in chunks:
chunks_for_docs[chunk.source_document.id].append(chunk)
insertion_records: list[DocumentInsertionRecord] = []
failures: list[ConnectorFailure] = []
for doc_id, chunks_for_doc in chunks_for_docs.items():
try:
insertion_records.extend(
document_index.index(
chunks=chunks_for_doc,
index_batch_params=index_batch_params,
)
)
except Exception as e:
logger.exception(
f"Failed to write document chunks for '{doc_id}' to vector db"
)
# give some specific logging on this common failure case.
_log_insufficient_storage_error(e)
failures.append(
ConnectorFailure(
failed_document=DocumentFailure(
document_id=doc_id,
document_link=(
chunks_for_doc[0].get_link() if chunks_for_doc else None
),
),
failure_message=str(e),
exception=e,
)
)
return insertion_records, failures

Some files were not shown because too many files have changed in this diff Show More