Compare commits

..

27 Commits

Author SHA1 Message Date
Weves
4d1d5bdcfe tweak worker count 2025-02-16 14:51:10 -08:00
Weves
0d468b49a1 Final fixes 2025-02-16 14:50:51 -08:00
Weves
67b87ced39 fix paths 2025-02-16 14:23:51 -08:00
Weves
8b4e4a6c80 Fix paths 2025-02-16 14:22:12 -08:00
Weves
e26bcf5a05 Misc fixes 2025-02-16 14:02:16 -08:00
Weves
435959cf90 test 2025-02-16 14:02:16 -08:00
Weves
fcbe305dc0 test 2025-02-16 14:02:16 -08:00
Weves
6f13d44564 move 2025-02-16 14:02:16 -08:00
Weves
c1810a35cd Fix redis port 2025-02-16 14:02:16 -08:00
Weves
4003e7346a test 2025-02-16 14:02:16 -08:00
Weves
8057f1eb0d Add logging 2025-02-16 14:02:16 -08:00
Weves
7eebd3cff1 test not removing files 2025-02-16 14:02:16 -08:00
Weves
bac2aeb8b7 Fix 2025-02-16 14:02:16 -08:00
Weves
9831697acc Make migrations work 2025-02-16 14:02:15 -08:00
Weves
5da766dd3b testing 2025-02-16 14:01:50 -08:00
Weves
180608694a improvements 2025-02-16 14:01:49 -08:00
Weves
96b92edfdb Parallelize IT
Parallelization

Full draft of first pass

Adjsut test name

test

test

Fix

Update cmd

test

Fix

test

Test with all tests

Resource bump + limit num parallel runs

Add retries
2025-02-16 14:01:12 -08:00
pablonyx
ec0e55fd39 Seeding count issue (#4009)
* k

* k

* quick nit

* nit
2025-02-16 20:49:25 +00:00
pablonyx
e441c899af Playwright + Chromatic update (#4015) 2025-02-16 13:03:45 -08:00
Chris Weaver
f1fc8ac19b Connector checkpointing (#3876)
* wip checkpointing/continue on failure

more stuff for checkpointing

Basic implementation

FE stuff

More checkpointing/failure handling

rebase

rebase

initial scaffolding for IT

IT to test checkpointing

Cleanup

cleanup

Fix it

Rebase

Add todo

Fix actions IT

Test more

Pagination + fixes + cleanup

Fix IT networking

fix it

* rebase

* Address misc comments

* Address comments

* Remove unused router

* rebase

* Fix mypy

* Fixes

* fix it

* Fix tests

* Add drop index

* Add retries

* reset lock timeout

* Try hard drop of schema

* Add timeout/retries to downgrade

* rebase

* test

* test

* test

* Close all connections

* test closing idle only

* Fix it

* fix

* try using null pool

* Test

* fix

* rebase

* log

* Fix

* apply null pool

* Fix other test

* Fix quality checks

* Test not using the fixture

* Fix ordering

* fix test

* Change pooling behavior
2025-02-16 02:34:39 +00:00
Weves
bc087fc20e Fix ruff 2025-02-15 16:35:15 -08:00
Yuhong Sun
ab8081c36b k 2025-02-15 13:42:43 -08:00
Adam Siemiginowski
f371efc916 Fix Zulip connector schema + links and enable temporal metadata (#4005) 2025-02-15 11:49:41 -08:00
pablonyx
7fd5d31dbe Minor background process log cleanup (#4010) 2025-02-15 11:03:10 -08:00
rkuo-danswer
2829e6715e Feature/propagate exceptions (#3974)
* better propagation of exceptions up the stack

* remove debug testing

* refactor the watchdog more to emit data consistently at the end of the function

* enumerate a lot more terminal statuses

* handle more codes

* improve logging

* handle "-9"

* single line exception logging

* typo/grammar

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-15 04:53:01 +00:00
Weves
bc7b4ec396 Fix typing for metadata 2025-02-14 18:19:37 -08:00
pablonyx
697f8bc1c6 Reduce background errors (#4004) 2025-02-14 17:35:26 -08:00
93 changed files with 5426 additions and 1552 deletions

View File

@@ -0,0 +1,153 @@
name: Run Integration Tests v3
concurrency:
group: Run-Integration-Tests-Parallel-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
merge_group:
pull_request:
branches:
- main
- "release/**"
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
jobs:
integration-tests:
# See https://runs-on.com/runners/linux/
runs-on:
[runs-on, runner=32cpu-linux-x64, ram=64, "run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build integration test Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/tests/integration/Dockerfile
platforms: linux/amd64
tags: danswer/danswer-integration:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration-parallel/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration-parallel/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Run Standard Integration Tests
run: |
# Print a message indicating that tests are starting
echo "Running integration tests..."
# Create a directory for test logs that will be mounted into the container
mkdir -p ${{ github.workspace }}/test_logs
chmod 777 ${{ github.workspace }}/test_logs
# Run the integration tests in a Docker container
# Mount the Docker socket to allow Docker-in-Docker (DinD)
# Mount the test_logs directory to capture logs
# Use host network for easier communication with other services
docker run \
-v /var/run/docker.sock:/var/run/docker.sock \
-v ${{ github.workspace }}/test_logs:/tmp \
--network host \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
danswer/danswer-integration:test \
python /app/tests/integration/run.py
continue-on-error: true
id: run_tests
- name: Check test results
run: |
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
echo "Integration tests failed. Exiting with error."
exit 1
else
echo "All integration tests passed successfully."
fi
- name: Collect log files
if: success() || failure()
run: |
# Create a directory for logs
mkdir -p ${{ github.workspace }}/logs
mkdir -p ${{ github.workspace }}/logs/shared_services
# Copy all relevant log files from the mounted directory
cp ${{ github.workspace }}/test_logs/api_server_*.txt ${{ github.workspace }}/logs/ || true
cp ${{ github.workspace }}/test_logs/background_*.txt ${{ github.workspace }}/logs/ || true
cp ${{ github.workspace }}/test_logs/shared_model_server.txt ${{ github.workspace }}/logs/ || true
# Collect logs from shared services (Docker containers)
# Note: using a wildcard for the UUID part of the stack name
docker ps -a --filter "name=base-onyx-" --format "{{.Names}}" | while read container; do
echo "Collecting logs from $container"
docker logs $container > "${{ github.workspace }}/logs/shared_services/${container}.log" 2>&1 || true
done
# Also collect Redis container logs
docker ps -a --filter "name=redis-onyx-" --format "{{.Names}}" | while read container; do
echo "Collecting logs from $container"
docker logs $container > "${{ github.workspace }}/logs/shared_services/${container}.log" 2>&1 || true
done
# List collected logs
echo "Collected log files:"
ls -l ${{ github.workspace }}/logs/
echo "Collected shared services logs:"
ls -l ${{ github.workspace }}/logs/shared_services/
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@v4
with:
name: integration-test-logs
path: |
${{ github.workspace }}/logs/
${{ github.workspace }}/logs/shared_services/
retention-days: 5
# save before stopping the containers so the logs can be captured
# - name: Save Docker logs
# if: success() || failure()
# run: |
# cd deployment/docker_compose
# docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
# mv docker-compose.log ${{ github.workspace }}/docker-compose.log
# - name: Stop Docker containers
# run: |
# cd deployment/docker_compose
# docker compose -f docker-compose.dev.yml -p danswer-stack down -v
# - name: Upload logs
# if: success() || failure()
# uses: actions/upload-artifact@v4
# with:
# name: docker-logs
# path: ${{ github.workspace }}/docker-compose.log
# - name: Stop Docker containers
# run: |
# cd deployment/docker_compose
# docker compose -f docker-compose.dev.yml -p danswer-stack down -v

View File

@@ -5,10 +5,10 @@ concurrency:
on:
merge_group:
pull_request:
branches:
- main
- "release/**"
# pull_request:
# branches:
# - main
# - "release/**"
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@@ -99,7 +99,7 @@ jobs:
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
DEV_MODE=true \
docker compose -f docker-compose.multitenant-dev.yml -p danswer-stack up -d
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack up -d
id: start_docker_multi_tenant
# In practice, `cloud` Auth type would require OAUTH credentials to be set.
@@ -108,12 +108,13 @@ jobs:
echo "Waiting for 3 minutes to ensure API server is ready..."
sleep 180
echo "Running integration tests..."
docker run --rm --network danswer-stack_default \
docker run --rm --network onyx-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
@@ -143,24 +144,27 @@ jobs:
- name: Stop multi-tenant Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.multitenant-dev.yml -p danswer-stack down -v
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack down -v
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
- name: Start Docker containers
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
AUTH_TYPE=basic \
POSTGRES_POOL_PRE_PING=true \
POSTGRES_USE_NULL_POOL=true \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
docker logs -f danswer-stack-api_server-1 &
docker logs -f onyx-stack-api_server-1 &
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
@@ -190,15 +194,24 @@ jobs:
done
echo "Finished waiting for service."
- name: Start Mock Services
run: |
cd backend/tests/integration/mock_services
docker compose -f docker-compose.mock-it-services.yml \
-p mock-it-services-stack up -d
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
- name: Run Standard Integration Tests
run: |
echo "Running integration tests..."
docker run --rm --network danswer-stack_default \
docker run --rm --network onyx-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
@@ -208,6 +221,8 @@ jobs:
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
onyxdotapp/onyx-integration:test \
/app/tests/integration/tests \
/app/tests/integration/connector_job_tests
@@ -229,13 +244,13 @@ jobs:
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
- name: Dump all-container logs (optional)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
@@ -249,4 +264,4 @@ jobs:
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
docker compose -f docker-compose.dev.yml -p onyx-stack down -v

View File

@@ -1,6 +1,6 @@
name: Run Chromatic Tests
name: Run Playwright Tests
concurrency:
group: Run-Chromatic-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
group: Run-Playwright-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on: push
@@ -198,43 +198,47 @@ jobs:
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
chromatic-tests:
name: Chromatic Tests
# NOTE: Chromatic UI diff testing is currently disabled.
# We are using Playwright for local and CI testing without visual regression checks.
# Chromatic may be reintroduced in the future for UI diff testing if needed.
needs: playwright-tests
runs-on:
[
runs-on,
runner=32cpu-linux-x64,
disk=large,
"run-id=${{ github.run_id }}",
]
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
# chromatic-tests:
# name: Chromatic Tests
- name: Setup node
uses: actions/setup-node@v4
with:
node-version: 22
# needs: playwright-tests
# runs-on:
# [
# runs-on,
# runner=32cpu-linux-x64,
# disk=large,
# "run-id=${{ github.run_id }}",
# ]
# steps:
# - name: Checkout code
# uses: actions/checkout@v4
# with:
# fetch-depth: 0
- name: Install node dependencies
working-directory: ./web
run: npm ci
# - name: Setup node
# uses: actions/setup-node@v4
# with:
# node-version: 22
- name: Download Playwright test results
uses: actions/download-artifact@v4
with:
name: test-results
path: ./web/test-results
# - name: Install node dependencies
# working-directory: ./web
# run: npm ci
- name: Run Chromatic
uses: chromaui/action@latest
with:
playwright: true
projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
workingDir: ./web
env:
CHROMATIC_ARCHIVE_LOCATION: ./test-results
# - name: Download Playwright test results
# uses: actions/download-artifact@v4
# with:
# name: test-results
# path: ./web/test-results
# - name: Run Chromatic
# uses: chromaui/action@latest
# with:
# playwright: true
# projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
# workingDir: ./web
# env:
# CHROMATIC_ARCHIVE_LOCATION: ./test-results

View File

@@ -205,7 +205,7 @@
"--loglevel=INFO",
"--hostname=light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup",
],
"presentation": {
"group": "2",

View File

@@ -1,6 +1,6 @@
from typing import Any, Literal
from onyx.db.engine import get_iam_auth_token
from onyx.configs.app_configs import USE_IAM_AUTH
from onyx.db.engine import SYNC_DB_API, get_iam_auth_token
from onyx.configs.app_configs import POSTGRES_DB, USE_IAM_AUTH
from onyx.configs.app_configs import POSTGRES_HOST
from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USER
@@ -13,12 +13,11 @@ from sqlalchemy import text
from sqlalchemy.engine.base import Connection
import os
import ssl
import asyncio
import logging
from logging.config import fileConfig
from alembic import context
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy import create_engine
from sqlalchemy.sql.schema import SchemaItem
from onyx.configs.constants import SSL_CERT_FILE
from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
@@ -133,17 +132,32 @@ def provide_iam_token_for_alembic(
cparams["ssl"] = ssl_context
async def run_async_migrations() -> None:
def run_migrations() -> None:
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
engine = create_async_engine(
build_connection_string(),
# Get any environment variables passed through alembic config
env_vars = context.config.attributes.get("env_vars", {})
# Use env vars if provided, otherwise fall back to defaults
postgres_host = env_vars.get("POSTGRES_HOST", POSTGRES_HOST)
postgres_port = env_vars.get("POSTGRES_PORT", POSTGRES_PORT)
postgres_user = env_vars.get("POSTGRES_USER", POSTGRES_USER)
postgres_db = env_vars.get("POSTGRES_DB", POSTGRES_DB)
engine = create_engine(
build_connection_string(
db=postgres_db,
user=postgres_user,
host=postgres_host,
port=postgres_port,
db_api=SYNC_DB_API,
),
poolclass=pool.NullPool,
)
if USE_IAM_AUTH:
@event.listens_for(engine.sync_engine, "do_connect")
@event.listens_for(engine, "do_connect")
def event_provide_iam_token_for_alembic(
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
) -> None:
@@ -152,31 +166,26 @@ async def run_async_migrations() -> None:
if upgrade_all_tenants:
tenant_schemas = get_all_tenant_ids()
for schema in tenant_schemas:
if schema is None:
continue
try:
logger.info(f"Migrating schema: {schema}")
async with engine.connect() as connection:
await connection.run_sync(
do_run_migrations,
schema_name=schema,
create_schema=create_schema,
)
with engine.connect() as connection:
do_run_migrations(connection, schema, create_schema)
except Exception as e:
logger.error(f"Error migrating schema {schema}: {e}")
raise
else:
try:
logger.info(f"Migrating schema: {schema_name}")
async with engine.connect() as connection:
await connection.run_sync(
do_run_migrations,
schema_name=schema_name,
create_schema=create_schema,
)
with engine.connect() as connection:
do_run_migrations(connection, schema_name, create_schema)
except Exception as e:
logger.error(f"Error migrating schema {schema_name}: {e}")
raise
await engine.dispose()
engine.dispose()
def run_migrations_offline() -> None:
@@ -184,18 +193,18 @@ def run_migrations_offline() -> None:
url = build_connection_string()
if upgrade_all_tenants:
engine = create_async_engine(url)
engine = create_engine(url)
if USE_IAM_AUTH:
@event.listens_for(engine.sync_engine, "do_connect")
@event.listens_for(engine, "do_connect")
def event_provide_iam_token_for_alembic_offline(
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
) -> None:
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)
tenant_schemas = get_all_tenant_ids()
engine.sync_engine.dispose()
engine.dispose()
for schema in tenant_schemas:
logger.info(f"Migrating schema: {schema}")
@@ -230,7 +239,7 @@ def run_migrations_offline() -> None:
def run_migrations_online() -> None:
asyncio.run(run_async_migrations())
run_migrations()
if context.is_offline_mode():

View File

@@ -0,0 +1,124 @@
"""Add checkpointing/failure handling
Revision ID: b7a7eee5aa15
Revises: f39c5794c10a
Create Date: 2025-01-24 15:17:36.763172
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "b7a7eee5aa15"
down_revision = "f39c5794c10a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"index_attempt",
sa.Column("checkpoint_pointer", sa.String(), nullable=True),
)
op.add_column(
"index_attempt",
sa.Column("poll_range_start", sa.DateTime(timezone=True), nullable=True),
)
op.add_column(
"index_attempt",
sa.Column("poll_range_end", sa.DateTime(timezone=True), nullable=True),
)
op.create_index(
"ix_index_attempt_cc_pair_settings_poll",
"index_attempt",
[
"connector_credential_pair_id",
"search_settings_id",
"status",
sa.text("time_updated DESC"),
],
)
# Drop the old IndexAttemptError table
op.drop_index("index_attempt_id", table_name="index_attempt_errors")
op.drop_table("index_attempt_errors")
# Create the new version of the table
op.create_table(
"index_attempt_errors",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("index_attempt_id", sa.Integer(), nullable=False),
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=False),
sa.Column("document_id", sa.String(), nullable=True),
sa.Column("document_link", sa.String(), nullable=True),
sa.Column("entity_id", sa.String(), nullable=True),
sa.Column("failed_time_range_start", sa.DateTime(timezone=True), nullable=True),
sa.Column("failed_time_range_end", sa.DateTime(timezone=True), nullable=True),
sa.Column("failure_message", sa.Text(), nullable=False),
sa.Column("is_resolved", sa.Boolean(), nullable=False, default=False),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["index_attempt_id"],
["index_attempt.id"],
),
sa.ForeignKeyConstraint(
["connector_credential_pair_id"],
["connector_credential_pair.id"],
),
)
def downgrade() -> None:
op.execute("SET lock_timeout = '5s'")
# try a few times to drop the table, this has been observed to fail due to other locks
# blocking the drop
NUM_TRIES = 10
for i in range(NUM_TRIES):
try:
op.drop_table("index_attempt_errors")
break
except Exception as e:
if i == NUM_TRIES - 1:
raise e
print(f"Error dropping table: {e}. Retrying...")
op.execute("SET lock_timeout = DEFAULT")
# Recreate the old IndexAttemptError table
op.create_table(
"index_attempt_errors",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("index_attempt_id", sa.Integer(), nullable=True),
sa.Column("batch", sa.Integer(), nullable=True),
sa.Column("doc_summaries", postgresql.JSONB(), nullable=False),
sa.Column("error_msg", sa.Text(), nullable=True),
sa.Column("traceback", sa.Text(), nullable=True),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
),
sa.ForeignKeyConstraint(
["index_attempt_id"],
["index_attempt.id"],
),
)
op.create_index(
"index_attempt_id",
"index_attempt_errors",
["time_created"],
)
op.drop_index("ix_index_attempt_cc_pair_settings_poll")
op.drop_column("index_attempt", "checkpoint_pointer")
op.drop_column("index_attempt", "poll_range_start")
op.drop_column("index_attempt", "poll_range_end")

View File

@@ -5,7 +5,7 @@ from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.slack.connector import get_channels
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
from onyx.connectors.slack.connector import SlackPollConnector
from onyx.connectors.slack.connector import SlackConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -17,7 +17,7 @@ logger = setup_logger()
def _get_slack_document_ids_and_channels(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
) -> dict[str, list[str]]:
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
slack_connector.load_credentials(cc_pair.credential.credential_json)
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)

View File

@@ -94,7 +94,6 @@ from onyx.db.models import User
from onyx.db.users import get_user_by_email
from onyx.redis.redis_pool import get_async_redis_connection
from onyx.redis.redis_pool import get_redis_client
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
@@ -108,6 +107,11 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
class BasicAuthenticationError(HTTPException):
def __init__(self, detail: str):
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
def is_user_admin(user: User | None) -> bool:
if AUTH_TYPE == AuthType.DISABLED:
return True

View File

@@ -36,6 +36,15 @@ beat_task_templates.extend(
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-checkpoint-cleanup",
"task": OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-connector-deletion",
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,

View File

@@ -1,9 +1,10 @@
import multiprocessing
import os
import sys
import time
import traceback
from datetime import datetime
from datetime import timezone
from enum import Enum
from http import HTTPStatus
from time import sleep
from typing import Any
@@ -15,6 +16,7 @@ from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from celery.result import AsyncResult
from celery.states import READY_STATES
from pydantic import BaseModel
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
@@ -26,7 +28,13 @@ from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attem
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
from onyx.background.celery.tasks.indexing.utils import try_creating_indexing_task
from onyx.background.celery.tasks.indexing.utils import validate_indexing_fences
from onyx.background.indexing.checkpointing_utils import cleanup_checkpoint
from onyx.background.indexing.checkpointing_utils import (
get_index_attempts_with_old_checkpoints,
)
from onyx.background.indexing.job_client import SimpleJob
from onyx.background.indexing.job_client import SimpleJobClient
from onyx.background.indexing.job_client import SimpleJobException
from onyx.background.indexing.run_indexing import run_indexing_entrypoint
from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
@@ -34,6 +42,7 @@ from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
@@ -70,6 +79,123 @@ from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
class IndexingWatchdogTerminalStatus(str, Enum):
"""The different statuses the watchdog can finish with.
TODO: create broader success/failure/abort categories
"""
UNDEFINED = "undefined"
SUCCEEDED = "succeeded"
SPAWN_FAILED = "spawn_failed" # connector spawn failed
BLOCKED_BY_DELETION = "blocked_by_deletion"
BLOCKED_BY_STOP_SIGNAL = "blocked_by_stop_signal"
FENCE_NOT_FOUND = "fence_not_found" # fence does not exist
FENCE_READINESS_TIMEOUT = (
"fence_readiness_timeout" # fence exists but wasn't ready within the timeout
)
FENCE_MISMATCH = "fence_mismatch" # task and fence metadata mismatch
TASK_ALREADY_RUNNING = "task_already_running" # task appears to be running already
INDEX_ATTEMPT_MISMATCH = (
"index_attempt_mismatch" # expected index attempt metadata not found in db
)
CONNECTOR_EXCEPTIONED = "connector_exceptioned" # the connector itself exceptioned
WATCHDOG_EXCEPTIONED = "watchdog_exceptioned" # the watchdog exceptioned
# the watchdog received a termination signal
TERMINATED_BY_SIGNAL = "terminated_by_signal"
# the watchdog terminated the task due to no activity
TERMINATED_BY_ACTIVITY_TIMEOUT = "terminated_by_activity_timeout"
OUT_OF_MEMORY = "out_of_memory"
PROCESS_SIGNAL_SIGKILL = "process_signal_sigkill"
@property
def code(self) -> int:
_ENUM_TO_CODE: dict[IndexingWatchdogTerminalStatus, int] = {
IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL: -9,
IndexingWatchdogTerminalStatus.OUT_OF_MEMORY: 137,
IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION: 248,
IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL: 249,
IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND: 250,
IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT: 251,
IndexingWatchdogTerminalStatus.FENCE_MISMATCH: 252,
IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING: 253,
IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH: 254,
IndexingWatchdogTerminalStatus.CONNECTOR_EXCEPTIONED: 255,
}
return _ENUM_TO_CODE[self]
@classmethod
def from_code(cls, code: int) -> "IndexingWatchdogTerminalStatus":
_CODE_TO_ENUM: dict[int, IndexingWatchdogTerminalStatus] = {
-9: IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL,
248: IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION,
249: IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL,
250: IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND,
251: IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT,
252: IndexingWatchdogTerminalStatus.FENCE_MISMATCH,
253: IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING,
254: IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH,
255: IndexingWatchdogTerminalStatus.CONNECTOR_EXCEPTIONED,
}
if code in _CODE_TO_ENUM:
return _CODE_TO_ENUM[code]
return IndexingWatchdogTerminalStatus.UNDEFINED
class SimpleJobResult:
"""The data we want to have when the watchdog finishes"""
def __init__(self) -> None:
self.status = IndexingWatchdogTerminalStatus.UNDEFINED
self.connector_source = None
self.exit_code = None
self.exception_str = None
status: IndexingWatchdogTerminalStatus
connector_source: str | None
exit_code: int | None
exception_str: str | None
class ConnectorIndexingContext(BaseModel):
tenant_id: str | None
cc_pair_id: int
search_settings_id: int
index_attempt_id: int
class ConnectorIndexingLogBuilder:
def __init__(self, ctx: ConnectorIndexingContext):
self.ctx = ctx
def build(self, msg: str, **kwargs: Any) -> str:
msg_final = (
f"{msg}: "
f"tenant_id={self.ctx.tenant_id} "
f"attempt={self.ctx.index_attempt_id} "
f"cc_pair={self.ctx.cc_pair_id} "
f"search_settings={self.ctx.search_settings_id}"
)
# Append extra keyword arguments in logfmt style
if kwargs:
extra_logfmt = " ".join(f"{key}={value}" for key, value in kwargs.items())
msg_final = f"{msg_final} {extra_logfmt}"
return msg_final
def monitor_ccpair_indexing_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
@@ -496,7 +622,6 @@ def connector_indexing_task(
f"search_settings={search_settings_id}"
)
attempt_found = False
n_final_progress: int | None = None
# 20 is the documented default for httpx max_keepalive_connections
@@ -513,19 +638,21 @@ def connector_indexing_task(
r = get_redis_client(tenant_id=tenant_id)
if redis_connector.delete.fenced:
raise RuntimeError(
raise SimpleJobException(
f"Indexing will not start because connector deletion is in progress: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.delete.fence_key}"
f"fence={redis_connector.delete.fence_key}",
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION.code,
)
if redis_connector.stop.fenced:
raise RuntimeError(
raise SimpleJobException(
f"Indexing will not start because a connector stop signal was detected: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.stop.fence_key}"
f"fence={redis_connector.stop.fence_key}",
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL.code,
)
# this wait is needed to avoid a race condition where
@@ -534,19 +661,24 @@ def connector_indexing_task(
start = time.monotonic()
while True:
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
raise ValueError(
raise SimpleJobException(
f"connector_indexing_task - timed out waiting for fence to be ready: "
f"fence={redis_connector.permissions.fence_key}"
f"fence={redis_connector.permissions.fence_key}",
code=IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT.code,
)
if not redis_connector_index.fenced: # The fence must exist
raise ValueError(
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
raise SimpleJobException(
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}",
code=IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND.code,
)
payload = redis_connector_index.payload # The payload must exist
if not payload:
raise ValueError("connector_indexing_task: payload invalid or not found")
raise SimpleJobException(
"connector_indexing_task: payload invalid or not found",
code=IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND.code,
)
if payload.index_attempt_id is None or payload.celery_task_id is None:
logger.info(
@@ -556,10 +688,11 @@ def connector_indexing_task(
continue
if payload.index_attempt_id != index_attempt_id:
raise ValueError(
raise SimpleJobException(
f"connector_indexing_task - id mismatch. Task may be left over from previous run.: "
f"task_index_attempt={index_attempt_id} "
f"payload_index_attempt={payload.index_attempt_id}"
f"payload_index_attempt={payload.index_attempt_id}",
code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code,
)
logger.info(
@@ -583,7 +716,14 @@ def connector_indexing_task(
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
return None
raise SimpleJobException(
f"Indexing task already running, exiting...: "
f"index_attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}",
code=IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING.code,
)
payload.started = datetime.now(timezone.utc)
redis_connector_index.set_fence(payload)
@@ -592,10 +732,10 @@ def connector_indexing_task(
with get_session_with_tenant(tenant_id) as db_session:
attempt = get_index_attempt(db_session, index_attempt_id)
if not attempt:
raise ValueError(
f"Index attempt not found: index_attempt={index_attempt_id}"
raise SimpleJobException(
f"Index attempt not found: index_attempt={index_attempt_id}",
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
)
attempt_found = True
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
@@ -603,16 +743,21 @@ def connector_indexing_task(
)
if not cc_pair:
raise ValueError(f"cc_pair not found: cc_pair={cc_pair_id}")
raise SimpleJobException(
f"cc_pair not found: cc_pair={cc_pair_id}",
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
)
if not cc_pair.connector:
raise ValueError(
f"Connector not found: cc_pair={cc_pair_id} connector={cc_pair.connector_id}"
raise SimpleJobException(
f"Connector not found: cc_pair={cc_pair_id} connector={cc_pair.connector_id}",
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
)
if not cc_pair.credential:
raise ValueError(
f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}"
raise SimpleJobException(
f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}",
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
)
# define a callback class
@@ -650,20 +795,6 @@ def connector_indexing_task(
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
if attempt_found:
try:
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_failed(
index_attempt_id, db_session, failure_reason=str(e)
)
except Exception:
logger.exception(
"Indexing watchdog - transient exception looking up index attempt: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
raise e
finally:
@@ -678,41 +809,49 @@ def connector_indexing_task(
return n_final_progress
def connector_indexing_task_wrapper(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
is_ee: bool,
) -> int | None:
"""Just wraps connector_indexing_task so we can log any exceptions before
re-raising it."""
result: int | None = None
def process_job_result(
job: SimpleJob,
connector_source: str | None,
redis_connector_index: RedisConnectorIndex,
log_builder: ConnectorIndexingLogBuilder,
) -> SimpleJobResult:
result = SimpleJobResult()
result.connector_source = connector_source
try:
result = connector_indexing_task(
index_attempt_id,
cc_pair_id,
search_settings_id,
tenant_id,
is_ee,
)
except Exception:
logger.exception(
f"connector_indexing_task exceptioned: "
f"tenant={tenant_id} "
f"index_attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
if job.process:
result.exit_code = job.process.exitcode
# There is a cloud related bug outside of our code
# where spawned tasks return with an exit code of 1.
# Unfortunately, exceptions also return with an exit code of 1,
# so just raising an exception isn't informative
# Exiting with 255 makes it possible to distinguish between normal exits
# and exceptions.
sys.exit(255)
if job.status != "error":
result.status = IndexingWatchdogTerminalStatus.SUCCEEDED
return result
ignore_exitcode = False
# In EKS, there is an edge case where successful tasks return exit
# code 1 in the cloud due to the set_spawn_method not sticking.
# We've since worked around this, but the following is a safe way to
# work around this issue. Basically, we ignore the job error state
# if the completion signal is OK.
status_int = redis_connector_index.get_completion()
if status_int:
status_enum = HTTPStatus(status_int)
if status_enum == HTTPStatus.OK:
ignore_exitcode = True
if ignore_exitcode:
result.status = IndexingWatchdogTerminalStatus.SUCCEEDED
task_logger.warning(
log_builder.build(
"Indexing watchdog - spawned task has non-zero exit code "
"but completion signal is OK. Continuing...",
exit_code=str(result.exit_code),
)
)
else:
if result.exit_code is not None:
result.status = IndexingWatchdogTerminalStatus.from_code(result.exit_code)
result.exception_str = job.exception()
return result
@@ -730,12 +869,32 @@ def connector_indexing_proxy_task(
search_settings_id: int,
tenant_id: str | None,
) -> None:
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
"""celery out of process task execution strategy is pool=prefork, but it uses fork,
and forking is inherently unstable.
To work around this, we use pool=threads and proxy our work to a spawned task.
TODO(rkuo): refactor this so that there is a single return path where we canonically
log the result of running this function.
"""
start = time.monotonic()
result = SimpleJobResult()
ctx = ConnectorIndexingContext(
tenant_id=tenant_id,
cc_pair_id=cc_pair_id,
search_settings_id=search_settings_id,
index_attempt_id=index_attempt_id,
)
log_builder = ConnectorIndexingLogBuilder(ctx)
task_logger.info(
f"Indexing watchdog - starting: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"mp_start_method={multiprocessing.get_start_method()}"
log_builder.build(
"Indexing watchdog - starting",
mp_start_method=str(multiprocessing.get_start_method()),
)
)
if not self.request.id:
@@ -744,7 +903,7 @@ def connector_indexing_proxy_task(
client = SimpleJobClient()
job = client.submit(
connector_indexing_task_wrapper,
connector_indexing_task,
index_attempt_id,
cc_pair_id,
search_settings_id,
@@ -754,139 +913,223 @@ def connector_indexing_proxy_task(
)
if not job:
result.status = IndexingWatchdogTerminalStatus.SPAWN_FAILED
task_logger.info(
f"Indexing watchdog - spawn failed: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
log_builder.build(
"Indexing watchdog - finished",
status=str(result.status.value),
exit_code=str(result.exit_code),
)
)
return
task_logger.info(
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
task_logger.info(log_builder.build("Indexing watchdog - spawn succeeded"))
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
while True:
sleep(5)
try:
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
if not index_attempt:
raise RuntimeError("Index attempt not found")
# renew watchdog signal (this has a shorter timeout than set_active)
redis_connector_index.set_watchdog(True)
# renew active signal
redis_connector_index.set_active()
# if the job is done, clean up and break
if job.done():
exit_code: int | None
try:
if job.status == "error":
ignore_exitcode = False
exit_code = None
if job.process:
exit_code = job.process.exitcode
# seeing odd behavior where spawned tasks usually return exit code 1 in the cloud,
# even though logging clearly indicates successful completion
# to work around this, we ignore the job error state if the completion signal is OK
status_int = redis_connector_index.get_completion()
if status_int:
status_enum = HTTPStatus(status_int)
if status_enum == HTTPStatus.OK:
ignore_exitcode = True
if not ignore_exitcode:
raise RuntimeError("Spawned task exceptioned.")
task_logger.warning(
"Indexing watchdog - spawned task has non-zero exit code "
"but completion signal is OK. Continuing...: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"exit_code={exit_code}"
)
except Exception:
task_logger.error(
"Indexing watchdog - spawned task exceptioned: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"exit_code={exit_code} "
f"error={job.exception()}"
)
raise
finally:
job.release()
break
# if a termination signal is detected, clean up and break
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
"Indexing watchdog - termination signal detected: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
result.connector_source = (
index_attempt.connector_credential_pair.connector.source.value
)
while True:
sleep(5)
# renew watchdog signal (this has a shorter timeout than set_active)
redis_connector_index.set_watchdog(True)
# renew active signal
redis_connector_index.set_active()
# if the job is done, clean up and break
if job.done():
try:
result = process_job_result(
job, result.connector_source, redis_connector_index, log_builder
)
except Exception:
task_logger.exception(
log_builder.build(
"Indexing watchdog - spawned task exceptioned"
)
)
finally:
job.release()
break
# if a termination signal is detected, clean up and break
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
log_builder.build("Indexing watchdog - termination signal detected")
)
result.status = IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL
break
# if the spawned task is still running, restart the check once again
# if the index attempt is not in a finished status
try:
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_canceled(
index_attempt_id,
db_session,
"Connector termination signal detected",
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
except Exception:
# if the DB exceptions, we'll just get an unfriendly failure message
# in the UI instead of the cancellation message
logger.exception(
"Indexing watchdog - transient exception marking index attempt as canceled: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
# if the DB exceptioned, just restart the check.
# polling the index attempt status doesn't need to be strongly consistent
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception looking up index attempt"
)
)
continue
except Exception:
result.status = IndexingWatchdogTerminalStatus.WATCHDOG_EXCEPTIONED
result.exception_str = traceback.format_exc()
job.cancel()
break
# if the spawned task is still running, restart the check once again
# if the index attempt is not in a finished status
# handle exit and reporting
elapsed = time.monotonic() - start
if result.exception_str is not None:
# print with exception
try:
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
failure_reason = (
f"Spawned task exceptioned: exit_code={result.exit_code}"
)
mark_attempt_failed(
ctx.index_attempt_id,
db_session,
failure_reason=failure_reason,
full_exception_trace=result.exception_str,
)
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
except Exception:
# if the DB exceptioned, just restart the check.
# polling the index attempt status doesn't need to be strongly consistent
logger.exception(
"Indexing watchdog - transient exception looking up index attempt: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as failed"
)
)
continue
normalized_exception_str = "None"
if result.exception_str:
normalized_exception_str = result.exception_str.replace(
"\n", "\\n"
).replace('"', '\\"')
task_logger.warning(
log_builder.build(
"Indexing watchdog - finished",
source=result.connector_source,
status=result.status.value,
exit_code=str(result.exit_code),
exception=f'"{normalized_exception_str}"',
elapsed=f"{elapsed:.2f}s",
)
)
redis_connector_index.set_watchdog(False)
raise RuntimeError(f"Exception encountered: traceback={result.exception_str}")
# print without exception
if result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL:
try:
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_canceled(
index_attempt_id,
db_session,
"Connector termination signal detected",
)
except Exception:
# if the DB exceptions, we'll just get an unfriendly failure message
# in the UI instead of the cancellation message
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as canceled"
)
)
job.cancel()
task_logger.info(
log_builder.build(
"Indexing watchdog - finished",
source=result.connector_source,
status=str(result.status.value),
exit_code=str(result.exit_code),
elapsed=f"{elapsed:.2f}s",
)
)
redis_connector_index.set_watchdog(False)
task_logger.info(
f"Indexing watchdog - finished: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
return
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
soft_time_limit=300,
)
def check_for_checkpoint_cleanup(*, tenant_id: str | None) -> None:
"""Clean up old checkpoints that are older than 7 days."""
locked = False
redis_client = get_redis_client(tenant_id=tenant_id)
lock: RedisLock = redis_client.lock(
OnyxRedisLocks.CHECK_CHECKPOINT_CLEANUP_BEAT_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
if not lock.acquire(blocking=False):
return None
try:
locked = True
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
old_attempts = get_index_attempts_with_old_checkpoints(db_session)
for attempt in old_attempts:
task_logger.info(
f"Cleaning up checkpoint for index attempt {attempt.id}"
)
cleanup_checkpoint_task.apply_async(
kwargs={
"index_attempt_id": attempt.id,
"tenant_id": tenant_id,
},
queue=OnyxCeleryQueues.CHECKPOINT_CLEANUP,
)
except Exception:
task_logger.exception("Unexpected exception during checkpoint cleanup")
return None
finally:
if locked:
if lock.owned():
lock.release()
else:
task_logger.error(
"check_for_checkpoint_cleanup - Lock not owned on completion: "
f"tenant={tenant_id}"
)
@shared_task(
name=OnyxCeleryTask.CLEANUP_CHECKPOINT,
bind=True,
)
def cleanup_checkpoint_task(
self: Task, *, index_attempt_id: int, tenant_id: str | None
) -> None:
"""Clean up a checkpoint for a given index attempt"""
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
cleanup_checkpoint(db_session, index_attempt_id)

View File

@@ -240,7 +240,8 @@ def validate_indexing_fence(
# it would be odd to get here as there isn't that much that can go wrong during
# initial fence setup, but it's still worth making sure we can recover
logger.info(
f"validate_indexing_fence - Resetting fence in basic state without any activity: fence={fence_key}"
f"validate_indexing_fence - "
f"Resetting fence in basic state without any activity: fence={fence_key}"
)
redis_connector_index.reset()
return

View File

@@ -105,6 +105,7 @@ def document_by_cc_pair_cleanup_task(
tenant_id=tenant_id,
chunk_count=chunk_count,
)
delete_documents_complete__no_commit(
db_session=db_session,
document_ids=[document_id],

View File

@@ -1,80 +0,0 @@
"""Experimental functionality related to splitting up indexing
into a series of checkpoints to better handle intermittent failures
/ jobs being killed by cloud providers."""
import datetime
from onyx.configs.app_configs import EXPERIMENTAL_CHECKPOINTING_ENABLED
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import datetime_to_utc
def _2010_dt() -> datetime.datetime:
return datetime.datetime(year=2010, month=1, day=1, tzinfo=datetime.timezone.utc)
def _2020_dt() -> datetime.datetime:
return datetime.datetime(year=2020, month=1, day=1, tzinfo=datetime.timezone.utc)
def _default_end_time(
last_successful_run: datetime.datetime | None,
) -> datetime.datetime:
"""If year is before 2010, go to the beginning of 2010.
If year is 2010-2020, go in 5 year increments.
If year > 2020, then go in 180 day increments.
For connectors that don't support a `filter_by` and instead rely on `sort_by`
for polling, then this will cause a massive duplication of fetches. For these
connectors, you may want to override this function to return a more reasonable
plan (e.g. extending the 2020+ windows to 6 months, 1 year, or higher)."""
last_successful_run = (
datetime_to_utc(last_successful_run) if last_successful_run else None
)
if last_successful_run is None or last_successful_run < _2010_dt():
return _2010_dt()
if last_successful_run < _2020_dt():
return min(last_successful_run + datetime.timedelta(days=365 * 5), _2020_dt())
return last_successful_run + datetime.timedelta(days=180)
def find_end_time_for_indexing_attempt(
last_successful_run: datetime.datetime | None,
# source_type can be used to override the default for certain connectors, currently unused
source_type: DocumentSource,
) -> datetime.datetime | None:
"""Is the current time unless the connector is run over a large period, in which case it is
split up into large time segments that become smaller as it approaches the present
"""
# NOTE: source_type can be used to override the default for certain connectors
end_of_window = _default_end_time(last_successful_run)
now = datetime.datetime.now(tz=datetime.timezone.utc)
if end_of_window < now:
return end_of_window
# None signals that we should index up to current time
return None
def get_time_windows_for_index_attempt(
last_successful_run: datetime.datetime, source_type: DocumentSource
) -> list[tuple[datetime.datetime, datetime.datetime]]:
if not EXPERIMENTAL_CHECKPOINTING_ENABLED:
return [(last_successful_run, datetime.datetime.now(tz=datetime.timezone.utc))]
time_windows: list[tuple[datetime.datetime, datetime.datetime]] = []
start_of_window: datetime.datetime | None = last_successful_run
while start_of_window:
end_of_window = find_end_time_for_indexing_attempt(
last_successful_run=start_of_window, source_type=source_type
)
time_windows.append(
(
start_of_window,
end_of_window or datetime.datetime.now(tz=datetime.timezone.utc),
)
)
start_of_window = end_of_window
return time_windows

View File

@@ -0,0 +1,200 @@
from datetime import datetime
from datetime import timedelta
from io import BytesIO
from sqlalchemy import and_
from sqlalchemy.orm import Session
from onyx.configs.constants import FileOrigin
from onyx.connectors.models import ConnectorCheckpoint
from onyx.db.engine import get_db_current_time
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import get_recent_completed_attempts_for_cc_pair
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
from onyx.utils.object_size_check import deep_getsizeof
logger = setup_logger()
_NUM_RECENT_ATTEMPTS_TO_CONSIDER = 20
_NUM_DOCS_INDEXED_TO_BE_VALID_CHECKPOINT = 100
def _build_checkpoint_pointer(index_attempt_id: int) -> str:
return f"checkpoint_{index_attempt_id}.json"
def save_checkpoint(
db_session: Session, index_attempt_id: int, checkpoint: ConnectorCheckpoint
) -> str:
"""Save a checkpoint for a given index attempt to the file store"""
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
file_store = get_default_file_store(db_session)
file_store.save_file(
file_name=checkpoint_pointer,
content=BytesIO(checkpoint.model_dump_json().encode()),
display_name=checkpoint_pointer,
file_origin=FileOrigin.INDEXING_CHECKPOINT,
file_type="application/json",
)
index_attempt = get_index_attempt(db_session, index_attempt_id)
if not index_attempt:
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
index_attempt.checkpoint_pointer = checkpoint_pointer
db_session.add(index_attempt)
db_session.commit()
return checkpoint_pointer
def load_checkpoint(
db_session: Session, index_attempt_id: int
) -> ConnectorCheckpoint | None:
"""Load a checkpoint for a given index attempt from the file store"""
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
file_store = get_default_file_store(db_session)
try:
checkpoint_io = file_store.read_file(checkpoint_pointer, mode="rb")
checkpoint_data = checkpoint_io.read().decode("utf-8")
return ConnectorCheckpoint.model_validate_json(checkpoint_data)
except RuntimeError:
return None
def get_latest_valid_checkpoint(
db_session: Session,
cc_pair_id: int,
search_settings_id: int,
window_start: datetime,
window_end: datetime,
) -> ConnectorCheckpoint:
"""Get the latest valid checkpoint for a given connector credential pair"""
checkpoint_candidates = get_recent_completed_attempts_for_cc_pair(
cc_pair_id=cc_pair_id,
search_settings_id=search_settings_id,
db_session=db_session,
limit=_NUM_RECENT_ATTEMPTS_TO_CONSIDER,
)
checkpoint_candidates = [
candidate
for candidate in checkpoint_candidates
if (
candidate.poll_range_start == window_start
and candidate.poll_range_end == window_end
and candidate.status == IndexingStatus.FAILED
and candidate.checkpoint_pointer is not None
# we want to make sure that the checkpoint is actually useful
# if it's only gone through a few docs, it's probably not worth
# using. This also avoids weird cases where a connector is basically
# non-functional but still "makes progress" by slowly moving the
# checkpoint forward run after run
and candidate.total_docs_indexed
and candidate.total_docs_indexed > _NUM_DOCS_INDEXED_TO_BE_VALID_CHECKPOINT
)
]
# don't keep using checkpoints if we've had a bunch of failed attempts in a row
# for now, capped at 10
if len(checkpoint_candidates) == _NUM_RECENT_ATTEMPTS_TO_CONSIDER:
logger.warning(
f"{_NUM_RECENT_ATTEMPTS_TO_CONSIDER} consecutive failed attempts found "
f"for cc_pair={cc_pair_id}. Ignoring checkpoint to let the run start "
"from scratch."
)
return ConnectorCheckpoint.build_dummy_checkpoint()
# assumes latest checkpoint is the furthest along. This only isn't true
# if something else has gone wrong.
latest_valid_checkpoint_candidate = (
checkpoint_candidates[0] if checkpoint_candidates else None
)
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
if latest_valid_checkpoint_candidate:
try:
previous_checkpoint = load_checkpoint(
db_session=db_session,
index_attempt_id=latest_valid_checkpoint_candidate.id,
)
except Exception:
logger.exception(
f"Failed to load checkpoint from previous failed attempt with ID "
f"{latest_valid_checkpoint_candidate.id}."
)
previous_checkpoint = None
if previous_checkpoint is not None:
logger.info(
f"Using checkpoint from previous failed attempt with ID "
f"{latest_valid_checkpoint_candidate.id}. Previous checkpoint: "
f"{previous_checkpoint}"
)
save_checkpoint(
db_session=db_session,
index_attempt_id=latest_valid_checkpoint_candidate.id,
checkpoint=previous_checkpoint,
)
checkpoint = previous_checkpoint
return checkpoint
def get_index_attempts_with_old_checkpoints(
db_session: Session, days_to_keep: int = 7
) -> list[IndexAttempt]:
"""Get all index attempts with checkpoints older than the specified number of days.
Args:
db_session: The database session
days_to_keep: Number of days to keep checkpoints for (default: 7)
Returns:
Number of checkpoints deleted
"""
cutoff_date = get_db_current_time(db_session) - timedelta(days=days_to_keep)
# Find all index attempts with checkpoints older than cutoff_date
old_attempts = (
db_session.query(IndexAttempt)
.filter(
and_(
IndexAttempt.checkpoint_pointer.isnot(None),
IndexAttempt.time_created < cutoff_date,
)
)
.all()
)
return old_attempts
def cleanup_checkpoint(db_session: Session, index_attempt_id: int) -> None:
"""Clean up a checkpoint for a given index attempt"""
index_attempt = get_index_attempt(db_session, index_attempt_id)
if not index_attempt:
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
if not index_attempt.checkpoint_pointer:
return None
file_store = get_default_file_store(db_session)
file_store.delete_file(index_attempt.checkpoint_pointer)
index_attempt.checkpoint_pointer = None
db_session.add(index_attempt)
db_session.commit()
return None
def check_checkpoint_size(checkpoint: ConnectorCheckpoint) -> None:
"""Check if the checkpoint content size exceeds the limit (200MB)"""
content_size = deep_getsizeof(checkpoint.checkpoint_content)
if content_size > 200_000_000: # 200MB in bytes
raise ValueError(
f"Checkpoint content size ({content_size} bytes) exceeds 200MB limit"
)

View File

@@ -5,6 +5,8 @@ not follow the expected behavior, etc.
NOTE: cannot use Celery directly due to
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
import multiprocessing as mp
import sys
import traceback
from collections.abc import Callable
from dataclasses import dataclass
from multiprocessing.context import SpawnProcess
@@ -18,6 +20,16 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
class SimpleJobException(Exception):
"""lets us raise an exception that will return a specific error code"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
code: int | None = kwargs.pop("code", None)
self.code = code
super().__init__(*args, **kwargs)
JobStatusType = (
Literal["error"]
| Literal["finished"]
@@ -28,7 +40,10 @@ JobStatusType = (
def _initializer(
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
func: Callable,
queue: mp.Queue,
args: list | tuple,
kwargs: dict[str, Any] | None = None,
) -> Any:
"""Initialize the child process with a fresh SQLAlchemy Engine.
@@ -52,13 +67,29 @@ def _initializer(
)
# Proceed with executing the target function
return func(*args, **kwargs)
try:
return func(*args, **kwargs)
except SimpleJobException as e:
logger.exception("SimpleJob raised a SimpleJobException")
error_msg = traceback.format_exc()
queue.put(error_msg) # Send the exception to the parent process
sys.exit(e.code) # use the given exit code
except Exception:
logger.exception("SimpleJob raised an exception")
error_msg = traceback.format_exc()
queue.put(error_msg) # Send the exception to the parent process
sys.exit(255) # use 255 to indicate a generic exception
def _run_in_process(
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
func: Callable,
queue: mp.Queue,
args: list | tuple,
kwargs: dict[str, Any] | None = None,
) -> None:
_initializer(func, args, kwargs)
_initializer(func, queue, args, kwargs)
@dataclass
@@ -67,6 +98,8 @@ class SimpleJob:
id: int
process: Optional["SpawnProcess"] = None
queue: Optional[mp.Queue] = None
_exception: Optional[str] = None
def cancel(self) -> bool:
return self.release()
@@ -100,9 +133,15 @@ class SimpleJob:
def exception(self) -> str:
"""Needed to match the Dask API, but not implemented since we don't currently
have a way to get back the exception information from the child process."""
return (
f"Job with ID '{self.id}' was killed or encountered an unhandled exception."
)
"""Retrieve exception from the multiprocessing queue if available."""
if self._exception is None and self.queue and not self.queue.empty():
self._exception = self.queue.get() # Get exception from queue
if self._exception:
return self._exception
return f"Job with ID '{self.id}' did not report an exception."
class SimpleJobClient:
@@ -137,8 +176,11 @@ class SimpleJobClient:
# this approach allows us to always "spawn" a new process regardless of
# get_start_method's current setting
ctx = mp.get_context("spawn")
process = ctx.Process(target=_run_in_process, args=(func, args), daemon=True)
job = SimpleJob(id=job_id, process=process)
queue = ctx.Queue()
process = ctx.Process(
target=_run_in_process, args=(func, queue, args), daemon=True
)
job = SimpleJob(id=job_id, process=process, queue=queue)
process.start()
self.jobs[job_id] = job

View File

@@ -0,0 +1,87 @@
import tracemalloc
from onyx.utils.logger import setup_logger
logger = setup_logger()
DANSWER_TRACEMALLOC_FRAMES = 10
class MemoryTracer:
def __init__(self, interval: int = 0, num_print_entries: int = 5):
self.interval = interval
self.num_print_entries = num_print_entries
self.snapshot_first: tracemalloc.Snapshot | None = None
self.snapshot_prev: tracemalloc.Snapshot | None = None
self.snapshot: tracemalloc.Snapshot | None = None
self.counter = 0
def start(self) -> None:
"""Start the memory tracer if interval is greater than 0."""
if self.interval > 0:
logger.debug(f"Memory tracer starting: interval={self.interval}")
tracemalloc.start(DANSWER_TRACEMALLOC_FRAMES)
self._take_snapshot()
def stop(self) -> None:
"""Stop the memory tracer if it's running."""
if self.interval > 0:
self.log_final_diff()
tracemalloc.stop()
logger.debug("Memory tracer stopped.")
def _take_snapshot(self) -> None:
"""Take a snapshot and update internal snapshot states."""
snapshot = tracemalloc.take_snapshot()
# Filter out irrelevant frames
snapshot = snapshot.filter_traces(
(
tracemalloc.Filter(False, tracemalloc.__file__),
tracemalloc.Filter(False, "<frozen importlib._bootstrap>"),
tracemalloc.Filter(False, "<frozen importlib._bootstrap_external>"),
)
)
if not self.snapshot_first:
self.snapshot_first = snapshot
if self.snapshot:
self.snapshot_prev = self.snapshot
self.snapshot = snapshot
def _log_diff(
self, current: tracemalloc.Snapshot, previous: tracemalloc.Snapshot
) -> None:
"""Log the memory difference between two snapshots."""
stats = current.compare_to(previous, "traceback")
for s in stats[: self.num_print_entries]:
logger.debug(f"Tracer diff: {s}")
for line in s.traceback.format():
logger.debug(f"* {line}")
def increment_and_maybe_trace(self) -> None:
"""Increment counter and perform trace if interval is hit."""
if self.interval <= 0:
return
self.counter += 1
if self.counter % self.interval == 0:
logger.debug(
f"Running trace comparison for batch {self.counter}. interval={self.interval}"
)
self._take_snapshot()
if self.snapshot and self.snapshot_prev:
self._log_diff(self.snapshot, self.snapshot_prev)
def log_final_diff(self) -> None:
"""Log the final memory diff between start and end of indexing."""
if self.interval <= 0:
return
logger.debug(
f"Running trace comparison between start and end of indexing. {self.counter} batches processed."
)
self._take_snapshot()
if self.snapshot and self.snapshot_first:
self._log_diff(self.snapshot, self.snapshot_first)

View File

@@ -0,0 +1,40 @@
from datetime import datetime
from pydantic import BaseModel
from onyx.db.models import IndexAttemptError
class IndexAttemptErrorPydantic(BaseModel):
id: int
connector_credential_pair_id: int
document_id: str | None
document_link: str | None
entity_id: str | None
failed_time_range_start: datetime | None
failed_time_range_end: datetime | None
failure_message: str
is_resolved: bool = False
time_created: datetime
index_attempt_id: int
@classmethod
def from_model(cls, model: IndexAttemptError) -> "IndexAttemptErrorPydantic":
return cls(
id=model.id,
connector_credential_pair_id=model.connector_credential_pair_id,
document_id=model.document_id,
document_link=model.document_link,
entity_id=model.entity_id,
failed_time_range_start=model.failed_time_range_start,
failed_time_range_end=model.failed_time_range_end,
failure_message=model.failure_message,
is_resolved=model.is_resolved,
time_created=model.time_created,
index_attempt_id=model.index_attempt_id,
)

View File

@@ -1,5 +1,6 @@
import time
import traceback
from collections import defaultdict
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -7,8 +8,11 @@ from datetime import timezone
from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.background.indexing.checkpointing import get_time_windows_for_index_attempt
from onyx.background.indexing.tracer import OnyxTracer
from onyx.background.indexing.checkpointing_utils import check_checkpoint_size
from onyx.background.indexing.checkpointing_utils import get_latest_valid_checkpoint
from onyx.background.indexing.checkpointing_utils import save_checkpoint
from onyx.background.indexing.memory_tracer import MemoryTracer
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE
@@ -17,6 +21,8 @@ from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MilestoneRecordType
from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import IndexAttemptMetadata
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
@@ -24,15 +30,18 @@ from onyx.db.connector_credential_pair import get_last_successful_attempt_time
from onyx.db.connector_credential_pair import update_connector_credential_pair
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.index_attempt import create_index_attempt_error
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair
from onyx.db.index_attempt import get_recent_completed_attempts_for_cc_pair
from onyx.db.index_attempt import mark_attempt_canceled
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.index_attempt import mark_attempt_partially_succeeded
from onyx.db.index_attempt import mark_attempt_succeeded
from onyx.db.index_attempt import transition_attempt_to_in_progress
from onyx.db.index_attempt import update_docs_indexed
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
from onyx.db.models import IndexingStatus
from onyx.db.models import IndexModelStatus
from onyx.document_index.factory import get_default_document_index
@@ -53,6 +62,7 @@ INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
def _get_connector_runner(
db_session: Session,
attempt: IndexAttempt,
batch_size: int,
start_time: datetime,
end_time: datetime,
tenant_id: str | None,
@@ -100,7 +110,9 @@ def _get_connector_runner(
raise e
return ConnectorRunner(
connector=runnable_connector, time_range=(start_time, end_time)
connector=runnable_connector,
batch_size=batch_size,
time_range=(start_time, end_time),
)
@@ -159,6 +171,66 @@ class RunIndexingContext(BaseModel):
search_settings_status: IndexModelStatus
def _check_connector_and_attempt_status(
db_session_temp: Session, ctx: RunIndexingContext, index_attempt_id: int
) -> None:
"""
Checks the status of the connector credential pair and index attempt.
Raises a RuntimeError if any conditions are not met.
"""
cc_pair_loop = get_connector_credential_pair_from_id(
db_session_temp,
ctx.cc_pair_id,
)
if not cc_pair_loop:
raise RuntimeError(f"CC pair {ctx.cc_pair_id} not found in DB.")
if (
cc_pair_loop.status == ConnectorCredentialPairStatus.PAUSED
and ctx.search_settings_status != IndexModelStatus.FUTURE
) or cc_pair_loop.status == ConnectorCredentialPairStatus.DELETING:
raise RuntimeError("Connector was disabled mid run")
index_attempt_loop = get_index_attempt(db_session_temp, index_attempt_id)
if not index_attempt_loop:
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
if index_attempt_loop.status != IndexingStatus.IN_PROGRESS:
raise RuntimeError(
f"Index Attempt was canceled, status is {index_attempt_loop.status}"
)
def _check_failure_threshold(
total_failures: int,
document_count: int,
batch_num: int,
last_failure: ConnectorFailure | None,
) -> None:
"""Check if we've hit the failure threshold and raise an appropriate exception if so.
We consider the threshold hit if:
1. We have more than 3 failures AND
2. Failures account for more than 10% of processed documents
"""
failure_ratio = total_failures / (document_count or 1)
FAILURE_THRESHOLD = 3
FAILURE_RATIO_THRESHOLD = 0.1
if total_failures > FAILURE_THRESHOLD and failure_ratio > FAILURE_RATIO_THRESHOLD:
logger.error(
f"Connector run failed with '{total_failures}' errors "
f"after '{batch_num}' batches."
)
if last_failure and last_failure.exception:
raise last_failure.exception from last_failure.exception
raise RuntimeError(
f"Connector run encountered too many errors, aborting. "
f"Last error: {last_failure}"
)
def _run_indexing(
db_session: Session,
index_attempt_id: int,
@@ -169,11 +241,8 @@ def _run_indexing(
1. Get documents which are either new or updated from specified application
2. Embed and index these documents into the chosen datastore (vespa)
3. Updates Postgres to record the indexed documents + the outcome of this run
TODO: do not change index attempt statuses here ... instead, set signals in redis
and allow the monitor function to clean them up
"""
start_time = time.time()
start_time = time.monotonic() # jsut used for logging
with get_session_with_tenant(tenant_id) as db_session_temp:
index_attempt_start = get_index_attempt(db_session_temp, index_attempt_id)
@@ -221,6 +290,46 @@ def _run_indexing(
db_session=db_session_temp,
)
)
if last_successful_index_time > POLL_CONNECTOR_OFFSET:
window_start = datetime.fromtimestamp(
last_successful_index_time, tz=timezone.utc
) - timedelta(minutes=POLL_CONNECTOR_OFFSET)
else:
# don't go into "negative" time if we've never indexed before
window_start = datetime.fromtimestamp(0, tz=timezone.utc)
most_recent_attempt = next(
iter(
get_recent_completed_attempts_for_cc_pair(
cc_pair_id=ctx.cc_pair_id,
search_settings_id=index_attempt_start.search_settings_id,
db_session=db_session_temp,
limit=1,
)
),
None,
)
# if the last attempt failed, try and use the same window. This is necessary
# to ensure correctness with checkpointing. If we don't do this, things like
# new slack channels could be missed (since existing slack channels are
# cached as part of the checkpoint).
if (
most_recent_attempt
and most_recent_attempt.poll_range_end
and (
most_recent_attempt.status == IndexingStatus.FAILED
or most_recent_attempt.status == IndexingStatus.CANCELED
)
):
window_end = most_recent_attempt.poll_range_end
else:
window_end = datetime.now(tz=timezone.utc)
# add start/end now that they have been set
index_attempt_start.poll_range_start = window_start
index_attempt_start.poll_range_end = window_end
db_session_temp.add(index_attempt_start)
db_session_temp.commit()
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=index_attempt_start.search_settings,
@@ -234,7 +343,6 @@ def _run_indexing(
)
indexing_pipeline = build_indexing_pipeline(
attempt_id=index_attempt_id,
embedder=embedding_model,
document_index=document_index,
ignore_time_skip=(
@@ -246,63 +354,73 @@ def _run_indexing(
callback=callback,
)
tracer: OnyxTracer
if INDEXING_TRACER_INTERVAL > 0:
logger.debug(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}")
tracer = OnyxTracer()
tracer.start()
tracer.snap()
# Initialize memory tracer. NOTE: won't actually do anything if
# `INDEXING_TRACER_INTERVAL` is 0.
memory_tracer = MemoryTracer(interval=INDEXING_TRACER_INTERVAL)
memory_tracer.start()
index_attempt_md = IndexAttemptMetadata(
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
)
total_failures = 0
batch_num = 0
net_doc_change = 0
document_count = 0
chunk_count = 0
run_end_dt = None
tracer_counter: int
try:
with get_session_with_tenant(tenant_id) as db_session_temp:
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
if not index_attempt:
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
for ind, (window_start, window_end) in enumerate(
get_time_windows_for_index_attempt(
last_successful_run=datetime.fromtimestamp(
last_successful_index_time, tz=timezone.utc
),
source_type=db_connector.source,
)
):
cc_pair_loop: ConnectorCredentialPair | None = None
index_attempt_loop: IndexAttempt | None = None
tracer_counter = 0
try:
window_start = max(
window_start - timedelta(minutes=POLL_CONNECTOR_OFFSET),
datetime(1970, 1, 1, tzinfo=timezone.utc),
connector_runner = _get_connector_runner(
db_session=db_session_temp,
attempt=index_attempt,
batch_size=INDEX_BATCH_SIZE,
start_time=window_start,
end_time=window_end,
tenant_id=tenant_id,
)
with get_session_with_tenant(tenant_id) as db_session_temp:
index_attempt_loop_start = get_index_attempt(
db_session_temp, index_attempt_id
)
if not index_attempt_loop_start:
raise RuntimeError(
f"Index attempt {index_attempt_id} not found in DB."
)
connector_runner = _get_connector_runner(
# don't use a checkpoint if we're explicitly indexing from
# the beginning in order to avoid weird interactions between
# checkpointing / failure handling.
if index_attempt.from_beginning:
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
else:
checkpoint = get_latest_valid_checkpoint(
db_session=db_session_temp,
attempt=index_attempt_loop_start,
start_time=window_start,
end_time=window_end,
tenant_id=tenant_id,
cc_pair_id=ctx.cc_pair_id,
search_settings_id=index_attempt.search_settings_id,
window_start=window_start,
window_end=window_end,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.snap()
for doc_batch in connector_runner.run():
unresolved_errors = get_index_attempt_errors_for_cc_pair(
cc_pair_id=ctx.cc_pair_id,
unresolved_only=True,
db_session=db_session_temp,
)
doc_id_to_unresolved_errors: dict[
str, list[IndexAttemptError]
] = defaultdict(list)
for error in unresolved_errors:
if error.document_id:
doc_id_to_unresolved_errors[error.document_id].append(error)
entity_based_unresolved_errors = [
error for error in unresolved_errors if error.entity_id
]
while checkpoint.has_more:
logger.info(
f"Running '{ctx.source}' connector with checkpoint: {checkpoint}"
)
for document_batch, failure, next_checkpoint in connector_runner.run(
checkpoint
):
# Check if connector is disabled mid run and stop if so unless it's the secondary
# index being built. We want to populate it even for paused connectors
# Often paused connectors are sources that aren't updated frequently but the
@@ -313,41 +431,37 @@ def _run_indexing(
# TODO: should we move this into the above callback instead?
with get_session_with_tenant(tenant_id) as db_session_temp:
cc_pair_loop = get_connector_credential_pair_from_id(
db_session_temp,
ctx.cc_pair_id,
# will exception if the connector/index attempt is marked as paused/failed
_check_connector_and_attempt_status(
db_session_temp, ctx, index_attempt_id
)
if not cc_pair_loop:
raise RuntimeError(f"CC pair {ctx.cc_pair_id} not found in DB.")
if (
(
cc_pair_loop.status == ConnectorCredentialPairStatus.PAUSED
and ctx.search_settings_status != IndexModelStatus.FUTURE
# save record of any failures at the connector level
if failure is not None:
total_failures += 1
with get_session_with_tenant(tenant_id) as db_session_temp:
create_index_attempt_error(
index_attempt_id,
ctx.cc_pair_id,
failure,
db_session_temp,
)
# if it's deleting, we don't care if this is a secondary index
or cc_pair_loop.status == ConnectorCredentialPairStatus.DELETING
):
# let the `except` block handle this
raise RuntimeError("Connector was disabled mid run")
index_attempt_loop = get_index_attempt(
db_session_temp, index_attempt_id
_check_failure_threshold(
total_failures, document_count, batch_num, failure
)
if not index_attempt_loop:
raise RuntimeError(
f"Index attempt {index_attempt_id} not found in DB."
)
if index_attempt_loop.status != IndexingStatus.IN_PROGRESS:
# Likely due to user manually disabling it or model swap
raise RuntimeError(
f"Index Attempt was canceled, status is {index_attempt_loop.status}"
)
# save the new checkpoint (if one is provided)
if next_checkpoint:
checkpoint = next_checkpoint
# below is all document processing logic, so if no batch we can just continue
if document_batch is None:
continue
batch_description = []
doc_batch_cleaned = strip_null_characters(doc_batch)
doc_batch_cleaned = strip_null_characters(document_batch)
for doc in doc_batch_cleaned:
batch_description.append(doc.to_short_descriptor())
@@ -377,15 +491,51 @@ def _run_indexing(
chunk_count += index_pipeline_result.total_chunks
document_count += index_pipeline_result.total_docs
# commit transaction so that the `update` below begins
# with a brand new transaction. Postgres uses the start
# of the transactions when computing `NOW()`, so if we have
# a long running transaction, the `time_updated` field will
# be inaccurate
db_session.commit()
# resolve errors for documents that were successfully indexed
failed_document_ids = [
failure.failed_document.document_id
for failure in index_pipeline_result.failures
if failure.failed_document
]
successful_document_ids = [
document.id
for document in document_batch
if document.id not in failed_document_ids
]
for document_id in successful_document_ids:
with get_session_with_tenant(tenant_id) as db_session_temp:
if document_id in doc_id_to_unresolved_errors:
logger.info(
f"Resolving IndexAttemptError for document '{document_id}'"
)
for error in doc_id_to_unresolved_errors[document_id]:
error.is_resolved = True
db_session_temp.add(error)
db_session_temp.commit()
# add brand new failures
if index_pipeline_result.failures:
total_failures += len(index_pipeline_result.failures)
with get_session_with_tenant(tenant_id) as db_session_temp:
for failure in index_pipeline_result.failures:
create_index_attempt_error(
index_attempt_id,
ctx.cc_pair_id,
failure,
db_session_temp,
)
_check_failure_threshold(
total_failures,
document_count,
batch_num,
index_pipeline_result.failures[-1],
)
# This new value is updated every batch, so UI can refresh per batch update
with get_session_with_tenant(tenant_id) as db_session_temp:
# NOTE: Postgres uses the start of the transactions when computing `NOW()`
# so we need either to commit() or to use a new session
update_docs_indexed(
db_session=db_session_temp,
index_attempt_id=index_attempt_id,
@@ -397,126 +547,77 @@ def _run_indexing(
if callback:
callback.progress("_run_indexing", len(doc_batch_cleaned))
tracer_counter += 1
if (
INDEXING_TRACER_INTERVAL > 0
and tracer_counter % INDEXING_TRACER_INTERVAL == 0
):
logger.debug(
f"Running trace comparison for batch {tracer_counter}. interval={INDEXING_TRACER_INTERVAL}"
)
tracer.snap()
tracer.log_previous_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
memory_tracer.increment_and_maybe_trace()
run_end_dt = window_end
if ctx.is_primary:
with get_session_with_tenant(tenant_id) as db_session_temp:
# `make sure the checkpoints aren't getting too large`at some regular interval
CHECKPOINT_SIZE_CHECK_INTERVAL = 100
if batch_num % CHECKPOINT_SIZE_CHECK_INTERVAL == 0:
check_checkpoint_size(checkpoint)
# save latest checkpoint
with get_session_with_tenant(tenant_id) as db_session_temp:
save_checkpoint(
db_session=db_session_temp,
index_attempt_id=index_attempt_id,
checkpoint=checkpoint,
)
except Exception as e:
logger.exception(
"Connector run exceptioned after elapsed time: "
f"{time.monotonic() - start_time} seconds"
)
if isinstance(e, ConnectorStopSignal):
with get_session_with_tenant(tenant_id) as db_session_temp:
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
reason=str(e),
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
net_docs=net_doc_change,
run_dt=run_end_dt,
)
except Exception as e:
logger.exception(
f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds"
)
if isinstance(e, ConnectorStopSignal):
with get_session_with_tenant(tenant_id) as db_session_temp:
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
reason=str(e),
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
net_docs=net_doc_change,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
else:
# Only mark the attempt as a complete failure if this is the first indexing window.
# Otherwise, some progress was made - the next run will not start from the beginning.
# In this case, it is not accurate to mark it as a failure. When the next run begins,
# if that fails immediately, it will be marked as a failure.
#
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
# to give better clarity in the UI, as the next run will never happen.
if (
ind == 0
or (
cc_pair_loop is not None and not cc_pair_loop.status.is_active()
)
or (
index_attempt_loop is not None
and index_attempt_loop.status != IndexingStatus.IN_PROGRESS
)
):
with get_session_with_tenant(tenant_id) as db_session_temp:
mark_attempt_failed(
index_attempt_id,
db_session_temp,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
net_docs=net_doc_change,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
# break => similar to success case. As mentioned above, if the next run fails for the same
# reason it will then be marked as a failure
break
if INDEXING_TRACER_INTERVAL > 0:
logger.debug(
f"Running trace comparison between start and end of indexing. {tracer_counter} batches processed."
)
tracer.snap()
tracer.log_first_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
tracer.stop()
logger.debug("Memory tracer stopped.")
if (
index_attempt_md.num_exceptions > 0
and index_attempt_md.num_exceptions >= batch_num
):
with get_session_with_tenant(tenant_id) as db_session_temp:
mark_attempt_failed(
index_attempt_id,
db_session_temp,
failure_reason="All batches exceptioned.",
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
memory_tracer.stop()
raise e
else:
with get_session_with_tenant(tenant_id) as db_session_temp:
mark_attempt_failed(
index_attempt_id,
db_session_temp,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
raise Exception(
f"Connector failed - All batches exceptioned: batches={batch_num}"
)
elapsed_time = time.time() - start_time
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
net_docs=net_doc_change,
)
memory_tracer.stop()
raise e
memory_tracer.stop()
elapsed_time = time.monotonic() - start_time
with get_session_with_tenant(tenant_id) as db_session_temp:
if index_attempt_md.num_exceptions == 0:
# resolve entity-based errors
for error in entity_based_unresolved_errors:
logger.info(f"Resolving IndexAttemptError for entity '{error.entity_id}'")
error.is_resolved = True
db_session_temp.add(error)
db_session_temp.commit()
if total_failures == 0:
mark_attempt_succeeded(index_attempt_id, db_session_temp)
create_milestone_and_report(
@@ -535,7 +636,7 @@ def _run_indexing(
mark_attempt_partially_succeeded(index_attempt_id, db_session_temp)
logger.info(
f"Connector completed with some errors: "
f"exceptions={index_attempt_md.num_exceptions} "
f"failures={total_failures} "
f"batches={batch_num} "
f"docs={document_count} "
f"chunks={chunk_count} "
@@ -547,7 +648,7 @@ def _run_indexing(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
run_dt=run_end_dt,
run_dt=window_end,
)
@@ -558,46 +659,43 @@ def run_indexing_entrypoint(
is_ee: bool = False,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
try:
if is_ee:
global_version.set_ee()
"""Don't swallow exceptions here ... propagate them up."""
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
TaskAttemptSingleton.set_cc_and_index_id(
index_attempt_id, connector_credential_pair_id
if is_ee:
global_version.set_ee()
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
TaskAttemptSingleton.set_cc_and_index_id(
index_attempt_id, connector_credential_pair_id
)
with get_session_with_tenant(tenant_id) as db_session:
# TODO: remove long running session entirely
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
tenant_str = ""
if tenant_id is not None:
tenant_str = f" for tenant {tenant_id}"
connector_name = attempt.connector_credential_pair.connector.name
connector_config = (
attempt.connector_credential_pair.connector.connector_specific_config
)
with get_session_with_tenant(tenant_id) as db_session:
# TODO: remove long running session entirely
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
credential_id = attempt.connector_credential_pair.credential_id
tenant_str = ""
if tenant_id is not None:
tenant_str = f" for tenant {tenant_id}"
logger.info(
f"Indexing starting{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
connector_name = attempt.connector_credential_pair.connector.name
connector_config = (
attempt.connector_credential_pair.connector.connector_specific_config
)
credential_id = attempt.connector_credential_pair.credential_id
with get_session_with_tenant(tenant_id) as db_session:
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
logger.info(
f"Indexing starting{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
with get_session_with_tenant(tenant_id) as db_session:
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
logger.info(
f"Indexing finished{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
except Exception as e:
logger.exception(
f"Indexing job with ID '{index_attempt_id}' for tenant {tenant_id} failed due to {e}"
)
logger.info(
f"Indexing finished{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)

View File

@@ -1,77 +0,0 @@
import tracemalloc
from onyx.utils.logger import setup_logger
logger = setup_logger()
DANSWER_TRACEMALLOC_FRAMES = 10
class OnyxTracer:
def __init__(self) -> None:
self.snapshot_first: tracemalloc.Snapshot | None = None
self.snapshot_prev: tracemalloc.Snapshot | None = None
self.snapshot: tracemalloc.Snapshot | None = None
def start(self) -> None:
tracemalloc.start(DANSWER_TRACEMALLOC_FRAMES)
def stop(self) -> None:
tracemalloc.stop()
def snap(self) -> None:
snapshot = tracemalloc.take_snapshot()
# Filter out irrelevant frames (e.g., from tracemalloc itself or importlib)
snapshot = snapshot.filter_traces(
(
tracemalloc.Filter(False, tracemalloc.__file__), # Exclude tracemalloc
tracemalloc.Filter(
False, "<frozen importlib._bootstrap>"
), # Exclude importlib
tracemalloc.Filter(
False, "<frozen importlib._bootstrap_external>"
), # Exclude external importlib
)
)
if not self.snapshot_first:
self.snapshot_first = snapshot
if self.snapshot:
self.snapshot_prev = self.snapshot
self.snapshot = snapshot
def log_snapshot(self, numEntries: int) -> None:
if not self.snapshot:
return
stats = self.snapshot.statistics("traceback")
for s in stats[:numEntries]:
logger.debug(f"Tracer snap: {s}")
for line in s.traceback:
logger.debug(f"* {line}")
@staticmethod
def log_diff(
snap_current: tracemalloc.Snapshot,
snap_previous: tracemalloc.Snapshot,
numEntries: int,
) -> None:
stats = snap_current.compare_to(snap_previous, "traceback")
for s in stats[:numEntries]:
logger.debug(f"Tracer diff: {s}")
for line in s.traceback.format():
logger.debug(f"* {line}")
def log_previous_diff(self, numEntries: int) -> None:
if not self.snapshot or not self.snapshot_prev:
return
OnyxTracer.log_diff(self.snapshot, self.snapshot_prev, numEntries)
def log_first_diff(self, numEntries: int) -> None:
if not self.snapshot or not self.snapshot_first:
return
OnyxTracer.log_diff(self.snapshot, self.snapshot_first, numEntries)

View File

@@ -169,6 +169,11 @@ POSTGRES_API_SERVER_POOL_SIZE = int(
POSTGRES_API_SERVER_POOL_OVERFLOW = int(
os.environ.get("POSTGRES_API_SERVER_POOL_OVERFLOW") or 10
)
# defaults to False
# generally should only be used for
POSTGRES_USE_NULL_POOL = os.environ.get("POSTGRES_USE_NULL_POOL", "").lower() == "true"
# defaults to False
POSTGRES_POOL_PRE_PING = os.environ.get("POSTGRES_POOL_PRE_PING", "").lower() == "true"
@@ -621,6 +626,8 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
MOCK_CONNECTOR_FILE_PATH = os.environ.get("MOCK_CONNECTOR_FILE_PATH")
TEST_ENV = os.environ.get("TEST_ENV", "").lower() == "true"
# Set to true to mock LLM responses for testing purposes

View File

@@ -165,6 +165,9 @@ class DocumentSource(str, Enum):
EGNYTE = "egnyte"
AIRTABLE = "airtable"
# Special case just for integration tests
MOCK_CONNECTOR = "mock_connector"
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
@@ -243,6 +246,7 @@ class FileOrigin(str, Enum):
CHAT_IMAGE_GEN = "chat_image_gen"
CONNECTOR = "connector"
GENERATED_REPORT = "generated_report"
INDEXING_CHECKPOINT = "indexing_checkpoint"
OTHER = "other"
@@ -274,6 +278,7 @@ class OnyxCeleryQueues:
DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert"
CONNECTOR_DELETION = "connector_deletion"
LLM_MODEL_UPDATE = "llm_model_update"
CHECKPOINT_CLEANUP = "checkpoint_cleanup"
# Heavy queue
CONNECTOR_PRUNING = "connector_pruning"
@@ -293,6 +298,7 @@ class OnyxRedisLocks:
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat"
CHECK_INDEXING_BEAT_LOCK = "da_lock:check_indexing_beat"
CHECK_CHECKPOINT_CLEANUP_BEAT_LOCK = "da_lock:check_checkpoint_cleanup_beat"
CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK = (
"da_lock:check_connector_doc_permissions_sync_beat"
)
@@ -368,6 +374,10 @@ class OnyxCeleryTask:
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
# Connector checkpoint cleanup
CHECK_FOR_CHECKPOINT_CLEANUP = "check_for_checkpoint_cleanup"
CLEANUP_CHECKPOINT = "cleanup_checkpoint"
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
MONITOR_CELERY_QUEUES = "monitor_celery_queues"

View File

@@ -0,0 +1,6 @@
import os
SKIP_CONNECTION_POOL_WARM_UP = (
os.environ.get("SKIP_CONNECTION_POOL_WARM_UP", "").lower() == "true"
)

View File

@@ -245,7 +245,7 @@ class AirtableConnector(LoadConnector):
return [(" ".join(combined) if combined else str(field_info), default_link)]
if isinstance(field_info, list):
return [(item, default_link) for item in field_info]
return [(str(item), default_link) for item in field_info]
return [(str(field_info), default_link)]
@@ -268,7 +268,7 @@ class AirtableConnector(LoadConnector):
table_id: str,
view_id: str | None,
record_id: str,
) -> tuple[list[Section], dict[str, Any]]:
) -> tuple[list[Section], dict[str, str | list[str]]]:
"""
Process a single Airtable field and return sections or metadata.
@@ -342,7 +342,7 @@ class AirtableConnector(LoadConnector):
record_id = record["id"]
fields = record["fields"]
sections: list[Section] = []
metadata: dict[str, Any] = {}
metadata: dict[str, str | list[str]] = {}
# Get primary field value if it exists
primary_field_value = (

View File

@@ -1,11 +1,16 @@
import sys
import time
from collections.abc import Generator
from datetime import datetime
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.utils.logger import setup_logger
@@ -15,48 +20,139 @@ logger = setup_logger()
TimeRange = tuple[datetime, datetime]
class CheckpointOutputWrapper:
"""
Wraps a CheckpointOutput generator to give things back in a more digestible format.
The connector format is easier for the connector implementor (e.g. it enforces exactly
one new checkpoint is returned AND that the checkpoint is at the end), thus the different
formats.
"""
def __init__(self) -> None:
self.next_checkpoint: ConnectorCheckpoint | None = None
def __call__(
self,
checkpoint_connector_generator: CheckpointOutput,
) -> Generator[
tuple[Document | None, ConnectorFailure | None, ConnectorCheckpoint | None],
None,
None,
]:
# grabs the final return value and stores it in the `next_checkpoint` variable
def _inner_wrapper(
checkpoint_connector_generator: CheckpointOutput,
) -> CheckpointOutput:
self.next_checkpoint = yield from checkpoint_connector_generator
return self.next_checkpoint # not used
for document_or_failure in _inner_wrapper(checkpoint_connector_generator):
if isinstance(document_or_failure, Document):
yield document_or_failure, None, None
elif isinstance(document_or_failure, ConnectorFailure):
yield None, document_or_failure, None
else:
raise ValueError(
f"Invalid document_or_failure type: {type(document_or_failure)}"
)
if self.next_checkpoint is None:
raise RuntimeError(
"Checkpoint is None. This should never happen - the connector should always return a checkpoint."
)
yield None, None, self.next_checkpoint
class ConnectorRunner:
"""
Handles:
- Batching
- Additional exception logging
- Combining different connector types to a single interface
"""
def __init__(
self,
connector: BaseConnector,
batch_size: int,
time_range: TimeRange | None = None,
fail_loudly: bool = False,
):
self.connector = connector
self.time_range = time_range
self.batch_size = batch_size
if isinstance(self.connector, PollConnector):
if time_range is None:
raise ValueError("time_range is required for PollConnector")
self.doc_batch: list[Document] = []
self.doc_batch_generator = self.connector.poll_source(
time_range[0].timestamp(), time_range[1].timestamp()
)
elif isinstance(self.connector, LoadConnector):
if time_range and fail_loudly:
raise ValueError(
"time_range specified, but passed in connector is not a PollConnector"
)
self.doc_batch_generator = self.connector.load_from_state()
else:
raise ValueError(f"Invalid connector. type: {type(self.connector)}")
def run(self) -> GenerateDocumentsOutput:
def run(
self, checkpoint: ConnectorCheckpoint
) -> Generator[
tuple[
list[Document] | None, ConnectorFailure | None, ConnectorCheckpoint | None
],
None,
None,
]:
"""Adds additional exception logging to the connector."""
try:
start = time.monotonic()
for batch in self.doc_batch_generator:
# to know how long connector is taking
logger.debug(
f"Connector took {time.monotonic() - start} seconds to build a batch."
)
yield batch
if isinstance(self.connector, CheckpointConnector):
if self.time_range is None:
raise ValueError("time_range is required for CheckpointConnector")
start = time.monotonic()
checkpoint_connector_generator = self.connector.load_from_checkpoint(
start=self.time_range[0].timestamp(),
end=self.time_range[1].timestamp(),
checkpoint=checkpoint,
)
next_checkpoint: ConnectorCheckpoint | None = None
# this is guaranteed to always run at least once with next_checkpoint being non-None
for document, failure, next_checkpoint in CheckpointOutputWrapper()(
checkpoint_connector_generator
):
if document is not None:
self.doc_batch.append(document)
if failure is not None:
yield None, failure, None
if len(self.doc_batch) >= self.batch_size:
yield self.doc_batch, None, None
self.doc_batch = []
# yield remaining documents
if len(self.doc_batch) > 0:
yield self.doc_batch, None, None
self.doc_batch = []
yield None, None, next_checkpoint
logger.debug(
f"Connector took {time.monotonic() - start} seconds to get to the next checkpoint."
)
else:
finished_checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
finished_checkpoint.has_more = False
if isinstance(self.connector, PollConnector):
if self.time_range is None:
raise ValueError("time_range is required for PollConnector")
for document_batch in self.connector.poll_source(
start=self.time_range[0].timestamp(),
end=self.time_range[1].timestamp(),
):
yield document_batch, None, None
yield None, None, finished_checkpoint
elif isinstance(self.connector, LoadConnector):
for document_batch in self.connector.load_from_state():
yield document_batch, None, None
yield None, None, finished_checkpoint
else:
raise ValueError(f"Invalid connector. type: {type(self.connector)}")
except Exception:
exc_type, _, exc_traceback = sys.exc_info()
@@ -76,6 +172,6 @@ class ConnectorRunner:
)
logger.error(
f"Error in connector. type: {exc_type};\n"
f"local_vars below -> \n{local_vars_str}"
f"local_vars below -> \n{local_vars_str[:1024]}"
)
raise

View File

@@ -30,12 +30,14 @@ from onyx.connectors.google_site.connector import GoogleSitesConnector
from onyx.connectors.guru.connector import GuruConnector
from onyx.connectors.hubspot.connector import HubSpotConnector
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import EventConnector
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.linear.connector import LinearConnector
from onyx.connectors.loopio.connector import LoopioConnector
from onyx.connectors.mediawiki.wiki import MediaWikiConnector
from onyx.connectors.mock_connector.connector import MockConnector
from onyx.connectors.models import InputType
from onyx.connectors.notion.connector import NotionConnector
from onyx.connectors.onyx_jira.connector import JiraConnector
@@ -43,7 +45,7 @@ from onyx.connectors.productboard.connector import ProductboardConnector
from onyx.connectors.salesforce.connector import SalesforceConnector
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.connectors.slab.connector import SlabConnector
from onyx.connectors.slack.connector import SlackPollConnector
from onyx.connectors.slack.connector import SlackConnector
from onyx.connectors.teams.connector import TeamsConnector
from onyx.connectors.web.connector import WebConnector
from onyx.connectors.wikipedia.connector import WikipediaConnector
@@ -66,8 +68,8 @@ def identify_connector_class(
DocumentSource.WEB: WebConnector,
DocumentSource.FILE: LocalFileConnector,
DocumentSource.SLACK: {
InputType.POLL: SlackPollConnector,
InputType.SLIM_RETRIEVAL: SlackPollConnector,
InputType.POLL: SlackConnector,
InputType.SLIM_RETRIEVAL: SlackConnector,
},
DocumentSource.GITHUB: GithubConnector,
DocumentSource.GMAIL: GmailConnector,
@@ -109,6 +111,8 @@ def identify_connector_class(
DocumentSource.FIREFLIES: FirefliesConnector,
DocumentSource.EGNYTE: EgnyteConnector,
DocumentSource.AIRTABLE: AirtableConnector,
# just for integration tests
DocumentSource.MOCK_CONNECTOR: MockConnector,
}
connector_by_source = connector_map.get(source, {})
@@ -125,10 +129,23 @@ def identify_connector_class(
if any(
[
input_type == InputType.LOAD_STATE
and not issubclass(connector, LoadConnector),
input_type == InputType.POLL and not issubclass(connector, PollConnector),
input_type == InputType.EVENT and not issubclass(connector, EventConnector),
(
input_type == InputType.LOAD_STATE
and not issubclass(connector, LoadConnector)
),
(
input_type == InputType.POLL
# either poll or checkpoint works for this, in the future
# all connectors should be checkpoint connectors
and (
not issubclass(connector, PollConnector)
and not issubclass(connector, CheckpointConnector)
)
),
(
input_type == InputType.EVENT
and not issubclass(connector, EventConnector)
),
]
):
raise ConnectorMissingException(

View File

@@ -1,10 +1,13 @@
import abc
from collections.abc import Generator
from collections.abc import Iterator
from typing import Any
from pydantic import BaseModel
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
@@ -14,6 +17,7 @@ SecondsSinceUnixEpoch = float
GenerateDocumentsOutput = Iterator[list[Document]]
GenerateSlimDocumentOutput = Iterator[list[SlimDocument]]
CheckpointOutput = Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]
class BaseConnector(abc.ABC):
@@ -105,3 +109,33 @@ class EventConnector(BaseConnector):
@abc.abstractmethod
def handle_event(self, event: Any) -> GenerateDocumentsOutput:
raise NotImplementedError
class CheckpointConnector(BaseConnector):
@abc.abstractmethod
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> CheckpointOutput:
"""Yields back documents or failures. Final return is the new checkpoint.
Final return can be access via either:
```
try:
for document_or_failure in connector.load_from_checkpoint(start, end, checkpoint):
print(document_or_failure)
except StopIteration as e:
checkpoint = e.value # Extracting the return value
print(checkpoint)
```
OR
```
checkpoint = yield from connector.load_from_checkpoint(start, end, checkpoint)
```
"""
raise NotImplementedError

View File

@@ -0,0 +1,86 @@
from typing import Any
import httpx
from pydantic import BaseModel
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.utils.logger import setup_logger
logger = setup_logger()
class SingleConnectorYield(BaseModel):
documents: list[Document]
checkpoint: ConnectorCheckpoint
failures: list[ConnectorFailure]
unhandled_exception: str | None = None
class MockConnector(CheckpointConnector):
def __init__(
self,
mock_server_host: str,
mock_server_port: int,
) -> None:
self.mock_server_host = mock_server_host
self.mock_server_port = mock_server_port
self.client = httpx.Client(timeout=30.0)
self.connector_yields: list[SingleConnectorYield] | None = None
self.current_yield_index: int = 0
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
response = self.client.get(self._get_mock_server_url("get-documents"))
response.raise_for_status()
data = response.json()
self.connector_yields = [
SingleConnectorYield(**yield_data) for yield_data in data
]
return None
def _get_mock_server_url(self, endpoint: str) -> str:
return f"http://{self.mock_server_host}:{self.mock_server_port}/{endpoint}"
def _save_checkpoint(self, checkpoint: ConnectorCheckpoint) -> None:
response = self.client.post(
self._get_mock_server_url("add-checkpoint"),
json=checkpoint.model_dump(mode="json"),
)
response.raise_for_status()
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> CheckpointOutput:
if self.connector_yields is None:
raise ValueError("No connector yields configured")
# Save the checkpoint to the mock server
self._save_checkpoint(checkpoint)
yield_index = self.current_yield_index
self.current_yield_index += 1
current_yield = self.connector_yields[yield_index]
# If the current yield has an unhandled exception, raise it
# This is used to simulate an unhandled failure in the connector.
if current_yield.unhandled_exception:
raise RuntimeError(current_yield.unhandled_exception)
# yield all documents
for document in current_yield.documents:
yield document
for failure in current_yield.failures:
yield failure
return current_yield.checkpoint

View File

@@ -3,6 +3,7 @@ from enum import Enum
from typing import Any
from pydantic import BaseModel
from pydantic import model_validator
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import INDEX_SEPARATOR
@@ -187,36 +188,48 @@ class SlimDocument(BaseModel):
perm_sync_data: Any | None = None
class DocumentErrorSummary(BaseModel):
id: str
semantic_id: str
section_link: str | None
@classmethod
def from_document(cls, doc: Document) -> "DocumentErrorSummary":
section_link = doc.sections[0].link if len(doc.sections) > 0 else None
return cls(
id=doc.id, semantic_id=doc.semantic_identifier, section_link=section_link
)
@classmethod
def from_dict(cls, data: dict) -> "DocumentErrorSummary":
return cls(
id=str(data.get("id")),
semantic_id=str(data.get("semantic_id")),
section_link=str(data.get("section_link")),
)
def to_dict(self) -> dict[str, str | None]:
return {
"id": self.id,
"semantic_id": self.semantic_id,
"section_link": self.section_link,
}
class IndexAttemptMetadata(BaseModel):
batch_num: int | None = None
num_exceptions: int = 0
connector_id: int
credential_id: int
class ConnectorCheckpoint(BaseModel):
# TODO: maybe move this to something disk-based to handle extremely large checkpoints?
checkpoint_content: dict
has_more: bool
@classmethod
def build_dummy_checkpoint(cls) -> "ConnectorCheckpoint":
return ConnectorCheckpoint(checkpoint_content={}, has_more=True)
class DocumentFailure(BaseModel):
document_id: str
document_link: str | None = None
class EntityFailure(BaseModel):
entity_id: str
missed_time_range: tuple[datetime, datetime] | None = None
class ConnectorFailure(BaseModel):
failed_document: DocumentFailure | None = None
failed_entity: EntityFailure | None = None
failure_message: str
exception: Exception | None = None
model_config = {"arbitrary_types_allowed": True}
@model_validator(mode="before")
def check_failed_fields(cls, values: dict) -> dict:
failed_document = values.get("failed_document")
failed_entity = values.get("failed_entity")
if (failed_document is None and failed_entity is None) or (
failed_document is not None and failed_entity is not None
):
raise ValueError(
"Exactly one of 'failed_document' or 'failed_entity' must be specified."
)
return values

View File

@@ -1,10 +1,16 @@
import contextvars
import copy
import re
from collections.abc import Callable
from collections.abc import Generator
from concurrent.futures import as_completed
from concurrent.futures import Future
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from typing import TypedDict
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
@@ -12,14 +18,18 @@ from slack_sdk.errors import SlackApiError
from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import EntityFailure
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.slack.utils import expert_info_from_slack_id
@@ -33,6 +43,8 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_SLACK_LIMIT = 900
ChannelType = dict[str, Any]
MessageType = dict[str, Any]
@@ -40,6 +52,13 @@ MessageType = dict[str, Any]
ThreadType = list[MessageType]
class SlackCheckpointContent(TypedDict):
channel_ids: list[str]
channel_completion_map: dict[str, str]
current_channel: ChannelType | None
seen_thread_ts: list[str]
def _collect_paginated_channels(
client: WebClient,
exclude_archived: bool,
@@ -140,6 +159,10 @@ def get_latest_message_time(thread: ThreadType) -> datetime:
return datetime.fromtimestamp(max_ts, tz=timezone.utc)
def _build_doc_id(channel_id: str, thread_ts: str) -> str:
return f"{channel_id}__{thread_ts}"
def thread_to_doc(
channel: ChannelType,
thread: ThreadType,
@@ -182,7 +205,7 @@ def thread_to_doc(
)
return Document(
id=f"{channel_id}__{thread[0]['ts']}",
id=_build_doc_id(channel_id=channel_id, thread_ts=thread[0]["ts"]),
sections=[
Section(
link=get_message_link(event=m, client=client, channel_id=channel_id),
@@ -267,64 +290,97 @@ def filter_channels(
]
def _get_all_docs(
def _get_channel_by_id(client: WebClient, channel_id: str) -> ChannelType:
"""Get a channel by its ID.
Args:
client: The Slack WebClient instance
channel_id: The ID of the channel to fetch
Returns:
The channel information
Raises:
SlackApiError: If the channel cannot be fetched
"""
response = make_slack_api_call_w_retries(
client.conversations_info,
channel=channel_id,
)
return cast(ChannelType, response["channel"])
def _get_messages(
channel: ChannelType,
client: WebClient,
channels: list[str] | None = None,
channel_name_regex_enabled: bool = False,
oldest: str | None = None,
latest: str | None = None,
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
) -> Generator[Document, None, None]:
"""Get all documents in the workspace, channel by channel"""
slack_cleaner = SlackTextCleaner(client=client)
) -> tuple[list[MessageType], bool]:
"""Slack goes from newest to oldest."""
# Cache to prevent refetching via API since users
user_cache: dict[str, BasicExpertInfo | None] = {}
# have to be in the channel in order to read messages
if not channel["is_member"]:
make_slack_api_call_w_retries(
client.conversations_join,
channel=channel["id"],
is_private=channel["is_private"],
)
logger.info(f"Successfully joined '{channel['name']}'")
all_channels = get_channels(client)
filtered_channels = filter_channels(
all_channels, channels, channel_name_regex_enabled
response = make_slack_api_call_w_retries(
client.conversations_history,
channel=channel["id"],
oldest=oldest,
latest=latest,
limit=_SLACK_LIMIT,
)
response.validate()
for channel in filtered_channels:
channel_docs = 0
channel_message_batches = get_channel_messages(
client=client, channel=channel, oldest=oldest, latest=latest
messages = cast(list[MessageType], response.get("messages", []))
cursor = cast(dict[str, Any], response.get("response_metadata", {})).get(
"next_cursor", ""
)
has_more = bool(cursor)
return messages, has_more
def _message_to_doc(
message: MessageType,
client: WebClient,
channel: ChannelType,
slack_cleaner: SlackTextCleaner,
user_cache: dict[str, BasicExpertInfo | None],
seen_thread_ts: set[str],
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
) -> Document | None:
filtered_thread: ThreadType | None = None
thread_ts = message.get("thread_ts")
if thread_ts:
# skip threads we've already seen, since we've already processed all
# messages in that thread
if thread_ts in seen_thread_ts:
return None
thread = get_thread(
client=client, channel_id=channel["id"], thread_id=thread_ts
)
filtered_thread = [
message for message in thread if not msg_filter_func(message)
]
elif not msg_filter_func(message):
filtered_thread = [message]
if filtered_thread:
return thread_to_doc(
channel=channel,
thread=filtered_thread,
slack_cleaner=slack_cleaner,
client=client,
user_cache=user_cache,
)
seen_thread_ts: set[str] = set()
for message_batch in channel_message_batches:
for message in message_batch:
filtered_thread: ThreadType | None = None
thread_ts = message.get("thread_ts")
if thread_ts:
# skip threads we've already seen, since we've already processed all
# messages in that thread
if thread_ts in seen_thread_ts:
continue
seen_thread_ts.add(thread_ts)
thread = get_thread(
client=client, channel_id=channel["id"], thread_id=thread_ts
)
filtered_thread = [
message for message in thread if not msg_filter_func(message)
]
elif not msg_filter_func(message):
filtered_thread = [message]
if filtered_thread:
channel_docs += 1
yield thread_to_doc(
channel=channel,
thread=filtered_thread,
slack_cleaner=slack_cleaner,
client=client,
user_cache=user_cache,
)
logger.info(
f"Pulled {channel_docs} documents from slack channel {channel['name']}"
)
return None
def _get_all_doc_ids(
@@ -368,7 +424,7 @@ def _get_all_doc_ids(
for message_ts in message_ts_set:
channel_metadata_list.append(
SlimDocument(
id=f"{channel_id}__{message_ts}",
id=_build_doc_id(channel_id=channel_id, thread_ts=message_ts),
perm_sync_data={"channel_id": channel_id},
)
)
@@ -376,7 +432,51 @@ def _get_all_doc_ids(
yield channel_metadata_list
class SlackPollConnector(PollConnector, SlimConnector):
def _process_message(
message: MessageType,
client: WebClient,
channel: ChannelType,
slack_cleaner: SlackTextCleaner,
user_cache: dict[str, BasicExpertInfo | None],
seen_thread_ts: set[str],
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
) -> tuple[Document | None, str | None, ConnectorFailure | None]:
thread_ts = message.get("thread_ts")
try:
# causes random failures for testing checkpointing / continue on failure
# import random
# if random.random() > 0.95:
# raise RuntimeError("Random failure :P")
doc = _message_to_doc(
message=message,
client=client,
channel=channel,
slack_cleaner=slack_cleaner,
user_cache=user_cache,
seen_thread_ts=seen_thread_ts,
msg_filter_func=msg_filter_func,
)
return (doc, thread_ts, None)
except Exception as e:
logger.exception(f"Error processing message {message['ts']}")
return (
None,
thread_ts,
ConnectorFailure(
failed_document=DocumentFailure(
document_id=_build_doc_id(
channel_id=channel["id"], thread_ts=(thread_ts or message["ts"])
),
document_link=get_message_link(message, client, channel["id"]),
),
failure_message=str(e),
exception=e,
),
)
class SlackConnector(SlimConnector, CheckpointConnector):
def __init__(
self,
channels: list[str] | None = None,
@@ -390,9 +490,14 @@ class SlackPollConnector(PollConnector, SlimConnector):
self.batch_size = batch_size
self.client: WebClient | None = None
# just used for efficiency
self.text_cleaner: SlackTextCleaner | None = None
self.user_cache: dict[str, BasicExpertInfo | None] = {}
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
bot_token = credentials["slack_bot_token"]
self.client = WebClient(token=bot_token)
self.text_cleaner = SlackTextCleaner(client=self.client)
return None
def retrieve_all_slim_documents(
@@ -411,30 +516,155 @@ class SlackPollConnector(PollConnector, SlimConnector):
callback=callback,
)
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
if self.client is None:
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> CheckpointOutput:
"""Rough outline:
Step 1: Get all channels, yield back Checkpoint.
Step 2: Loop through each channel. For each channel:
Step 2.1: Get messages within the time range.
Step 2.2: Process messages in parallel, yield back docs.
Step 2.3: Update checkpoint with new_latest, seen_thread_ts, and current_channel.
Slack returns messages from newest to oldest, so we need to keep track of
the latest message we've seen in each channel.
Step 2.4: If there are no more messages in the channel, switch the current
channel to the next channel.
"""
if self.client is None or self.text_cleaner is None:
raise ConnectorMissingCredentialError("Slack")
documents: list[Document] = []
for document in _get_all_docs(
client=self.client,
channels=self.channels,
channel_name_regex_enabled=self.channel_regex_enabled,
# NOTE: need to impute to `None` instead of using 0.0, since Slack will
# throw an error if we use 0.0 on an account without infinite data
# retention
oldest=str(start) if start else None,
latest=str(end),
):
documents.append(document)
if len(documents) >= self.batch_size:
yield documents
documents = []
checkpoint_content = cast(
SlackCheckpointContent,
(
copy.deepcopy(checkpoint.checkpoint_content)
or {
"channel_ids": None,
"channel_completion_map": {},
"current_channel": None,
"seen_thread_ts": [],
}
),
)
if documents:
yield documents
# if this is the very first time we've called this, need to
# get all relevant channels and save them into the checkpoint
if checkpoint_content["channel_ids"] is None:
raw_channels = get_channels(self.client)
filtered_channels = filter_channels(
raw_channels, self.channels, self.channel_regex_enabled
)
if len(filtered_channels) == 0:
return checkpoint
checkpoint_content["channel_ids"] = [c["id"] for c in filtered_channels]
checkpoint_content["current_channel"] = filtered_channels[0]
checkpoint = ConnectorCheckpoint(
checkpoint_content=checkpoint_content, # type: ignore
has_more=True,
)
return checkpoint
final_channel_ids = checkpoint_content["channel_ids"]
channel = checkpoint_content["current_channel"]
if channel is None:
raise ValueError("current_channel key not found in checkpoint")
channel_id = channel["id"]
if channel_id not in final_channel_ids:
raise ValueError(f"Channel {channel_id} not found in checkpoint")
oldest = str(start) if start else None
latest = checkpoint_content["channel_completion_map"].get(channel_id, str(end))
seen_thread_ts = set(checkpoint_content["seen_thread_ts"])
try:
logger.debug(
f"Getting messages for channel {channel} within range {oldest} - {latest}"
)
message_batch, has_more_in_channel = _get_messages(
channel, self.client, oldest, latest
)
new_latest = message_batch[-1]["ts"] if message_batch else latest
# Process messages in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=8) as executor:
futures: list[Future] = []
for message in message_batch:
# Capture the current context so that the thread gets the current tenant ID
current_context = contextvars.copy_context()
futures.append(
executor.submit(
current_context.run,
_process_message,
message=message,
client=self.client,
channel=channel,
slack_cleaner=self.text_cleaner,
user_cache=self.user_cache,
seen_thread_ts=seen_thread_ts,
)
)
for future in as_completed(futures):
doc, thread_ts, failures = future.result()
if doc:
# handle race conditions here since this is single
# threaded. Multi-threaded _process_message reads from this
# but since this is single threaded, we won't run into simul
# writes. At worst, we can duplicate a thread, which will be
# deduped later on.
if thread_ts not in seen_thread_ts:
yield doc
if thread_ts:
seen_thread_ts.add(thread_ts)
elif failures:
for failure in failures:
yield failure
checkpoint_content["seen_thread_ts"] = list(seen_thread_ts)
checkpoint_content["channel_completion_map"][channel["id"]] = new_latest
if has_more_in_channel:
checkpoint_content["current_channel"] = channel
else:
new_channel_id = next(
(
channel_id
for channel_id in final_channel_ids
if channel_id
not in checkpoint_content["channel_completion_map"]
),
None,
)
if new_channel_id:
new_channel = _get_channel_by_id(self.client, new_channel_id)
checkpoint_content["current_channel"] = new_channel
else:
checkpoint_content["current_channel"] = None
checkpoint = ConnectorCheckpoint(
checkpoint_content=checkpoint_content, # type: ignore
has_more=checkpoint_content["current_channel"] is not None,
)
return checkpoint
except Exception as e:
logger.exception(f"Error processing channel {channel['name']}")
yield ConnectorFailure(
failed_entity=EntityFailure(
entity_id=channel["id"],
missed_time_range=(
datetime.fromtimestamp(start, tz=timezone.utc),
datetime.fromtimestamp(end, tz=timezone.utc),
),
),
failure_message=str(e),
exception=e,
)
return checkpoint
if __name__ == "__main__":
@@ -442,7 +672,7 @@ if __name__ == "__main__":
import time
slack_channel = os.environ.get("SLACK_CHANNEL")
connector = SlackPollConnector(
connector = SlackConnector(
channels=[slack_channel] if slack_channel else None,
)
connector.load_credentials({"slack_bot_token": os.environ["SLACK_BOT_TOKEN"]})
@@ -450,6 +680,17 @@ if __name__ == "__main__":
current = time.time()
one_day_ago = current - 24 * 60 * 60 # 1 day
document_batches = connector.poll_source(one_day_ago, current)
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
print(next(document_batches))
gen = connector.load_from_checkpoint(one_day_ago, current, checkpoint)
try:
for document_or_failure in gen:
if isinstance(document_or_failure, Document):
print(document_or_failure)
elif isinstance(document_or_failure, ConnectorFailure):
print(document_or_failure)
except StopIteration as e:
checkpoint = e.value
print("Next checkpoint:", checkpoint)
print("Next checkpoint:", checkpoint)

View File

@@ -34,9 +34,14 @@ def get_message_link(
) -> str:
channel_id = channel_id or event["channel"]
message_ts = event["ts"]
response = client.chat_getPermalink(channel=channel_id, message_ts=message_ts)
permalink = response["permalink"]
return permalink
message_ts_without_dot = message_ts.replace(".", "")
thread_ts = event.get("thread_ts")
base_url = get_base_url(client.token)
link = f"{base_url.rstrip('/')}/archives/{channel_id}/p{message_ts_without_dot}" + (
f"?thread_ts={thread_ts}" if thread_ts else ""
)
return link
def _make_slack_api_call_paginated(

View File

@@ -1,9 +1,14 @@
import os
import tempfile
import urllib.parse
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import Dict
from typing import List
from typing import Tuple
from typing import Union
from zulip import Client
@@ -36,8 +41,39 @@ class ZulipConnector(LoadConnector, PollConnector):
) -> None:
self.batch_size = batch_size
self.realm_name = realm_name
self.realm_url = realm_url if realm_url.endswith("/") else realm_url + "/"
self.client: Client | None = None
# Clean and normalize the URL
realm_url = realm_url.strip().lower()
# Remove any trailing slashes
realm_url = realm_url.rstrip("/")
# Ensure the URL has a scheme
if not realm_url.startswith(("http://", "https://")):
realm_url = f"https://{realm_url}"
try:
parsed = urllib.parse.urlparse(realm_url)
# Extract the base domain without any paths or ports
netloc = parsed.netloc.split(":")[0] # Remove port if present
if not netloc:
raise ValueError(
f"Invalid realm URL format: {realm_url}. "
f"URL must include a valid domain name."
)
# Always use HTTPS for security
self.base_url = f"https://{netloc}"
self.client: Client | None = None
except Exception as e:
raise ValueError(
f"Failed to parse Zulip realm URL: {realm_url}. "
f"Please provide a URL in the format: domain.com or https://domain.com. "
f"Error: {str(e)}"
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
contents = credentials["zuliprc_content"]
@@ -55,12 +91,17 @@ class ZulipConnector(LoadConnector, PollConnector):
return None
def _message_to_narrow_link(self, m: Message) -> str:
stream_name = m.display_recipient # assume str
stream_operand = encode_zulip_narrow_operand(f"{m.stream_id}-{stream_name}")
topic_operand = encode_zulip_narrow_operand(m.subject)
try:
stream_name = m.display_recipient # assume str
stream_operand = encode_zulip_narrow_operand(f"{m.stream_id}-{stream_name}")
topic_operand = encode_zulip_narrow_operand(m.subject)
narrow_link = f"{self.realm_url}#narrow/stream/{stream_operand}/topic/{topic_operand}/near/{m.id}"
return narrow_link
narrow_link = f"{self.base_url}#narrow/stream/{stream_operand}/topic/{topic_operand}/near/{m.id}"
return narrow_link
except Exception as e:
logger.error(f"Error generating Zulip message link: {e}")
# Fallback to a basic link that at least includes the base URL
return f"{self.base_url}#narrow/id/{m.id}"
def _get_message_batch(self, anchor: str) -> Tuple[bool, List[Message]]:
if self.client is None:
@@ -83,6 +124,40 @@ class ZulipConnector(LoadConnector, PollConnector):
def _message_to_doc(self, message: Message) -> Document:
text = f"{message.sender_full_name}: {message.content}"
try:
# Convert timestamps to UTC datetime objects
post_time = datetime.fromtimestamp(message.timestamp, tz=timezone.utc)
edit_time = (
datetime.fromtimestamp(message.last_edit_timestamp, tz=timezone.utc)
if message.last_edit_timestamp is not None
else None
)
# Use the most recent edit time if available, otherwise use post time
doc_time = edit_time if edit_time is not None else post_time
except (ValueError, TypeError) as e:
logger.warning(f"Failed to parse timestamp for message {message.id}: {e}")
post_time = None
edit_time = None
doc_time = None
metadata: Dict[str, Union[str, List[str]]] = {
"stream_name": str(message.display_recipient),
"topic": str(message.subject),
"sender_name": str(message.sender_full_name),
"sender_email": str(message.sender_email),
"message_timestamp": str(message.timestamp),
"message_id": str(message.id),
"stream_id": str(message.stream_id),
"has_reactions": str(len(message.reactions) > 0),
"content_type": str(message.content_type or "text"),
}
# Always include edit timestamp in metadata when available
if edit_time is not None:
metadata["edit_timestamp"] = str(message.last_edit_timestamp)
return Document(
id=f"{message.stream_id}__{message.id}",
sections=[
@@ -92,8 +167,9 @@ class ZulipConnector(LoadConnector, PollConnector):
)
],
source=DocumentSource.ZULIP,
semantic_identifier=message.display_recipient or message.subject,
metadata={},
semantic_identifier=f"{message.display_recipient} > {message.subject}",
metadata=metadata,
doc_updated_at=doc_time, # Use most recent edit time or post time
)
def _get_docs(

View File

@@ -1,6 +1,7 @@
from typing import Any
from typing import List
from typing import Optional
from typing import Union
from pydantic import BaseModel
from pydantic import Field
@@ -19,7 +20,7 @@ class Message(BaseModel):
sender_realm_str: str
subject: str
topic_links: Optional[List[Any]] = None
last_edit_timestamp: Optional[int]
last_edit_timestamp: Optional[int] = None
edit_history: Any = None
reactions: List[Any]
submessages: List[Any]
@@ -39,5 +40,5 @@ class GetMessagesResponse(BaseModel):
found_oldest: Optional[bool] = None
found_newest: Optional[bool] = None
history_limited: Optional[bool] = None
anchor: Optional[str] = None
anchor: Optional[Union[str, int]] = None
messages: List[Message] = Field(default_factory=list)

View File

@@ -354,7 +354,10 @@ def delete_chat_session(
hard_delete: bool = HARD_DELETE_CHATS,
) -> None:
chat_session = get_chat_session_by_id(
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
chat_session_id=chat_session_id,
user_id=user_id,
db_session=db_session,
include_deleted=include_deleted,
)
if chat_session.deleted and not include_deleted:

View File

@@ -18,6 +18,7 @@ import boto3
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy import event
from sqlalchemy import pool
from sqlalchemy import text
from sqlalchemy.engine import create_engine
from sqlalchemy.engine import Engine
@@ -39,6 +40,7 @@ from onyx.configs.app_configs import POSTGRES_PASSWORD
from onyx.configs.app_configs import POSTGRES_POOL_PRE_PING
from onyx.configs.app_configs import POSTGRES_POOL_RECYCLE
from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USE_NULL_POOL
from onyx.configs.app_configs import POSTGRES_USER
from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from onyx.configs.constants import SSL_CERT_FILE
@@ -187,20 +189,45 @@ class SqlEngine:
_engine: Engine | None = None
_lock: threading.Lock = threading.Lock()
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
DEFAULT_ENGINE_KWARGS = {
"pool_size": 20,
"max_overflow": 5,
"pool_pre_ping": POSTGRES_POOL_PRE_PING,
"pool_recycle": POSTGRES_POOL_RECYCLE,
}
@classmethod
def _init_engine(cls, **engine_kwargs: Any) -> Engine:
def _init_engine(
cls, host: str, port: str, db: str, **engine_kwargs: Any
) -> Engine:
connection_string = build_connection_string(
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH
db_api=SYNC_DB_API,
host=host,
port=port,
db=db,
app_name=cls._app_name + "_sync",
use_iam=USE_IAM_AUTH,
)
merged_kwargs = {**cls.DEFAULT_ENGINE_KWARGS, **engine_kwargs}
engine = create_engine(connection_string, **merged_kwargs)
# Start with base kwargs that are valid for all pool types
final_engine_kwargs: dict[str, Any] = {}
if POSTGRES_USE_NULL_POOL:
# if null pool is specified, then we need to make sure that
# we remove any passed in kwargs related to pool size that would
# cause the initialization to fail
final_engine_kwargs.update(engine_kwargs)
final_engine_kwargs["poolclass"] = pool.NullPool
if "pool_size" in final_engine_kwargs:
del final_engine_kwargs["pool_size"]
if "max_overflow" in final_engine_kwargs:
del final_engine_kwargs["max_overflow"]
else:
final_engine_kwargs["pool_size"] = 20
final_engine_kwargs["max_overflow"] = 5
final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING
final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE
# any passed in kwargs override the defaults
final_engine_kwargs.update(engine_kwargs)
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
engine = create_engine(connection_string, **final_engine_kwargs)
if USE_IAM_AUTH:
event.listen(engine, "do_connect", provide_iam_token)
@@ -211,15 +238,19 @@ class SqlEngine:
def init_engine(cls, **engine_kwargs: Any) -> None:
with cls._lock:
if not cls._engine:
cls._engine = cls._init_engine(**engine_kwargs)
cls._engine = cls._init_engine(
host=engine_kwargs.get("host", POSTGRES_HOST),
port=engine_kwargs.get("port", POSTGRES_PORT),
db=engine_kwargs.get("db", POSTGRES_DB),
**engine_kwargs,
)
@classmethod
def get_engine(cls) -> Engine:
if not cls._engine:
with cls._lock:
if not cls._engine:
cls._engine = cls._init_engine()
return cls._engine
cls.init_engine()
return cls._engine # type: ignore
@classmethod
def set_app_name(cls, app_name: str) -> None:
@@ -299,13 +330,21 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
connect_args["ssl"] = ssl_context
engine_kwargs = {
"connect_args": connect_args,
"pool_pre_ping": POSTGRES_POOL_PRE_PING,
"pool_recycle": POSTGRES_POOL_RECYCLE,
}
if POSTGRES_USE_NULL_POOL:
engine_kwargs["poolclass"] = pool.NullPool
else:
engine_kwargs["pool_size"] = POSTGRES_API_SERVER_POOL_SIZE
engine_kwargs["max_overflow"] = POSTGRES_API_SERVER_POOL_OVERFLOW
_ASYNC_ENGINE = create_async_engine(
connection_string,
connect_args=connect_args,
pool_size=POSTGRES_API_SERVER_POOL_SIZE,
max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW,
pool_pre_ping=POSTGRES_POOL_PRE_PING,
pool_recycle=POSTGRES_POOL_RECYCLE,
**engine_kwargs,
)
if USE_IAM_AUTH:

View File

@@ -11,8 +11,7 @@ from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import Session
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentErrorSummary
from onyx.connectors.models import ConnectorFailure
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
from onyx.db.models import IndexingStatus
@@ -41,6 +40,27 @@ def get_last_attempt_for_cc_pair(
)
def get_recent_completed_attempts_for_cc_pair(
cc_pair_id: int,
search_settings_id: int,
limit: int,
db_session: Session,
) -> list[IndexAttempt]:
return (
db_session.query(IndexAttempt)
.filter(
IndexAttempt.connector_credential_pair_id == cc_pair_id,
IndexAttempt.search_settings_id == search_settings_id,
IndexAttempt.status.notin_(
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
),
)
.order_by(IndexAttempt.time_updated.desc())
.limit(limit)
.all()
)
def get_index_attempt(
db_session: Session, index_attempt_id: int
) -> IndexAttempt | None:
@@ -615,23 +635,32 @@ def count_unique_cc_pairs_with_successful_index_attempts(
def create_index_attempt_error(
index_attempt_id: int | None,
batch: int | None,
docs: list[Document],
exception_msg: str,
exception_traceback: str,
connector_credential_pair_id: int,
failure: ConnectorFailure,
db_session: Session,
) -> int:
doc_summaries = []
for doc in docs:
doc_summary = DocumentErrorSummary.from_document(doc)
doc_summaries.append(doc_summary.to_dict())
new_error = IndexAttemptError(
index_attempt_id=index_attempt_id,
batch=batch,
doc_summaries=doc_summaries,
error_msg=exception_msg,
traceback=exception_traceback,
connector_credential_pair_id=connector_credential_pair_id,
document_id=(
failure.failed_document.document_id if failure.failed_document else None
),
document_link=(
failure.failed_document.document_link if failure.failed_document else None
),
entity_id=(failure.failed_entity.entity_id if failure.failed_entity else None),
failed_time_range_start=(
failure.failed_entity.missed_time_range[0]
if failure.failed_entity and failure.failed_entity.missed_time_range
else None
),
failed_time_range_end=(
failure.failed_entity.missed_time_range[1]
if failure.failed_entity and failure.failed_entity.missed_time_range
else None
),
failure_message=failure.failure_message,
is_resolved=False,
)
db_session.add(new_error)
db_session.commit()
@@ -649,3 +678,42 @@ def get_index_attempt_errors(
errors = db_session.scalars(stmt)
return list(errors.all())
def count_index_attempt_errors_for_cc_pair(
cc_pair_id: int,
unresolved_only: bool,
db_session: Session,
) -> int:
stmt = (
select(func.count())
.select_from(IndexAttemptError)
.where(IndexAttemptError.connector_credential_pair_id == cc_pair_id)
)
if unresolved_only:
stmt = stmt.where(IndexAttemptError.is_resolved.is_(False))
result = db_session.scalar(stmt)
return 0 if result is None else result
def get_index_attempt_errors_for_cc_pair(
cc_pair_id: int,
unresolved_only: bool,
db_session: Session,
page: int | None = None,
page_size: int | None = None,
) -> list[IndexAttemptError]:
stmt = select(IndexAttemptError).where(
IndexAttemptError.connector_credential_pair_id == cc_pair_id
)
if unresolved_only:
stmt = stmt.where(IndexAttemptError.is_resolved.is_(False))
# Order by most recent first
stmt = stmt.order_by(desc(IndexAttemptError.time_created))
if page is not None and page_size is not None:
stmt = stmt.offset(page * page_size).limit(page_size)
return list(db_session.scalars(stmt).all())

View File

@@ -827,6 +827,19 @@ class IndexAttempt(Base):
nullable=True,
)
# for polling connectors, the start and end time of the poll window
# will be set when the index attempt starts
poll_range_start: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True, default=None
)
poll_range_end: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True, default=None
)
# Points to the last checkpoint that was saved for this run. The pointer here
# can be taken to the FileStore to grab the actual checkpoint value
checkpoint_pointer: Mapped[str | None] = mapped_column(String, nullable=True)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
@@ -870,6 +883,13 @@ class IndexAttempt(Base):
desc("time_updated"),
unique=False,
),
Index(
"ix_index_attempt_cc_pair_settings_poll",
"connector_credential_pair_id",
"search_settings_id",
"status",
desc("time_updated"),
),
)
def __repr__(self) -> str:
@@ -886,25 +906,33 @@ class IndexAttempt(Base):
class IndexAttemptError(Base):
"""
Represents an error that was encountered during an IndexAttempt.
"""
__tablename__ = "index_attempt_errors"
id: Mapped[int] = mapped_column(primary_key=True)
index_attempt_id: Mapped[int] = mapped_column(
ForeignKey("index_attempt.id"),
nullable=True,
nullable=False,
)
connector_credential_pair_id: Mapped[int] = mapped_column(
ForeignKey("connector_credential_pair.id"),
nullable=False,
)
# The index of the batch where the error occurred (if looping thru batches)
# Just informational.
batch: Mapped[int | None] = mapped_column(Integer, default=None)
doc_summaries: Mapped[list[Any]] = mapped_column(postgresql.JSONB())
error_msg: Mapped[str | None] = mapped_column(Text, default=None)
traceback: Mapped[str | None] = mapped_column(Text, default=None)
document_id: Mapped[str | None] = mapped_column(String, nullable=True)
document_link: Mapped[str | None] = mapped_column(String, nullable=True)
entity_id: Mapped[str | None] = mapped_column(String, nullable=True)
failed_time_range_start: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
failed_time_range_end: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
failure_message: Mapped[str] = mapped_column(Text)
is_resolved: Mapped[bool] = mapped_column(Boolean, default=False)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
@@ -913,21 +941,6 @@ class IndexAttemptError(Base):
# This is the reverse side of the relationship
index_attempt = relationship("IndexAttempt", back_populates="error_rows")
__table_args__ = (
Index(
"index_attempt_id",
"time_created",
),
)
def __repr__(self) -> str:
return (
f"<IndexAttempt(id={self.id!r}, "
f"index_attempt_id={self.index_attempt_id!r}, "
f"error_msg={self.error_msg!r})>"
f"time_created={self.time_created!r}, "
)
class SyncRecord(Base):
"""

View File

@@ -6,6 +6,7 @@ from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.vespa.index import VespaIndex
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY
def get_default_document_index(
@@ -23,14 +24,27 @@ def get_default_document_index(
secondary_index_name = secondary_search_settings.index_name
secondary_large_chunks_enabled = secondary_search_settings.large_chunks_enabled
# modify index names for integration tests so that we can run many tests
# using the same Vespa instance w/o having them collide
primary_index_name = search_settings.index_name
if VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY:
primary_index_name = (
f"{VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY}_{primary_index_name}"
)
if secondary_index_name:
secondary_index_name = f"{VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY}_{secondary_index_name}"
# Currently only supporting Vespa
return VespaIndex(
index_name=search_settings.index_name,
index_name=primary_index_name,
secondary_index_name=secondary_index_name,
large_chunks_enabled=search_settings.large_chunks_enabled,
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
multitenant=MULTI_TENANT,
httpx_client=httpx_client,
preserve_existing_indices=bool(
VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY
),
)

View File

@@ -136,6 +136,7 @@ class VespaIndex(DocumentIndex):
secondary_large_chunks_enabled: bool | None,
multitenant: bool = False,
httpx_client: httpx.Client | None = None,
preserve_existing_indices: bool = False,
) -> None:
self.index_name = index_name
self.secondary_index_name = secondary_index_name
@@ -161,18 +162,18 @@ class VespaIndex(DocumentIndex):
secondary_index_name
] = secondary_large_chunks_enabled
def ensure_indices_exist(
self,
index_embedding_dim: int,
secondary_index_embedding_dim: int | None,
) -> None:
if MULTI_TENANT:
logger.info(
"Skipping Vespa index seup for multitenant (would wipe all indices)"
)
return None
self.preserve_existing_indices = preserve_existing_indices
deploy_url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/prepareandactivate"
@classmethod
def create_indices(
cls,
indices: list[tuple[str, int, bool]],
application_endpoint: str = VESPA_APPLICATION_ENDPOINT,
) -> None:
"""
Create indices in Vespa based on the passed in configuration(s).
"""
deploy_url = f"{application_endpoint}/tenant/default/prepareandactivate"
logger.notice(f"Deploying Vespa application package to {deploy_url}")
vespa_schema_path = os.path.join(
@@ -185,7 +186,7 @@ class VespaIndex(DocumentIndex):
with open(services_file, "r") as services_f:
services_template = services_f.read()
schema_names = [self.index_name, self.secondary_index_name]
schema_names = [index_name for (index_name, _, _) in indices]
doc_lines = _create_document_xml_lines(schema_names)
services = services_template.replace(DOCUMENT_REPLACEMENT_PAT, doc_lines)
@@ -193,14 +194,6 @@ class VespaIndex(DocumentIndex):
SEARCH_THREAD_NUMBER_PAT, str(VESPA_SEARCHER_THREADS)
)
kv_store = get_kv_store()
needs_reindexing = False
try:
needs_reindexing = cast(bool, kv_store.load(KV_REINDEX_KEY))
except Exception:
logger.debug("Could not load the reindexing flag. Using ngrams")
with open(overrides_file, "r") as overrides_f:
overrides_template = overrides_f.read()
@@ -221,29 +214,63 @@ class VespaIndex(DocumentIndex):
schema_template = schema_f.read()
schema_template = schema_template.replace(TENANT_ID_PAT, "")
schema = schema_template.replace(
DANSWER_CHUNK_REPLACEMENT_PAT, self.index_name
).replace(VESPA_DIM_REPLACEMENT_PAT, str(index_embedding_dim))
for index_name, index_embedding_dim, needs_reindexing in indices:
schema = schema_template.replace(
DANSWER_CHUNK_REPLACEMENT_PAT, index_name
).replace(VESPA_DIM_REPLACEMENT_PAT, str(index_embedding_dim))
schema = add_ngrams_to_schema(schema) if needs_reindexing else schema
schema = schema.replace(TENANT_ID_PAT, "")
zip_dict[f"schemas/{schema_names[0]}.sd"] = schema.encode("utf-8")
if self.secondary_index_name:
upcoming_schema = schema_template.replace(
DANSWER_CHUNK_REPLACEMENT_PAT, self.secondary_index_name
).replace(VESPA_DIM_REPLACEMENT_PAT, str(secondary_index_embedding_dim))
zip_dict[f"schemas/{schema_names[1]}.sd"] = upcoming_schema.encode("utf-8")
schema = add_ngrams_to_schema(schema) if needs_reindexing else schema
schema = schema.replace(TENANT_ID_PAT, "")
logger.info(
f"Creating index: {index_name} with embedding "
f"dimension: {index_embedding_dim}. Schema:\n\n {schema}"
)
zip_dict[f"schemas/{index_name}.sd"] = schema.encode("utf-8")
zip_file = in_memory_zip_from_file_bytes(zip_dict)
headers = {"Content-Type": "application/zip"}
response = requests.post(deploy_url, headers=headers, data=zip_file)
if response.status_code != 200:
logger.error(f"Failed to create Vespa indices: {response.text}")
raise RuntimeError(
f"Failed to prepare Vespa Onyx Index. Response: {response.text}"
)
def ensure_indices_exist(
self,
index_embedding_dim: int,
secondary_index_embedding_dim: int | None,
) -> None:
if self.multitenant or MULTI_TENANT: # be extra safe here
logger.info(
"Skipping Vespa index setup for multitenant (would wipe all indices)"
)
return None
# Used in IT
# NOTE: this means that we can't switch embedding models
if self.preserve_existing_indices:
logger.info("Preserving existing indices")
return None
kv_store = get_kv_store()
primary_needs_reindexing = False
try:
primary_needs_reindexing = cast(bool, kv_store.load(KV_REINDEX_KEY))
except Exception:
logger.debug("Could not load the reindexing flag. Using ngrams")
indices = [
(self.index_name, index_embedding_dim, primary_needs_reindexing),
]
if self.secondary_index_name and secondary_index_embedding_dim:
indices.append(
(self.secondary_index_name, secondary_index_embedding_dim, False)
)
self.create_indices(indices)
@staticmethod
def register_multitenant_indices(
indices: list[str],

View File

@@ -1,6 +1,10 @@
import time
from abc import ABC
from abc import abstractmethod
from collections import defaultdict
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import DocumentFailure
from onyx.db.models import SearchSettings
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.models import ChunkEmbedding
@@ -217,3 +221,49 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
deployment_name=search_settings.deployment_name,
callback=callback,
)
def embed_chunks_with_failure_handling(
chunks: list[DocAwareChunk],
embedder: IndexingEmbedder,
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
"""Tries to embed all chunks in one large batch. If that batch fails for any reason,
goes document by document to isolate the failure(s).
"""
# First try to embed all chunks in one batch
try:
return embedder.embed_chunks(chunks=chunks), []
except Exception:
logger.exception("Failed to embed chunk batch. Trying individual docs.")
# wait a couple seconds to let any rate limits or temporary issues resolve
time.sleep(2)
# Try embedding each document's chunks individually
chunks_by_doc: dict[str, list[DocAwareChunk]] = defaultdict(list)
for chunk in chunks:
chunks_by_doc[chunk.source_document.id].append(chunk)
embedded_chunks: list[IndexChunk] = []
failures: list[ConnectorFailure] = []
for doc_id, chunks_for_doc in chunks_by_doc.items():
try:
doc_embedded_chunks = embedder.embed_chunks(chunks=chunks_for_doc)
embedded_chunks.extend(doc_embedded_chunks)
except Exception as e:
logger.exception(f"Failed to embed chunks for document '{doc_id}'")
failures.append(
ConnectorFailure(
failed_document=DocumentFailure(
document_id=doc_id,
document_link=(
chunks_for_doc[0].get_link() if chunks_for_doc else None
),
),
failure_message=str(e),
exception=e,
)
)
return embedded_chunks, failures

View File

@@ -1,23 +1,21 @@
import traceback
from collections.abc import Callable
from functools import partial
from http import HTTPStatus
from typing import Protocol
import httpx
from pydantic import BaseModel
from pydantic import ConfigDict
from sqlalchemy.orm import Session
from onyx.access.access import get_access_for_documents
from onyx.access.models import DocumentAccess
from onyx.configs.app_configs import INDEXING_EXCEPTION_LIMIT
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
from onyx.configs.constants import DEFAULT_BOOST
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
get_experts_stores_representations,
)
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import IndexAttemptMetadata
from onyx.db.document import fetch_chunk_counts_for_documents
from onyx.db.document import get_documents_by_ids
@@ -29,7 +27,6 @@ from onyx.db.document import update_docs_updated_at__no_commit
from onyx.db.document import upsert_document_by_connector_credential_pair
from onyx.db.document import upsert_documents
from onyx.db.document_set import fetch_document_sets_for_documents
from onyx.db.index_attempt import create_index_attempt_error
from onyx.db.models import Document as DBDocument
from onyx.db.search_settings import get_current_search_settings
from onyx.db.tag import create_or_add_document_tag
@@ -41,10 +38,12 @@ from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import DocumentMetadata
from onyx.document_index.interfaces import IndexBatchParams
from onyx.indexing.chunker import Chunker
from onyx.indexing.embedder import embed_chunks_with_failure_handling
from onyx.indexing.embedder import IndexingEmbedder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time
@@ -67,6 +66,8 @@ class IndexingPipelineResult(BaseModel):
# number of chunks that were inserted into Vespa
total_chunks: int
failures: list[ConnectorFailure]
class IndexingPipelineProtocol(Protocol):
def __call__(
@@ -156,14 +157,10 @@ def index_doc_batch_with_handler(
document_index: DocumentIndex,
document_batch: list[Document],
index_attempt_metadata: IndexAttemptMetadata,
attempt_id: int | None,
db_session: Session,
ignore_time_skip: bool = False,
tenant_id: str | None = None,
) -> IndexingPipelineResult:
index_pipeline_result = IndexingPipelineResult(
new_docs=0, total_docs=len(document_batch), total_chunks=0
)
try:
index_pipeline_result = index_doc_batch(
chunker=chunker,
@@ -176,47 +173,25 @@ def index_doc_batch_with_handler(
tenant_id=tenant_id,
)
except Exception as e:
if isinstance(e, httpx.HTTPStatusError):
if e.response.status_code == HTTPStatus.INSUFFICIENT_STORAGE:
logger.error(
"NOTE: HTTP Status 507 Insufficient Storage indicates "
"you need to allocate more memory or disk space to the "
"Vespa/index container."
logger.exception(f"Failed to index document batch: {document_batch}")
index_pipeline_result = IndexingPipelineResult(
new_docs=0,
total_docs=len(document_batch),
total_chunks=0,
failures=[
ConnectorFailure(
failed_document=DocumentFailure(
document_id=document.id,
document_link=(
document.sections[0].link if document.sections else None
),
),
failure_message=str(e),
exception=e,
)
if INDEXING_EXCEPTION_LIMIT == 0:
raise
trace = traceback.format_exc()
create_index_attempt_error(
attempt_id,
batch=index_attempt_metadata.batch_num,
docs=document_batch,
exception_msg=str(e),
exception_traceback=trace,
db_session=db_session,
for document in document_batch
],
)
logger.exception(
f"Indexing batch {index_attempt_metadata.batch_num} failed. msg='{e}' trace='{trace}'"
)
index_attempt_metadata.num_exceptions += 1
if index_attempt_metadata.num_exceptions == INDEXING_EXCEPTION_LIMIT:
logger.warning(
f"Maximum number of exceptions for this index attempt "
f"({INDEXING_EXCEPTION_LIMIT}) has been reached. "
f"The next exception will abort the indexing attempt."
)
elif index_attempt_metadata.num_exceptions > INDEXING_EXCEPTION_LIMIT:
logger.warning(
f"Maximum number of exceptions for this index attempt "
f"({INDEXING_EXCEPTION_LIMIT}) has been exceeded."
)
raise RuntimeError(
f"Maximum exception limit of {INDEXING_EXCEPTION_LIMIT} exceeded."
)
else:
pass
return index_pipeline_result
@@ -376,8 +351,12 @@ def index_doc_batch(
document_ids=[doc.id for doc in filtered_documents],
db_session=db_session,
)
db_session.commit()
return IndexingPipelineResult(
new_docs=0, total_docs=len(filtered_documents), total_chunks=0
new_docs=0,
total_docs=len(filtered_documents),
total_chunks=0,
failures=[],
)
doc_descriptors = [
@@ -390,10 +369,19 @@ def index_doc_batch(
logger.debug(f"Starting indexing process for documents: {doc_descriptors}")
logger.debug("Starting chunking")
# NOTE: no special handling for failures here, since the chunker is not
# a common source of failure for the indexing pipeline
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)
logger.debug("Starting embedding")
chunks_with_embeddings = embedder.embed_chunks(chunks) if chunks else []
chunks_with_embeddings, embedding_failures = (
embed_chunks_with_failure_handling(
chunks=chunks,
embedder=embedder,
)
if chunks
else ([], [])
)
updatable_ids = [doc.id for doc in ctx.updatable_docs]
@@ -459,7 +447,11 @@ def index_doc_batch(
# A document will not be spread across different batches, so all the
# documents with chunks in this set, are fully represented by the chunks
# in this set
insertion_records = document_index.index(
(
insertion_records,
vector_db_write_failures,
) = write_chunks_to_vector_db_with_backoff(
document_index=document_index,
chunks=access_aware_chunks,
index_batch_params=IndexBatchParams(
doc_id_to_previous_chunk_cnt=doc_id_to_previous_chunk_cnt,
@@ -519,6 +511,7 @@ def index_doc_batch(
new_docs=len([r for r in insertion_records if r.already_existed is False]),
total_docs=len(filtered_documents),
total_chunks=len(access_aware_chunks),
failures=vector_db_write_failures + embedding_failures,
)
return result
@@ -531,7 +524,6 @@ def build_indexing_pipeline(
db_session: Session,
chunker: Chunker | None = None,
ignore_time_skip: bool = False,
attempt_id: int | None = None,
tenant_id: str | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> IndexingPipelineProtocol:
@@ -553,7 +545,6 @@ def build_indexing_pipeline(
embedder=embedder,
document_index=document_index,
ignore_time_skip=ignore_time_skip,
attempt_id=attempt_id,
db_session=db_session,
tenant_id=tenant_id,
)

View File

@@ -57,6 +57,13 @@ class DocAwareChunk(BaseChunk):
"""Used when logging the identity of a chunk"""
return f"{self.source_document.to_short_descriptor()} Chunk ID: {self.chunk_id}"
def get_link(self) -> str | None:
return (
self.source_document.sections[0].link
if self.source_document.sections
else None
)
class IndexChunk(DocAwareChunk):
embeddings: ChunkEmbedding

View File

@@ -0,0 +1,99 @@
import time
from collections import defaultdict
from http import HTTPStatus
import httpx
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import DocumentFailure
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import DocumentInsertionRecord
from onyx.document_index.interfaces import IndexBatchParams
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _log_insufficient_storage_error(e: Exception) -> None:
if isinstance(e, httpx.HTTPStatusError):
if e.response.status_code == HTTPStatus.INSUFFICIENT_STORAGE:
logger.error(
"NOTE: HTTP Status 507 Insufficient Storage indicates "
"you need to allocate more memory or disk space to the "
"Vespa/index container."
)
def write_chunks_to_vector_db_with_backoff(
document_index: DocumentIndex,
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> tuple[list[DocumentInsertionRecord], list[ConnectorFailure]]:
"""Tries to insert all chunks in one large batch. If that batch fails for any reason,
goes document by document to isolate the failure(s).
IMPORTANT: must pass in whole documents at a time not individual chunks, since the
vector DB interface assumes that all chunks for a single document are present.
"""
# first try to write the chunks to the vector db
try:
return (
list(
document_index.index(
chunks=chunks,
index_batch_params=index_batch_params,
)
),
[],
)
except Exception as e:
logger.exception(
"Failed to write chunk batch to vector db. Trying individual docs."
)
# give some specific logging on this common failure case.
_log_insufficient_storage_error(e)
# wait a couple seconds just to give the vector db a chance to recover
time.sleep(2)
# try writing each doc one by one
chunks_for_docs: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(list)
for chunk in chunks:
chunks_for_docs[chunk.source_document.id].append(chunk)
insertion_records: list[DocumentInsertionRecord] = []
failures: list[ConnectorFailure] = []
for doc_id, chunks_for_doc in chunks_for_docs.items():
try:
insertion_records.extend(
document_index.index(
chunks=chunks_for_doc,
index_batch_params=index_batch_params,
)
)
except Exception as e:
logger.exception(
f"Failed to write document chunks for '{doc_id}' to vector db"
)
# give some specific logging on this common failure case.
_log_insufficient_storage_error(e)
failures.append(
ConnectorFailure(
failed_document=DocumentFailure(
document_id=doc_id,
document_link=(
chunks_for_doc[0].get_link() if chunks_for_doc else None
),
),
failure_message=str(e),
exception=e,
)
)
return insertion_records, failures

View File

@@ -43,6 +43,7 @@ from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import POSTGRES_WEB_APP_NAME
from onyx.configs.integration_test_configs import SKIP_CONNECTION_POOL_WARM_UP
from onyx.db.engine import SqlEngine
from onyx.db.engine import warm_up_connections
from onyx.server.api_key.api import router as api_key_router
@@ -51,7 +52,6 @@ from onyx.server.documents.cc_pair import router as cc_pair_router
from onyx.server.documents.connector import router as connector_router
from onyx.server.documents.credential import router as credential_router
from onyx.server.documents.document import router as document_router
from onyx.server.documents.indexing import router as indexing_router
from onyx.server.documents.standard_oauth import router as oauth_router
from onyx.server.features.document_set.api import router as document_set_router
from onyx.server.features.folder.api import router as folder_router
@@ -209,8 +209,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
if DISABLE_GENERATIVE_AI:
logger.notice("Generative AI Q&A disabled")
# fill up Postgres connection pools
await warm_up_connections()
# only used for IT. Need to skip since it overloads postgres when we have 50+
# instances running
if not SKIP_CONNECTION_POOL_WARM_UP:
# fill up Postgres connection pools
await warm_up_connections()
if not MULTI_TENANT:
# We cache this at the beginning so there is no delay in the first telemetry
@@ -251,7 +254,6 @@ def log_http_error(request: Request, exc: Exception) -> JSONResponse:
logger.debug(f"404 error for /metrics endpoint: {str(exc)}")
elif status_code >= 400:
print("FORMATTING ERROR")
error_msg = f"{str(exc)}\n"
error_msg += "".join(traceback.format_tb(exc.__traceback__))
logger.error(error_msg)
@@ -318,7 +320,6 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(
application, token_rate_limit_settings_router
)
include_router_with_global_prefix_prepended(application, indexing_router)
include_router_with_global_prefix_prepended(
application, get_full_openai_assistants_api_router()
)

View File

@@ -146,10 +146,10 @@ class RedisPool:
cls._instance._init_pools()
return cls._instance
def _init_pools(self) -> None:
self._pool = RedisPool.create_pool(ssl=REDIS_SSL)
def _init_pools(self, redis_port: int = REDIS_PORT) -> None:
self._pool = RedisPool.create_pool(port=redis_port, ssl=REDIS_SSL)
self._replica_pool = RedisPool.create_pool(
host=REDIS_REPLICA_HOST, ssl=REDIS_SSL
host=REDIS_REPLICA_HOST, port=redis_port, ssl=REDIS_SSL
)
def get_client(self, tenant_id: str | None) -> Redis:

View File

@@ -61,10 +61,10 @@ def _create_indexable_chunks(
doc_updated_at=None,
primary_owners=[],
secondary_owners=[],
chunk_count=1,
chunk_count=preprocessed_doc["chunk_ind"] + 1,
)
if preprocessed_doc["chunk_ind"] == 0:
ids_to_documents[document.id] = document
ids_to_documents[document.id] = document
chunk = DocMetadataAwareIndexChunk(
chunk_id=preprocessed_doc["chunk_ind"],
@@ -92,6 +92,7 @@ def _create_indexable_chunks(
boost=DEFAULT_BOOST,
large_chunk_id=None,
)
chunks.append(chunk)
return list(ids_to_documents.values()), chunks
@@ -192,6 +193,7 @@ def seed_initial_documents(
last_successful_index_time=last_index_time,
seeding_flow=True,
)
cc_pair_id = cast(int, result.data)
processed_docs = fetch_versioned_implementation(
"onyx.seeding.load_docs",
@@ -249,4 +251,5 @@ def seed_initial_documents(
.values(chunk_count=doc.chunk_count)
)
db_session.commit()
kv_store.store(KV_DOCUMENTS_SEEDED_KEY, True)

View File

@@ -22,6 +22,7 @@ from onyx.background.celery.tasks.pruning.tasks import (
try_creating_prune_generator_task,
)
from onyx.background.celery.versioned_apps.primary import app as primary_app
from onyx.background.indexing.models import IndexAttemptErrorPydantic
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.connector_credential_pair import add_credential_to_connector
@@ -39,7 +40,9 @@ from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.index_attempt import count_index_attempt_errors_for_cc_pair
from onyx.db.index_attempt import count_index_attempts_for_connector
from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair
from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from onyx.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
from onyx.db.models import SearchSettings
@@ -546,6 +549,47 @@ def get_docs_sync_status(
return [DocumentSyncStatus.from_model(doc) for doc in all_docs_for_cc_pair]
@router.get("/admin/cc-pair/{cc_pair_id}/errors")
def get_cc_pair_indexing_errors(
cc_pair_id: int,
include_resolved: bool = Query(False),
page: int = Query(0, ge=0),
page_size: int = Query(10, ge=1, le=100),
_: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> PaginatedReturn[IndexAttemptErrorPydantic]:
"""Gives back all errors for a given CC Pair. Allows pagination based on page and page_size params.
Args:
cc_pair_id: ID of the connector-credential pair to get errors for
include_resolved: Whether to include resolved errors in the results
page: Page number for pagination, starting at 0
page_size: Number of errors to return per page
_: Current user, must be curator or admin
db_session: Database session
Returns:
Paginated list of indexing errors for the CC pair.
"""
total_count = count_index_attempt_errors_for_cc_pair(
db_session=db_session,
cc_pair_id=cc_pair_id,
unresolved_only=not include_resolved,
)
index_attempt_errors = get_index_attempt_errors_for_cc_pair(
db_session=db_session,
cc_pair_id=cc_pair_id,
unresolved_only=not include_resolved,
page=page,
page_size=page_size,
)
return PaginatedReturn(
items=[IndexAttemptErrorPydantic.from_model(e) for e in index_attempt_errors],
total_items=total_count,
)
@router.put("/connector/{connector_id}/credential/{credential_id}")
def associate_credential_to_connector(
connector_id: int,

View File

@@ -22,6 +22,7 @@ from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.background.celery.versioned_apps.primary import app as primary_app
from onyx.configs.app_configs import ENABLED_CONNECTOR_TYPES
from onyx.configs.app_configs import MOCK_CONNECTOR_FILE_PATH
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import MilestoneRecordType
@@ -613,6 +614,16 @@ def get_connector_indexing_status(
) -> list[ConnectorIndexingStatus]:
indexing_statuses: list[ConnectorIndexingStatus] = []
if MOCK_CONNECTOR_FILE_PATH:
import json
with open(MOCK_CONNECTOR_FILE_PATH, "r") as f:
raw_data = json.load(f)
connector_indexing_statuses = [
ConnectorIndexingStatus(**status) for status in raw_data
]
return connector_indexing_statuses
# NOTE: If the connector is deleting behind the scenes,
# accessing cc_pairs can be inconsistent and members like
# connector or credential may be None.

View File

@@ -1,23 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.db.engine import get_session
from onyx.db.index_attempt import (
get_index_attempt_errors,
)
from onyx.db.models import User
from onyx.server.documents.models import IndexAttemptError
router = APIRouter(prefix="/manage")
@router.get("/admin/indexing-errors/{index_attempt_id}")
def get_indexing_errors(
index_attempt_id: int,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[IndexAttemptError]:
indexing_errors = get_index_attempt_errors(index_attempt_id, db_session)
return [IndexAttemptError.from_db_model(e) for e in indexing_errors]

View File

@@ -8,9 +8,9 @@ from pydantic import BaseModel
from pydantic import Field
from ee.onyx.server.query_history.models import ChatSessionMinimal
from onyx.background.indexing.models import IndexAttemptErrorPydantic
from onyx.configs.app_configs import MASK_CREDENTIAL_PREFIX
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import DocumentErrorSummary
from onyx.connectors.models import InputType
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
@@ -19,7 +19,6 @@ from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import Document as DbDocument
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError as DbIndexAttemptError
from onyx.db.models import IndexingStatus
from onyx.db.models import TaskStatus
from onyx.server.models import FullUserSnapshot
@@ -150,6 +149,7 @@ class CredentialSnapshot(CredentialBase):
class IndexAttemptSnapshot(BaseModel):
id: int
status: IndexingStatus | None
from_beginning: bool
new_docs_indexed: int # only includes completely new docs
total_docs_indexed: int # includes docs that are updated
docs_removed_from_index: int
@@ -166,6 +166,7 @@ class IndexAttemptSnapshot(BaseModel):
return IndexAttemptSnapshot(
id=index_attempt.id,
status=index_attempt.status,
from_beginning=index_attempt.from_beginning,
new_docs_indexed=index_attempt.new_docs_indexed or 0,
total_docs_indexed=index_attempt.total_docs_indexed or 0,
docs_removed_from_index=index_attempt.docs_removed_from_index or 0,
@@ -181,31 +182,6 @@ class IndexAttemptSnapshot(BaseModel):
)
class IndexAttemptError(BaseModel):
id: int
index_attempt_id: int | None
batch_number: int | None
doc_summaries: list[DocumentErrorSummary]
error_msg: str | None
traceback: str | None
time_created: str
@classmethod
def from_db_model(cls, error: DbIndexAttemptError) -> "IndexAttemptError":
doc_summaries = [
DocumentErrorSummary.from_dict(summary) for summary in error.doc_summaries
]
return IndexAttemptError(
id=error.id,
index_attempt_id=error.index_attempt_id,
batch_number=error.batch,
doc_summaries=doc_summaries,
error_msg=error.error_msg,
traceback=error.traceback,
time_created=error.time_created.isoformat(),
)
# These are the types currently supported by the pagination hook
# More api endpoints can be refactored and be added here for use with the pagination hook
PaginatedType = TypeVar(
@@ -214,6 +190,7 @@ PaginatedType = TypeVar(
FullUserSnapshot,
InvitedUserSnapshot,
ChatSessionMinimal,
IndexAttemptErrorPydantic,
)

View File

@@ -19,7 +19,8 @@ def load_settings() -> Settings:
Settings.model_validate(stored_settings) if stored_settings else Settings()
)
except KvKeyNotFoundError:
logger.error(f"No settings found in KV store for key: {KV_SETTINGS_KEY}")
# Default to empty settings if no settings have been set yet
logger.debug(f"No settings found in KV store for key: {KV_SETTINGS_KEY}")
settings = Settings()
except Exception as e:
logger.error(f"Error loading settings from KV store: {str(e)}")

View File

@@ -251,7 +251,8 @@ def setup_vespa(
logger.notice("Vespa setup complete.")
return True
except Exception:
except Exception as e:
logger.debug(f"Error creating Vespa indices: {e}")
logger.notice(
f"Vespa setup did not succeed. The Vespa service may not be ready yet. Retrying in {WAIT_SECONDS} seconds."
)

View File

@@ -0,0 +1,26 @@
import sys
from typing import TypeVar
T = TypeVar("T", dict, list, tuple, set, frozenset)
def deep_getsizeof(obj: T, seen: set[int] | None = None) -> int:
"""Recursively sum size of objects, handling circular references."""
if seen is None:
seen = set()
obj_id = id(obj)
if obj_id in seen:
return 0 # Prevent infinite recursion for circular references
seen.add(obj_id)
size = sys.getsizeof(obj)
if isinstance(obj, dict):
size += sum(
deep_getsizeof(k, seen) + deep_getsizeof(v, seen) for k, v in obj.items()
)
elif isinstance(obj, (list, tuple, set, frozenset)):
size += sum(deep_getsizeof(i, seen) for i in obj)
return size

View File

@@ -42,7 +42,7 @@ def run_jobs() -> None:
"--loglevel=INFO",
"--hostname=light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup",
]
cmd_worker_heavy = [

View File

@@ -270,3 +270,10 @@ SUPPORTED_EMBEDDING_MODELS = [
index_name="danswer_chunk_intfloat_multilingual_e5_small",
),
]
"""
INTEGRATION TEST ONLY SETTINGS
"""
VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY = os.getenv(
"VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY"
)

View File

@@ -33,7 +33,7 @@ stopasgroup=true
command=celery -A onyx.background.celery.versioned_apps.light worker
--loglevel=INFO
--hostname=light@%%n
-Q vespa_metadata_sync,connector_deletion,doc_permissions_upsert
-Q vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup
stdout_logfile=/var/log/celery_worker_light.log
stdout_logfile_maxbytes=16MB
redirect_stderr=true

View File

@@ -3,6 +3,19 @@ FROM python:3.11.7-slim-bookworm
# Currently needs all dependencies, since the ITs use some of the Onyx
# backend code.
# Add Docker's official GPG key and repository for Debian
RUN apt-get update && \
apt-get install -y ca-certificates curl && \
install -m 0755 -d /etc/apt/keyrings && \
curl -fsSL https://download.docker.com/linux/debian/gpg -o /etc/apt/keyrings/docker.asc && \
chmod a+r /etc/apt/keyrings/docker.asc && \
echo \
"deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/debian \
$(. /etc/os-release && echo "$VERSION_CODENAME") stable" | \
tee /etc/apt/sources.list.d/docker.list > /dev/null && \
apt-get update
# Install system dependencies
# cmake needed for psycopg (postgres)
# libpq-dev needed for psycopg (postgres)
@@ -15,6 +28,9 @@ RUN apt-get update && \
curl \
zip \
ca-certificates \
postgresql-client \
# Install Docker for DinD
docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin \
libgnutls30=3.7.9-2+deb12u3 \
libblkid1=2.38.1-5+deb12u1 \
libmount1=2.38.1-5+deb12u1 \
@@ -29,37 +45,19 @@ RUN apt-get update && \
# Install Python dependencies
# Remove py which is pulled in by retry, py is not needed and is a CVE
COPY ./requirements/default.txt /tmp/requirements.txt
COPY ./requirements/model_server.txt /tmp/model_server-requirements.txt
COPY ./requirements/ee.txt /tmp/ee-requirements.txt
RUN pip install --no-cache-dir --upgrade \
--retries 5 \
--timeout 30 \
-r /tmp/requirements.txt \
-r /tmp/model_server-requirements.txt \
-r /tmp/ee-requirements.txt && \
pip uninstall -y py && \
playwright install chromium && \
playwright install-deps chromium && \
ln -s /usr/local/bin/supervisord /usr/bin/supervisord
# Cleanup for CVEs and size reduction
# https://github.com/tornadoweb/tornado/issues/3107
# xserver-common and xvfb included by playwright installation but not needed after
# perl-base is part of the base Python Debian image but not needed for Onyx functionality
# perl-base could only be removed with --allow-remove-essential
RUN apt-get update && \
apt-get remove -y --allow-remove-essential \
perl-base \
xserver-common \
xvfb \
cmake \
libldap-2.5-0 \
libxmlsec1-dev \
pkg-config \
gcc && \
apt-get install -y libxmlsec1-openssl && \
apt-get autoremove -y && \
rm -rf /var/lib/apt/lists/* && \
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
# Set up application files
WORKDIR /app
@@ -76,6 +74,9 @@ COPY ./alembic.ini /app/alembic.ini
COPY ./pytest.ini /app/pytest.ini
COPY supervisord.conf /usr/etc/supervisord.conf
# need to copy over model server as well, since we're running it in the same container
COPY ./model_server /app/model_server
# Integration test stuff
COPY ./requirements/dev.txt /tmp/dev-requirements.txt
RUN pip install --no-cache-dir --upgrade \
@@ -84,5 +85,6 @@ COPY ./tests/integration /app/tests/integration
ENV PYTHONPATH=/app
ENTRYPOINT ["pytest", "-s"]
CMD ["/app/tests/integration", "--ignore=/app/tests/integration/multitenant_tests"]
ENTRYPOINT []
# let caller specify the command
CMD ["tail", "-f", "/dev/null"]

View File

@@ -1,14 +1,15 @@
import requests
from sqlalchemy.orm import Session
from onyx.db.engine import get_session_context_manager
from onyx.db.models import User
def test_create_chat_session_and_send_messages(db_session: Session) -> None:
def test_create_chat_session_and_send_messages() -> None:
# Create a test user
test_user = User(email="test@example.com", hashed_password="dummy_hash")
db_session.add(test_user)
db_session.commit()
with get_session_context_manager() as db_session:
test_user = User(email="test@example.com", hashed_password="dummy_hash")
db_session.add(test_user)
db_session.commit()
base_url = "http://localhost:8080" # Adjust this to your API's base URL
headers = {"Authorization": f"Bearer {test_user.id}"}

View File

@@ -1,5 +1,8 @@
import os
ADMIN_USER_NAME = "admin_user"
GUARANTEED_FRESH_SETUP = os.getenv("GUARANTEED_FRESH_SETUP") == "true"
API_SERVER_PROTOCOL = os.getenv("API_SERVER_PROTOCOL") or "http"
API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "localhost"
API_SERVER_PORT = os.getenv("API_SERVER_PORT") or "8080"
@@ -9,3 +12,6 @@ MAX_DELAY = 45
GENERAL_HEADERS = {"Content-Type": "application/json"}
NUM_DOCS = 5
MOCK_CONNECTOR_SERVER_HOST = os.getenv("MOCK_CONNECTOR_SERVER_HOST") or "localhost"
MOCK_CONNECTOR_SERVER_PORT = os.getenv("MOCK_CONNECTOR_SERVER_PORT") or 8001

View File

@@ -223,12 +223,13 @@ class CCPairManager:
@staticmethod
def run_once(
cc_pair: DATestCCPair,
from_beginning: bool,
user_performing_action: DATestUser | None = None,
) -> None:
body = {
"connector_id": cc_pair.connector_id,
"credential_ids": [cc_pair.credential_id],
"from_beginning": True,
"from_beginning": from_beginning,
}
result = requests.post(
url=f"{API_SERVER_URL}/manage/admin/connector/run-once",

View File

@@ -1,9 +1,14 @@
from uuid import uuid4
import requests
from sqlalchemy import and_
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.configs.constants import DocumentSource
from onyx.db.enums import AccessType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import DocumentByConnectorCredentialPair
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import NUM_DOCS
@@ -186,3 +191,39 @@ class DocumentManager:
group_names,
doc_creating_user,
)
@staticmethod
def fetch_documents_for_cc_pair(
cc_pair_id: int,
db_session: Session,
vespa_client: vespa_fixture,
) -> list[SimpleTestDocument]:
stmt = (
select(DocumentByConnectorCredentialPair)
.join(
ConnectorCredentialPair,
and_(
DocumentByConnectorCredentialPair.connector_id
== ConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id
== ConnectorCredentialPair.credential_id,
),
)
.where(ConnectorCredentialPair.id == cc_pair_id)
)
documents = db_session.execute(stmt).scalars().all()
if not documents:
return []
doc_ids = [document.id for document in documents]
retrieved_docs_dict = vespa_client.get_documents_by_id(doc_ids)["documents"]
final_docs: list[SimpleTestDocument] = []
# NOTE: they are really chunks, but we're assuming that for these tests
# we only have one chunk per document for now
for doc_dict in retrieved_docs_dict:
doc_id = doc_dict["fields"]["document_id"]
doc_content = doc_dict["fields"]["content"]
final_docs.append(SimpleTestDocument(id=doc_id, content=doc_content))
return final_docs

View File

@@ -4,6 +4,7 @@ from urllib.parse import urlencode
import requests
from onyx.background.indexing.models import IndexAttemptErrorPydantic
from onyx.db.engine import get_session_context_manager
from onyx.db.enums import IndexModelStatus
from onyx.db.models import IndexAttempt
@@ -13,6 +14,7 @@ from onyx.server.documents.models import IndexAttemptSnapshot
from onyx.server.documents.models import PaginatedReturn
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.test_models import DATestIndexAttempt
from tests.integration.common_utils.test_models import DATestUser
@@ -92,8 +94,12 @@ class IndexAttemptManager:
"page_size": page_size,
}
url = (
f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/index-attempts"
f"?{urlencode(query_params, doseq=True)}"
)
response = requests.get(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/index-attempts?{urlencode(query_params, doseq=True)}",
url=url,
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
@@ -104,3 +110,125 @@ class IndexAttemptManager:
items=[IndexAttemptSnapshot(**item) for item in data["items"]],
total_items=data["total_items"],
)
@staticmethod
def get_latest_index_attempt_for_cc_pair(
cc_pair_id: int,
user_performing_action: DATestUser | None = None,
) -> IndexAttemptSnapshot | None:
"""Get an IndexAttempt by ID"""
index_attempts = IndexAttemptManager.get_index_attempt_page(
cc_pair_id, user_performing_action=user_performing_action
).items
if not index_attempts:
return None
index_attempts = sorted(
index_attempts, key=lambda x: x.time_started or "0", reverse=True
)
return index_attempts[0]
@staticmethod
def wait_for_index_attempt_start(
cc_pair_id: int,
index_attempts_to_ignore: list[int] | None = None,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> IndexAttemptSnapshot:
"""Wait for an IndexAttempt to start"""
start = datetime.now()
index_attempts_to_ignore = index_attempts_to_ignore or []
while True:
index_attempt = IndexAttemptManager.get_latest_index_attempt_for_cc_pair(
cc_pair_id=cc_pair_id,
user_performing_action=user_performing_action,
)
if (
index_attempt
and index_attempt.time_started
and index_attempt.id not in index_attempts_to_ignore
):
return index_attempt
elapsed = (datetime.now() - start).total_seconds()
if elapsed > timeout:
raise TimeoutError(
f"IndexAttempt for CC Pair {cc_pair_id} did not start within {timeout} seconds"
)
@staticmethod
def get_index_attempt_by_id(
index_attempt_id: int,
cc_pair_id: int,
user_performing_action: DATestUser | None = None,
) -> IndexAttemptSnapshot:
page_num = 0
page_size = 10
while True:
page = IndexAttemptManager.get_index_attempt_page(
cc_pair_id=cc_pair_id,
page=page_num,
page_size=page_size,
user_performing_action=user_performing_action,
)
for attempt in page.items:
if attempt.id == index_attempt_id:
return attempt
if len(page.items) < page_size:
break
page_num += 1
raise ValueError(f"IndexAttempt {index_attempt_id} not found")
@staticmethod
def wait_for_index_attempt_completion(
index_attempt_id: int,
cc_pair_id: int,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> None:
"""Wait for an IndexAttempt to complete"""
start = datetime.now()
while True:
index_attempt = IndexAttemptManager.get_index_attempt_by_id(
index_attempt_id=index_attempt_id,
cc_pair_id=cc_pair_id,
user_performing_action=user_performing_action,
)
if index_attempt.status and index_attempt.status.is_terminal():
print(f"IndexAttempt {index_attempt_id} completed")
return
elapsed = (datetime.now() - start).total_seconds()
if elapsed > timeout:
raise TimeoutError(
f"IndexAttempt {index_attempt_id} did not complete within {timeout} seconds"
)
print(
f"Waiting for IndexAttempt {index_attempt_id} to complete. "
f"elapsed={elapsed:.2f} timeout={timeout}"
)
@staticmethod
def get_index_attempt_errors_for_cc_pair(
cc_pair_id: int,
include_resolved: bool = True,
user_performing_action: DATestUser | None = None,
) -> list[IndexAttemptErrorPydantic]:
url = f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/errors?page_size=100"
if include_resolved:
url += "&include_resolved=true"
response = requests.get(
url=url,
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
data = response.json()
return [IndexAttemptErrorPydantic(**item) for item in data["items"]]

View File

@@ -1,5 +1,9 @@
import contextlib
import io
import logging
import sys
import time
from logging import Logger
from types import SimpleNamespace
import psycopg2
@@ -11,10 +15,12 @@ from onyx.configs.app_configs import POSTGRES_HOST
from onyx.configs.app_configs import POSTGRES_PASSWORD
from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USER
from onyx.configs.app_configs import REDIS_PORT
from onyx.db.engine import build_connection_string
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_session_context_manager
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import SqlEngine
from onyx.db.engine import SYNC_DB_API
from onyx.db.search_settings import get_current_search_settings
from onyx.db.swap_index import check_index_swap
@@ -22,6 +28,7 @@ from onyx.document_index.document_index_utils import get_multipass_config
from onyx.document_index.vespa.index import DOCUMENT_ID_ENDPOINT
from onyx.document_index.vespa.index import VespaIndex
from onyx.indexing.models import IndexingSetting
from onyx.redis.redis_pool import redis_pool
from onyx.setup import setup_postgres
from onyx.setup import setup_vespa
from onyx.utils.logger import setup_logger
@@ -30,8 +37,11 @@ logger = setup_logger()
def _run_migrations(
database_url: str,
database: str,
config_name: str,
postgres_host: str,
postgres_port: str,
redis_port: int,
direction: str = "upgrade",
revision: str = "head",
schema: str = "public",
@@ -45,9 +55,28 @@ def _run_migrations(
alembic_cfg.attributes["configure_logger"] = False
alembic_cfg.config_ini_section = config_name
# Add environment variables to the config attributes
alembic_cfg.attributes["env_vars"] = {
"POSTGRES_HOST": postgres_host,
"POSTGRES_PORT": postgres_port,
"POSTGRES_DB": database,
# some migrations call redis directly, so we need to pass the port
"REDIS_PORT": str(redis_port),
}
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
alembic_cfg.cmd_opts.x = [f"schema={schema}"] # type: ignore
# Build the database URL
database_url = build_connection_string(
db=database,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
host=postgres_host,
port=postgres_port,
db_api=SYNC_DB_API,
)
# Set the SQLAlchemy URL in the Alembic configuration
alembic_cfg.set_main_option("sqlalchemy.url", database_url)
@@ -66,86 +95,80 @@ def _run_migrations(
def downgrade_postgres(
database: str = "postgres",
schema: str = "public",
config_name: str = "alembic",
revision: str = "base",
clear_data: bool = False,
postgres_host: str = POSTGRES_HOST,
postgres_port: str = POSTGRES_PORT,
redis_port: int = REDIS_PORT,
) -> None:
"""Downgrade Postgres database to base state."""
if clear_data:
if revision != "base":
logger.warning("Clearing data without rolling back to base state")
# Delete all rows to allow migrations to be rolled back
raise ValueError("Clearing data without rolling back to base state")
conn = psycopg2.connect(
dbname=database,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
host=postgres_host,
port=postgres_port,
)
conn.autocommit = True # Need autocommit for dropping schema
cur = conn.cursor()
# Disable triggers to prevent foreign key constraints from being checked
cur.execute("SET session_replication_role = 'replica';")
# Fetch all table names in the current database
# Close any existing connections to the schema before dropping
cur.execute(
"""
SELECT tablename
FROM pg_tables
WHERE schemaname = 'public'
f"""
SELECT pg_terminate_backend(pg_stat_activity.pid)
FROM pg_stat_activity
WHERE pg_stat_activity.datname = '{database}'
AND pg_stat_activity.state = 'idle in transaction'
AND pid <> pg_backend_pid();
"""
)
tables = cur.fetchall()
# Drop and recreate the public schema - this removes ALL objects
cur.execute(f"DROP SCHEMA {schema} CASCADE;")
cur.execute(f"CREATE SCHEMA {schema};")
for table in tables:
table_name = table[0]
# Restore default privileges
cur.execute(f"GRANT ALL ON SCHEMA {schema} TO postgres;")
cur.execute(f"GRANT ALL ON SCHEMA {schema} TO public;")
# Don't touch migration history or Kombu
if table_name in ("alembic_version", "kombu_message", "kombu_queue"):
continue
cur.execute(f'DELETE FROM "{table_name}"')
# Re-enable triggers
cur.execute("SET session_replication_role = 'origin';")
conn.commit()
cur.close()
conn.close()
return
# Downgrade to base
conn_str = build_connection_string(
db=database,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
db_api=SYNC_DB_API,
)
_run_migrations(
conn_str,
config_name,
database=database,
config_name=config_name,
postgres_host=postgres_host,
postgres_port=postgres_port,
redis_port=redis_port,
direction="downgrade",
revision=revision,
)
def upgrade_postgres(
database: str = "postgres", config_name: str = "alembic", revision: str = "head"
database: str = "postgres",
config_name: str = "alembic",
revision: str = "head",
postgres_host: str = POSTGRES_HOST,
postgres_port: str = POSTGRES_PORT,
redis_port: int = REDIS_PORT,
) -> None:
"""Upgrade Postgres database to latest version."""
conn_str = build_connection_string(
db=database,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
db_api=SYNC_DB_API,
)
_run_migrations(
conn_str,
config_name,
database=database,
config_name=config_name,
postgres_host=postgres_host,
postgres_port=postgres_port,
redis_port=redis_port,
direction="upgrade",
revision=revision,
)
@@ -155,20 +178,44 @@ def reset_postgres(
database: str = "postgres",
config_name: str = "alembic",
setup_onyx: bool = True,
postgres_host: str = POSTGRES_HOST,
postgres_port: str = POSTGRES_PORT,
redis_port: int = REDIS_PORT,
) -> None:
"""Reset the Postgres database."""
downgrade_postgres(
database=database, config_name=config_name, revision="base", clear_data=True
database=database,
config_name=config_name,
revision="base",
clear_data=True,
postgres_host=postgres_host,
postgres_port=postgres_port,
redis_port=redis_port,
)
upgrade_postgres(
database=database,
config_name=config_name,
revision="head",
postgres_host=postgres_host,
postgres_port=postgres_port,
redis_port=redis_port,
)
upgrade_postgres(database=database, config_name=config_name, revision="head")
if setup_onyx:
logger.info("Setting up Postgres...")
with get_session_context_manager() as db_session:
setup_postgres(db_session)
def reset_vespa() -> None:
"""Wipe all data from the Vespa index."""
def reset_vespa(
skip_creating_indices: bool, document_id_endpoint: str = DOCUMENT_ID_ENDPOINT
) -> None:
"""Wipe all data from the Vespa index.
Args:
skip_creating_indices: If True, the indices will not be recreated.
This is useful if the indices already exist and you do not want to
recreate them (e.g. when running parallel tests).
"""
with get_session_context_manager() as db_session:
# swap to the correct default model
check_index_swap(db_session)
@@ -177,18 +224,21 @@ def reset_vespa() -> None:
multipass_config = get_multipass_config(search_settings)
index_name = search_settings.index_name
success = setup_vespa(
document_index=VespaIndex(
index_name=index_name,
secondary_index_name=None,
large_chunks_enabled=multipass_config.enable_large_chunks,
secondary_large_chunks_enabled=None,
),
index_setting=IndexingSetting.from_db_model(search_settings),
secondary_index_setting=None,
)
if not success:
raise RuntimeError("Could not connect to Vespa within the specified timeout.")
if not skip_creating_indices:
success = setup_vespa(
document_index=VespaIndex(
index_name=index_name,
secondary_index_name=None,
large_chunks_enabled=multipass_config.enable_large_chunks,
secondary_large_chunks_enabled=None,
),
index_setting=IndexingSetting.from_db_model(search_settings),
secondary_index_setting=None,
)
if not success:
raise RuntimeError(
"Could not connect to Vespa within the specified timeout."
)
for _ in range(5):
try:
@@ -199,7 +249,7 @@ def reset_vespa() -> None:
if continuation:
params = {**params, "continuation": continuation}
response = requests.delete(
DOCUMENT_ID_ENDPOINT.format(index_name=index_name), params=params
document_id_endpoint.format(index_name=index_name), params=params
)
response.raise_for_status()
@@ -313,11 +363,99 @@ def reset_vespa_multitenant() -> None:
time.sleep(5)
def reset_all() -> None:
def reset_all(
database: str = "postgres",
postgres_host: str = POSTGRES_HOST,
postgres_port: str = POSTGRES_PORT,
redis_port: int = REDIS_PORT,
silence_logs: bool = False,
skip_creating_indices: bool = False,
document_id_endpoint: str = DOCUMENT_ID_ENDPOINT,
) -> None:
if not silence_logs:
with contextlib.redirect_stdout(sys.stdout), contextlib.redirect_stderr(
sys.stderr
):
_do_reset(
database,
postgres_host,
postgres_port,
redis_port,
skip_creating_indices,
document_id_endpoint,
)
return
# Store original logging levels
loggers_to_silence: list[Logger] = [
logging.getLogger(), # Root logger
logging.getLogger("alembic"),
logger.logger, # Our custom logger
]
original_levels = [logger.level for logger in loggers_to_silence]
# Temporarily set all loggers to ERROR level
for log in loggers_to_silence:
log.setLevel(logging.ERROR)
stdout_redirect = io.StringIO()
stderr_redirect = io.StringIO()
try:
with contextlib.redirect_stdout(stdout_redirect), contextlib.redirect_stderr(
stderr_redirect
):
_do_reset(
database,
postgres_host,
postgres_port,
redis_port,
skip_creating_indices,
document_id_endpoint,
)
except Exception as e:
print(stdout_redirect.getvalue(), file=sys.stdout)
print(stderr_redirect.getvalue(), file=sys.stderr)
raise e
finally:
# Restore original logging levels
for logger_, level in zip(loggers_to_silence, original_levels):
logger_.setLevel(level)
def _do_reset(
database: str,
postgres_host: str,
postgres_port: str,
redis_port: int,
skip_creating_indices: bool,
document_id_endpoint: str,
) -> None:
"""NOTE: should only be be running in one worker/thread a time."""
# force re-create the engine to allow for the same worker to reset
# different databases
with SqlEngine._lock:
SqlEngine._engine = SqlEngine._init_engine(
host=postgres_host,
port=postgres_port,
db=database,
)
# same with redis
redis_pool._init_pools(redis_port=redis_port)
logger.info("Resetting Postgres...")
reset_postgres()
reset_postgres(
database=database,
postgres_host=postgres_host,
postgres_port=postgres_port,
redis_port=redis_port,
)
logger.info("Resetting Vespa...")
reset_vespa()
reset_vespa(
skip_creating_indices=skip_creating_indices,
document_id_endpoint=document_id_endpoint,
)
def reset_all_multitenant() -> None:

View File

@@ -0,0 +1,57 @@
import uuid
from datetime import datetime
from datetime import timezone
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import Section
def create_test_document(
doc_id: str | None = None,
text: str = "Test content",
link: str = "http://test.com",
source: DocumentSource = DocumentSource.MOCK_CONNECTOR,
metadata: dict | None = None,
) -> Document:
"""Create a test document with the given parameters.
Args:
doc_id: Optional document ID. If not provided, a random UUID will be generated.
text: The text content of the document. Defaults to "Test content".
link: The link for the document section. Defaults to "http://test.com".
source: The document source. Defaults to MOCK_CONNECTOR.
metadata: Optional metadata dictionary. Defaults to empty dict.
"""
doc_id = doc_id or f"test-doc-{uuid.uuid4()}"
return Document(
id=doc_id,
sections=[Section(text=text, link=link)],
source=source,
semantic_identifier=doc_id,
doc_updated_at=datetime.now(timezone.utc),
metadata=metadata or {},
)
def create_test_document_failure(
doc_id: str,
failure_message: str = "Simulated failure",
document_link: str | None = None,
) -> ConnectorFailure:
"""Create a test document failure with the given parameters.
Args:
doc_id: The ID of the document that failed.
failure_message: The failure message. Defaults to "Simulated failure".
document_link: Optional link to the failed document.
"""
return ConnectorFailure(
failed_document=DocumentFailure(
document_id=doc_id,
document_link=document_link,
),
failure_message=failure_message,
)

View File

@@ -0,0 +1,18 @@
import multiprocessing
from collections.abc import Callable
from typing import Any
from typing import TypeVar
T = TypeVar("T")
def run_with_timeout(task: Callable[..., T], timeout: int, kwargs: dict[str, Any]) -> T:
# Use multiprocessing to prevent a thread from blocking the main thread
with multiprocessing.Pool(processes=1) as pool:
async_result = pool.apply_async(task, kwds=kwargs)
try:
# Wait at most timeout seconds for the function to complete
result = async_result.get(timeout=timeout)
return result
except multiprocessing.TimeoutError:
raise TimeoutError(f"Function timed out after {timeout} seconds")

View File

@@ -1,13 +1,13 @@
import os
from collections.abc import Generator
import pytest
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
from onyx.db.engine import get_session_context_manager
from onyx.db.search_settings import get_current_search_settings
from tests.integration.common_utils.constants import ADMIN_USER_NAME
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import GUARANTEED_FRESH_SETUP
from tests.integration.common_utils.managers.user import build_email
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
from tests.integration.common_utils.managers.user import UserManager
@@ -15,61 +15,66 @@ from tests.integration.common_utils.reset import reset_all
from tests.integration.common_utils.reset import reset_all_multitenant
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.vespa import vespa_fixture
def load_env_vars(env_file: str = ".env") -> None:
current_dir = os.path.dirname(os.path.abspath(__file__))
env_path = os.path.join(current_dir, env_file)
try:
with open(env_path, "r") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#"):
key, value = line.split("=", 1)
os.environ[key] = value.strip()
print("Successfully loaded environment variables")
except FileNotFoundError:
print(f"File {env_file} not found")
from tests.integration.introspection import load_env_vars
# Load environment variables at the module level
load_env_vars()
load_env_vars(os.environ.get("IT_ENV_FILE_PATH", ".env"))
"""NOTE: for some reason using this seems to lead to misc
`sqlalchemy.exc.OperationalError: (psycopg2.OperationalError) server closed the connection unexpectedly`
errors.
Commenting out till we can get to the bottom of it. For now, just using
instantiate the session directly within the test.
"""
# @pytest.fixture
# def db_session() -> Generator[Session, None, None]:
# with get_session_context_manager() as session:
# yield session
@pytest.fixture
def db_session() -> Generator[Session, None, None]:
with get_session_context_manager() as session:
yield session
@pytest.fixture
def vespa_client(db_session: Session) -> vespa_fixture:
search_settings = get_current_search_settings(db_session)
return vespa_fixture(index_name=search_settings.index_name)
def vespa_client() -> vespa_fixture:
with get_session_context_manager() as db_session:
search_settings = get_current_search_settings(db_session)
return vespa_fixture(index_name=search_settings.index_name)
@pytest.fixture
def reset() -> None:
if GUARANTEED_FRESH_SETUP:
print("GUARANTEED_FRESH_SETUP is true, skipping reset")
return None
reset_all()
@pytest.fixture
def new_admin_user(reset: None) -> DATestUser | None:
try:
return UserManager.create(name="admin_user")
return UserManager.create(name=ADMIN_USER_NAME)
except Exception:
return None
@pytest.fixture
def admin_user() -> DATestUser | None:
def admin_user() -> DATestUser:
try:
return UserManager.create(name="admin_user")
except Exception:
pass
user = UserManager.create(name=ADMIN_USER_NAME, is_first_user=True)
# if there are other users for some reason, reset and try again
if not UserManager.is_role(user, UserRole.ADMIN):
print("Trying to reset")
reset_all()
user = UserManager.create(name=ADMIN_USER_NAME)
return user
except Exception as e:
print(f"Failed to create admin user: {e}")
try:
return UserManager.login_as_user(
user = UserManager.login_as_user(
DATestUser(
id="",
email=build_email("admin_user"),
@@ -79,10 +84,16 @@ def admin_user() -> DATestUser | None:
is_active=True,
)
)
except Exception:
pass
if not UserManager.is_role(user, UserRole.ADMIN):
reset_all()
user = UserManager.create(name=ADMIN_USER_NAME)
return user
return None
return user
except Exception as e:
print(f"Failed to create or login as admin user: {e}")
raise RuntimeError("Failed to create or login as admin user")
@pytest.fixture

View File

@@ -118,6 +118,7 @@ def test_google_permission_sync(
GoogleDriveService, str, DATestCCPair, DATestUser, DATestUser, DATestUser
],
) -> None:
print("Running test_google_permission_sync")
(
drive_service,
drive_id,
@@ -138,7 +139,9 @@ def test_google_permission_sync(
GoogleDriveManager.append_text_to_doc(drive_service, doc_id_1, doc_text_1)
# run indexing
CCPairManager.run_once(cc_pair, admin_user)
CCPairManager.run_once(
cc_pair, from_beginning=True, user_performing_action=admin_user
)
CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair, after=before, user_performing_action=admin_user
)
@@ -184,7 +187,9 @@ def test_google_permission_sync(
GoogleDriveManager.append_text_to_doc(drive_service, doc_id_2, doc_text_2)
# Run indexing
CCPairManager.run_once(cc_pair, admin_user)
CCPairManager.run_once(
cc_pair, from_beginning=True, user_performing_action=admin_user
)
CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=before,

View File

@@ -113,7 +113,9 @@ def test_slack_permission_sync(
# Run indexing
before = datetime.now(timezone.utc)
CCPairManager.run_once(cc_pair, admin_user)
CCPairManager.run_once(
cc_pair, from_beginning=True, user_performing_action=admin_user
)
CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=before,
@@ -305,7 +307,9 @@ def test_slack_group_permission_sync(
)
# Run indexing
CCPairManager.run_once(cc_pair, admin_user)
CCPairManager.run_once(
cc_pair, from_beginning=True, user_performing_action=admin_user
)
CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=before,

View File

@@ -111,7 +111,9 @@ def test_slack_prune(
# Run indexing
before = datetime.now(timezone.utc)
CCPairManager.run_once(cc_pair, admin_user)
CCPairManager.run_once(
cc_pair, from_beginning=True, user_performing_action=admin_user
)
CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=before,

View File

@@ -0,0 +1,71 @@
import os
from pathlib import Path
import pytest
from _pytest.nodes import Item
def list_all_tests(directory: str | Path = ".") -> list[str]:
"""
List all pytest test functions under the specified directory.
Args:
directory: Directory path to search for tests (defaults to current directory)
Returns:
List of test function names with their module paths
"""
directory = Path(directory).absolute()
print(f"Searching for tests in: {directory}")
class TestCollector:
def __init__(self) -> None:
self.collected: list[str] = []
def pytest_collection_modifyitems(self, items: list[Item]) -> None:
for item in items:
if isinstance(item, Item):
# Get the relative path from the test file to the directory we're searching from
rel_path = Path(item.fspath).relative_to(directory)
# Remove the .py extension
module_path = str(rel_path.with_suffix(""))
# Replace directory separators with dots
module_path = module_path.replace("/", ".")
test_name = item.name
self.collected.append(f"{module_path}::{test_name}")
collector = TestCollector()
# Run pytest in collection-only mode
pytest.main(
[
str(directory),
"--collect-only",
"-q", # quiet mode
],
plugins=[collector],
)
return sorted(collector.collected)
def load_env_vars(env_file: str = ".env") -> None:
current_dir = os.path.dirname(os.path.abspath(__file__))
env_path = os.path.join(current_dir, env_file)
try:
with open(env_path, "r") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#"):
key, value = line.split("=", 1)
os.environ[key] = value.strip()
print("Successfully loaded environment variables")
except FileNotFoundError:
print(f"File {env_file} not found")
if __name__ == "__main__":
tests = list_all_tests()
print("\nFound tests:")
for test in tests:
print(f"- {test}")

View File

@@ -0,0 +1,637 @@
#!/usr/bin/env python3
import atexit
import os
import random
import signal
import socket
import subprocess
import sys
import time
import uuid
from collections.abc import Callable
from multiprocessing.pool import ThreadPool
from pathlib import Path
from typing import NamedTuple
import requests
import yaml
from onyx.configs.constants import AuthType
from onyx.document_index.vespa.index import VespaIndex
BACKEND_DIR_PATH = Path(__file__).parent.parent.parent
COMPOSE_DIR_PATH = BACKEND_DIR_PATH.parent / "deployment/docker_compose"
DEFAULT_EMBEDDING_DIMENSION = 768
DEFAULT_SCHEMA_NAME = "danswer_chunk_nomic_ai_nomic_embed_text_v1"
class DeploymentConfig(NamedTuple):
instance_num: int
api_port: int
web_port: int
nginx_port: int
redis_port: int
postgres_db: str
class SharedServicesConfig(NamedTuple):
run_id: uuid.UUID
postgres_port: int
vespa_port: int
vespa_tenant_port: int
model_server_port: int
def get_random_port() -> int:
"""Find a random available port."""
while True:
port = random.randint(10000, 65535)
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
if sock.connect_ex(("localhost", port)) != 0:
return port
def cleanup_pid(pid: int) -> None:
"""Cleanup a specific PID."""
print(f"Killing process {pid}")
try:
os.kill(pid, signal.SIGTERM)
except ProcessLookupError:
print(f"Process {pid} not found")
def get_shared_services_stack_name(run_id: uuid.UUID) -> str:
return f"base-onyx-{run_id}"
def get_db_name(instance_num: int) -> str:
"""Get the database name for a given instance number."""
return f"onyx_{instance_num}"
def get_vector_db_prefix(instance_num: int) -> str:
"""Get the vector DB prefix for a given instance number."""
return f"test_instance_{instance_num}"
def setup_db(
instance_num: int,
postgres_port: int,
) -> str:
env = os.environ.copy()
# Wait for postgres to be ready
max_attempts = 10
for attempt in range(max_attempts):
try:
subprocess.run(
[
"psql",
"-h",
"localhost",
"-p",
str(postgres_port),
"-U",
"postgres",
"-c",
"SELECT 1",
],
env={**env, "PGPASSWORD": "password"},
check=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
break
except subprocess.CalledProcessError:
if attempt == max_attempts - 1:
raise RuntimeError("Postgres failed to become ready within timeout")
time.sleep(1)
db_name = get_db_name(instance_num)
# Create the database first
subprocess.run(
[
"psql",
"-h",
"localhost",
"-p",
str(postgres_port),
"-U",
"postgres",
"-c",
f"CREATE DATABASE {db_name}",
],
env={**env, "PGPASSWORD": "password"},
check=True,
)
# NEW: Stamp this brand-new DB at 'base' so Alembic doesn't fail
subprocess.run(
[
"alembic",
"stamp",
"base",
],
env={
**env,
"PGPASSWORD": "password",
"POSTGRES_HOST": "localhost",
"POSTGRES_PORT": str(postgres_port),
"POSTGRES_DB": db_name,
},
check=True,
cwd=str(BACKEND_DIR_PATH),
)
# Run alembic upgrade to create tables
max_attempts = 3
for attempt in range(max_attempts):
try:
subprocess.run(
["alembic", "upgrade", "head"],
env={
**env,
"POSTGRES_HOST": "localhost",
"POSTGRES_PORT": str(postgres_port),
"POSTGRES_DB": db_name,
},
check=True,
cwd=str(BACKEND_DIR_PATH),
)
break
except subprocess.CalledProcessError:
if attempt == max_attempts - 1:
raise
print("Alembic upgrade failed, retrying in 5 seconds...")
time.sleep(5)
return db_name
def start_api_server(
instance_num: int,
model_server_port: int,
postgres_port: int,
vespa_port: int,
vespa_tenant_port: int,
redis_port: int,
register_process: Callable[[subprocess.Popen], None],
) -> int:
"""Start the API server.
NOTE: assumes that Postgres is all set up (database exists, migrations ran)
"""
print("Starting API server...")
db_name = get_db_name(instance_num)
vector_db_prefix = get_vector_db_prefix(instance_num)
env = os.environ.copy()
env.update(
{
"POSTGRES_HOST": "localhost",
"POSTGRES_PORT": str(postgres_port),
"POSTGRES_DB": db_name,
"REDIS_HOST": "localhost",
"REDIS_PORT": str(redis_port),
"VESPA_HOST": "localhost",
"VESPA_PORT": str(vespa_port),
"VESPA_TENANT_PORT": str(vespa_tenant_port),
"MODEL_SERVER_PORT": str(model_server_port),
"VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY": vector_db_prefix,
"LOG_LEVEL": "debug",
"AUTH_TYPE": AuthType.BASIC,
}
)
port = get_random_port()
# Open log file for API server in /tmp
log_file = open(f"/tmp/api_server_{instance_num}.txt", "w")
process = subprocess.Popen(
[
"uvicorn",
"onyx.main:app",
"--host",
"localhost",
"--port",
str(port),
],
env=env,
cwd=str(BACKEND_DIR_PATH),
stdout=log_file,
stderr=subprocess.STDOUT,
)
register_process(process)
return port
def start_background(
instance_num: int,
postgres_port: int,
vespa_port: int,
vespa_tenant_port: int,
redis_port: int,
register_process: Callable[[subprocess.Popen], None],
) -> None:
"""Start the background process."""
print("Starting background process...")
env = os.environ.copy()
env.update(
{
"POSTGRES_HOST": "localhost",
"POSTGRES_PORT": str(postgres_port),
"POSTGRES_DB": get_db_name(instance_num),
"REDIS_HOST": "localhost",
"REDIS_PORT": str(redis_port),
"VESPA_HOST": "localhost",
"VESPA_PORT": str(vespa_port),
"VESPA_TENANT_PORT": str(vespa_tenant_port),
"VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY": get_vector_db_prefix(
instance_num
),
"LOG_LEVEL": "debug",
}
)
str(Path(__file__).parent / "backend")
# Open log file for background process in /tmp
log_file = open(f"/tmp/background_{instance_num}.txt", "w")
process = subprocess.Popen(
["supervisord", "-n", "-c", "./supervisord.conf"],
env=env,
cwd=str(BACKEND_DIR_PATH),
stdout=log_file,
stderr=subprocess.STDOUT,
)
register_process(process)
def start_shared_services(run_id: uuid.UUID) -> SharedServicesConfig:
"""Start Postgres and Vespa using docker-compose.
Returns (postgres_port, vespa_port, vespa_tenant_port, model_server_port)
"""
print("Starting database services...")
postgres_port = get_random_port()
vespa_port = get_random_port()
vespa_tenant_port = get_random_port()
model_server_port = get_random_port()
minimal_compose = {
"services": {
"relational_db": {
"image": "postgres:15.2-alpine",
"command": "-c 'max_connections=1000'",
"environment": {
"POSTGRES_USER": os.getenv("POSTGRES_USER", "postgres"),
"POSTGRES_PASSWORD": os.getenv("POSTGRES_PASSWORD", "password"),
},
"ports": [f"{postgres_port}:5432"],
},
"index": {
"image": "vespaengine/vespa:8.277.17",
"ports": [
f"{vespa_port}:8081", # Main Vespa port
f"{vespa_tenant_port}:19071", # Tenant port
],
},
},
}
# Write the minimal compose file
temp_compose = Path("/tmp/docker-compose.minimal.yml")
with open(temp_compose, "w") as f:
yaml.dump(minimal_compose, f)
# Start the services
subprocess.run(
[
"docker",
"compose",
"-f",
str(temp_compose),
"-p",
get_shared_services_stack_name(run_id),
"up",
"-d",
],
check=True,
)
# Start the shared model server
env = os.environ.copy()
env.update(
{
"POSTGRES_HOST": "localhost",
"POSTGRES_PORT": str(postgres_port),
"VESPA_HOST": "localhost",
"VESPA_PORT": str(vespa_port),
"VESPA_TENANT_PORT": str(vespa_tenant_port),
"LOG_LEVEL": "debug",
}
)
# Open log file for shared model server in /tmp
log_file = open("/tmp/shared_model_server.txt", "w")
process = subprocess.Popen(
[
"uvicorn",
"model_server.main:app",
"--host",
"0.0.0.0",
"--port",
str(model_server_port),
],
env=env,
cwd=str(BACKEND_DIR_PATH),
stdout=log_file,
stderr=subprocess.STDOUT,
)
atexit.register(cleanup_pid, process.pid)
shared_services_config = SharedServicesConfig(
run_id, postgres_port, vespa_port, vespa_tenant_port, model_server_port
)
print(f"Shared services config: {shared_services_config}")
return shared_services_config
def prepare_vespa(instance_ids: list[int], vespa_tenant_port: int) -> None:
schema_names = [
(
f"{get_vector_db_prefix(instance_id)}_{DEFAULT_SCHEMA_NAME}",
DEFAULT_EMBEDDING_DIMENSION,
False,
)
for instance_id in instance_ids
]
print(f"Creating indices: {schema_names}")
for _ in range(7):
try:
VespaIndex.create_indices(
schema_names, f"http://localhost:{vespa_tenant_port}/application/v2"
)
return
except Exception as e:
print(f"Error creating indices: {e}. Trying again in 5 seconds...")
time.sleep(5)
raise RuntimeError("Failed to create indices in Vespa")
def start_redis(
instance_num: int,
register_process: Callable[[subprocess.Popen], None],
) -> int:
"""Start a Redis instance for a specific deployment."""
print(f"Starting Redis for instance {instance_num}...")
redis_port = get_random_port()
container_name = f"redis-onyx-{instance_num}"
# Start Redis using docker run
subprocess.run(
[
"docker",
"run",
"-d",
"--name",
container_name,
"-p",
f"{redis_port}:6379",
"redis:7.4-alpine",
"redis-server",
"--save",
'""',
"--appendonly",
"no",
],
check=True,
)
return redis_port
def launch_instance(
instance_num: int,
postgres_port: int,
vespa_port: int,
vespa_tenant_port: int,
model_server_port: int,
register_process: Callable[[subprocess.Popen], None],
) -> DeploymentConfig:
"""Launch a Docker Compose instance with custom ports."""
api_port = get_random_port()
web_port = get_random_port()
nginx_port = get_random_port()
# Start Redis for this instance
redis_port = start_redis(instance_num, register_process)
try:
db_name = setup_db(instance_num, postgres_port)
api_port = start_api_server(
instance_num,
model_server_port, # Use shared model server port
postgres_port,
vespa_port,
vespa_tenant_port,
redis_port,
register_process,
)
start_background(
instance_num,
postgres_port,
vespa_port,
vespa_tenant_port,
redis_port,
register_process,
)
except Exception as e:
print(f"Failed to start API server for instance {instance_num}: {e}")
raise
return DeploymentConfig(
instance_num, api_port, web_port, nginx_port, redis_port, db_name
)
def wait_for_instance(
ports: DeploymentConfig, max_attempts: int = 120, wait_seconds: int = 2
) -> None:
"""Wait for an instance to be healthy."""
print(f"Waiting for instance {ports.instance_num} to be ready...")
for attempt in range(1, max_attempts + 1):
try:
response = requests.get(f"http://localhost:{ports.api_port}/health")
if response.status_code == 200:
print(
f"Instance {ports.instance_num} is ready on port {ports.api_port}"
)
return
raise ConnectionError(
f"Health check returned status {response.status_code}"
)
except (requests.RequestException, ConnectionError):
if attempt == max_attempts:
raise TimeoutError(
f"Timeout waiting for instance {ports.instance_num} "
f"on port {ports.api_port}"
)
print(
f"Waiting for instance {ports.instance_num} on port "
f" {ports.api_port}... ({attempt}/{max_attempts})"
)
time.sleep(wait_seconds)
def cleanup_instance(instance_num: int) -> None:
"""Cleanup a specific instance."""
print(f"Cleaning up instance {instance_num}...")
temp_compose = Path(f"/tmp/docker-compose.dev.instance{instance_num}.yml")
try:
subprocess.run(
[
"docker",
"compose",
"-f",
str(temp_compose),
"-p",
f"onyx-stack-{instance_num}",
"down",
],
check=True,
)
print(f"Instance {instance_num} cleaned up successfully")
except subprocess.CalledProcessError:
print(f"Error cleaning up instance {instance_num}")
except FileNotFoundError:
print(f"No compose file found for instance {instance_num}")
finally:
# Clean up the temporary compose file if it exists
if temp_compose.exists():
temp_compose.unlink()
print(f"Removed temporary compose file for instance {instance_num}")
def run_x_instances(
num_instances: int,
) -> tuple[SharedServicesConfig, list[DeploymentConfig]]:
"""Start x instances of the application and return their configurations."""
run_id = uuid.uuid4()
instance_ids = list(range(1, num_instances + 1))
_pids: list[int] = []
def register_process(process: subprocess.Popen) -> None:
_pids.append(process.pid)
def cleanup_all_instances() -> None:
"""Cleanup all instances."""
print("Cleaning up all instances...")
# Stop the database services
subprocess.run(
[
"docker",
"compose",
"-p",
get_shared_services_stack_name(run_id),
"-f",
"/tmp/docker-compose.minimal.yml",
"down",
],
check=True,
)
# Stop and remove all Redis containers
for instance_id in range(1, num_instances + 1):
container_name = f"redis-onyx-{instance_id}"
try:
subprocess.run(["docker", "rm", "-f", container_name], check=True)
except subprocess.CalledProcessError:
print(f"Error cleaning up Redis container {container_name}")
for pid in _pids:
cleanup_pid(pid)
# Register cleanup handler
atexit.register(cleanup_all_instances)
# Start database services first
print("Starting shared services...")
shared_services_config = start_shared_services(run_id)
# create documents
print("Creating indices in Vespa...")
prepare_vespa(instance_ids, shared_services_config.vespa_tenant_port)
# Use ThreadPool to launch instances in parallel and collect results
# NOTE: only kick off 10 at a time to avoid overwhelming the system
print("Launching instances...")
with ThreadPool(processes=len(instance_ids)) as pool:
# Create list of arguments for each instance
launch_args = [
(
i,
shared_services_config.postgres_port,
shared_services_config.vespa_port,
shared_services_config.vespa_tenant_port,
shared_services_config.model_server_port,
register_process,
)
for i in instance_ids
]
# Launch instances and get results
port_configs = pool.starmap(launch_instance, launch_args)
# Wait for all instances to be healthy
print("Waiting for instances to be healthy...")
with ThreadPool(processes=len(port_configs)) as pool:
pool.map(wait_for_instance, port_configs)
print("All instances launched!")
print("Database Services:")
print(f"Postgres port: {shared_services_config.postgres_port}")
print(f"Vespa main port: {shared_services_config.vespa_port}")
print(f"Vespa tenant port: {shared_services_config.vespa_tenant_port}")
print("\nApplication Instances:")
for ports in port_configs:
print(
f"Instance {ports.instance_num}: "
f"API={ports.api_port}, Web={ports.web_port}, Nginx={ports.nginx_port}"
)
return shared_services_config, port_configs
def main() -> None:
shared_services_config, port_configs = run_x_instances(1)
# Run pytest with the API server port set
api_port = port_configs[0].api_port # Use first instance's API port
try:
subprocess.run(
["pytest", "tests/integration/openai_assistants_api"],
env={**os.environ, "API_SERVER_PORT": str(api_port)},
cwd=str(BACKEND_DIR_PATH),
check=True,
)
except subprocess.CalledProcessError as e:
print(f"Tests failed with exit code {e.returncode}")
sys.exit(e.returncode)
time.sleep(5)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,20 @@
version: '3.8'
services:
mock_connector_server:
build:
context: ./mock_connector_server
dockerfile: Dockerfile
ports:
- "8001:8001"
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8001/health"]
interval: 10s
timeout: 5s
retries: 5
networks:
- onyx-stack_default
networks:
onyx-stack_default:
name: onyx-stack_default
external: true

View File

@@ -0,0 +1,9 @@
FROM python:3.11.7-slim-bookworm
WORKDIR /app
RUN pip install fastapi uvicorn
COPY ./main.py /app/main.py
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8001"]

View File

@@ -0,0 +1,76 @@
from fastapi import FastAPI
from fastapi import HTTPException
from pydantic import BaseModel
from pydantic import Field
# We would like to import these, but it makes building this so much harder/slower
# from onyx.connectors.mock_connector.connector import SingleConnectorYield
# from onyx.connectors.models import ConnectorCheckpoint
app = FastAPI()
# Global state to store connector behavior configuration
class ConnectorBehavior(BaseModel):
connector_yields: list[dict] = Field(
default_factory=list
) # really list[SingleConnectorYield]
called_with_checkpoints: list[dict] = Field(
default_factory=list
) # really list[ConnectorCheckpoint]
current_behavior: ConnectorBehavior = ConnectorBehavior()
@app.post("/set-behavior")
async def set_behavior(behavior: list[dict]) -> None:
"""Set the behavior for the next connector run"""
global current_behavior
current_behavior = ConnectorBehavior(connector_yields=behavior)
@app.get("/get-documents")
async def get_documents() -> list[dict]:
"""Get the next batch of documents and update the checkpoint"""
global current_behavior
if not current_behavior.connector_yields:
raise HTTPException(
status_code=400, detail="No documents or failures configured"
)
connector_yields = current_behavior.connector_yields
# Clear the current behavior after returning it
current_behavior = ConnectorBehavior()
return connector_yields
@app.post("/add-checkpoint")
async def add_checkpoint(checkpoint: dict) -> None:
"""Add a checkpoint to the list of checkpoints. Called by the MockConnector."""
global current_behavior
current_behavior.called_with_checkpoints.append(checkpoint)
@app.get("/get-checkpoints")
async def get_checkpoints() -> list[dict]:
"""Get the list of checkpoints. Used by the test to verify the
proper checkpoint ordering."""
global current_behavior
return current_behavior.called_with_checkpoints
@app.post("/reset")
async def reset() -> None:
"""Reset the connector behavior to default"""
global current_behavior
current_behavior = ConnectorBehavior()
@app.get("/health")
async def health_check() -> dict[str, str]:
"""Health check endpoint"""
return {"status": "healthy"}

View File

@@ -0,0 +1,266 @@
import multiprocessing
import os
import queue
import subprocess
import sys
import threading
import time
from dataclasses import dataclass
from multiprocessing.synchronize import Lock as LockType
from pathlib import Path
from tests.integration.common_utils.reset import reset_all
from tests.integration.introspection import list_all_tests
from tests.integration.introspection import load_env_vars
from tests.integration.kickoff import BACKEND_DIR_PATH
from tests.integration.kickoff import DeploymentConfig
from tests.integration.kickoff import run_x_instances
from tests.integration.kickoff import SharedServicesConfig
@dataclass
class TestResult:
test_name: str
success: bool
output: str
error: str | None = None
def run_single_test(
test_name: str,
deployment_config: DeploymentConfig,
shared_services_config: SharedServicesConfig,
result_queue: multiprocessing.Queue,
) -> None:
"""Run a single test with the given API port."""
test_path, test_name = test_name.split("::")
processed_test_name = f"{test_path.replace('.', '/')}.py::{test_name}"
print(f"Running test: {processed_test_name}", flush=True)
try:
env = {
**os.environ,
"API_SERVER_PORT": str(deployment_config.api_port),
"PYTHONPATH": ".",
"GUARANTEED_FRESH_SETUP": "true",
"POSTGRES_PORT": str(shared_services_config.postgres_port),
"POSTGRES_DB": deployment_config.postgres_db,
"REDIS_PORT": str(deployment_config.redis_port),
"VESPA_PORT": str(shared_services_config.vespa_port),
"VESPA_TENANT_PORT": str(shared_services_config.vespa_tenant_port),
}
result = subprocess.run(
["pytest", processed_test_name, "-v"],
env=env,
cwd=str(BACKEND_DIR_PATH),
capture_output=True,
text=True,
)
result_queue.put(
TestResult(
test_name=test_name,
success=result.returncode == 0,
output=result.stdout,
error=result.stderr if result.returncode != 0 else None,
)
)
except Exception as e:
result_queue.put(
TestResult(
test_name=test_name,
success=False,
output="",
error=str(e),
)
)
def worker(
test_queue: queue.Queue[str],
instance_queue: queue.Queue[int],
result_queue: multiprocessing.Queue,
shared_services_config: SharedServicesConfig,
deployment_configs: list[DeploymentConfig],
reset_lock: LockType,
) -> None:
"""Worker process that runs tests on available instances."""
while True:
# Get the next test from the queue
try:
test = test_queue.get(block=False)
except queue.Empty:
break
# Get an available instance
instance_idx = instance_queue.get()
deployment_config = deployment_configs[
instance_idx - 1
] # Convert to 0-based index
try:
# Run the test
run_single_test(
test, deployment_config, shared_services_config, result_queue
)
# get instance ready for next test
print(
f"Resetting instance for next. DB: {deployment_config.postgres_db}, "
f"Port: {shared_services_config.postgres_port}"
)
# alembic is NOT thread-safe, so we need to make sure only one worker is resetting at a time
with reset_lock:
reset_all(
database=deployment_config.postgres_db,
postgres_port=str(shared_services_config.postgres_port),
redis_port=deployment_config.redis_port,
silence_logs=True,
# indices are created during the kickoff process, no need to recreate them
skip_creating_indices=True,
# use the special vespa port
document_id_endpoint=(
f"http://localhost:{shared_services_config.vespa_port}"
"/document/v1/default/{{index_name}}/docid"
),
)
except Exception as e:
# Log the error and put it in the result queue
error_msg = f"Critical error in worker thread for test {test}: {str(e)}"
print(error_msg, file=sys.stderr)
result_queue.put(
TestResult(
test_name=test,
success=False,
output="",
error=error_msg,
)
)
# Re-raise to stop the worker
raise
finally:
# Put the instance back in the queue
instance_queue.put(instance_idx)
test_queue.task_done()
def main() -> None:
NUM_INSTANCES = 7
# Get all tests
prefixes = ["tests", "connector_job_tests"]
tests = []
for prefix in prefixes:
tests += [
f"tests/integration/{prefix}/{test_path}"
for test_path in list_all_tests(Path(__file__).parent / prefix)
]
print(f"Found {len(tests)} tests to run")
# load env vars which will be passed into the tests
load_env_vars(os.environ.get("IT_ENV_FILE_PATH", ".env"))
# For debugging
# tests = [test for test in tests if "openai_assistants_api" in test]
# tests = tests[:2]
print(f"Running {len(tests)} tests")
# Start all instances at once
shared_services_config, deployment_configs = run_x_instances(NUM_INSTANCES)
# Create queues and lock
test_queue: queue.Queue[str] = queue.Queue()
instance_queue: queue.Queue[int] = queue.Queue()
result_queue: multiprocessing.Queue = multiprocessing.Queue()
reset_lock: LockType = multiprocessing.Lock()
# Fill the instance queue with available instance numbers
for i in range(1, NUM_INSTANCES + 1):
instance_queue.put(i)
# Fill the test queue with all tests
for test in tests:
test_queue.put(test)
# Start worker threads
workers = []
for _ in range(NUM_INSTANCES):
worker_thread = threading.Thread(
target=worker,
args=(
test_queue,
instance_queue,
result_queue,
shared_services_config,
deployment_configs,
reset_lock,
),
)
worker_thread.start()
workers.append(worker_thread)
# Monitor workers and fail fast if any die
try:
while any(w.is_alive() for w in workers):
# Check if all tests are done
if test_queue.empty() and all(not w.is_alive() for w in workers):
break
# Check for dead workers that died with unfinished tests
if not test_queue.empty() and any(not w.is_alive() for w in workers):
print(
"\nCritical: Worker thread(s) died with tests remaining!",
file=sys.stderr,
)
sys.exit(1)
time.sleep(0.1) # Avoid busy waiting
# Collect results
print("Collecting results")
results: list[TestResult] = []
while not result_queue.empty():
results.append(result_queue.get())
# Print results
print("\nTest Results:")
failed = False
failed_tests: list[str] = []
total_tests = len(results)
passed_tests = 0
for result in results:
status = "✅ PASSED" if result.success else "❌ FAILED"
print(f"{status} - {result.test_name}")
if result.success:
passed_tests += 1
else:
failed = True
failed_tests.append(result.test_name)
print("Error output:")
print(result.error)
print("Test output:")
print(result.output)
print("-" * 80)
# Print summary
print("\nTest Summary:")
print(f"Total Tests: {total_tests}")
print(f"Passed: {passed_tests}")
print(f"Failed: {len(failed_tests)}")
if failed_tests:
print("\nFailed Tests:")
for test_name in failed_tests:
print(f"{test_name}")
print()
if failed:
sys.exit(1)
except KeyboardInterrupt:
print("\nTest run interrupted by user", file=sys.stderr)
sys.exit(130) # Standard exit code for SIGINT
except Exception as e:
print(f"\nCritical error during result collection: {str(e)}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -9,6 +9,8 @@ from uuid import uuid4
from sqlalchemy.orm import Session
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import DocumentFailure
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.enums import IndexingStatus
from onyx.db.index_attempt import create_index_attempt
@@ -101,10 +103,15 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
create_index_attempt_error(
index_attempt_id=new_attempt.id,
batch=1,
docs=[],
exception_msg="",
exception_traceback="",
connector_credential_pair_id=cc_pair_1.id,
failure=ConnectorFailure(
failure_message="Test error",
failed_document=DocumentFailure(
document_id=cc_pair_1.documents[0].id,
document_link=None,
),
failed_entity=None,
),
db_session=db_session,
)
@@ -127,10 +134,15 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
)
create_index_attempt_error(
index_attempt_id=attempt_id,
batch=1,
docs=[],
exception_msg="",
exception_traceback="",
connector_credential_pair_id=cc_pair_1.id,
failure=ConnectorFailure(
failure_message="Test error",
failed_document=DocumentFailure(
document_id=cc_pair_1.documents[0].id,
document_link=None,
),
failed_entity=None,
),
db_session=db_session,
)

View File

@@ -0,0 +1,518 @@
import uuid
from datetime import datetime
from datetime import timedelta
from datetime import timezone
import httpx
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import EntityFailure
from onyx.connectors.models import InputType
from onyx.db.engine import get_session_context_manager
from onyx.db.enums import IndexingStatus
from tests.integration.common_utils.constants import MOCK_CONNECTOR_SERVER_HOST
from tests.integration.common_utils.constants import MOCK_CONNECTOR_SERVER_PORT
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.document import DocumentManager
from tests.integration.common_utils.managers.index_attempt import IndexAttemptManager
from tests.integration.common_utils.test_document_utils import create_test_document
from tests.integration.common_utils.test_document_utils import (
create_test_document_failure,
)
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.vespa import vespa_fixture
@pytest.fixture
def mock_server_client() -> httpx.Client:
print(
f"Initializing mock server client with host: "
f"{MOCK_CONNECTOR_SERVER_HOST} and port: "
f"{MOCK_CONNECTOR_SERVER_PORT}"
)
return httpx.Client(
base_url=f"http://{MOCK_CONNECTOR_SERVER_HOST}:{MOCK_CONNECTOR_SERVER_PORT}",
timeout=5.0,
)
def test_mock_connector_basic_flow(
mock_server_client: httpx.Client,
vespa_client: vespa_fixture,
admin_user: DATestUser,
) -> None:
"""Test that the mock connector can successfully process documents and failures"""
# Set up mock server behavior
doc_uuid = uuid.uuid4()
test_doc = create_test_document(doc_id=f"test-doc-{doc_uuid}")
response = mock_server_client.post(
"/set-behavior",
json=[
{
"documents": [test_doc.model_dump(mode="json")],
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=False
).model_dump(mode="json"),
"failures": [],
}
],
)
assert response.status_code == 200
# create CC Pair + index attempt
cc_pair = CCPairManager.create_from_scratch(
name=f"mock-connector-{uuid.uuid4()}",
source=DocumentSource.MOCK_CONNECTOR,
input_type=InputType.POLL,
connector_specific_config={
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
},
user_performing_action=admin_user,
)
# wait for index attempt to start
index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
# wait for index attempt to finish
IndexAttemptManager.wait_for_index_attempt_completion(
index_attempt_id=index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
# validate status
finished_index_attempt = IndexAttemptManager.get_index_attempt_by_id(
index_attempt_id=index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
assert finished_index_attempt.status == IndexingStatus.SUCCESS
# Verify results
with get_session_context_manager() as db_session:
documents = DocumentManager.fetch_documents_for_cc_pair(
cc_pair_id=cc_pair.id,
db_session=db_session,
vespa_client=vespa_client,
)
assert len(documents) == 1
assert documents[0].id == test_doc.id
errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair(
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
assert len(errors) == 0
def test_mock_connector_with_failures(
mock_server_client: httpx.Client,
vespa_client: vespa_fixture,
admin_user: DATestUser,
) -> None:
"""Test that the mock connector processes both successes and failures properly."""
doc1 = create_test_document()
doc2 = create_test_document()
doc2_failure = create_test_document_failure(doc_id=doc2.id)
response = mock_server_client.post(
"/set-behavior",
json=[
{
"documents": [doc1.model_dump(mode="json")],
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=False
).model_dump(mode="json"),
"failures": [doc2_failure.model_dump(mode="json")],
}
],
)
assert response.status_code == 200
# Create a CC Pair for the mock connector
cc_pair = CCPairManager.create_from_scratch(
name=f"mock-connector-failure-{uuid.uuid4()}",
source=DocumentSource.MOCK_CONNECTOR,
input_type=InputType.POLL,
connector_specific_config={
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
},
user_performing_action=admin_user,
)
# Wait for the index attempt to start and then complete
index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
IndexAttemptManager.wait_for_index_attempt_completion(
index_attempt_id=index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
# validate status
finished_index_attempt = IndexAttemptManager.get_index_attempt_by_id(
index_attempt_id=index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
assert finished_index_attempt.status == IndexingStatus.COMPLETED_WITH_ERRORS
# Verify results: doc1 should be indexed and doc2 should have an error entry
with get_session_context_manager() as db_session:
documents = DocumentManager.fetch_documents_for_cc_pair(
cc_pair_id=cc_pair.id,
db_session=db_session,
vespa_client=vespa_client,
)
assert len(documents) == 1
assert documents[0].id == doc1.id
errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair(
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
assert len(errors) == 1
error = errors[0]
assert error.failure_message == doc2_failure.failure_message
assert error.document_id == doc2.id
def test_mock_connector_failure_recovery(
mock_server_client: httpx.Client,
vespa_client: vespa_fixture,
admin_user: DATestUser,
) -> None:
"""Test that a failed document can be successfully indexed in a subsequent attempt
while maintaining previously successful documents."""
# Create test documents and failure
doc1 = create_test_document()
doc2 = create_test_document()
doc2_failure = create_test_document_failure(doc_id=doc2.id)
entity_id = "test-entity-id"
entity_failure_msg = "Simulated unhandled error"
response = mock_server_client.post(
"/set-behavior",
json=[
{
"documents": [doc1.model_dump(mode="json")],
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=False
).model_dump(mode="json"),
"failures": [
doc2_failure.model_dump(mode="json"),
ConnectorFailure(
failed_entity=EntityFailure(
entity_id=entity_id,
missed_time_range=(
datetime.now(timezone.utc) - timedelta(days=1),
datetime.now(timezone.utc),
),
),
failure_message=entity_failure_msg,
).model_dump(mode="json"),
],
}
],
)
assert response.status_code == 200
# Create CC Pair and run initial indexing attempt
cc_pair = CCPairManager.create_from_scratch(
name=f"mock-connector-{uuid.uuid4()}",
source=DocumentSource.MOCK_CONNECTOR,
input_type=InputType.POLL,
connector_specific_config={
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
},
user_performing_action=admin_user,
)
# Wait for first index attempt to complete
initial_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
IndexAttemptManager.wait_for_index_attempt_completion(
index_attempt_id=initial_index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
# validate status
finished_index_attempt = IndexAttemptManager.get_index_attempt_by_id(
index_attempt_id=initial_index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
assert finished_index_attempt.status == IndexingStatus.COMPLETED_WITH_ERRORS
# Verify initial state: doc1 indexed, doc2 failed
with get_session_context_manager() as db_session:
documents = DocumentManager.fetch_documents_for_cc_pair(
cc_pair_id=cc_pair.id,
db_session=db_session,
vespa_client=vespa_client,
)
assert len(documents) == 1
assert documents[0].id == doc1.id
errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair(
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
assert len(errors) == 2
error_doc2 = next(error for error in errors if error.document_id == doc2.id)
assert error_doc2.failure_message == doc2_failure.failure_message
assert not error_doc2.is_resolved
error_entity = next(error for error in errors if error.entity_id == entity_id)
assert error_entity.failure_message == entity_failure_msg
assert not error_entity.is_resolved
# Update mock server to return success for both documents
response = mock_server_client.post(
"/set-behavior",
json=[
{
"documents": [
doc1.model_dump(mode="json"),
doc2.model_dump(mode="json"),
],
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=False
).model_dump(mode="json"),
"failures": [],
}
],
)
assert response.status_code == 200
# Trigger another indexing attempt
# NOTE: must be from beginning to handle the entity failure
CCPairManager.run_once(
cc_pair, from_beginning=True, user_performing_action=admin_user
)
recovery_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
cc_pair_id=cc_pair.id,
index_attempts_to_ignore=[initial_index_attempt.id],
user_performing_action=admin_user,
)
IndexAttemptManager.wait_for_index_attempt_completion(
index_attempt_id=recovery_index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
finished_second_index_attempt = IndexAttemptManager.get_index_attempt_by_id(
index_attempt_id=recovery_index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
assert finished_second_index_attempt.status == IndexingStatus.SUCCESS
# Verify both documents are now indexed
with get_session_context_manager() as db_session:
documents = DocumentManager.fetch_documents_for_cc_pair(
cc_pair_id=cc_pair.id,
db_session=db_session,
vespa_client=vespa_client,
)
assert len(documents) == 2
document_ids = {doc.id for doc in documents}
assert doc2.id in document_ids
assert doc1.id in document_ids
# Verify original failures were marked as resolved
errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair(
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
assert len(errors) == 2
error_doc2 = next(error for error in errors if error.document_id == doc2.id)
error_entity = next(error for error in errors if error.entity_id == entity_id)
assert error_doc2.is_resolved
assert error_entity.is_resolved
def test_mock_connector_checkpoint_recovery(
mock_server_client: httpx.Client,
vespa_client: vespa_fixture,
admin_user: DATestUser,
) -> None:
"""Test that checkpointing works correctly when an unhandled exception occurs
and that subsequent runs pick up from the last successful checkpoint."""
# Create test documents
# Create 100 docs for first batch, this is needed to get past the
# `_NUM_DOCS_INDEXED_TO_BE_VALID_CHECKPOINT` logic in `get_latest_valid_checkpoint`.
docs_batch_1 = [create_test_document() for _ in range(100)]
doc2 = create_test_document()
doc3 = create_test_document()
# Set up mock server behavior for initial run:
# - First yield: 100 docs with checkpoint1
# - Second yield: doc2 with checkpoint2
# - Third yield: unhandled exception
response = mock_server_client.post(
"/set-behavior",
json=[
{
"documents": [doc.model_dump(mode="json") for doc in docs_batch_1],
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=True
).model_dump(mode="json"),
"failures": [],
},
{
"documents": [doc2.model_dump(mode="json")],
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=True
).model_dump(mode="json"),
"failures": [],
},
{
"documents": [],
# should never hit this, unhandled exception happens first
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=False
).model_dump(mode="json"),
"failures": [],
"unhandled_exception": "Simulated unhandled error",
},
],
)
assert response.status_code == 200
# Create CC Pair and run initial indexing attempt
cc_pair = CCPairManager.create_from_scratch(
name=f"mock-connector-checkpoint-{uuid.uuid4()}",
source=DocumentSource.MOCK_CONNECTOR,
input_type=InputType.POLL,
connector_specific_config={
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
},
user_performing_action=admin_user,
)
# Wait for first index attempt to complete
initial_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
IndexAttemptManager.wait_for_index_attempt_completion(
index_attempt_id=initial_index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
# validate status
finished_index_attempt = IndexAttemptManager.get_index_attempt_by_id(
index_attempt_id=initial_index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
assert finished_index_attempt.status == IndexingStatus.FAILED
# Verify initial state: both docs should be indexed
with get_session_context_manager() as db_session:
documents = DocumentManager.fetch_documents_for_cc_pair(
cc_pair_id=cc_pair.id,
db_session=db_session,
vespa_client=vespa_client,
)
assert len(documents) == 101 # 100 docs from first batch + doc2
document_ids = {doc.id for doc in documents}
assert doc2.id in document_ids
assert all(doc.id in document_ids for doc in docs_batch_1)
# Get the checkpoints that were sent to the mock server
response = mock_server_client.get("/get-checkpoints")
assert response.status_code == 200
initial_checkpoints = response.json()
# Verify we got the expected checkpoints in order
assert len(initial_checkpoints) > 0
assert (
initial_checkpoints[0]["checkpoint_content"] == {}
) # Initial empty checkpoint
assert initial_checkpoints[1]["checkpoint_content"] == {}
assert initial_checkpoints[2]["checkpoint_content"] == {}
# Reset the mock server for the next run
response = mock_server_client.post("/reset")
assert response.status_code == 200
# Set up mock server behavior for recovery run - should succeed fully this time
response = mock_server_client.post(
"/set-behavior",
json=[
{
"documents": [doc3.model_dump(mode="json")],
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=False
).model_dump(mode="json"),
"failures": [],
}
],
)
assert response.status_code == 200
# Trigger another indexing attempt
CCPairManager.run_once(
cc_pair, from_beginning=False, user_performing_action=admin_user
)
recovery_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
cc_pair_id=cc_pair.id,
index_attempts_to_ignore=[initial_index_attempt.id],
user_performing_action=admin_user,
)
IndexAttemptManager.wait_for_index_attempt_completion(
index_attempt_id=recovery_index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
# validate status
finished_recovery_attempt = IndexAttemptManager.get_index_attempt_by_id(
index_attempt_id=recovery_index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
assert finished_recovery_attempt.status == IndexingStatus.SUCCESS
# Verify results
with get_session_context_manager() as db_session:
documents = DocumentManager.fetch_documents_for_cc_pair(
cc_pair_id=cc_pair.id,
db_session=db_session,
vespa_client=vespa_client,
)
assert len(documents) == 102 # 100 docs from first batch + doc2 + doc3
document_ids = {doc.id for doc in documents}
assert doc3.id in document_ids
assert doc2.id in document_ids
assert all(doc.id in document_ids for doc in docs_batch_1)
# Get the checkpoints from the recovery run
response = mock_server_client.get("/get-checkpoints")
assert response.status_code == 200
recovery_checkpoints = response.json()
# Verify the recovery run started from the last successful checkpoint
assert len(recovery_checkpoints) == 1
assert recovery_checkpoints[0]["checkpoint_content"] == {}

View File

@@ -61,6 +61,7 @@ services:
# Other services
- POSTGRES_HOST=relational_db
- POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-}
- POSTGRES_USE_NULL_POOL=${POSTGRES_USE_NULL_POOL:-}
- VESPA_HOST=index
- REDIS_HOST=cache
- WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose
@@ -97,6 +98,9 @@ services:
- LINEAR_CLIENT_ID=${LINEAR_CLIENT_ID:-}
- LINEAR_CLIENT_SECRET=${LINEAR_CLIENT_SECRET:-}
# Demo purposes
- MOCK_CONNECTOR_FILE_PATH=${MOCK_CONNECTOR_FILE_PATH:-}
# Analytics Configs
- SENTRY_DSN=${SENTRY_DSN:-}
@@ -171,6 +175,7 @@ services:
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-}
- POSTGRES_DB=${POSTGRES_DB:-}
- POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-}
- POSTGRES_USE_NULL_POOL=${POSTGRES_USE_NULL_POOL:-}
- VESPA_HOST=index
- REDIS_HOST=cache
- WEB_DOMAIN=${WEB_DOMAIN:-} # For frontend redirect auth purpose for OAuth2 connectors

View File

@@ -23,7 +23,7 @@ export default defineConfig({
viewport: { width: 1280, height: 720 },
storageState: "admin_auth.json",
},
testIgnore: ["**/codeUtils.test.ts"],
testIgnore: ["**/codeUtils.test.ts", "**/chat/**/*.spec.ts"],
},
],
});

View File

@@ -0,0 +1,141 @@
import { Modal } from "@/components/Modal";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table";
import { IndexAttemptError } from "./types";
import { localizeAndPrettify } from "@/lib/time";
import { Button } from "@/components/ui/button";
import { useState } from "react";
import { PageSelector } from "@/components/PageSelector";
interface IndexAttemptErrorsModalProps {
errors: {
items: IndexAttemptError[];
total_items: number;
};
onClose: () => void;
onResolveAll: () => void;
isResolvingErrors?: boolean;
onPageChange: (page: number) => void;
currentPage: number;
pageSize?: number;
}
const DEFAULT_PAGE_SIZE = 10;
export default function IndexAttemptErrorsModal({
errors,
onClose,
onResolveAll,
isResolvingErrors = false,
onPageChange,
currentPage,
pageSize = DEFAULT_PAGE_SIZE,
}: IndexAttemptErrorsModalProps) {
const totalPages = Math.ceil(errors.total_items / pageSize);
const hasUnresolvedErrors = errors.items.some((error) => !error.is_resolved);
return (
<Modal title="Indexing Errors" onOutsideClick={onClose} width="max-w-6xl">
<div className="flex flex-col gap-4">
<div className="flex flex-col gap-2">
{isResolvingErrors ? (
<div className="text-sm text-text-default">
Currently attempting to resolve all errors by performing a full
re-index. This may take some time to complete.
</div>
) : (
<>
<div className="text-sm text-text-default">
Below are the errors encountered during indexing. Each row
represents a failed document or entity.
</div>
<div className="text-sm text-text-default">
Click the button below to kick off a full re-index to try and
resolve these errors. This full re-index may take much longer
than a normal update.
</div>
</>
)}
</div>
<Table>
<TableHeader>
<TableRow>
<TableHead>Time</TableHead>
<TableHead>Document ID</TableHead>
<TableHead className="w-1/2">Error Message</TableHead>
<TableHead>Status</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{errors.items.map((error) => (
<TableRow key={error.id}>
<TableCell>{localizeAndPrettify(error.time_created)}</TableCell>
<TableCell>
{error.document_link ? (
<a
href={error.document_link}
target="_blank"
rel="noopener noreferrer"
className="text-link hover:underline"
>
{error.document_id || error.entity_id || "Unknown"}
</a>
) : (
error.document_id || error.entity_id || "Unknown"
)}
</TableCell>
<TableCell className="whitespace-normal">
{error.failure_message}
</TableCell>
<TableCell>
<span
className={`px-2 py-1 rounded text-xs ${
error.is_resolved
? "bg-green-100 text-green-800"
: "bg-red-100 text-red-800"
}`}
>
{error.is_resolved ? "Resolved" : "Unresolved"}
</span>
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
<div className="mt-4">
{totalPages > 1 && (
<div className="flex-1 flex justify-center mb-2">
<PageSelector
totalPages={totalPages}
currentPage={currentPage + 1}
onPageChange={(page) => onPageChange(page - 1)}
/>
</div>
)}
<div className="flex w-full">
<div className="flex gap-2 ml-auto">
{hasUnresolvedErrors && !isResolvingErrors && (
<Button
onClick={onResolveAll}
variant="default"
className="ml-4 whitespace-nowrap"
>
Resolve All
</Button>
)}
</div>
</div>
</div>
</div>
</Modal>
);
}

View File

@@ -34,38 +34,26 @@ import usePaginatedFetch from "@/hooks/usePaginatedFetch";
const ITEMS_PER_PAGE = 8;
const PAGES_PER_BATCH = 8;
export function IndexingAttemptsTable({ ccPair }: { ccPair: CCPairFullInfo }) {
export interface IndexingAttemptsTableProps {
ccPair: CCPairFullInfo;
indexAttempts: IndexAttemptSnapshot[];
currentPage: number;
totalPages: number;
onPageChange: (page: number) => void;
}
export function IndexingAttemptsTable({
ccPair,
indexAttempts,
currentPage,
totalPages,
onPageChange,
}: IndexingAttemptsTableProps) {
const [indexAttemptTracePopupId, setIndexAttemptTracePopupId] = useState<
number | null
>(null);
const {
currentPageData: pageOfIndexAttempts,
isLoading,
error,
currentPage,
totalPages,
goToPage,
} = usePaginatedFetch<IndexAttemptSnapshot>({
itemsPerPage: ITEMS_PER_PAGE,
pagesPerBatch: PAGES_PER_BATCH,
endpoint: `${buildCCPairInfoUrl(ccPair.id)}/index-attempts`,
});
if (isLoading || !pageOfIndexAttempts) {
return <ThreeDotsLoader />;
}
if (error) {
return (
<ErrorCallout
errorTitle={`Failed to fetch info on Connector with ID ${ccPair.id}`}
errorMsg={error?.toString() || "Unknown error"}
/>
);
}
if (!pageOfIndexAttempts?.length) {
if (!indexAttempts?.length) {
return (
<Callout
className="mt-4"
@@ -78,7 +66,7 @@ export function IndexingAttemptsTable({ ccPair }: { ccPair: CCPairFullInfo }) {
);
}
const indexAttemptToDisplayTraceFor = pageOfIndexAttempts?.find(
const indexAttemptToDisplayTraceFor = indexAttempts?.find(
(indexAttempt) => indexAttempt.id === indexAttemptTracePopupId
);
@@ -119,7 +107,7 @@ export function IndexingAttemptsTable({ ccPair }: { ccPair: CCPairFullInfo }) {
</TableRow>
</TableHeader>
<TableBody>
{pageOfIndexAttempts.map((indexAttempt) => {
{indexAttempts.map((indexAttempt) => {
const docsPerMinute =
getDocsProcessedPerMinute(indexAttempt)?.toFixed(2);
return (
@@ -161,18 +149,6 @@ export function IndexingAttemptsTable({ ccPair }: { ccPair: CCPairFullInfo }) {
<TableCell>{indexAttempt.total_docs_indexed}</TableCell>
<TableCell>
<div>
{indexAttempt.error_count > 0 && (
<Link
className="cursor-pointer my-auto"
href={`/admin/indexing/${indexAttempt.id}`}
>
<Text className="flex flex-wrap text-link whitespace-normal">
<SearchIcon />
&nbsp;View Errors
</Text>
</Link>
)}
{indexAttempt.status === "success" && (
<Text className="flex flex-wrap whitespace-normal">
{"-"}
@@ -209,7 +185,7 @@ export function IndexingAttemptsTable({ ccPair }: { ccPair: CCPairFullInfo }) {
<PageSelector
totalPages={totalPages}
currentPage={currentPage}
onPageChange={goToPage}
onPageChange={onPageChange}
/>
</div>
</div>

View File

@@ -1,11 +1,9 @@
"use client";
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
import { runConnector } from "@/lib/connector";
import { Button } from "@/components/ui/button";
import Text from "@/components/ui/text";
import { mutate } from "swr";
import { buildCCPairInfoUrl } from "./lib";
import { triggerIndexing } from "./lib";
import { useState } from "react";
import { Modal } from "@/components/Modal";
import { Separator } from "@/components/ui/separator";
@@ -23,26 +21,6 @@ function ReIndexPopup({
setPopup: (popupSpec: PopupSpec | null) => void;
hide: () => void;
}) {
async function triggerIndexing(fromBeginning: boolean) {
const errorMsg = await runConnector(
connectorId,
[credentialId],
fromBeginning
);
if (errorMsg) {
setPopup({
message: errorMsg,
type: "error",
});
} else {
setPopup({
message: "Triggered connector run",
type: "success",
});
}
mutate(buildCCPairInfoUrl(ccPairId));
}
return (
<Modal title="Run Indexing" onOutsideClick={hide}>
<div>
@@ -50,7 +28,13 @@ function ReIndexPopup({
variant="submit"
className="ml-auto"
onClick={() => {
triggerIndexing(false);
triggerIndexing(
false,
connectorId,
credentialId,
ccPairId,
setPopup
);
hide();
}}
>
@@ -68,7 +52,13 @@ function ReIndexPopup({
variant="submit"
className="ml-auto"
onClick={() => {
triggerIndexing(true);
triggerIndexing(
true,
connectorId,
credentialId,
ccPairId,
setPopup
);
hide();
}}
>

View File

@@ -1,4 +1,7 @@
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { runConnector } from "@/lib/connector";
import { ValidSources } from "@/lib/types";
import { mutate } from "swr";
export function buildCCPairInfoUrl(ccPairId: string | number) {
return `/api/manage/admin/cc-pair/${ccPairId}`;
@@ -11,3 +14,29 @@ export function buildSimilarCredentialInfoURL(
const base = `/api/manage/admin/similar-credentials/${source_type}`;
return get_editable ? `${base}?get_editable=True` : base;
}
export async function triggerIndexing(
fromBeginning: boolean,
connectorId: number,
credentialId: number,
ccPairId: number,
setPopup: (popupSpec: PopupSpec | null) => void
) {
const errorMsg = await runConnector(
connectorId,
[credentialId],
fromBeginning
);
if (errorMsg) {
setPopup({
message: errorMsg,
type: "error",
});
} else {
setPopup({
message: "Triggered connector run",
type: "success",
});
}
mutate(buildCCPairInfoUrl(ccPairId));
}

View File

@@ -25,13 +25,24 @@ import DeletionErrorStatus from "./DeletionErrorStatus";
import { IndexingAttemptsTable } from "./IndexingAttemptsTable";
import { ModifyStatusButtonCluster } from "./ModifyStatusButtonCluster";
import { ReIndexButton } from "./ReIndexButton";
import { buildCCPairInfoUrl } from "./lib";
import { CCPairFullInfo, ConnectorCredentialPairStatus } from "./types";
import { buildCCPairInfoUrl, triggerIndexing } from "./lib";
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
import {
CCPairFullInfo,
ConnectorCredentialPairStatus,
IndexAttemptError,
PaginatedIndexAttemptErrors,
} from "./types";
import { EditableStringFieldDisplay } from "@/components/EditableStringFieldDisplay";
import { Button } from "@/components/ui/button";
import EditPropertyModal from "@/components/modals/EditPropertyModal";
import * as Yup from "yup";
import { AlertCircle } from "lucide-react";
import IndexAttemptErrorsModal from "./IndexAttemptErrorsModal";
import usePaginatedFetch from "@/hooks/usePaginatedFetch";
import { IndexAttemptSnapshot } from "@/lib/types";
import { Spinner } from "@/components/Spinner";
// synchronize these validations with the SQLAlchemy connector class until we have a
// centralized schema for both frontend and backend
@@ -51,43 +62,99 @@ const PruneFrequencySchema = Yup.object().shape({
.required("Property value is required"),
});
const ITEMS_PER_PAGE = 8;
const PAGES_PER_BATCH = 8;
function Main({ ccPairId }: { ccPairId: number }) {
const router = useRouter(); // Initialize the router
const router = useRouter();
const {
data: ccPair,
isLoading,
error,
isLoading: isLoadingCCPair,
error: ccPairError,
} = useSWR<CCPairFullInfo>(
buildCCPairInfoUrl(ccPairId),
errorHandlingFetcher,
{ refreshInterval: 5000 } // 5 seconds
);
const {
currentPageData: indexAttempts,
isLoading: isLoadingIndexAttempts,
currentPage,
totalPages,
goToPage,
} = usePaginatedFetch<IndexAttemptSnapshot>({
itemsPerPage: ITEMS_PER_PAGE,
pagesPerBatch: PAGES_PER_BATCH,
endpoint: `${buildCCPairInfoUrl(ccPairId)}/index-attempts`,
});
const {
currentPageData: indexAttemptErrorsPage,
currentPage: errorsCurrentPage,
totalPages: errorsTotalPages,
goToPage: goToErrorsPage,
} = usePaginatedFetch<IndexAttemptError>({
itemsPerPage: 10,
pagesPerBatch: 1,
endpoint: `/api/manage/admin/cc-pair/${ccPairId}/errors`,
});
const indexAttemptErrors = indexAttemptErrorsPage
? {
items: indexAttemptErrorsPage,
total_items:
errorsCurrentPage === errorsTotalPages &&
indexAttemptErrorsPage.length === 0
? 0
: errorsTotalPages * 10,
}
: null;
const [hasLoadedOnce, setHasLoadedOnce] = useState(false);
const [editingRefreshFrequency, setEditingRefreshFrequency] = useState(false);
const [editingPruningFrequency, setEditingPruningFrequency] = useState(false);
const [showIndexAttemptErrors, setShowIndexAttemptErrors] = useState(false);
const [showIsResolvingKickoffLoader, setShowIsResolvingKickoffLoader] =
useState(false);
const { popup, setPopup } = usePopup();
const latestIndexAttempt = indexAttempts?.[0];
const isResolvingErrors =
(latestIndexAttempt?.status === "in_progress" ||
latestIndexAttempt?.status === "not_started") &&
latestIndexAttempt?.from_beginning &&
// if there are errors in the latest index attempt, we don't want to show the loader
!indexAttemptErrors?.items?.some(
(error) => error.index_attempt_id === latestIndexAttempt?.id
);
const finishConnectorDeletion = useCallback(() => {
router.push("/admin/indexing/status?message=connector-deleted");
}, [router]);
useEffect(() => {
if (isLoading) {
if (isLoadingCCPair) {
return;
}
if (ccPair && !error) {
if (ccPair && !ccPairError) {
setHasLoadedOnce(true);
}
if (
(hasLoadedOnce && (error || !ccPair)) ||
(hasLoadedOnce && (ccPairError || !ccPair)) ||
(ccPair?.status === ConnectorCredentialPairStatus.DELETING &&
!ccPair.connector)
) {
finishConnectorDeletion();
}
}, [isLoading, ccPair, error, hasLoadedOnce, finishConnectorDeletion]);
}, [
isLoadingCCPair,
ccPair,
ccPairError,
hasLoadedOnce,
finishConnectorDeletion,
]);
const handleUpdateName = async (newName: string) => {
try {
@@ -191,15 +258,19 @@ function Main({ ccPairId }: { ccPairId: number }) {
}
};
if (isLoading) {
if (isLoadingCCPair || isLoadingIndexAttempts) {
return <ThreeDotsLoader />;
}
if (!ccPair || (!hasLoadedOnce && error)) {
if (!ccPair || (!hasLoadedOnce && ccPairError)) {
return (
<ErrorCallout
errorTitle={`Failed to fetch info on Connector with ID ${ccPairId}`}
errorMsg={error?.info?.detail || error?.toString() || "Unknown error"}
errorMsg={
ccPairError?.info?.detail ||
ccPairError?.toString() ||
"Unknown error"
}
/>
);
}
@@ -219,6 +290,7 @@ function Main({ ccPairId }: { ccPairId: number }) {
return (
<>
{popup}
{showIsResolvingKickoffLoader && !isResolvingErrors && <Spinner />}
{editingRefreshFrequency && (
<EditPropertyModal
@@ -244,6 +316,32 @@ function Main({ ccPairId }: { ccPairId: number }) {
/>
)}
{showIndexAttemptErrors && indexAttemptErrors && (
<IndexAttemptErrorsModal
errors={indexAttemptErrors}
onClose={() => setShowIndexAttemptErrors(false)}
onResolveAll={async () => {
setShowIndexAttemptErrors(false);
setShowIsResolvingKickoffLoader(true);
await triggerIndexing(
true,
ccPair.connector.id,
ccPair.credential.id,
ccPair.id,
setPopup
);
// show the loader for a max of 10 seconds
setTimeout(() => {
setShowIsResolvingKickoffLoader(false);
}, 10000);
}}
isResolvingErrors={isResolvingErrors}
onPageChange={goToErrorsPage}
currentPage={errorsCurrentPage}
/>
)}
<BackButton
behaviorOverride={() => router.push("/admin/indexing/status")}
/>
@@ -342,13 +440,46 @@ function Main({ ccPairId }: { ccPairId: number }) {
/>
)}
{/* NOTE: no divider / title here for `ConfigDisplay` since it is optional and we need
to render these conditionally.*/}
<div className="mt-6">
<div className="flex">
<Title>Indexing Attempts</Title>
</div>
<IndexingAttemptsTable ccPair={ccPair} />
{indexAttemptErrors && indexAttemptErrors.total_items > 0 && (
<Alert className="border-alert bg-yellow-50 my-2">
<AlertCircle className="h-4 w-4 text-yellow-700" />
<AlertTitle className="text-yellow-950 font-semibold">
Some documents failed to index
</AlertTitle>
<AlertDescription className="text-yellow-900">
{isResolvingErrors ? (
<span>
<span className="text-sm text-yellow-700 animate-pulse">
Resolving failures
</span>
</span>
) : (
<>
We ran into some issues while processing some documents.{" "}
<b
className="text-link cursor-pointer"
onClick={() => setShowIndexAttemptErrors(true)}
>
View details.
</b>
</>
)}
</AlertDescription>
</Alert>
)}
{indexAttempts && (
<IndexingAttemptsTable
ccPair={ccPair}
indexAttempts={indexAttempts}
currentPage={currentPage}
totalPages={totalPages}
onPageChange={goToPage}
/>
)}
</div>
<Separator />
<div className="flex mt-4">

View File

@@ -37,3 +37,27 @@ export interface PaginatedIndexAttempts {
page: number;
total_pages: number;
}
export interface IndexAttemptError {
id: number;
connector_credential_pair_id: number;
document_id: string | null;
document_link: string | null;
entity_id: string | null;
failed_time_range_start: string | null;
failed_time_range_end: string | null;
failure_message: string;
is_resolved: boolean;
time_created: string;
index_attempt_id: number;
}
export interface PaginatedIndexAttemptErrors {
items: IndexAttemptError[];
total_items: number;
}

View File

@@ -1,189 +0,0 @@
"use client";
import { Modal } from "@/components/Modal";
import { PageSelector } from "@/components/PageSelector";
import { CheckmarkIcon, CopyIcon } from "@/components/icons/icons";
import { localizeAndPrettify } from "@/lib/time";
import {
Table,
TableBody,
TableCell,
TableHead,
TableRow,
} from "@/components/ui/table";
import Text from "@/components/ui/text";
import { useState } from "react";
import { IndexAttemptError } from "./types";
import { TableHeader } from "@/components/ui/table";
const NUM_IN_PAGE = 8;
export function CustomModal({
isVisible,
onClose,
title,
content,
showCopyButton = false,
}: {
isVisible: boolean;
onClose: () => void;
title: string;
content: string;
showCopyButton?: boolean;
}) {
const [copyClicked, setCopyClicked] = useState(false);
if (!isVisible) return null;
return (
<Modal
width="w-4/6"
className="h-5/6 overflow-y-hidden flex flex-col"
title={title}
onOutsideClick={onClose}
>
<div className="overflow-y-auto mb-6">
{showCopyButton && (
<div className="mb-6">
{!copyClicked ? (
<div
onClick={() => {
navigator.clipboard.writeText(content);
setCopyClicked(true);
setTimeout(() => setCopyClicked(false), 2000);
}}
className="flex w-fit cursor-pointer hover:bg-accent-background p-2 border-border border rounded"
>
Copy full content
<CopyIcon className="ml-2 my-auto" />
</div>
) : (
<div className="flex w-fit hover:bg-accent-background p-2 border-border border rounded cursor-default">
Copied to clipboard
<CheckmarkIcon
className="my-auto ml-2 flex flex-shrink-0 text-success"
size={16}
/>
</div>
)}
</div>
)}
<div className="whitespace-pre-wrap">{content}</div>
</div>
</Modal>
);
}
export function IndexAttemptErrorsTable({
indexAttemptErrors,
}: {
indexAttemptErrors: IndexAttemptError[];
}) {
const [page, setPage] = useState(1);
const [modalData, setModalData] = useState<{
id: number | null;
title: string;
content: string;
} | null>(null);
const closeModal = () => setModalData(null);
return (
<>
{modalData && (
<CustomModal
isVisible={!!modalData}
onClose={closeModal}
title={modalData.title}
content={modalData.content}
showCopyButton
/>
)}
<Table>
<TableHeader>
<TableRow>
<TableHead>Timestamp</TableHead>
<TableHead>Batch Number</TableHead>
<TableHead>Document Summaries</TableHead>
<TableHead>Error Message</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{indexAttemptErrors
.slice(NUM_IN_PAGE * (page - 1), NUM_IN_PAGE * page)
.map((indexAttemptError) => {
return (
<TableRow key={indexAttemptError.id}>
<TableCell>
{indexAttemptError.time_created
? localizeAndPrettify(indexAttemptError.time_created)
: "-"}
</TableCell>
<TableCell>{indexAttemptError.batch_number}</TableCell>
<TableCell>
{indexAttemptError.doc_summaries && (
<div
onClick={() =>
setModalData({
id: indexAttemptError.id,
title: "Document Summaries",
content: JSON.stringify(
indexAttemptError.doc_summaries,
null,
2
),
})
}
className="mt-2 text-link cursor-pointer select-none"
>
View Document Summaries
</div>
)}
</TableCell>
<TableCell>
<div>
<Text className="flex flex-wrap whitespace-normal">
{indexAttemptError.error_msg || "-"}
</Text>
{indexAttemptError.traceback && (
<div
onClick={() =>
setModalData({
id: indexAttemptError.id,
title: "Exception Traceback",
content: indexAttemptError.traceback!,
})
}
className="mt-2 text-link cursor-pointer select-none"
>
View Full Trace
</div>
)}
</div>
</TableCell>
</TableRow>
);
})}
</TableBody>
</Table>
{indexAttemptErrors.length > NUM_IN_PAGE && (
<div className="mt-3 flex">
<div className="mx-auto">
<PageSelector
totalPages={Math.ceil(indexAttemptErrors.length / NUM_IN_PAGE)}
currentPage={page}
onPageChange={(newPage) => {
setPage(newPage);
window.scrollTo({
top: 0,
left: 0,
behavior: "smooth",
});
}}
/>
</div>
</div>
)}
</>
);
}

View File

@@ -1,3 +0,0 @@
export function buildIndexingErrorsUrl(id: string | number) {
return `/api/manage/admin/indexing-errors/${id}`;
}

View File

@@ -1,59 +0,0 @@
"use client";
import { use } from "react";
import { BackButton } from "@/components/BackButton";
import { ErrorCallout } from "@/components/ErrorCallout";
import { ThreeDotsLoader } from "@/components/Loading";
import { errorHandlingFetcher } from "@/lib/fetcher";
import Title from "@/components/ui/title";
import useSWR from "swr";
import { IndexAttemptErrorsTable } from "./IndexAttemptErrorsTable";
import { buildIndexingErrorsUrl } from "./lib";
import { IndexAttemptError } from "./types";
function Main({ id }: { id: number }) {
const {
data: indexAttemptErrors,
isLoading,
error,
} = useSWR<IndexAttemptError[]>(
buildIndexingErrorsUrl(id),
errorHandlingFetcher
);
if (isLoading) {
return <ThreeDotsLoader />;
}
if (error || !indexAttemptErrors) {
return (
<ErrorCallout
errorTitle={`Failed to fetch errors for attempt ID ${id}`}
errorMsg={error?.info?.detail || error.toString()}
/>
);
}
return (
<>
<BackButton />
<div className="mt-6">
<div className="flex">
<Title>Indexing Errors for Attempt {id}</Title>
</div>
<IndexAttemptErrorsTable indexAttemptErrors={indexAttemptErrors} />
</div>
</>
);
}
export default function Page(props: { params: Promise<{ id: string }> }) {
const params = use(props.params);
const id = parseInt(params.id);
return (
<div className="mx-auto container">
<Main id={id} />
</div>
);
}

View File

@@ -1,15 +0,0 @@
export interface IndexAttemptError {
id: number;
index_attempt_id: number;
batch_number: number;
doc_summaries: DocumentErrorSummary[];
error_msg: string;
traceback: string;
time_created: string;
}
export interface DocumentErrorSummary {
id: string;
semantic_id: string;
section_link: string;
}

View File

@@ -41,25 +41,11 @@ export function IndexAttemptStatus({
badge = icon;
}
} else if (status === "completed_with_errors") {
const icon = (
badge = (
<Badge variant="secondary" icon={FiAlertTriangle}>
Completed with errors
</Badge>
);
badge = (
<HoverPopup
mainContent={<div className="cursor-pointer">{icon}</div>}
popupContent={
<div className="w-64 p-2 break-words overflow-hidden whitespace-normal">
The indexing attempt completed, but some errors were encountered
during the run.
<br />
<br />
Click View Errors for more details.
</div>
}
/>
);
} else if (status === "success") {
badge = (
<Badge variant="success" icon={FiCheckCircle}>

View File

@@ -7,12 +7,13 @@ import {
} from "@/lib/types";
import { ChatSessionMinimal } from "@/app/ee/admin/performance/usage/types";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { PaginatedIndexAttemptErrors } from "@/app/admin/connector/[ccPairId]/types";
type PaginatedType =
| IndexAttemptSnapshot
| AcceptedUserSnapshot
| InvitedUserSnapshot
| ChatSessionMinimal;
// Any type that has an id property
type PaginatedType = {
id: number | string;
[key: string]: any;
};
interface PaginatedApiResponse<T extends PaginatedType> {
items: T[];

View File

@@ -1232,6 +1232,7 @@ export interface ConnectorBase<T> {
indexing_start: Date | null;
access_type: string;
groups?: number[];
from_beginning?: boolean;
}
export interface Connector<T> extends ConnectorBase<T> {
@@ -1253,6 +1254,7 @@ export interface ConnectorSnapshot {
indexing_start: number | null;
time_created: string;
time_updated: string;
from_beginning?: boolean;
}
export interface WebConfig {

View File

@@ -335,6 +335,13 @@ export const SOURCE_METADATA_MAP: SourceMap = {
displayName: "Not Applicable",
category: SourceCategory.Other,
},
// Just so integration tests don't crash the UI
mock_connector: {
icon: GlobeIcon,
displayName: "Mock Connector",
category: SourceCategory.Other,
},
} as SourceMap;
function fillSourceMetadata(

View File

@@ -123,6 +123,7 @@ export interface FailedConnectorIndexingStatus {
export interface IndexAttemptSnapshot {
id: number;
status: ValidStatuses | null;
from_beginning: boolean;
new_docs_indexed: number;
docs_removed_from_index: number;
total_docs_indexed: number;