1
0
forked from github/onyx

feat: connector indexing decoupling (#4893)

* WIP

* renamed and moved tasks (WIP)

* minio migration

* bug fixes and finally add document batch storage

* WIP: can suceed but status is error

* WIP

* import fixes

* working v1 of decoupled

* catastrophe handling

* refactor

* remove unused db session in prep for new approach

* renaming and docstrings (untested)

* renames

* WIP with no more indexing fences

* robustness improvements

* clean up rebase

* migration and salesforce rate limits

* minor tweaks

* test fix

* connector pausing behavior

* correct checkpoint resumption logic

* cleanups in docfetching

* add heartbeat file

* update template jsonc

* deployment fixes

* fix vespa httpx pool

* error handling

* cosmetic fixes

* dumb

* logging improvements and non checkpointed connector fixes

* didnt save

* misc fixes

* fix import

* fix deletion of old files

* add in attempt prefix

* fix attempt prefix

* tiny log improvement

* minor changes

* fixed resumption behavior

* passing int tests

* fix unit test

* fixed unit tests

* trying timeout bump to see if int tests pass

* trying timeout bump to see if int tests pass

* fix autodiscovery

* helm chart fixes

* helm and logging
This commit is contained in:
Evan Lohn
2025-07-21 20:33:25 -07:00
committed by GitHub
parent 1f3cc9ed6e
commit bd06147d26
107 changed files with 4976 additions and 2601 deletions

View File

@@ -46,7 +46,8 @@
"Celery primary",
"Celery light",
"Celery heavy",
"Celery indexing",
"Celery docfetching",
"Celery docprocessing",
"Celery user files indexing",
"Celery beat",
"Celery monitoring"
@@ -226,35 +227,66 @@
"consoleTitle": "Celery heavy Console"
},
{
"name": "Celery indexing",
"name": "Celery docfetching",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"ENABLE_MULTIPASS_INDEXING": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.indexing",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=indexing@%n",
"-Q",
"connector_indexing"
"-A",
"onyx.background.celery.versioned_apps.docfetching",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=docfetching@%n",
"-Q",
"connector_doc_fetching,user_files_indexing"
],
"presentation": {
"group": "2"
"group": "2"
},
"consoleTitle": "Celery indexing Console"
},
"consoleTitle": "Celery docfetching Console",
"justMyCode": false
},
{
"name": "Celery docprocessing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"ENABLE_MULTIPASS_INDEXING": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.docprocessing",
"worker",
"--pool=threads",
"--concurrency=6",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=docprocessing@%n",
"-Q",
"docprocessing"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery docprocessing Console",
"justMyCode": false
},
{
"name": "Celery monitoring",
"type": "debugpy",
@@ -303,35 +335,6 @@
},
"consoleTitle": "Celery beat Console"
},
{
"name": "Celery user files indexing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.indexing",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=user_files_indexing@%n",
"-Q",
"user_files_indexing"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery user files indexing Console"
},
{
"name": "Pytest",
"consoleName": "Pytest",

View File

@@ -96,7 +96,7 @@ def get_google_drive_documents_from_database() -> list[dict]:
result = bind.execute(
sa.text(
"""
SELECT d.id, cc.id as cc_pair_id
SELECT d.id
FROM document d
JOIN document_by_connector_credential_pair dcc ON d.id = dcc.id
JOIN connector_credential_pair cc ON dcc.connector_id = cc.connector_id
@@ -109,7 +109,7 @@ def get_google_drive_documents_from_database() -> list[dict]:
documents = []
for row in result:
documents.append({"document_id": row.id, "cc_pair_id": row.cc_pair_id})
documents.append({"document_id": row.id})
return documents

View File

@@ -0,0 +1,115 @@
"""add_indexing_coordination
Revision ID: 2f95e36923e6
Revises: 0816326d83aa
Create Date: 2025-07-10 16:17:57.762182
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2f95e36923e6"
down_revision = "0816326d83aa"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add database-based coordination fields (replacing Redis fencing)
op.add_column(
"index_attempt", sa.Column("celery_task_id", sa.String(), nullable=True)
)
op.add_column(
"index_attempt",
sa.Column(
"cancellation_requested",
sa.Boolean(),
nullable=False,
server_default="false",
),
)
# Add batch coordination fields (replacing FileStore state)
op.add_column(
"index_attempt", sa.Column("total_batches", sa.Integer(), nullable=True)
)
op.add_column(
"index_attempt",
sa.Column(
"completed_batches", sa.Integer(), nullable=False, server_default="0"
),
)
op.add_column(
"index_attempt",
sa.Column(
"total_failures_batch_level",
sa.Integer(),
nullable=False,
server_default="0",
),
)
op.add_column(
"index_attempt",
sa.Column("total_chunks", sa.Integer(), nullable=False, server_default="0"),
)
# Progress tracking for stall detection
op.add_column(
"index_attempt",
sa.Column("last_progress_time", sa.DateTime(timezone=True), nullable=True),
)
op.add_column(
"index_attempt",
sa.Column(
"last_batches_completed_count",
sa.Integer(),
nullable=False,
server_default="0",
),
)
# Heartbeat tracking for worker liveness detection
op.add_column(
"index_attempt",
sa.Column(
"heartbeat_counter", sa.Integer(), nullable=False, server_default="0"
),
)
op.add_column(
"index_attempt",
sa.Column(
"last_heartbeat_value", sa.Integer(), nullable=False, server_default="0"
),
)
op.add_column(
"index_attempt",
sa.Column("last_heartbeat_time", sa.DateTime(timezone=True), nullable=True),
)
# Add index for coordination queries
op.create_index(
"ix_index_attempt_active_coordination",
"index_attempt",
["connector_credential_pair_id", "search_settings_id", "status"],
)
def downgrade() -> None:
# Remove the new index
op.drop_index("ix_index_attempt_active_coordination", table_name="index_attempt")
# Remove the new columns
op.drop_column("index_attempt", "last_batches_completed_count")
op.drop_column("index_attempt", "last_progress_time")
op.drop_column("index_attempt", "last_heartbeat_time")
op.drop_column("index_attempt", "last_heartbeat_value")
op.drop_column("index_attempt", "heartbeat_counter")
op.drop_column("index_attempt", "total_chunks")
op.drop_column("index_attempt", "total_failures_batch_level")
op.drop_column("index_attempt", "completed_batches")
op.drop_column("index_attempt", "total_batches")
op.drop_column("index_attempt", "cancellation_requested")
op.drop_column("index_attempt", "celery_task_id")

View File

@@ -159,7 +159,7 @@ def _migrate_files_to_postgres() -> None:
# only create external store if we have files to migrate. This line
# makes it so we need to have S3/MinIO configured to run this migration.
external_store = get_s3_file_store(db_session=session)
external_store = get_s3_file_store()
for i, file_id in enumerate(files_to_migrate, 1):
print(f"Migrating file {i}/{total_files}: {file_id}")
@@ -219,7 +219,7 @@ def _migrate_files_to_external_storage() -> None:
# Get database session
bind = op.get_bind()
session = Session(bind=bind)
external_store = get_s3_file_store(db_session=session)
external_store = get_s3_file_store()
# Find all files currently stored in PostgreSQL (lobj_oid is not null)
result = session.execute(

View File

@@ -91,7 +91,7 @@ def export_query_history_task(
with get_session_with_current_tenant() as db_session:
try:
stream.seek(0)
get_default_file_store(db_session).save_file(
get_default_file_store().save_file(
content=stream,
display_name=report_name,
file_origin=FileOrigin.QUERY_HISTORY_CSV,

View File

@@ -422,7 +422,7 @@ def connector_permission_sync_generator_task(
lock: RedisLock = r.lock(
OnyxRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
+ f"_{redis_connector.cc_pair_id}",
timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT,
thread_local=False,
)

View File

@@ -383,7 +383,7 @@ def connector_external_group_sync_generator_task(
lock: RedisLock = r.lock(
OnyxRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
+ f"_{redis_connector.cc_pair_id}",
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
)

View File

@@ -114,7 +114,6 @@ def get_all_usage_reports(db_session: Session) -> list[UsageReportMetadata]:
def get_usage_report_data(
db_session: Session,
report_display_name: str,
) -> IO:
"""
@@ -128,7 +127,7 @@ def get_usage_report_data(
Returns:
The usage report data.
"""
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
# usage report may be very large, so don't load it all into memory
return file_store.read_file(
file_id=report_display_name, mode="b", use_tempfile=True

View File

@@ -134,15 +134,14 @@ def ee_fetch_settings() -> EnterpriseSettings:
def put_logo(
file: UploadFile,
is_logotype: bool = False,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_admin_user),
) -> None:
upload_logo(file=file, db_session=db_session, is_logotype=is_logotype)
upload_logo(file=file, is_logotype=is_logotype)
def fetch_logo_helper(db_session: Session) -> Response:
try:
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
onyx_file = file_store.get_file_with_mime_type(get_logo_filename())
if not onyx_file:
raise ValueError("get_onyx_file returned None!")
@@ -158,7 +157,7 @@ def fetch_logo_helper(db_session: Session) -> Response:
def fetch_logotype_helper(db_session: Session) -> Response:
try:
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
onyx_file = file_store.get_file_with_mime_type(get_logotype_filename())
if not onyx_file:
raise ValueError("get_onyx_file returned None!")

View File

@@ -6,7 +6,6 @@ from typing import IO
from fastapi import HTTPException
from fastapi import UploadFile
from sqlalchemy.orm import Session
from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload
from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
@@ -99,9 +98,7 @@ def guess_file_type(filename: str) -> str:
return "application/octet-stream"
def upload_logo(
db_session: Session, file: UploadFile | str, is_logotype: bool = False
) -> bool:
def upload_logo(file: UploadFile | str, is_logotype: bool = False) -> bool:
content: IO[Any]
if isinstance(file, str):
@@ -129,7 +126,7 @@ def upload_logo(
display_name = file.filename
file_type = file.content_type or "image/jpeg"
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
file_store.save_file(
content=content,
display_name=display_name,

View File

@@ -358,7 +358,7 @@ def get_query_history_export_status(
# If task is None, then it's possible that the task has already finished processing.
# Therefore, we should then check if the export file has already been stored inside of the file-store.
# If that *also* doesn't exist, then we can return a 404.
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
report_name = construct_query_history_report_name(request_id)
has_file = file_store.has_file(
@@ -385,7 +385,7 @@ def download_query_history_csv(
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
report_name = construct_query_history_report_name(request_id)
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
has_file = file_store.has_file(
file_id=report_name,
file_origin=FileOrigin.QUERY_HISTORY_CSV,

View File

@@ -53,7 +53,7 @@ def read_usage_report(
db_session: Session = Depends(get_session),
) -> Response:
try:
file = get_usage_report_data(db_session, report_name)
file = get_usage_report_data(report_name)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -112,7 +112,7 @@ def create_new_usage_report(
period: tuple[datetime, datetime] | None,
) -> UsageReportMetadata:
report_id = str(uuid.uuid4())
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
messages_file_id = generate_chat_messages_report(
db_session, file_store, report_id, period

View File

@@ -200,10 +200,10 @@ def _seed_enterprise_settings(seed_config: SeedConfiguration) -> None:
store_ee_settings(final_enterprise_settings)
def _seed_logo(db_session: Session, logo_path: str | None) -> None:
def _seed_logo(logo_path: str | None) -> None:
if logo_path:
logger.notice("Uploading logo")
upload_logo(db_session=db_session, file=logo_path)
upload_logo(file=logo_path)
def _seed_analytics_script(seed_config: SeedConfiguration) -> None:
@@ -245,7 +245,7 @@ def seed_db() -> None:
if seed_config.custom_tools is not None:
_seed_custom_tools(db_session, seed_config.custom_tools)
_seed_logo(db_session, seed_config.seeded_logo_path)
_seed_logo(seed_config.seeded_logo_path)
_seed_enterprise_settings(seed_config)
_seed_analytics_script(seed_config)

View File

@@ -40,6 +40,7 @@ from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import ColoredFormatter
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import PlainFormatter
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -93,7 +94,13 @@ def on_task_prerun(
kwargs: dict[str, Any] | None = None,
**other_kwargs: Any,
) -> None:
pass
# Reset any per-task logging context so that prefixes (e.g. pruning_ctx)
# from a previous task executed in the same worker process do not leak
# into the next task's log messages. This fixes incorrect [CC Pair:/Index Attempt]
# prefixes observed when a pruning task finishes and an indexing task
# runs in the same process.
LoggerContextVars.reset()
def on_task_postrun(
@@ -474,7 +481,8 @@ class TenantContextFilter(logging.Filter):
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if tenant_id:
tenant_id = tenant_id.split(TENANT_ID_PREFIX)[-1][:5]
# Match the 8 character tenant abbreviation used in OnyxLoggingAdapter
tenant_id = tenant_id.split(TENANT_ID_PREFIX)[-1][:8]
record.name = f"[t:{tenant_id}]"
else:
record.name = ""

View File

@@ -0,0 +1,102 @@
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.docfetching")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
@celeryd_init.connect
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME)
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
app_base.on_secondary_worker_init(sender, **kwargs)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_ready(sender, **kwargs)
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
base_bootsteps = app_base.get_bootsteps()
for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.docfetching",
]
)

View File

@@ -12,7 +12,7 @@ from celery.signals import worker_ready
from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -21,7 +21,7 @@ from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.indexing")
celery_app.config_from_object("onyx.background.celery.configs.docprocessing")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@@ -60,7 +60,7 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME)
# rkuo: Transient errors keep happening in the indexing watchdog threads.
# "SSL connection has been closed unexpectedly"
@@ -108,6 +108,6 @@ for bootstep in base_bootsteps:
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.indexing",
"onyx.background.celery.tasks.docprocessing",
]
)

View File

@@ -116,6 +116,6 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.user_file_folder_sync",
"onyx.background.celery.tasks.indexing",
"onyx.background.celery.tasks.docprocessing",
]
)

View File

@@ -9,6 +9,7 @@ from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.exceptions import WorkerShutdown
from celery.result import AsyncResult
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
@@ -18,9 +19,6 @@ from redis.lock import Lock as RedisLock
import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import celery_is_worker_primary
from onyx.background.celery.tasks.indexing.utils import (
get_unfenced_index_attempt_ids,
)
from onyx.background.celery.tasks.vespa.document_sync import reset_document_sync
from onyx.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from onyx.configs.constants import OnyxRedisConstants
@@ -30,6 +28,7 @@ from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import mark_attempt_canceled
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
@@ -168,24 +167,50 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
RedisConnectorExternalGroupSync.reset_all(r)
# mark orphaned index attempts as failed
# This uses database coordination instead of Redis fencing
with get_session_with_current_tenant() as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
# Get potentially orphaned attempts (those with active status and task IDs)
potentially_orphaned_ids = IndexingCoordination.get_orphaned_index_attempt_ids(
db_session
)
for attempt_id in potentially_orphaned_ids:
attempt = get_index_attempt(db_session, attempt_id)
if not attempt:
# handle case where not started or docfetching is done but indexing is not
if (
not attempt
or not attempt.celery_task_id
or attempt.total_batches is not None
):
continue
failure_reason = (
f"Canceling leftover index attempt found on startup: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
logger.warning(failure_reason)
logger.exception(
f"Marking attempt {attempt.id} as canceled due to validation error 2"
)
mark_attempt_canceled(attempt.id, db_session, failure_reason)
# Check if the Celery task actually exists
try:
result: AsyncResult = AsyncResult(attempt.celery_task_id)
# If the task is not in PENDING state, it exists in Celery
if result.state != "PENDING":
continue
# Task is orphaned - mark as failed
failure_reason = (
f"Orphaned index attempt found on startup - Celery task not found: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id} "
f"celery_task_id={attempt.celery_task_id}"
)
logger.warning(failure_reason)
mark_attempt_canceled(attempt.id, db_session, failure_reason)
except Exception:
# If we can't check the task status, be conservative and continue
logger.warning(
f"Could not verify Celery task status on startup for attempt {attempt.id}, "
f"task_id={attempt.celery_task_id}"
)
@worker_ready.connect
@@ -292,7 +317,7 @@ for bootstep in base_bootsteps:
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.indexing",
"onyx.background.celery.tasks.docprocessing",
"onyx.background.celery.tasks.periodic",
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.shared",

View File

@@ -26,7 +26,7 @@ def celery_get_unacked_length(r: Redis) -> int:
def celery_get_unacked_task_ids(queue: str, r: Redis) -> set[str]:
"""Gets the set of task id's matching the given queue in the unacked hash.
Unacked entries belonging to the indexing queue are "prefetched", so this gives
Unacked entries belonging to the indexing queues are "prefetched", so this gives
us crucial visibility as to what tasks are in that state.
"""
tasks: set[str] = set()

View File

@@ -0,0 +1,22 @@
import onyx.background.celery.configs.base as shared_config
from onyx.configs.app_configs import CELERY_WORKER_DOCFETCHING_CONCURRENCY
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
# Docfetching worker configuration
worker_concurrency = CELERY_WORKER_DOCFETCHING_CONCURRENCY
worker_pool = "threads"
worker_prefetch_multiplier = 1

View File

@@ -1,5 +1,5 @@
import onyx.background.celery.configs.base as shared_config
from onyx.configs.app_configs import CELERY_WORKER_INDEXING_CONCURRENCY
from onyx.configs.app_configs import CELERY_WORKER_DOCPROCESSING_CONCURRENCY
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
@@ -24,6 +24,6 @@ task_acks_late = shared_config.task_acks_late
# which means a duplicate run might change the task state unexpectedly
# task_track_started = True
worker_concurrency = CELERY_WORKER_INDEXING_CONCURRENCY
worker_concurrency = CELERY_WORKER_DOCPROCESSING_CONCURRENCY
worker_pool = "threads"
worker_prefetch_multiplier = 1

View File

@@ -40,9 +40,11 @@ from onyx.db.document import get_document_ids_for_connector_credential_pair
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.index_attempt import delete_index_attempts
from onyx.db.index_attempt import get_recent_attempts_for_cc_pair
from onyx.db.search_settings import get_all_search_settings
from onyx.db.sync_record import cleanup_sync_records
from onyx.db.sync_record import insert_sync_record
@@ -69,13 +71,21 @@ def revoke_tasks_blocking_deletion(
) -> None:
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(search_settings.id)
try:
index_payload = redis_connector_index.payload
if index_payload and index_payload.celery_task_id:
app.control.revoke(index_payload.celery_task_id)
recent_index_attempts = get_recent_attempts_for_cc_pair(
cc_pair_id=redis_connector.cc_pair_id,
search_settings_id=search_settings.id,
limit=1,
db_session=db_session,
)
if (
recent_index_attempts
and recent_index_attempts[0].status == IndexingStatus.IN_PROGRESS
and recent_index_attempts[0].celery_task_id
):
app.control.revoke(recent_index_attempts[0].celery_task_id)
task_logger.info(
f"Revoked indexing task {index_payload.celery_task_id}."
f"Revoked indexing task {recent_index_attempts[0].celery_task_id}."
)
except Exception:
task_logger.exception("Exception while revoking indexing task")
@@ -281,8 +291,16 @@ def try_generate_document_cc_pair_cleanup_tasks(
# do not proceed if connector indexing or connector pruning are running
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(search_settings.id)
if redis_connector_index.fenced:
recent_index_attempts = get_recent_attempts_for_cc_pair(
cc_pair_id=cc_pair_id,
search_settings_id=search_settings.id,
limit=1,
db_session=db_session,
)
if (
recent_index_attempts
and recent_index_attempts[0].status == IndexingStatus.IN_PROGRESS
):
raise TaskDependencyError(
"Connector deletion - Delayed (indexing in progress): "
f"cc_pair={cc_pair_id} "

View File

@@ -0,0 +1,675 @@
import multiprocessing
import os
import time
import traceback
from http import HTTPStatus
from time import sleep
import sentry_sdk
from celery import Celery
from celery import shared_task
from celery import Task
from redis.lock import Lock as RedisLock
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.memory_monitoring import emit_process_memory
from onyx.background.celery.tasks.docprocessing.heartbeat import start_heartbeat
from onyx.background.celery.tasks.docprocessing.heartbeat import stop_heartbeat
from onyx.background.celery.tasks.docprocessing.tasks import ConnectorIndexingLogBuilder
from onyx.background.celery.tasks.docprocessing.utils import IndexingCallback
from onyx.background.celery.tasks.models import DocProcessingContext
from onyx.background.celery.tasks.models import IndexingWatchdogTerminalStatus
from onyx.background.celery.tasks.models import SimpleJobResult
from onyx.background.indexing.job_client import SimpleJob
from onyx.background.indexing.job_client import SimpleJobClient
from onyx.background.indexing.job_client import SimpleJobException
from onyx.background.indexing.run_docfetching import run_indexing_entrypoint
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import IndexingStatus
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import mark_attempt_canceled
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
def _verify_indexing_attempt(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
) -> None:
"""
Verify that the indexing attempt exists and is in the correct state.
"""
with get_session_with_current_tenant() as db_session:
attempt = get_index_attempt(db_session, index_attempt_id)
if not attempt:
raise SimpleJobException(
f"docfetching_task - IndexAttempt not found: attempt_id={index_attempt_id}",
code=IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND.code,
)
if attempt.connector_credential_pair_id != cc_pair_id:
raise SimpleJobException(
f"docfetching_task - CC pair mismatch: "
f"expected={cc_pair_id} actual={attempt.connector_credential_pair_id}",
code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code,
)
if attempt.search_settings_id != search_settings_id:
raise SimpleJobException(
f"docfetching_task - Search settings mismatch: "
f"expected={search_settings_id} actual={attempt.search_settings_id}",
code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code,
)
if attempt.status not in [
IndexingStatus.NOT_STARTED,
IndexingStatus.IN_PROGRESS,
]:
raise SimpleJobException(
f"docfetching_task - Invalid attempt status: "
f"attempt_id={index_attempt_id} status={attempt.status}",
code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code,
)
# Check for cancellation
if IndexingCoordination.check_cancellation_requested(
db_session, index_attempt_id
):
raise SimpleJobException(
f"docfetching_task - Cancellation requested: attempt_id={index_attempt_id}",
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL.code,
)
logger.info(
f"docfetching_task - IndexAttempt verified: "
f"attempt_id={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
def docfetching_task(
app: Celery,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
is_ee: bool,
tenant_id: str,
) -> None:
"""
This function is run in a SimpleJob as a new process. It is responsible for validating
some stuff, but basically it just calls run_indexing_entrypoint.
NOTE: if an exception is raised out of this task, the primary worker will detect
that the task transitioned to a "READY" state but the generator_complete_key doesn't exist.
This will cause the primary worker to abort the indexing attempt and clean up.
"""
# Start heartbeat for this indexing attempt
heartbeat_thread, stop_event = start_heartbeat(index_attempt_id)
try:
_docfetching_task(
app, index_attempt_id, cc_pair_id, search_settings_id, is_ee, tenant_id
)
finally:
stop_heartbeat(heartbeat_thread, stop_event) # Stop heartbeat before exiting
def _docfetching_task(
app: Celery,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
is_ee: bool,
tenant_id: str,
) -> None:
# Since connector_indexing_proxy_task spawns a new process using this function as
# the entrypoint, we init Sentry here.
if SENTRY_DSN:
sentry_sdk.init(
dsn=SENTRY_DSN,
traces_sample_rate=0.1,
)
logger.info("Sentry initialized")
else:
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
logger.info(
f"Indexing spawned task starting: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
# TODO: remove all fences, cause all signals to be set in postgres
if redis_connector.delete.fenced:
raise SimpleJobException(
f"Indexing will not start because connector deletion is in progress: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.delete.fence_key}",
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION.code,
)
if redis_connector.stop.fenced:
raise SimpleJobException(
f"Indexing will not start because a connector stop signal was detected: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.stop.fence_key}",
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL.code,
)
# Verify the indexing attempt exists and is valid
# This replaces the Redis fence payload waiting
_verify_indexing_attempt(index_attempt_id, cc_pair_id, search_settings_id)
# We still need a basic Redis lock to prevent duplicate task execution
# but this is much simpler than the full fencing mechanism
r = get_redis_client()
# set thread_local=False since we don't control what thread the indexing/pruning
# might run our callback with
lock: RedisLock = r.lock(
redis_connector_index.generator_lock_key,
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
thread_local=False,
)
acquired = lock.acquire(blocking=False)
if not acquired:
logger.warning(
f"Docfetching task already running, exiting...: "
f"index_attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
raise SimpleJobException(
f"Docfetching task already running, exiting...: "
f"index_attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}",
code=IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING.code,
)
try:
with get_session_with_current_tenant() as db_session:
attempt = get_index_attempt(db_session, index_attempt_id)
if not attempt:
raise SimpleJobException(
f"Index attempt not found: index_attempt={index_attempt_id}",
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
)
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
raise SimpleJobException(
f"cc_pair not found: cc_pair={cc_pair_id}",
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
)
# define a callback class
callback = IndexingCallback(
os.getppid(),
redis_connector,
lock,
r,
)
logger.info(
f"Indexing spawned task running entrypoint: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
# This is where the heavy/real work happens
run_indexing_entrypoint(
app,
index_attempt_id,
tenant_id,
cc_pair_id,
is_ee,
callback=callback,
)
except ConnectorValidationError:
raise SimpleJobException(
f"Indexing task failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}",
code=IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR.code,
)
except Exception as e:
logger.exception(
f"Indexing spawned task failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
# special bulletproofing ... truncate long exception messages
# for exception types that require more args, this will fail
# thus the try/except
try:
sanitized_e = type(e)(str(e)[:1024])
sanitized_e.__traceback__ = e.__traceback__
raise sanitized_e
except Exception:
raise e
finally:
if lock.owned():
lock.release()
logger.info(
f"Indexing spawned task finished: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
os._exit(0) # ensure process exits cleanly
def process_job_result(
job: SimpleJob,
connector_source: str | None,
redis_connector_index: RedisConnectorIndex,
log_builder: ConnectorIndexingLogBuilder,
) -> SimpleJobResult:
result = SimpleJobResult()
result.connector_source = connector_source
if job.process:
result.exit_code = job.process.exitcode
if job.status != "error":
result.status = IndexingWatchdogTerminalStatus.SUCCEEDED
return result
ignore_exitcode = False
# In EKS, there is an edge case where successful tasks return exit
# code 1 in the cloud due to the set_spawn_method not sticking.
# We've since worked around this, but the following is a safe way to
# work around this issue. Basically, we ignore the job error state
# if the completion signal is OK.
status_int = redis_connector_index.get_completion()
if status_int:
status_enum = HTTPStatus(status_int)
if status_enum == HTTPStatus.OK:
ignore_exitcode = True
if ignore_exitcode:
result.status = IndexingWatchdogTerminalStatus.SUCCEEDED
task_logger.warning(
log_builder.build(
"Indexing watchdog - spawned task has non-zero exit code "
"but completion signal is OK. Continuing...",
exit_code=str(result.exit_code),
)
)
else:
if result.exit_code is not None:
result.status = IndexingWatchdogTerminalStatus.from_code(result.exit_code)
result.exception_str = job.exception()
return result
@shared_task(
name=OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
bind=True,
acks_late=False,
track_started=True,
)
def docfetching_proxy_task(
self: Task,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str,
) -> None:
"""
This task is the entrypoint for the full indexing pipeline, which is composed of two tasks:
docfetching and docprocessing.
This task is spawned by "try_creating_indexing_task" which is called in the "check_for_indexing" task.
This task spawns a new process for a new scheduled index attempt. That
new process (which runs the docfetching_task function) does the following:
1) determines parameters of the indexing attempt (which connector indexing function to run,
start and end time, from prev checkpoint or not), then run that connector. Specifically,
connectors are responsible for reading data from an outside source and converting it to Onyx documents.
At the moment these two steps (reading external data and converting to an Onyx document)
are not parallelized in most connectors; that's a subject for future work.
Each document batch produced by step 1 is stored in the file store, and a docprocessing task is spawned
to process it. docprocessing involves the steps listed below.
2) upserts documents to postgres (index_doc_batch_prepare)
3) chunks each document (optionally adds context for contextual rag)
4) embeds chunks (embed_chunks_with_failure_handling) via a call to the model server
5) write chunks to vespa (write_chunks_to_vector_db_with_backoff)
6) update document and indexing metadata in postgres
7) pulls all document IDs from the source and compares those IDs to locally stored documents and deletes
all locally stored IDs missing from the most recently pulled document ID list
Some important notes:
Invariants:
- docfetching proxy tasks are spawned by check_for_indexing. The proxy then runs the docfetching_task wrapped in a watchdog.
The watchdog is responsible for monitoring the docfetching_task and marking the index attempt as failed
if it is not making progress.
- All docprocessing tasks are spawned by a docfetching task.
- all docfetching tasks, docprocessing tasks, and document batches in the file store are
associated with a specific index attempt.
- the index attempt status is the source of truth for what is currently happening with the index attempt.
It is coupled with the creation/running of docfetching and docprocessing tasks as much as possible.
How we deal with failures/ partial indexing:
- non-checkpointed connectors/ new runs in general => delete the old document batches from the file store and do the new run
- checkpointed connectors + resuming from checkpoint => reissue the old document batches and do a new run
Misc:
- most inter-process communication is handled in postgres, some is still in redis and we're trying to remove it
- Heartbeat spawned in docfetching and docprocessing is how check_for_indexing monitors liveliness
- progress based liveliness check: if nothing is done in 3-6 hours, mark the attempt as failed
- TODO: task level timeouts (i.e. a connector stuck in an infinite loop)
Comments below are from the old version and some may no longer be valid.
TODO(rkuo): refactor this so that there is a single return path where we canonically
log the result of running this function.
Some more Richard notes:
celery out of process task execution strategy is pool=prefork, but it uses fork,
and forking is inherently unstable.
To work around this, we use pool=threads and proxy our work to a spawned task.
acks_late must be set to False. Otherwise, celery's visibility timeout will
cause any task that runs longer than the timeout to be redispatched by the broker.
There appears to be no good workaround for this, so we need to handle redispatching
manually.
NOTE: we try/except all db access in this function because as a watchdog, this function
needs to be extremely stable.
"""
# TODO: remove dependence on Redis
start = time.monotonic()
result = SimpleJobResult()
ctx = DocProcessingContext(
tenant_id=tenant_id,
cc_pair_id=cc_pair_id,
search_settings_id=search_settings_id,
index_attempt_id=index_attempt_id,
)
log_builder = ConnectorIndexingLogBuilder(ctx)
task_logger.info(
log_builder.build(
"Indexing watchdog - starting",
mp_start_method=str(multiprocessing.get_start_method()),
)
)
if not self.request.id:
task_logger.error("self.request.id is None!")
client = SimpleJobClient()
task_logger.info(f"submitting docfetching_task with tenant_id={tenant_id}")
job = client.submit(
docfetching_task,
self.app,
index_attempt_id,
cc_pair_id,
search_settings_id,
global_version.is_ee_version(),
tenant_id,
)
if not job or not job.process:
result.status = IndexingWatchdogTerminalStatus.SPAWN_FAILED
task_logger.info(
log_builder.build(
"Indexing watchdog - finished",
status=str(result.status.value),
exit_code=str(result.exit_code),
)
)
return
# Ensure the process has moved out of the starting state
num_waits = 0
while True:
if num_waits > 15:
result.status = IndexingWatchdogTerminalStatus.SPAWN_NOT_ALIVE
task_logger.info(
log_builder.build(
"Indexing watchdog - finished",
status=str(result.status.value),
exit_code=str(result.exit_code),
)
)
job.release()
return
if job.process.is_alive() or job.process.exitcode is not None:
break
sleep(1)
num_waits += 1
task_logger.info(
log_builder.build(
"Indexing watchdog - spawn succeeded",
pid=str(job.process.pid),
)
)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
# Track the last time memory info was emitted
last_memory_emit_time = 0.0
try:
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(
db_session=db_session,
index_attempt_id=index_attempt_id,
eager_load_cc_pair=True,
)
if not index_attempt:
raise RuntimeError("Index attempt not found")
result.connector_source = (
index_attempt.connector_credential_pair.connector.source.value
)
while True:
sleep(5)
time.monotonic()
# if the job is done, clean up and break
if job.done():
try:
result = process_job_result(
job, result.connector_source, redis_connector_index, log_builder
)
except Exception:
task_logger.exception(
log_builder.build(
"Indexing watchdog - spawned task exceptioned"
)
)
finally:
job.release()
break
# log the memory usage for tracking down memory leaks / connector-specific memory issues
pid = job.process.pid
if pid is not None:
# Only emit memory info once per minute (60 seconds)
current_time = time.monotonic()
if current_time - last_memory_emit_time >= 60.0:
emit_process_memory(
pid,
"indexing_worker",
{
"cc_pair_id": cc_pair_id,
"search_settings_id": search_settings_id,
"index_attempt_id": index_attempt_id,
},
)
last_memory_emit_time = current_time
# if the spawned task is still running, restart the check once again
# if the index attempt is not in a finished status
try:
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
except Exception:
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception looking up index attempt"
)
)
continue
except Exception as e:
result.status = IndexingWatchdogTerminalStatus.WATCHDOG_EXCEPTIONED
if isinstance(e, ConnectorValidationError):
# No need to expose full stack trace for validation errors
result.exception_str = str(e)
else:
result.exception_str = traceback.format_exc()
# handle exit and reporting
elapsed = time.monotonic() - start
if result.exception_str is not None:
# print with exception
try:
with get_session_with_current_tenant() as db_session:
failure_reason = (
f"Spawned task exceptioned: exit_code={result.exit_code}"
)
mark_attempt_failed(
ctx.index_attempt_id,
db_session,
failure_reason=failure_reason,
full_exception_trace=result.exception_str,
)
except Exception:
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as failed"
)
)
normalized_exception_str = "None"
if result.exception_str:
normalized_exception_str = result.exception_str.replace(
"\n", "\\n"
).replace('"', '\\"')
task_logger.warning(
log_builder.build(
"Indexing watchdog - finished",
source=result.connector_source,
status=result.status.value,
exit_code=str(result.exit_code),
exception=f'"{normalized_exception_str}"',
elapsed=f"{elapsed:.2f}s",
)
)
raise RuntimeError(f"Exception encountered: traceback={result.exception_str}")
# print without exception
if result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL:
try:
with get_session_with_current_tenant() as db_session:
logger.exception(
f"Marking attempt {index_attempt_id} as canceled due to termination signal"
)
mark_attempt_canceled(
index_attempt_id,
db_session,
"Connector termination signal detected",
)
except Exception:
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as canceled"
)
)
job.cancel()
elif result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_ACTIVITY_TIMEOUT:
try:
with get_session_with_current_tenant() as db_session:
mark_attempt_failed(
index_attempt_id,
db_session,
"Indexing watchdog - activity timeout exceeded: "
f"attempt={index_attempt_id} "
f"timeout={CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
)
except Exception:
logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as failed"
)
)
job.cancel()
else:
pass
task_logger.info(
log_builder.build(
"Indexing watchdog - finished",
source=result.connector_source,
status=str(result.status.value),
exit_code=str(result.exit_code),
elapsed=f"{elapsed:.2f}s",
)
)

View File

@@ -0,0 +1,36 @@
import threading
from sqlalchemy import update
from onyx.configs.constants import INDEXING_WORKER_HEARTBEAT_INTERVAL
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import IndexAttempt
def start_heartbeat(index_attempt_id: int) -> tuple[threading.Thread, threading.Event]:
"""Start a heartbeat thread for the given index attempt"""
stop_event = threading.Event()
def heartbeat_loop() -> None:
while not stop_event.wait(INDEXING_WORKER_HEARTBEAT_INTERVAL):
try:
with get_session_with_current_tenant() as db_session:
db_session.execute(
update(IndexAttempt)
.where(IndexAttempt.id == index_attempt_id)
.values(heartbeat_counter=IndexAttempt.heartbeat_counter + 1)
)
db_session.commit()
except Exception:
# Silently continue if heartbeat fails
pass
thread = threading.Thread(target=heartbeat_loop, daemon=True)
thread.start()
return thread, stop_event
def stop_heartbeat(thread: threading.Thread, stop_event: threading.Event) -> None:
"""Stop the heartbeat thread"""
stop_event.set()
thread.join(timeout=5) # Wait up to 5 seconds for clean shutdown

File diff suppressed because it is too large Load Diff

View File

@@ -1,10 +1,8 @@
import time
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from redis import Redis
from redis.exceptions import LockError
@@ -12,8 +10,6 @@ from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
@@ -21,27 +17,19 @@ from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.time_utils import get_db_current_time
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus
from onyx.db.index_attempt import create_index_attempt
from onyx.db.index_attempt import delete_index_attempt
from onyx.db.index_attempt import get_all_index_attempts_by_status
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import get_last_attempt_for_cc_pair
from onyx.db.index_attempt import get_recent_attempts_for_cc_pair
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import IndexAttempt
from onyx.db.models import SearchSettings
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_index import RedisConnectorIndexPayload
from onyx.redis.redis_pool import redis_lock_dump
from onyx.utils.logger import setup_logger
@@ -50,54 +38,6 @@ logger = setup_logger()
NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE = 5
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
want to clean them up.
Unfenced = attempt not in terminal state and fence does not exist.
"""
unfenced_attempts: list[int] = []
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# inner = index_attempt in non terminal state
# outer = r.fence_key down
# check the db for index attempts in a non terminal state
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for attempt in attempts:
fence_key = RedisConnectorIndex.fence_key_with_ids(
attempt.connector_credential_pair_id, attempt.search_settings_id
)
# if the fence is down / doesn't exist, possible error but not confirmed
if r.exists(fence_key):
continue
# Between the time the attempts are first looked up and the time we see the fence down,
# the attempt may have completed and taken down the fence normally.
# We need to double check that the index attempt is still in a non terminal state
# and matches the original state, which confirms we are really in a bad state.
attempt_2 = get_index_attempt(db_session, attempt.id)
if not attempt_2:
continue
if attempt.status != attempt_2.status:
continue
unfenced_attempts.append(attempt.id)
return unfenced_attempts
class IndexingCallbackBase(IndexingHeartbeatInterface):
PARENT_CHECK_INTERVAL = 60
@@ -123,10 +63,9 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
self.last_parent_check = time.monotonic()
def should_stop(self) -> bool:
if self.redis_connector.stop.fenced:
return True
return False
# Check if the associated indexing attempt has been cancelled
# TODO: Pass index_attempt_id to the callback and check cancellation using the db
return bool(self.redis_connector.stop.fenced)
def progress(self, tag: str, amount: int) -> None:
"""Amount isn't used yet."""
@@ -178,179 +117,16 @@ class IndexingCallback(IndexingCallbackBase):
redis_connector: RedisConnector,
redis_lock: RedisLock,
redis_client: Redis,
redis_connector_index: RedisConnectorIndex,
):
super().__init__(parent_pid, redis_connector, redis_lock, redis_client)
self.redis_connector_index: RedisConnectorIndex = redis_connector_index
def progress(self, tag: str, amount: int) -> None:
self.redis_connector_index.set_active()
self.redis_connector_index.set_connector_active()
super().progress(tag, amount)
self.redis_client.incrby(
self.redis_connector_index.generator_progress_key, amount
)
def validate_indexing_fence(
tenant_id: str,
key_bytes: bytes,
reserved_tasks: set[str],
r_celery: Redis,
db_session: Session,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
This can happen if the indexing worker hard crashes or is terminated.
Being in this bad state means the fence will never clear without help, so this function
gives the help.
How this works:
1. This function renews the active signal with a 5 minute TTL under the following conditions
1.2. When the task is seen in the redis queue
1.3. When the task is seen in the reserved / prefetched list
2. Externally, the active signal is renewed when:
2.1. The fence is created
2.2. The indexing watchdog checks the spawned task.
3. The TTL allows us to get through the transitions on fence startup
and when the task starts executing.
More TTL clarification: it is seemingly impossible to exactly query Celery for
whether a task is in the queue or currently executing.
1. An unknown task id is always returned as state PENDING.
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
and the time it actually starts on the worker.
"""
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
if composite_id is None:
task_logger.warning(
f"validate_indexing_fence - could not parse composite_id from {fence_key}"
)
return
# parse out metadata and initialize the helper class with it
parts = composite_id.split("/")
if len(parts) != 2:
return
cc_pair_id = int(parts[0])
search_settings_id = int(parts[1])
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
# check to see if the fence/payload exists
if not redis_connector_index.fenced:
return
payload = redis_connector_index.payload
if not payload:
return
# OK, there's actually something for us to validate
if payload.celery_task_id is None:
# the fence is just barely set up.
if redis_connector_index.active():
return
# it would be odd to get here as there isn't that much that can go wrong during
# initial fence setup, but it's still worth making sure we can recover
logger.info(
f"validate_indexing_fence - "
f"Resetting fence in basic state without any activity: fence={fence_key}"
)
redis_connector_index.reset()
return
found = celery_find_task(
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
if found:
# the celery task exists in the redis queue
redis_connector_index.set_active()
return
if payload.celery_task_id in reserved_tasks:
# the celery task was prefetched and is reserved within the indexing worker
redis_connector_index.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
# but they still might be there due to gaps in our ability to check states during transitions
# Checking the active signal safeguards us against these transition periods
# (which has a duration that allows us to bridge those gaps)
if redis_connector_index.active():
return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
logger.warning(
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: "
f"index_attempt={payload.index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"fence={fence_key}"
)
if payload.index_attempt_id:
try:
mark_attempt_failed(
payload.index_attempt_id,
db_session,
"validate_indexing_fence - Canceling index attempt due to missing celery tasks: "
f"index_attempt={payload.index_attempt_id}",
)
except Exception:
logger.exception(
"validate_indexing_fence - Exception while marking index attempt as failed: "
f"index_attempt={payload.index_attempt_id}",
)
redis_connector_index.reset()
return
def validate_indexing_fences(
tenant_id: str,
r_replica: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
"""Validates all indexing fences for this tenant ... aka makes sure
indexing tasks sent to celery are still in flight.
"""
reserved_indexing_tasks = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
# Use replica for this because the worst thing that happens
# is that we don't run the validation on this pass
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
for key in keys:
key_bytes = cast(bytes, key)
key_str = key_bytes.decode("utf-8")
if not key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
continue
with get_session_with_current_tenant() as db_session:
validate_indexing_fence(
tenant_id,
key_bytes,
reserved_indexing_tasks,
r_celery,
db_session,
)
lock_beat.reacquire()
return
# NOTE: The validate_indexing_fence and validate_indexing_fences functions have been removed
# as they are no longer needed with database-based coordination. The new validation is
# handled by validate_active_indexing_attempts in the main indexing tasks module.
def is_in_repeated_error_state(
@@ -414,10 +190,12 @@ def should_index(
)
# uncomment for debugging
# task_logger.info(f"_should_index: "
# f"cc_pair={cc_pair.id} "
# f"connector={cc_pair.connector_id} "
# f"refresh_freq={connector.refresh_freq}")
task_logger.info(
f"_should_index: "
f"cc_pair={cc_pair.id} "
f"connector={cc_pair.connector_id} "
f"refresh_freq={connector.refresh_freq}"
)
# don't kick off indexing for `NOT_APPLICABLE` sources
if connector.source == DocumentSource.NOT_APPLICABLE:
@@ -517,7 +295,7 @@ def should_index(
return True
def try_creating_indexing_task(
def try_creating_docfetching_task(
celery_app: Celery,
cc_pair: ConnectorCredentialPair,
search_settings: SearchSettings,
@@ -531,10 +309,11 @@ def try_creating_indexing_task(
Does not check for scheduling related conditions as this function
is used to trigger indexing immediately.
Now uses database-based coordination instead of Redis fencing.
"""
LOCK_TIMEOUT = 30
index_attempt_id: int | None = None
# we need to serialize any attempt to trigger indexing since it can be triggered
# either via celery beat or manually (API call)
@@ -547,61 +326,42 @@ def try_creating_indexing_task(
if not acquired:
return None
redis_connector_index: RedisConnectorIndex
index_attempt_id = None
try:
redis_connector = RedisConnector(tenant_id, cc_pair.id)
redis_connector_index = redis_connector.new_index(search_settings.id)
# skip if already indexing
if redis_connector_index.fenced:
return None
# skip indexing if the cc_pair is deleting
if redis_connector.delete.fenced:
return None
# Basic status checks
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
# add a long running generator task to the queue
redis_connector_index.generator_clear()
# Generate custom task ID for tracking
custom_task_id = f"docfetching_{cc_pair.id}_{search_settings.id}_{uuid4()}"
# set a basic fence to start
payload = RedisConnectorIndexPayload(
index_attempt_id=None,
started=None,
submitted=datetime.now(timezone.utc),
celery_task_id=None,
)
redis_connector_index.set_active()
redis_connector_index.set_fence(payload)
# create the index attempt for tracking purposes
# code elsewhere checks for index attempts without an associated redis key
# and cleans them up
# therefore we must create the attempt and the task after the fence goes up
index_attempt_id = create_index_attempt(
cc_pair.id,
search_settings.id,
from_beginning=reindex,
# Try to create a new index attempt using database coordination
# This replaces the Redis fencing mechanism
index_attempt_id = IndexingCoordination.try_create_index_attempt(
db_session=db_session,
cc_pair_id=cc_pair.id,
search_settings_id=search_settings.id,
celery_task_id=custom_task_id,
from_beginning=reindex,
)
custom_task_id = redis_connector_index.generate_generator_task_id()
if index_attempt_id is None:
# Another indexing attempt is already running
return None
# Determine which queue to use based on whether this is a user file
# TODO: at the moment the indexing pipeline is
# shared between user files and connectors
queue = (
OnyxCeleryQueues.USER_FILES_INDEXING
if cc_pair.is_user_file
else OnyxCeleryQueues.CONNECTOR_INDEXING
else OnyxCeleryQueues.CONNECTOR_DOC_FETCHING
)
# when the task is sent, we have yet to finish setting up the fence
# therefore, the task must contain code that blocks until the fence is ready
# Send the task to Celery
result = celery_app.send_task(
OnyxCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
kwargs=dict(
index_attempt_id=index_attempt_id,
cc_pair_id=cc_pair.id,
@@ -613,14 +373,18 @@ def try_creating_indexing_task(
priority=OnyxCeleryPriority.MEDIUM,
)
if not result:
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
raise RuntimeError("send_task for connector_doc_fetching_task failed.")
# now fill out the fence with the rest of the data
redis_connector_index.set_active()
task_logger.info(
f"Created docfetching task: "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings.id} "
f"attempt_id={index_attempt_id} "
f"celery_task_id={custom_task_id}"
)
return index_attempt_id
payload.index_attempt_id = index_attempt_id
payload.celery_task_id = result.id
redis_connector_index.set_fence(payload)
except Exception:
task_logger.exception(
f"try_creating_indexing_task - Unexpected exception: "
@@ -628,9 +392,10 @@ def try_creating_indexing_task(
f"search_settings={search_settings.id}"
)
# Clean up on failure
if index_attempt_id is not None:
delete_index_attempt(db_session, index_attempt_id)
redis_connector_index.set_fence(None)
mark_attempt_failed(index_attempt_id, db_session)
return None
finally:
if lock.owned():

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,110 @@
from enum import Enum
from pydantic import BaseModel
class DocProcessingContext(BaseModel):
tenant_id: str
cc_pair_id: int
search_settings_id: int
index_attempt_id: int
class IndexingWatchdogTerminalStatus(str, Enum):
"""The different statuses the watchdog can finish with.
TODO: create broader success/failure/abort categories
"""
UNDEFINED = "undefined"
SUCCEEDED = "succeeded"
SPAWN_FAILED = "spawn_failed" # connector spawn failed
SPAWN_NOT_ALIVE = (
"spawn_not_alive" # spawn succeeded but process did not come alive
)
BLOCKED_BY_DELETION = "blocked_by_deletion"
BLOCKED_BY_STOP_SIGNAL = "blocked_by_stop_signal"
FENCE_NOT_FOUND = "fence_not_found" # fence does not exist
FENCE_READINESS_TIMEOUT = (
"fence_readiness_timeout" # fence exists but wasn't ready within the timeout
)
FENCE_MISMATCH = "fence_mismatch" # task and fence metadata mismatch
TASK_ALREADY_RUNNING = "task_already_running" # task appears to be running already
INDEX_ATTEMPT_MISMATCH = (
"index_attempt_mismatch" # expected index attempt metadata not found in db
)
CONNECTOR_VALIDATION_ERROR = (
"connector_validation_error" # the connector validation failed
)
CONNECTOR_EXCEPTIONED = "connector_exceptioned" # the connector itself exceptioned
WATCHDOG_EXCEPTIONED = "watchdog_exceptioned" # the watchdog exceptioned
# the watchdog received a termination signal
TERMINATED_BY_SIGNAL = "terminated_by_signal"
# the watchdog terminated the task due to no activity
TERMINATED_BY_ACTIVITY_TIMEOUT = "terminated_by_activity_timeout"
# NOTE: this may actually be the same as SIGKILL, but parsed differently by python
# consolidate once we know more
OUT_OF_MEMORY = "out_of_memory"
PROCESS_SIGNAL_SIGKILL = "process_signal_sigkill"
@property
def code(self) -> int:
_ENUM_TO_CODE: dict[IndexingWatchdogTerminalStatus, int] = {
IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL: -9,
IndexingWatchdogTerminalStatus.OUT_OF_MEMORY: 137,
IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR: 247,
IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION: 248,
IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL: 249,
IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND: 250,
IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT: 251,
IndexingWatchdogTerminalStatus.FENCE_MISMATCH: 252,
IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING: 253,
IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH: 254,
IndexingWatchdogTerminalStatus.CONNECTOR_EXCEPTIONED: 255,
}
return _ENUM_TO_CODE[self]
@classmethod
def from_code(cls, code: int) -> "IndexingWatchdogTerminalStatus":
_CODE_TO_ENUM: dict[int, IndexingWatchdogTerminalStatus] = {
-9: IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL,
137: IndexingWatchdogTerminalStatus.OUT_OF_MEMORY,
247: IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR,
248: IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION,
249: IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL,
250: IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND,
251: IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT,
252: IndexingWatchdogTerminalStatus.FENCE_MISMATCH,
253: IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING,
254: IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH,
255: IndexingWatchdogTerminalStatus.CONNECTOR_EXCEPTIONED,
}
if code in _CODE_TO_ENUM:
return _CODE_TO_ENUM[code]
return IndexingWatchdogTerminalStatus.UNDEFINED
class SimpleJobResult:
"""The data we want to have when the watchdog finishes"""
def __init__(self) -> None:
self.status = IndexingWatchdogTerminalStatus.UNDEFINED
self.connector_source = None
self.exit_code = None
self.exception_str = None
status: IndexingWatchdogTerminalStatus
connector_source: str | None
exit_code: int | None
exception_str: str | None

View File

@@ -147,7 +147,7 @@ def _collect_queue_metrics(redis_celery: Redis) -> list[Metric]:
metrics = []
queue_mappings = {
"celery_queue_length": "celery",
"indexing_queue_length": "indexing",
"docprocessing_queue_length": "docprocessing",
"sync_queue_length": "sync",
"deletion_queue_length": "deletion",
"pruning_queue_length": "pruning",
@@ -882,7 +882,13 @@ def monitor_celery_queues_helper(
r_celery = task.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r_celery)
n_indexing = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery)
n_docfetching = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
)
n_docprocessing = celery_get_queue_length(OnyxCeleryQueues.DOCPROCESSING, r_celery)
n_user_files_indexing = celery_get_queue_length(
OnyxCeleryQueues.USER_FILES_INDEXING, r_celery
)
n_sync = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
n_deletion = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
n_pruning = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery)
@@ -896,14 +902,20 @@ def monitor_celery_queues_helper(
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
)
n_indexing_prefetched = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
n_docfetching_prefetched = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
)
n_docprocessing_prefetched = celery_get_unacked_task_ids(
OnyxCeleryQueues.DOCPROCESSING, r_celery
)
task_logger.info(
f"Queue lengths: celery={n_celery} "
f"indexing={n_indexing} "
f"indexing_prefetched={len(n_indexing_prefetched)} "
f"docfetching={n_docfetching} "
f"docfetching_prefetched={len(n_docfetching_prefetched)} "
f"docprocessing={n_docprocessing} "
f"docprocessing_prefetched={len(n_docprocessing_prefetched)} "
f"user_files_indexing={n_user_files_indexing} "
f"sync={n_sync} "
f"deletion={n_deletion} "
f"pruning={n_pruning} "

View File

@@ -22,7 +22,7 @@ from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.background.celery.tasks.indexing.utils import IndexingCallbackBase
from onyx.background.celery.tasks.docprocessing.utils import IndexingCallbackBase
from onyx.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
@@ -464,7 +464,7 @@ def connector_pruning_generator_task(
# set thread_local=False since we don't control what thread the indexing/pruning
# might run our callback with
lock: RedisLock = r.lock(
OnyxRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.id}",
OnyxRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.cc_pair_id}",
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
thread_local=False,
)

View File

@@ -10,7 +10,7 @@ set_is_ee_based_on_env_variable()
def get_app() -> Celery:
from onyx.background.celery.apps.indexing import celery_app
from onyx.background.celery.apps.docfetching import celery_app
return celery_app

View File

@@ -0,0 +1,18 @@
"""Factory stub for running celery worker / celery beat.
This code is different from the primary/beat stubs because there is no EE version to
fetch. Port over the code in those files if we add an EE version of this worker."""
from celery import Celery
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
def get_app() -> Celery:
from onyx.background.celery.apps.docprocessing import celery_app
return celery_app
app = get_app()

View File

@@ -33,7 +33,7 @@ def save_checkpoint(
"""Save a checkpoint for a given index attempt to the file store"""
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
file_store.save_file(
content=BytesIO(checkpoint.model_dump_json().encode()),
display_name=checkpoint_pointer,
@@ -52,11 +52,11 @@ def save_checkpoint(
def load_checkpoint(
db_session: Session, index_attempt_id: int, connector: BaseConnector
index_attempt_id: int, connector: BaseConnector
) -> ConnectorCheckpoint:
"""Load a checkpoint for a given index attempt from the file store"""
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
checkpoint_io = file_store.read_file(checkpoint_pointer, mode="rb")
checkpoint_data = checkpoint_io.read().decode("utf-8")
if isinstance(connector, CheckpointedConnector):
@@ -71,7 +71,7 @@ def get_latest_valid_checkpoint(
window_start: datetime,
window_end: datetime,
connector: BaseConnector,
) -> ConnectorCheckpoint:
) -> tuple[ConnectorCheckpoint, bool]:
"""Get the latest valid checkpoint for a given connector credential pair"""
checkpoint_candidates = get_recent_completed_attempts_for_cc_pair(
cc_pair_id=cc_pair_id,
@@ -83,7 +83,7 @@ def get_latest_valid_checkpoint(
# don't keep using checkpoints if we've had a bunch of failed attempts in a row
# where we make no progress. Only do this if we have had at least
# _NUM_RECENT_ATTEMPTS_TO_CONSIDER completed attempts.
if len(checkpoint_candidates) == _NUM_RECENT_ATTEMPTS_TO_CONSIDER:
if len(checkpoint_candidates) >= _NUM_RECENT_ATTEMPTS_TO_CONSIDER:
had_any_progress = False
for candidate in checkpoint_candidates:
if (
@@ -99,7 +99,7 @@ def get_latest_valid_checkpoint(
f"found for cc_pair={cc_pair_id}. Ignoring checkpoint to let the run start "
"from scratch."
)
return connector.build_dummy_checkpoint()
return connector.build_dummy_checkpoint(), False
# filter out any candidates that don't meet the criteria
checkpoint_candidates = [
@@ -140,11 +140,10 @@ def get_latest_valid_checkpoint(
logger.info(
f"No valid checkpoint found for cc_pair={cc_pair_id}. Starting from scratch."
)
return checkpoint
return checkpoint, False
try:
previous_checkpoint = load_checkpoint(
db_session=db_session,
index_attempt_id=latest_valid_checkpoint_candidate.id,
connector=connector,
)
@@ -153,14 +152,14 @@ def get_latest_valid_checkpoint(
f"Failed to load checkpoint from previous failed attempt with ID "
f"{latest_valid_checkpoint_candidate.id}. Falling back to default checkpoint."
)
return checkpoint
return checkpoint, False
logger.info(
f"Using checkpoint from previous failed attempt with ID "
f"{latest_valid_checkpoint_candidate.id}. Previous checkpoint: "
f"{previous_checkpoint}"
)
return previous_checkpoint
return previous_checkpoint, True
def get_index_attempts_with_old_checkpoints(
@@ -201,7 +200,7 @@ def cleanup_checkpoint(db_session: Session, index_attempt_id: int) -> None:
if not index_attempt.checkpoint_pointer:
return None
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
file_store.delete_file(index_attempt.checkpoint_pointer)
index_attempt.checkpoint_pointer = None

View File

@@ -1,3 +1,4 @@
import sys
import time
import traceback
from collections import defaultdict
@@ -5,7 +6,7 @@ from datetime import datetime
from datetime import timedelta
from datetime import timezone
from pydantic import BaseModel
from celery import Celery
from sqlalchemy.orm import Session
from onyx.access.access import source_should_fetch_permissions_during_indexing
@@ -18,18 +19,25 @@ from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE
from onyx.configs.app_configs import MAX_FILE_SIZE_BYTES
from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorStopSignal
from onyx.connectors.models import DocExtractionContext
from onyx.connectors.models import Document
from onyx.connectors.models import IndexAttemptMetadata
from onyx.connectors.models import TextSection
from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_last_successful_attempt_poll_range_end
from onyx.db.connector_credential_pair import update_connector_credential_pair
@@ -49,13 +57,16 @@ from onyx.db.index_attempt import mark_attempt_partially_succeeded
from onyx.db.index_attempt import mark_attempt_succeeded
from onyx.db.index_attempt import transition_attempt_to_in_progress
from onyx.db.index_attempt import update_docs_indexed
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
from onyx.document_index.factory import get_default_document_index
from onyx.file_store.document_batch_storage import DocumentBatchStorage
from onyx.file_store.document_batch_storage import get_document_batch_storage
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
@@ -68,7 +79,7 @@ from onyx.utils.telemetry import RecordType
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
logger = setup_logger(propagate=False)
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
@@ -146,6 +157,10 @@ def _get_connector_runner(
def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
cleaned_batch = []
for doc in doc_batch:
if sys.getsizeof(doc) > MAX_FILE_SIZE_BYTES:
logger.warning(
f"doc {doc.id} too large, Document size: {sys.getsizeof(doc)}"
)
cleaned_doc = doc.model_copy()
# Postgres cannot handle NUL characters in text fields
@@ -180,25 +195,11 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
return cleaned_batch
class ConnectorStopSignal(Exception):
"""A custom exception used to signal a stop in processing."""
class RunIndexingContext(BaseModel):
index_name: str
cc_pair_id: int
connector_id: int
credential_id: int
source: DocumentSource
earliest_index_time: float
from_beginning: bool
is_primary: bool
should_fetch_permissions_during_indexing: bool
search_settings_status: IndexModelStatus
def _check_connector_and_attempt_status(
db_session_temp: Session, ctx: RunIndexingContext, index_attempt_id: int
db_session_temp: Session,
cc_pair_id: int,
search_settings_status: IndexModelStatus,
index_attempt_id: int,
) -> None:
"""
Checks the status of the connector credential pair and index attempt.
@@ -206,27 +207,34 @@ def _check_connector_and_attempt_status(
"""
cc_pair_loop = get_connector_credential_pair_from_id(
db_session_temp,
ctx.cc_pair_id,
cc_pair_id,
)
if not cc_pair_loop:
raise RuntimeError(f"CC pair {ctx.cc_pair_id} not found in DB.")
raise RuntimeError(f"CC pair {cc_pair_id} not found in DB.")
if (
cc_pair_loop.status == ConnectorCredentialPairStatus.PAUSED
and ctx.search_settings_status != IndexModelStatus.FUTURE
and search_settings_status != IndexModelStatus.FUTURE
) or cc_pair_loop.status == ConnectorCredentialPairStatus.DELETING:
raise RuntimeError("Connector was disabled mid run")
raise ConnectorStopSignal(f"Connector {cc_pair_loop.status.value.lower()}")
index_attempt_loop = get_index_attempt(db_session_temp, index_attempt_id)
if not index_attempt_loop:
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
if index_attempt_loop.status == IndexingStatus.CANCELED:
raise ConnectorStopSignal(f"Index attempt {index_attempt_id} was canceled")
if index_attempt_loop.status != IndexingStatus.IN_PROGRESS:
raise RuntimeError(
f"Index Attempt was canceled, status is {index_attempt_loop.status}"
f"Index Attempt is not running, status is {index_attempt_loop.status}"
)
if index_attempt_loop.celery_task_id is None:
raise RuntimeError(f"Index attempt {index_attempt_id} has no celery task id")
# TODO: delete from here if ends up unused
def _check_failure_threshold(
total_failures: int,
document_count: int,
@@ -257,6 +265,9 @@ def _check_failure_threshold(
)
# NOTE: this is the old run_indexing function that the new decoupled approach
# is based on. Leaving this for comparison purposes, but if you see this comment
# has been here for >1 month, please delete this function.
def _run_indexing(
db_session: Session,
index_attempt_id: int,
@@ -271,7 +282,12 @@ def _run_indexing(
start_time = time.monotonic() # jsut used for logging
with get_session_with_current_tenant() as db_session_temp:
index_attempt_start = get_index_attempt(db_session_temp, index_attempt_id)
index_attempt_start = get_index_attempt(
db_session_temp,
index_attempt_id,
eager_load_cc_pair=True,
eager_load_search_settings=True,
)
if not index_attempt_start:
raise ValueError(
f"Index attempt {index_attempt_id} does not exist in DB. This should not be possible."
@@ -292,7 +308,7 @@ def _run_indexing(
index_attempt_start.connector_credential_pair.last_successful_index_time
is not None
)
ctx = RunIndexingContext(
ctx = DocExtractionContext(
index_name=index_attempt_start.search_settings.index_name,
cc_pair_id=index_attempt_start.connector_credential_pair.id,
connector_id=db_connector.id,
@@ -317,6 +333,7 @@ def _run_indexing(
and (from_beginning or not has_successful_attempt)
),
search_settings_status=index_attempt_start.search_settings.status,
doc_extraction_complete_batch_num=None,
)
last_successful_index_poll_range_end = (
@@ -384,19 +401,6 @@ def _run_indexing(
httpx_client=HttpxPool.get("vespa"),
)
indexing_pipeline = build_indexing_pipeline(
embedder=embedding_model,
information_content_classification_model=information_content_classification_model,
document_index=document_index,
ignore_time_skip=(
ctx.from_beginning
or (ctx.search_settings_status == IndexModelStatus.FUTURE)
),
db_session=db_session,
tenant_id=tenant_id,
callback=callback,
)
# Initialize memory tracer. NOTE: won't actually do anything if
# `INDEXING_TRACER_INTERVAL` is 0.
memory_tracer = MemoryTracer(interval=INDEXING_TRACER_INTERVAL)
@@ -416,7 +420,9 @@ def _run_indexing(
index_attempt: IndexAttempt | None = None
try:
with get_session_with_current_tenant() as db_session_temp:
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
index_attempt = get_index_attempt(
db_session_temp, index_attempt_id, eager_load_cc_pair=True
)
if not index_attempt:
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
@@ -439,7 +445,7 @@ def _run_indexing(
):
checkpoint = connector_runner.connector.build_dummy_checkpoint()
else:
checkpoint = get_latest_valid_checkpoint(
checkpoint, _ = get_latest_valid_checkpoint(
db_session=db_session_temp,
cc_pair_id=ctx.cc_pair_id,
search_settings_id=index_attempt.search_settings_id,
@@ -496,7 +502,10 @@ def _run_indexing(
with get_session_with_current_tenant() as db_session_temp:
# will exception if the connector/index attempt is marked as paused/failed
_check_connector_and_attempt_status(
db_session_temp, ctx, index_attempt_id
db_session_temp,
ctx.cc_pair_id,
ctx.search_settings_status,
index_attempt_id,
)
# save record of any failures at the connector level
@@ -554,7 +563,16 @@ def _run_indexing(
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
# real work happens here!
index_pipeline_result = indexing_pipeline(
index_pipeline_result = run_indexing_pipeline(
embedder=embedding_model,
information_content_classification_model=information_content_classification_model,
document_index=document_index,
ignore_time_skip=(
ctx.from_beginning
or (ctx.search_settings_status == IndexModelStatus.FUTURE)
),
db_session=db_session,
tenant_id=tenant_id,
document_batch=doc_batch_cleaned,
index_attempt_metadata=index_attempt_md,
)
@@ -815,6 +833,7 @@ def _run_indexing(
def run_indexing_entrypoint(
app: Celery,
index_attempt_id: int,
tenant_id: str,
connector_credential_pair_id: int,
@@ -832,7 +851,6 @@ def run_indexing_entrypoint(
index_attempt_id, connector_credential_pair_id
)
with get_session_with_current_tenant() as db_session:
# TODO: remove long running session entirely
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
tenant_str = ""
@@ -846,18 +864,516 @@ def run_indexing_entrypoint(
credential_id = attempt.connector_credential_pair.credential_id
logger.info(
f"Indexing starting{tenant_str}: "
f"Docfetching starting{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
with get_session_with_current_tenant() as db_session:
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
connector_document_extraction(
app,
index_attempt_id,
attempt.connector_credential_pair_id,
attempt.search_settings_id,
tenant_id,
callback,
)
logger.info(
f"Indexing finished{tenant_str}: "
f"Docfetching finished{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
def connector_document_extraction(
app: Celery,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
"""Extract documents from connector and queue them for indexing pipeline processing.
This is the first part of the split indexing process that runs the connector
and extracts documents, storing them in the filestore for later processing.
"""
start_time = time.monotonic()
logger.info(
f"Document extraction starting: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"tenant={tenant_id}"
)
# Get batch storage (transition to IN_PROGRESS is handled by run_indexing_entrypoint)
batch_storage = get_document_batch_storage(cc_pair_id, index_attempt_id)
# Initialize memory tracer. NOTE: won't actually do anything if
# `INDEXING_TRACER_INTERVAL` is 0.
memory_tracer = MemoryTracer(interval=INDEXING_TRACER_INTERVAL)
memory_tracer.start()
index_attempt = None
last_batch_num = 0 # used to continue from checkpointing
# comes from _run_indexing
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(
db_session,
index_attempt_id,
eager_load_cc_pair=True,
eager_load_search_settings=True,
)
if not index_attempt:
raise RuntimeError(f"Index attempt {index_attempt_id} not found")
if index_attempt.search_settings is None:
raise ValueError("Search settings must be set for indexing")
# Clear the indexing trigger if it was set, to prevent duplicate indexing attempts
if index_attempt.connector_credential_pair.indexing_trigger is not None:
logger.info(
"Clearing indexing trigger: "
f"cc_pair={index_attempt.connector_credential_pair.id} "
f"trigger={index_attempt.connector_credential_pair.indexing_trigger}"
)
mark_ccpair_with_indexing_trigger(
index_attempt.connector_credential_pair.id, None, db_session
)
db_connector = index_attempt.connector_credential_pair.connector
db_credential = index_attempt.connector_credential_pair.credential
is_primary = index_attempt.search_settings.status == IndexModelStatus.PRESENT
from_beginning = index_attempt.from_beginning
has_successful_attempt = (
index_attempt.connector_credential_pair.last_successful_index_time
is not None
)
earliest_index_time = (
db_connector.indexing_start.timestamp()
if db_connector.indexing_start
else 0
)
should_fetch_permissions_during_indexing = (
index_attempt.connector_credential_pair.access_type == AccessType.SYNC
and source_should_fetch_permissions_during_indexing(db_connector.source)
and is_primary
# if we've already successfully indexed, let the doc_sync job
# take care of doc-level permissions
and (from_beginning or not has_successful_attempt)
)
# Set up time windows for polling
last_successful_index_poll_range_end = (
earliest_index_time
if from_beginning
else get_last_successful_attempt_poll_range_end(
cc_pair_id=cc_pair_id,
earliest_index=earliest_index_time,
search_settings=index_attempt.search_settings,
db_session=db_session,
)
)
if last_successful_index_poll_range_end > POLL_CONNECTOR_OFFSET:
window_start = datetime.fromtimestamp(
last_successful_index_poll_range_end, tz=timezone.utc
) - timedelta(minutes=POLL_CONNECTOR_OFFSET)
else:
# don't go into "negative" time if we've never indexed before
window_start = datetime.fromtimestamp(0, tz=timezone.utc)
most_recent_attempt = next(
iter(
get_recent_completed_attempts_for_cc_pair(
cc_pair_id=cc_pair_id,
search_settings_id=index_attempt.search_settings_id,
db_session=db_session,
limit=1,
)
),
None,
)
# if the last attempt failed, try and use the same window. This is necessary
# to ensure correctness with checkpointing. If we don't do this, things like
# new slack channels could be missed (since existing slack channels are
# cached as part of the checkpoint).
if (
most_recent_attempt
and most_recent_attempt.poll_range_end
and (
most_recent_attempt.status == IndexingStatus.FAILED
or most_recent_attempt.status == IndexingStatus.CANCELED
)
):
window_end = most_recent_attempt.poll_range_end
else:
window_end = datetime.now(tz=timezone.utc)
# set time range in db
index_attempt.poll_range_start = window_start
index_attempt.poll_range_end = window_end
db_session.commit()
# TODO: maybe memory tracer here
# Set up connector runner
connector_runner = _get_connector_runner(
db_session=db_session,
attempt=index_attempt,
batch_size=INDEX_BATCH_SIZE,
start_time=window_start,
end_time=window_end,
include_permissions=should_fetch_permissions_during_indexing,
)
# don't use a checkpoint if we're explicitly indexing from
# the beginning in order to avoid weird interactions between
# checkpointing / failure handling
# OR
# if the last attempt was successful
if index_attempt.from_beginning or (
most_recent_attempt and most_recent_attempt.status.is_successful()
):
logger.info(
f"Cleaning up all old batches for index attempt {index_attempt_id} before starting new run"
)
batch_storage.cleanup_all_batches()
checkpoint = connector_runner.connector.build_dummy_checkpoint()
else:
logger.info(
f"Getting latest valid checkpoint for index attempt {index_attempt_id}"
)
checkpoint, resuming_from_checkpoint = get_latest_valid_checkpoint(
db_session=db_session,
cc_pair_id=cc_pair_id,
search_settings_id=index_attempt.search_settings_id,
window_start=window_start,
window_end=window_end,
connector=connector_runner.connector,
)
if (
isinstance(connector_runner.connector, CheckpointedConnector)
and resuming_from_checkpoint
):
reissued_batch_count, completed_batches = reissue_old_batches(
batch_storage,
index_attempt_id,
cc_pair_id,
tenant_id,
app,
most_recent_attempt,
)
last_batch_num = reissued_batch_count + completed_batches
index_attempt.completed_batches = completed_batches
db_session.commit()
else:
logger.info(
f"Cleaning up all batches for index attempt {index_attempt_id} before starting new run"
)
# for non-checkpointed connectors, throw out batches from previous unsuccessful attempts
# because we'll be getting those documents again anyways.
batch_storage.cleanup_all_batches()
# Save initial checkpoint
save_checkpoint(
db_session=db_session,
index_attempt_id=index_attempt_id,
checkpoint=checkpoint,
)
try:
batch_num = last_batch_num # starts at 0 if no last batch
total_doc_batches_queued = 0
total_failures = 0
document_count = 0
# Main extraction loop
while checkpoint.has_more:
logger.info(
f"Running '{db_connector.source.value}' connector with checkpoint: {checkpoint}"
)
for document_batch, failure, next_checkpoint in connector_runner.run(
checkpoint
):
# Check if connector is disabled mid run and stop if so unless it's the secondary
# index being built. We want to populate it even for paused connectors
# Often paused connectors are sources that aren't updated frequently but the
# contents still need to be initially pulled.
if callback:
if callback.should_stop():
raise ConnectorStopSignal("Connector stop signal detected")
# NOTE: this progress callback runs on every loop. We've seen cases
# where we loop many times with no new documents and eventually time
# out, so only doing the callback after indexing isn't sufficient.
# TODO: change to doc extraction if it doesnt break things
callback.progress("_run_indexing", 0)
# will exception if the connector/index attempt is marked as paused/failed
with get_session_with_current_tenant() as db_session_tmp:
_check_connector_and_attempt_status(
db_session_tmp,
cc_pair_id,
index_attempt.search_settings.status,
index_attempt_id,
)
# save record of any failures at the connector level
if failure is not None:
total_failures += 1
with get_session_with_current_tenant() as db_session:
create_index_attempt_error(
index_attempt_id,
cc_pair_id,
failure,
db_session,
)
_check_failure_threshold(
total_failures, document_count, batch_num, failure
)
# Save checkpoint if provided
if next_checkpoint:
checkpoint = next_checkpoint
# below is all document processing task, so if no batch we can just continue
if document_batch is None:
continue
# Clean documents and create batch
doc_batch_cleaned = strip_null_characters(document_batch)
batch_description = []
for doc in doc_batch_cleaned:
batch_description.append(doc.to_short_descriptor())
doc_size = 0
for section in doc.sections:
if (
isinstance(section, TextSection)
and section.text is not None
):
doc_size += len(section.text)
if doc_size > INDEXING_SIZE_WARNING_THRESHOLD:
logger.warning(
f"Document size: doc='{doc.to_short_descriptor()}' "
f"size={doc_size} "
f"threshold={INDEXING_SIZE_WARNING_THRESHOLD}"
)
logger.debug(f"Indexing batch of documents: {batch_description}")
memory_tracer.increment_and_maybe_trace()
# Store documents in storage
batch_storage.store_batch(batch_num, doc_batch_cleaned)
# Create processing task data
processing_batch_data = {
"index_attempt_id": index_attempt_id,
"cc_pair_id": cc_pair_id,
"tenant_id": tenant_id,
"batch_num": batch_num, # 0-indexed
}
# Queue document processing task
app.send_task(
OnyxCeleryTask.DOCPROCESSING_TASK,
kwargs=processing_batch_data,
queue=OnyxCeleryQueues.DOCPROCESSING,
priority=OnyxCeleryPriority.MEDIUM,
)
batch_num += 1
total_doc_batches_queued += 1
logger.info(
f"Queued document processing batch: "
f"batch_num={batch_num} "
f"docs={len(doc_batch_cleaned)} "
f"attempt={index_attempt_id}"
)
# Check checkpoint size periodically
CHECKPOINT_SIZE_CHECK_INTERVAL = 100
if batch_num % CHECKPOINT_SIZE_CHECK_INTERVAL == 0:
check_checkpoint_size(checkpoint)
# Save latest checkpoint
# NOTE: checkpointing is used to track which batches have
# been sent to the filestore, NOT which batches have been fully indexed
# as it used to be.
with get_session_with_current_tenant() as db_session:
save_checkpoint(
db_session=db_session,
index_attempt_id=index_attempt_id,
checkpoint=checkpoint,
)
elapsed_time = time.monotonic() - start_time
logger.info(
f"Document extraction completed: "
f"attempt={index_attempt_id} "
f"batches_queued={total_doc_batches_queued} "
f"elapsed={elapsed_time:.2f}s"
)
# Set total batches in database to signal extraction completion.
# Used by check_for_indexing to determine if the index attempt is complete.
with get_session_with_current_tenant() as db_session:
IndexingCoordination.set_total_batches(
db_session=db_session,
index_attempt_id=index_attempt_id,
total_batches=batch_num,
)
except Exception as e:
logger.exception(
f"Document extraction failed: "
f"attempt={index_attempt_id} "
f"error={str(e)}"
)
# Do NOT clean up batches on failure; future runs will use those batches
# while docfetching will continue from the saved checkpoint if one exists
if isinstance(e, ConnectorValidationError):
# On validation errors during indexing, we want to cancel the indexing attempt
# and mark the CCPair as invalid. This prevents the connector from being
# used in the future until the credentials are updated.
with get_session_with_current_tenant() as db_session_temp:
logger.exception(
f"Marking attempt {index_attempt_id} as canceled due to validation error."
)
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
reason=f"{CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX}{str(e)}",
)
if is_primary:
if not index_attempt:
# should always be set by now
raise RuntimeError("Should never happen.")
VALIDATION_ERROR_THRESHOLD = 5
recent_index_attempts = get_recent_completed_attempts_for_cc_pair(
cc_pair_id=cc_pair_id,
search_settings_id=index_attempt.search_settings_id,
limit=VALIDATION_ERROR_THRESHOLD,
db_session=db_session_temp,
)
num_validation_errors = len(
[
index_attempt
for index_attempt in recent_index_attempts
if index_attempt.error_msg
and index_attempt.error_msg.startswith(
CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX
)
]
)
if num_validation_errors >= VALIDATION_ERROR_THRESHOLD:
logger.warning(
f"Connector {db_connector.id} has {num_validation_errors} consecutive validation"
f" errors. Marking the CC Pair as invalid."
)
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=db_connector.id,
credential_id=db_credential.id,
status=ConnectorCredentialPairStatus.INVALID,
)
raise e
elif isinstance(e, ConnectorStopSignal):
with get_session_with_current_tenant() as db_session_temp:
logger.exception(
f"Marking attempt {index_attempt_id} as canceled due to stop signal."
)
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
reason=str(e),
)
else:
with get_session_with_current_tenant() as db_session_temp:
# don't overwrite attempts that are already failed/canceled for another reason
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
if index_attempt and index_attempt.status in [
IndexingStatus.CANCELED,
IndexingStatus.FAILED,
]:
logger.info(
f"Attempt {index_attempt_id} is already failed/canceled, skipping marking as failed."
)
raise e
mark_attempt_failed(
index_attempt_id,
db_session_temp,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
raise e
finally:
memory_tracer.stop()
def reissue_old_batches(
batch_storage: DocumentBatchStorage,
index_attempt_id: int,
cc_pair_id: int,
tenant_id: str,
app: Celery,
most_recent_attempt: IndexAttempt | None,
) -> tuple[int, int]:
# When loading from a checkpoint, we need to start new docprocessing tasks
# tied to the new index attempt for any batches left over in the file store
old_batches = batch_storage.get_all_batches_for_cc_pair()
batch_storage.update_old_batches_to_new_index_attempt(old_batches)
for batch_id in old_batches:
logger.info(
f"Re-issuing docprocessing task for batch {batch_id} for index attempt {index_attempt_id}"
)
path_info = batch_storage.extract_path_info(batch_id)
if path_info is None:
continue
if path_info.cc_pair_id != cc_pair_id:
raise RuntimeError(f"Batch {batch_id} is not for cc pair {cc_pair_id}")
app.send_task(
OnyxCeleryTask.DOCPROCESSING_TASK,
kwargs={
"index_attempt_id": index_attempt_id,
"cc_pair_id": cc_pair_id,
"tenant_id": tenant_id,
"batch_num": path_info.batch_num, # use same batch num as previously
},
queue=OnyxCeleryQueues.DOCPROCESSING,
priority=OnyxCeleryPriority.MEDIUM,
)
recent_batches = most_recent_attempt.completed_batches if most_recent_attempt else 0
# resume from the batch num of the last attempt. This should be one more
# than the last batch created by docfetching regardless of whether the batch
# is still in the filestore waiting for processing or not.
last_batch_num = len(old_batches) + recent_batches
logger.info(
f"Starting from batch {last_batch_num} due to "
f"re-issued batches: {old_batches}, completed batches: {recent_batches}"
)
return len(old_batches), recent_batches

View File

@@ -725,9 +725,7 @@ def stream_chat_message_objects(
)
# load all files needed for this chat chain in memory
files = load_all_chat_files(
history_msgs, new_msg_req.file_descriptors, db_session
)
files = load_all_chat_files(history_msgs, new_msg_req.file_descriptors)
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
latest_query_files = [file for file in files if file.file_id in req_file_ids]
user_file_ids = new_msg_req.user_file_ids or []

View File

@@ -311,18 +311,33 @@ except ValueError:
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT
)
CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT = 3
CELERY_WORKER_DOCPROCESSING_CONCURRENCY_DEFAULT = 6
try:
env_value = os.environ.get("CELERY_WORKER_INDEXING_CONCURRENCY")
env_value = os.environ.get("CELERY_WORKER_DOCPROCESSING_CONCURRENCY")
if not env_value:
env_value = os.environ.get("NUM_INDEXING_WORKERS")
if not env_value:
env_value = str(CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT)
CELERY_WORKER_INDEXING_CONCURRENCY = int(env_value)
env_value = str(CELERY_WORKER_DOCPROCESSING_CONCURRENCY_DEFAULT)
CELERY_WORKER_DOCPROCESSING_CONCURRENCY = int(env_value)
except ValueError:
CELERY_WORKER_INDEXING_CONCURRENCY = CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT
CELERY_WORKER_DOCPROCESSING_CONCURRENCY = (
CELERY_WORKER_DOCPROCESSING_CONCURRENCY_DEFAULT
)
CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT = 1
try:
env_value = os.environ.get("CELERY_WORKER_DOCFETCHING_CONCURRENCY")
if not env_value:
env_value = os.environ.get("NUM_DOCFETCHING_WORKERS")
if not env_value:
env_value = str(CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT)
CELERY_WORKER_DOCFETCHING_CONCURRENCY = int(env_value)
except ValueError:
CELERY_WORKER_DOCFETCHING_CONCURRENCY = (
CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT
)
CELERY_WORKER_KG_PROCESSING_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_KG_PROCESSING_CONCURRENCY") or 4

View File

@@ -65,7 +65,8 @@ POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
POSTGRES_CELERY_WORKER_INDEXING_APP_NAME = "celery_worker_indexing"
POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME = "celery_worker_docprocessing"
POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME = "celery_worker_docfetching"
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME = "celery_worker_kg_processing"
@@ -121,6 +122,8 @@ CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT = 3 * 60 * 60 # 3 hours (in seconds)
# hard termination should always fire first if the connector is hung
CELERY_INDEXING_LOCK_TIMEOUT = CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT + 900
# Heartbeat interval for indexing worker liveness detection
INDEXING_WORKER_HEARTBEAT_INTERVAL = 30 # seconds
# how long a task should wait for associated fence to be ready
CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT = 5 * 60 # 5 min
@@ -331,9 +334,12 @@ class OnyxCeleryQueues:
CSV_GENERATION = "csv_generation"
# Indexing queue
CONNECTOR_INDEXING = "connector_indexing"
USER_FILES_INDEXING = "user_files_indexing"
# Document processing pipeline queue
DOCPROCESSING = "docprocessing"
CONNECTOR_DOC_FETCHING = "connector_doc_fetching"
# Monitoring queue
MONITORING = "monitoring"
@@ -464,7 +470,11 @@ class OnyxCeleryTask:
CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK = (
"connector_external_group_sync_generator_task"
)
CONNECTOR_INDEXING_PROXY_TASK = "connector_indexing_proxy_task"
# New split indexing tasks
CONNECTOR_DOC_FETCHING_TASK = "connector_doc_fetching_task"
DOCPROCESSING_TASK = "docprocessing_task"
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"

View File

@@ -34,7 +34,6 @@ from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext
@@ -281,30 +280,28 @@ class BlobStorageConnector(LoadConnector, PollConnector):
# TODO: Refactor to avoid direct DB access in connector
# This will require broader refactoring across the codebase
with get_session_with_current_tenant() as db_session:
image_section, _ = store_image_and_create_section(
db_session=db_session,
image_data=downloaded_file,
file_id=f"{self.bucket_type}_{self.bucket_name}_{key.replace('/', '_')}",
display_name=file_name,
link=link,
file_origin=FileOrigin.CONNECTOR,
)
image_section, _ = store_image_and_create_section(
image_data=downloaded_file,
file_id=f"{self.bucket_type}_{self.bucket_name}_{key.replace('/', '_')}",
display_name=file_name,
link=link,
file_origin=FileOrigin.CONNECTOR,
)
batch.append(
Document(
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
sections=[image_section],
source=DocumentSource(self.bucket_type.value),
semantic_identifier=file_name,
doc_updated_at=last_modified,
metadata={},
)
batch.append(
Document(
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
sections=[image_section],
source=DocumentSource(self.bucket_type.value),
semantic_identifier=file_name,
doc_updated_at=last_modified,
metadata={},
)
)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) == self.batch_size:
yield batch
batch = []
except Exception:
logger.exception(f"Error processing image {key}")
continue

View File

@@ -23,7 +23,6 @@ from onyx.configs.app_configs import (
)
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.configs.constants import FileOrigin
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import is_accepted_file_ext
from onyx.file_processing.extract_file_text import OnyxExtensionType
@@ -224,19 +223,17 @@ def _process_image_attachment(
"""Process an image attachment by saving it without generating a summary."""
try:
# Use the standardized image storage and section creation
with get_session_with_current_tenant() as db_session:
section, file_name = store_image_and_create_section(
db_session=db_session,
image_data=raw_bytes,
file_id=Path(attachment["id"]).name,
display_name=attachment["title"],
media_type=media_type,
file_origin=FileOrigin.CONNECTOR,
)
logger.info(f"Stored image attachment with file name: {file_name}")
section, file_name = store_image_and_create_section(
image_data=raw_bytes,
file_id=Path(attachment["id"]).name,
display_name=attachment["title"],
media_type=media_type,
file_origin=FileOrigin.CONNECTOR,
)
logger.info(f"Stored image attachment with file name: {file_name}")
# Return empty text but include the file_name for later processing
return AttachmentProcessingResult(text="", file_name=file_name, error=None)
# Return empty text but include the file_name for later processing
return AttachmentProcessingResult(text="", file_name=file_name, error=None)
except Exception as e:
msg = f"Image storage failed for {attachment['title']}: {e}"
logger.error(msg, exc_info=e)

View File

@@ -5,8 +5,6 @@ from pathlib import Path
from typing import Any
from typing import IO
from sqlalchemy.orm import Session
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
@@ -18,7 +16,6 @@ from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext
@@ -32,7 +29,6 @@ logger = setup_logger()
def _create_image_section(
image_data: bytes,
db_session: Session,
parent_file_name: str,
display_name: str,
link: str | None = None,
@@ -58,7 +54,6 @@ def _create_image_section(
# Store the image and create a section
try:
section, stored_file_name = store_image_and_create_section(
db_session=db_session,
image_data=image_data,
file_id=file_id,
display_name=display_name,
@@ -77,7 +72,6 @@ def _process_file(
file: IO[Any],
metadata: dict[str, Any] | None,
pdf_pass: str | None,
db_session: Session,
) -> list[Document]:
"""
Process a file and return a list of Documents.
@@ -125,7 +119,6 @@ def _process_file(
try:
section, _ = _create_image_section(
image_data=image_data,
db_session=db_session,
parent_file_name=file_id,
display_name=title,
)
@@ -196,7 +189,6 @@ def _process_file(
try:
image_section, stored_file_name = _create_image_section(
image_data=img_data,
db_session=db_session,
parent_file_name=file_id,
display_name=f"{title} - image {idx}",
idx=idx,
@@ -260,37 +252,33 @@ class LocalFileConnector(LoadConnector):
"""
documents: list[Document] = []
with get_session_with_current_tenant() as db_session:
for file_id in self.file_locations:
file_store = get_default_file_store(db_session)
file_record = file_store.read_file_record(file_id=file_id)
if not file_record:
# typically an unsupported extension
logger.warning(
f"No file record found for '{file_id}' in PG; skipping."
)
continue
for file_id in self.file_locations:
file_store = get_default_file_store()
file_record = file_store.read_file_record(file_id=file_id)
if not file_record:
# typically an unsupported extension
logger.warning(f"No file record found for '{file_id}' in PG; skipping.")
continue
metadata = self._get_file_metadata(file_id)
file_io = file_store.read_file(file_id=file_id, mode="b")
new_docs = _process_file(
file_id=file_id,
file_name=file_record.display_name,
file=file_io,
metadata=metadata,
pdf_pass=self.pdf_pass,
db_session=db_session,
)
documents.extend(new_docs)
metadata = self._get_file_metadata(file_id)
file_io = file_store.read_file(file_id=file_id, mode="b")
new_docs = _process_file(
file_id=file_id,
file_name=file_record.display_name,
file=file_io,
metadata=metadata,
pdf_pass=self.pdf_pass,
)
documents.extend(new_docs)
if len(documents) >= self.batch_size:
yield documents
documents = []
if documents:
if len(documents) >= self.batch_size:
yield documents
documents = []
if documents:
yield documents
if __name__ == "__main__":
connector = LocalFileConnector(

View File

@@ -1,6 +1,10 @@
import copy
import json
import os
import sys
import threading
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from datetime import datetime
from enum import Enum
@@ -1374,3 +1378,139 @@ class GoogleDriveConnector(
@override
def validate_checkpoint_json(self, checkpoint_json: str) -> GoogleDriveCheckpoint:
return GoogleDriveCheckpoint.model_validate_json(checkpoint_json)
def get_credentials_from_env(email: str, oauth: bool) -> dict:
if oauth:
raw_credential_string = os.environ["GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR"]
else:
raw_credential_string = os.environ["GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR"]
refried_credential_string = json.dumps(json.loads(raw_credential_string))
# This is the Oauth token
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
# This is the service account key
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
# The email saved for both auth types
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method"
cred_key = (
DB_CREDENTIALS_DICT_TOKEN_KEY
if oauth
else DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
)
return {
cred_key: refried_credential_string,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email,
DB_CREDENTIALS_AUTHENTICATION_METHOD: "uploaded",
}
class CheckpointOutputWrapper:
"""
Wraps a CheckpointOutput generator to give things back in a more digestible format.
The connector format is easier for the connector implementor (e.g. it enforces exactly
one new checkpoint is returned AND that the checkpoint is at the end), thus the different
formats.
"""
def __init__(self) -> None:
self.next_checkpoint: GoogleDriveCheckpoint | None = None
def __call__(
self,
checkpoint_connector_generator: CheckpointOutput[GoogleDriveCheckpoint],
) -> Generator[
tuple[Document | None, ConnectorFailure | None, GoogleDriveCheckpoint | None],
None,
None,
]:
# grabs the final return value and stores it in the `next_checkpoint` variable
def _inner_wrapper(
checkpoint_connector_generator: CheckpointOutput[GoogleDriveCheckpoint],
) -> CheckpointOutput[GoogleDriveCheckpoint]:
self.next_checkpoint = yield from checkpoint_connector_generator
return self.next_checkpoint # not used
for document_or_failure in _inner_wrapper(checkpoint_connector_generator):
if isinstance(document_or_failure, Document):
yield document_or_failure, None, None
elif isinstance(document_or_failure, ConnectorFailure):
yield None, document_or_failure, None
else:
raise ValueError(
f"Invalid document_or_failure type: {type(document_or_failure)}"
)
if self.next_checkpoint is None:
raise RuntimeError(
"Checkpoint is None. This should never happen - the connector should always return a checkpoint."
)
yield None, None, self.next_checkpoint
def yield_all_docs_from_checkpoint_connector(
connector: GoogleDriveConnector,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
) -> Iterator[Document | ConnectorFailure]:
num_iterations = 0
checkpoint = connector.build_dummy_checkpoint()
while checkpoint.has_more:
doc_batch_generator = CheckpointOutputWrapper()(
connector.load_from_checkpoint(start, end, checkpoint)
)
for document, failure, next_checkpoint in doc_batch_generator:
if failure is not None:
yield failure
if document is not None:
yield document
if next_checkpoint is not None:
checkpoint = next_checkpoint
num_iterations += 1
if num_iterations > 100_000:
raise RuntimeError("Too many iterations. Infinite loop?")
if __name__ == "__main__":
import time
creds = get_credentials_from_env(
os.environ["GOOGLE_DRIVE_PRIMARY_ADMIN_EMAIL"], False
)
connector = GoogleDriveConnector(
include_shared_drives=True,
shared_drive_urls=None,
include_my_drives=True,
my_drive_emails=None,
shared_folder_urls=None,
include_files_shared_with_me=True,
specific_user_emails=None,
)
connector.load_credentials(creds)
max_fsize = 0
biggest_fsize = 0
num_errors = 0
start_time = time.time()
with open("stats.txt", "w") as f:
for num, doc_or_failure in enumerate(
yield_all_docs_from_checkpoint_connector(connector, 0, time.time())
):
if num % 200 == 0:
f.write(f"Processed {num} files\n")
f.write(f"Max file size: {max_fsize/1000_000:.2f} MB\n")
f.write(f"Time so far: {time.time() - start_time:.2f} seconds\n")
f.write(f"Docs per minute: {num/(time.time() - start_time)*60:.2f}\n")
biggest_fsize = max(biggest_fsize, max_fsize)
max_fsize = 0
if isinstance(doc_or_failure, Document):
max_fsize = max(max_fsize, sys.getsizeof(doc_or_failure))
elif isinstance(doc_or_failure, ConnectorFailure):
num_errors += 1
print(f"Num errors: {num_errors}")
print(f"Biggest file size: {biggest_fsize/1000_000:.2f} MB")
print(f"Time taken: {time.time() - start_time:.2f} seconds")

View File

@@ -29,7 +29,6 @@ from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import ImageSection
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import ALL_ACCEPTED_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import docx_to_text_and_images
from onyx.file_processing.extract_file_text import extract_file_text
@@ -143,17 +142,15 @@ def _download_and_extract_sections_basic(
# Store images for later processing
sections: list[TextSection | ImageSection] = []
try:
with get_session_with_current_tenant() as db_session:
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=response_call(),
file_id=file_id,
display_name=file_name,
media_type=mime_type,
file_origin=FileOrigin.CONNECTOR,
link=link,
)
sections.append(section)
section, embedded_id = store_image_and_create_section(
image_data=response_call(),
file_id=file_id,
display_name=file_name,
media_type=mime_type,
file_origin=FileOrigin.CONNECTOR,
link=link,
)
sections.append(section)
except Exception as e:
logger.error(f"Failed to process image {file_name}: {e}")
return sections
@@ -216,16 +213,14 @@ def _download_and_extract_sections_basic(
# Process embedded images in the PDF
try:
with get_session_with_current_tenant() as db_session:
for idx, (img_data, img_name) in enumerate(images):
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=img_data,
file_id=f"{file_id}_img_{idx}",
display_name=img_name or f"{file_name} - image {idx}",
file_origin=FileOrigin.CONNECTOR,
)
pdf_sections.append(section)
for idx, (img_data, img_name) in enumerate(images):
section, embedded_id = store_image_and_create_section(
image_data=img_data,
file_id=f"{file_id}_img_{idx}",
display_name=img_name or f"{file_name} - image {idx}",
file_origin=FileOrigin.CONNECTOR,
)
pdf_sections.append(section)
except Exception as e:
logger.error(f"Failed to process PDF images in {file_name}: {e}")
return pdf_sections

View File

@@ -12,7 +12,6 @@ from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import load_files_from_zip
from onyx.file_processing.extract_file_text import read_text_file
from onyx.file_processing.html_utils import web_html_cleanup
@@ -68,10 +67,7 @@ class GoogleSitesConnector(LoadConnector):
def load_from_state(self) -> GenerateDocumentsOutput:
documents: list[Document] = []
with get_session_with_current_tenant() as db_session:
file_content_io = get_default_file_store(db_session).read_file(
self.zip_path, mode="b"
)
file_content_io = get_default_file_store().read_file(self.zip_path, mode="b")
# load the HTML files
files = load_files_from_zip(file_content_io)

View File

@@ -11,6 +11,7 @@ from onyx.access.models import ExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import INDEX_SEPARATOR
from onyx.configs.constants import RETURN_SEPARATOR
from onyx.db.enums import IndexModelStatus
from onyx.utils.text_processing import make_url_compatible
@@ -363,6 +364,10 @@ class ConnectorFailure(BaseModel):
return values
class ConnectorStopSignal(Exception):
"""A custom exception used to signal a stop in processing."""
class OnyxMetadata(BaseModel):
# Note that doc_id cannot be overriden here as it may cause issues
# with the display functionalities in the UI. Ask @chris if clarification is needed.
@@ -373,3 +378,24 @@ class OnyxMetadata(BaseModel):
secondary_owners: list[BasicExpertInfo] | None = None
doc_updated_at: datetime | None = None
title: str | None = None
class DocExtractionContext(BaseModel):
index_name: str
cc_pair_id: int
connector_id: int
credential_id: int
source: DocumentSource
earliest_index_time: float
from_beginning: bool
is_primary: bool
should_fetch_permissions_during_indexing: bool
search_settings_status: IndexModelStatus
doc_extraction_complete_batch_num: int | None
class DocIndexingContext(BaseModel):
batches_done: int
total_failures: int
net_doc_change: int
total_chunks: int

View File

@@ -267,7 +267,7 @@ class NotionConnector(LoadConnector, PollConnector):
result = ""
for prop_name, prop in properties.items():
if not prop:
if not prop or not isinstance(prop, dict):
continue
try:

View File

@@ -992,7 +992,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
doc_metadata_list: list[SlimDocument] = []
for parent_object_type in self.parent_object_list:
query = f"SELECT Id FROM {parent_object_type}"
query_result = self.sf_client.query_all(query)
query_result = self.sf_client.safe_query_all(query)
doc_metadata_list.extend(
SlimDocument(
id=f"{ID_PREFIX}{instance_dict.get('Id', '')}",

View File

@@ -1,18 +1,31 @@
import time
from typing import Any
from simple_salesforce import Salesforce
from simple_salesforce import SFType
from simple_salesforce.exceptions import SalesforceRefusedRequest
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from onyx.connectors.salesforce.blacklist import SALESFORCE_BLACKLISTED_OBJECTS
from onyx.connectors.salesforce.blacklist import SALESFORCE_BLACKLISTED_PREFIXES
from onyx.connectors.salesforce.blacklist import SALESFORCE_BLACKLISTED_SUFFIXES
from onyx.connectors.salesforce.salesforce_calls import get_object_by_id_query
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
logger = setup_logger()
def is_salesforce_rate_limit_error(exception: Exception) -> bool:
"""Check if an exception is a Salesforce rate limit error."""
return isinstance(
exception, SalesforceRefusedRequest
) and "REQUEST_LIMIT_EXCEEDED" in str(exception)
class OnyxSalesforce(Salesforce):
SOQL_MAX_SUBQUERIES = 20
@@ -52,6 +65,48 @@ class OnyxSalesforce(Salesforce):
return False
@retry_builder(
tries=5,
delay=20,
backoff=1.5,
max_delay=60,
exceptions=(SalesforceRefusedRequest,),
)
@rate_limit_builder(max_calls=50, period=60)
def safe_query(self, query: str, **kwargs: Any) -> dict[str, Any]:
"""Wrapper around the original query method with retry logic and rate limiting."""
try:
return super().query(query, **kwargs)
except SalesforceRefusedRequest as e:
if is_salesforce_rate_limit_error(e):
logger.warning(
f"Salesforce rate limit exceeded for query: {query[:100]}..."
)
# Add additional delay for rate limit errors
time.sleep(5)
raise
@retry_builder(
tries=5,
delay=20,
backoff=1.5,
max_delay=60,
exceptions=(SalesforceRefusedRequest,),
)
@rate_limit_builder(max_calls=50, period=60)
def safe_query_all(self, query: str, **kwargs: Any) -> dict[str, Any]:
"""Wrapper around the original query_all method with retry logic and rate limiting."""
try:
return super().query_all(query, **kwargs)
except SalesforceRefusedRequest as e:
if is_salesforce_rate_limit_error(e):
logger.warning(
f"Salesforce rate limit exceeded for query_all: {query[:100]}..."
)
# Add additional delay for rate limit errors
time.sleep(5)
raise
@staticmethod
def _make_child_objects_by_id_query(
object_id: str,
@@ -99,7 +154,7 @@ class OnyxSalesforce(Salesforce):
queryable_fields = type_to_queryable_fields[object_type]
query = get_object_by_id_query(object_id, object_type, queryable_fields)
result = self.query(query)
result = self.safe_query(query)
if not result:
return None
@@ -151,7 +206,7 @@ class OnyxSalesforce(Salesforce):
)
try:
result = self.query(query)
result = self.safe_query(query)
except Exception:
logger.exception(f"Query failed: {query=}")
else:
@@ -189,10 +244,25 @@ class OnyxSalesforce(Salesforce):
return child_records
@retry_builder(
tries=3,
delay=1,
backoff=2,
exceptions=(SalesforceRefusedRequest,),
)
def describe_type(self, name: str) -> Any:
sf_object = SFType(name, self.session_id, self.sf_instance)
result = sf_object.describe()
return result
try:
result = sf_object.describe()
return result
except SalesforceRefusedRequest as e:
if is_salesforce_rate_limit_error(e):
logger.warning(
f"Salesforce rate limit exceeded for describe_type: {name}"
)
# Add additional delay for rate limit errors
time.sleep(3)
raise
def get_queryable_fields_by_type(self, name: str) -> list[str]:
object_description = self.describe_type(name)

View File

@@ -1,5 +1,6 @@
import gc
import os
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
@@ -7,13 +8,25 @@ from pytz import UTC
from simple_salesforce import Salesforce
from simple_salesforce.bulk2 import SFBulk2Handler
from simple_salesforce.bulk2 import SFBulk2Type
from simple_salesforce.exceptions import SalesforceRefusedRequest
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
logger = setup_logger()
def is_salesforce_rate_limit_error(exception: Exception) -> bool:
"""Check if an exception is a Salesforce rate limit error."""
return isinstance(
exception, SalesforceRefusedRequest
) and "REQUEST_LIMIT_EXCEEDED" in str(exception)
def _build_last_modified_time_filter_for_salesforce(
start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
) -> str:
@@ -71,6 +84,14 @@ def get_object_by_id_query(
return query
@retry_builder(
tries=5,
delay=2,
backoff=2,
max_delay=60,
exceptions=(SalesforceRefusedRequest,),
)
@rate_limit_builder(max_calls=50, period=60)
def _object_type_has_api_data(
sf_client: Salesforce, sf_type: str, time_filter: str
) -> bool:
@@ -82,6 +103,15 @@ def _object_type_has_api_data(
result = sf_client.query(query)
if result["totalSize"] == 0:
return False
except SalesforceRefusedRequest as e:
if is_salesforce_rate_limit_error(e):
logger.warning(
f"Salesforce rate limit exceeded for object type check: {sf_type}"
)
# Add additional delay for rate limit errors
time.sleep(3)
raise
except Exception as e:
if "OPERATION_TOO_LARGE" not in str(e):
logger.warning(f"Object type {sf_type} doesn't support query: {e}")

View File

@@ -25,7 +25,6 @@ from onyx.context.search.models import MAX_METRICS_CONTENT
from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import RerankMetricsContainer
from onyx.context.search.models import SearchQuery
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.document_index.document_index_utils import (
translate_boost_count_to_multiplier,
)
@@ -70,18 +69,16 @@ def update_image_sections_with_query(
logger.debug(
f"Processing image chunk with ID: {chunk.unique_id}, image: {chunk.image_file_id}"
)
with get_session_with_current_tenant() as db_session:
file_record = get_default_file_store(db_session).read_file(
cast(str, chunk.image_file_id), mode="b"
)
if not file_record:
logger.error(f"Image file not found: {chunk.image_file_id}")
raise Exception("File not found")
file_content = file_record.read()
image_base64 = base64.b64encode(file_content).decode()
logger.debug(
f"Successfully loaded image data for {chunk.image_file_id}"
)
file_record = get_default_file_store().read_file(
cast(str, chunk.image_file_id), mode="b"
)
if not file_record:
logger.error(f"Image file not found: {chunk.image_file_id}")
raise Exception("File not found")
file_content = file_record.read()
image_base64 = base64.b64encode(file_content).decode()
logger.debug(f"Successfully loaded image data for {chunk.image_file_id}")
messages: list[BaseMessage] = [
SystemMessage(content=IMAGE_ANALYSIS_SYSTEM_PROMPT),

View File

@@ -229,7 +229,7 @@ def delete_messages_and_files_from_chat_session(
delete_tool_call_for_message_id(message_id=id, db_session=db_session)
delete_search_doc_message_relationship(message_id=id, db_session=db_session)
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
for file_info in files or []:
file_store.delete_file(file_id=file_info.get("id"))

View File

@@ -264,6 +264,7 @@ def get_connector_credential_pair_from_id_for_user(
def get_connector_credential_pair_from_id(
db_session: Session,
cc_pair_id: int,
eager_load_connector: bool = False,
eager_load_credential: bool = False,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair).distinct()
@@ -271,6 +272,8 @@ def get_connector_credential_pair_from_id(
if eager_load_credential:
stmt = stmt.options(joinedload(ConnectorCredentialPair.credential))
if eager_load_connector:
stmt = stmt.options(joinedload(ConnectorCredentialPair.connector))
result = db_session.execute(stmt)
return result.scalar_one_or_none()

View File

@@ -849,7 +849,9 @@ def fetch_chunk_counts_for_documents(
# Create a dictionary of document_id to chunk_count
chunk_counts = {str(row.id): row.chunk_count or 0 for row in results}
# Return a list of tuples, using 0 for documents not found in the database
# Return a list of tuples, preserving `None` for documents not found or with
# an unknown chunk count. Callers should handle the `None` case and fall
# back to an existence check against the vector DB if necessary.
return [(doc_id, chunk_counts.get(doc_id, 0)) for doc_id in document_ids]

View File

@@ -305,6 +305,18 @@ def get_session_with_current_tenant() -> Generator[Session, None, None]:
yield session
@contextmanager
def get_session_with_current_tenant_if_none(
session: Session | None,
) -> Generator[Session, None, None]:
if session is None:
tenant_id = get_current_tenant_id()
with get_session_with_tenant(tenant_id=tenant_id) as session:
yield session
else:
yield session
# Used in multi tenant mode when need to refer to the shared `public` schema
@contextmanager
def get_session_with_shared_schema() -> Generator[Session, None, None]:

View File

@@ -43,6 +43,17 @@ def get_filerecord_by_file_id(
return filestore
def get_filerecord_by_prefix(
prefix: str,
db_session: Session,
) -> list[FileRecord]:
if not prefix:
return db_session.query(FileRecord).all()
return (
db_session.query(FileRecord).filter(FileRecord.file_id.like(f"{prefix}%")).all()
)
def delete_filerecord_by_file_id(
file_id: str,
db_session: Session,

View File

@@ -28,6 +28,8 @@ from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
# from sqlalchemy.sql.selectable import Select
# Comment out unused imports that cause mypy errors
# from onyx.auth.models import UserRole
# from onyx.configs.constants import MAX_LAST_VALID_CHECKPOINT_AGE_SECONDS
@@ -95,23 +97,52 @@ def get_recent_attempts_for_cc_pair(
def get_index_attempt(
db_session: Session, index_attempt_id: int
db_session: Session,
index_attempt_id: int,
eager_load_cc_pair: bool = False,
eager_load_search_settings: bool = False,
) -> IndexAttempt | None:
stmt = select(IndexAttempt).where(IndexAttempt.id == index_attempt_id)
if eager_load_cc_pair:
stmt = stmt.options(
joinedload(IndexAttempt.connector_credential_pair).joinedload(
ConnectorCredentialPair.connector
)
)
stmt = stmt.options(
joinedload(IndexAttempt.connector_credential_pair).joinedload(
ConnectorCredentialPair.credential
)
)
if eager_load_search_settings:
stmt = stmt.options(joinedload(IndexAttempt.search_settings))
return db_session.scalars(stmt).first()
def count_error_rows_for_index_attempt(
index_attempt_id: int,
db_session: Session,
) -> int:
return (
db_session.query(IndexAttemptError)
.filter(IndexAttemptError.index_attempt_id == index_attempt_id)
.count()
)
def create_index_attempt(
connector_credential_pair_id: int,
search_settings_id: int,
db_session: Session,
from_beginning: bool = False,
celery_task_id: str | None = None,
) -> int:
new_attempt = IndexAttempt(
connector_credential_pair_id=connector_credential_pair_id,
search_settings_id=search_settings_id,
from_beginning=from_beginning,
status=IndexingStatus.NOT_STARTED,
celery_task_id=celery_task_id,
)
db_session.add(new_attempt)
db_session.commit()
@@ -247,7 +278,7 @@ def mark_attempt_in_progress(
def mark_attempt_succeeded(
index_attempt_id: int,
db_session: Session,
) -> None:
) -> IndexAttempt:
try:
attempt = db_session.execute(
select(IndexAttempt)
@@ -256,6 +287,7 @@ def mark_attempt_succeeded(
).scalar_one()
attempt.status = IndexingStatus.SUCCESS
attempt.celery_task_id = None
db_session.commit()
# Add telemetry for index attempt status change
@@ -267,6 +299,7 @@ def mark_attempt_succeeded(
"cc_pair_id": attempt.connector_credential_pair_id,
},
)
return attempt
except Exception:
db_session.rollback()
raise
@@ -275,7 +308,7 @@ def mark_attempt_succeeded(
def mark_attempt_partially_succeeded(
index_attempt_id: int,
db_session: Session,
) -> None:
) -> IndexAttempt:
try:
attempt = db_session.execute(
select(IndexAttempt)
@@ -284,6 +317,7 @@ def mark_attempt_partially_succeeded(
).scalar_one()
attempt.status = IndexingStatus.COMPLETED_WITH_ERRORS
attempt.celery_task_id = None
db_session.commit()
# Add telemetry for index attempt status change
@@ -295,6 +329,7 @@ def mark_attempt_partially_succeeded(
"cc_pair_id": attempt.connector_credential_pair_id,
},
)
return attempt
except Exception:
db_session.rollback()
raise
@@ -350,6 +385,7 @@ def mark_attempt_failed(
attempt.status = IndexingStatus.FAILED
attempt.error_msg = failure_reason
attempt.full_exception_trace = full_exception_trace
attempt.celery_task_id = None
db_session.commit()
# Add telemetry for index attempt status change
@@ -373,16 +409,22 @@ def update_docs_indexed(
new_docs_indexed: int,
docs_removed_from_index: int,
) -> None:
"""Updates the docs_indexed and new_docs_indexed fields of an index attempt.
Adds the given values to the current values in the db"""
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt_id)
.with_for_update()
.with_for_update() # Locks the row when we try to update
).scalar_one()
attempt.total_docs_indexed = total_docs_indexed
attempt.new_docs_indexed = new_docs_indexed
attempt.docs_removed_from_index = docs_removed_from_index
attempt.total_docs_indexed = (
attempt.total_docs_indexed or 0
) + total_docs_indexed
attempt.new_docs_indexed = (attempt.new_docs_indexed or 0) + new_docs_indexed
attempt.docs_removed_from_index = (
attempt.docs_removed_from_index or 0
) + docs_removed_from_index
db_session.commit()
except Exception:
db_session.rollback()

View File

@@ -0,0 +1,307 @@
"""Database-based indexing coordination to replace Redis fencing."""
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from onyx.db.engine.time_utils import get_db_current_time
from onyx.db.enums import IndexingStatus
from onyx.db.index_attempt import count_error_rows_for_index_attempt
from onyx.db.index_attempt import create_index_attempt
from onyx.db.index_attempt import get_index_attempt
from onyx.db.models import IndexAttempt
from onyx.utils.logger import setup_logger
logger = setup_logger()
INDEXING_PROGRESS_TIMEOUT_HOURS = 6
class CoordinationStatus(BaseModel):
"""Status of an indexing attempt's coordination."""
found: bool
total_batches: int | None
completed_batches: int
total_failures: int
total_docs: int
total_chunks: int
status: IndexingStatus | None = None
cancellation_requested: bool = False
class IndexingCoordination:
"""Database-based coordination for indexing tasks, replacing Redis fencing."""
@staticmethod
def try_create_index_attempt(
db_session: Session,
cc_pair_id: int,
search_settings_id: int,
celery_task_id: str,
from_beginning: bool = False,
) -> int | None:
"""
Try to create a new index attempt for the given CC pair and search settings.
Returns the index_attempt_id if successful, None if another attempt is already running.
This replaces the Redis fencing mechanism by using database constraints
and transactions to prevent duplicate attempts.
"""
try:
# Check for existing active attempts (this is the "fence" check)
existing_attempt = db_session.execute(
select(IndexAttempt)
.where(
IndexAttempt.connector_credential_pair_id == cc_pair_id,
IndexAttempt.search_settings_id == search_settings_id,
IndexAttempt.status.in_(
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
),
)
.with_for_update(nowait=True)
).first()
if existing_attempt:
logger.info(
f"Indexing already in progress: "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"existing_attempt={existing_attempt[0].id}"
)
return None
# Create new index attempt (this is setting the "fence")
attempt_id = create_index_attempt(
connector_credential_pair_id=cc_pair_id,
search_settings_id=search_settings_id,
from_beginning=from_beginning,
db_session=db_session,
celery_task_id=celery_task_id,
)
logger.info(
f"Created Index Attempt: "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"attempt_id={attempt_id} "
f"celery_task_id={celery_task_id}"
)
return attempt_id
except SQLAlchemyError as e:
logger.info(
f"Failed to create index attempt (likely race condition): "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"error={str(e)}"
)
db_session.rollback()
return None
@staticmethod
def check_cancellation_requested(
db_session: Session,
index_attempt_id: int,
) -> bool:
"""
Check if cancellation has been requested for this indexing attempt.
This replaces Redis termination signals.
"""
attempt = get_index_attempt(db_session, index_attempt_id)
return attempt.cancellation_requested if attempt else False
@staticmethod
def request_cancellation(
db_session: Session,
index_attempt_id: int,
) -> None:
"""
Request cancellation of an indexing attempt.
This replaces Redis termination signals.
"""
attempt = get_index_attempt(db_session, index_attempt_id)
if attempt:
attempt.cancellation_requested = True
db_session.commit()
logger.info(f"Requested cancellation for attempt {index_attempt_id}")
@staticmethod
def set_total_batches(
db_session: Session,
index_attempt_id: int,
total_batches: int,
) -> None:
"""
Set the total number of batches for this indexing attempt.
Called by docfetching when extraction is complete.
"""
attempt = get_index_attempt(db_session, index_attempt_id)
if attempt:
attempt.total_batches = total_batches
db_session.commit()
logger.info(
f"Set total batches: attempt={index_attempt_id} total={total_batches}"
)
@staticmethod
def update_batch_completion_and_docs(
db_session: Session,
index_attempt_id: int,
total_docs_indexed: int,
new_docs_indexed: int,
total_chunks: int,
) -> tuple[int, int | None]:
"""
Update batch completion and document counts atomically.
Returns (completed_batches, total_batches).
This extends the existing update_docs_indexed pattern.
"""
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt_id)
.with_for_update() # Same pattern as existing update_docs_indexed
).scalar_one()
# Existing document count updates
attempt.total_docs_indexed = (
attempt.total_docs_indexed or 0
) + total_docs_indexed
attempt.new_docs_indexed = (
attempt.new_docs_indexed or 0
) + new_docs_indexed
# New coordination updates
attempt.completed_batches = (attempt.completed_batches or 0) + 1
attempt.total_chunks = (attempt.total_chunks or 0) + total_chunks
db_session.commit()
logger.info(
f"Updated batch completion: "
f"attempt={index_attempt_id} "
f"completed={attempt.completed_batches} "
f"total={attempt.total_batches} "
f"docs={total_docs_indexed} "
)
return attempt.completed_batches, attempt.total_batches
except Exception:
db_session.rollback()
logger.exception(
f"Failed to update batch completion for attempt {index_attempt_id}"
)
raise
@staticmethod
def get_coordination_status(
db_session: Session,
index_attempt_id: int,
) -> CoordinationStatus:
"""
Get the current coordination status for an indexing attempt.
This replaces reading FileStore state files.
"""
attempt = get_index_attempt(db_session, index_attempt_id)
if not attempt:
return CoordinationStatus(
found=False,
total_batches=None,
completed_batches=0,
total_failures=0,
total_docs=0,
total_chunks=0,
status=None,
cancellation_requested=False,
)
return CoordinationStatus(
found=True,
total_batches=attempt.total_batches,
completed_batches=attempt.completed_batches,
total_failures=count_error_rows_for_index_attempt(
index_attempt_id, db_session
),
total_docs=attempt.total_docs_indexed or 0,
total_chunks=attempt.total_chunks,
status=attempt.status,
cancellation_requested=attempt.cancellation_requested,
)
@staticmethod
def get_orphaned_index_attempt_ids(db_session: Session) -> list[int]:
"""
Gets a list of potentially orphaned index attempts.
These are attempts in non-terminal state that have task IDs but may have died.
This replaces the old get_unfenced_index_attempt_ids function.
The actual orphan detection requires checking with Celery, which should be
done by the caller.
"""
# Find attempts that are active and have task IDs
# The caller needs to check each one with Celery to confirm orphaned status
active_attempts = (
db_session.execute(
select(IndexAttempt).where(
IndexAttempt.status.in_(
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
),
IndexAttempt.celery_task_id.isnot(None),
)
)
.scalars()
.all()
)
return [attempt.id for attempt in active_attempts]
@staticmethod
def update_progress_tracking(
db_session: Session,
index_attempt_id: int,
current_batches_completed: int,
timeout_hours: int = INDEXING_PROGRESS_TIMEOUT_HOURS,
) -> bool:
"""
Update progress tracking for stall detection.
Returns True if sufficient progress was made, False if stalled.
"""
attempt = get_index_attempt(db_session, index_attempt_id)
if not attempt:
logger.error(f"Index attempt {index_attempt_id} not found in database")
return False
current_time = get_db_current_time(db_session)
# No progress - check if this is the first time tracking
if attempt.last_progress_time is None:
# First time tracking - initialize
attempt.last_progress_time = current_time
attempt.last_batches_completed_count = current_batches_completed
db_session.commit()
return True
time_elapsed = (current_time - attempt.last_progress_time).total_seconds()
# only actually write to db every timeout_hours/2
# this ensure thats at most timeout_hours will pass with no activity
if time_elapsed < timeout_hours * 1800:
return True
# Check if progress has been made
if current_batches_completed <= attempt.last_batches_completed_count:
# if between timeout_hours/2 and timeout_hours has passed
# without an update, we consider the attempt stalled
return False
# Progress made - update tracking
attempt.last_progress_time = current_time
attempt.last_batches_completed_count = current_batches_completed
db_session.commit()
return True

View File

@@ -1612,9 +1612,7 @@ class SearchSettings(Base):
@property
def final_embedding_dim(self) -> int:
if self.reduced_dimension:
return self.reduced_dimension
return self.model_dim
return self.reduced_dimension or self.model_dim
@staticmethod
def can_use_large_chunks(
@@ -1635,7 +1633,7 @@ class SearchSettings(Base):
class IndexAttempt(Base):
"""
Represents an attempt to index a group of 1 or more documents from a
Represents an attempt to index a group of 0 or more documents from a
source. For example, a single pull from Google Drive, a single event from
slack event API, or a single website crawl.
"""
@@ -1683,6 +1681,30 @@ class IndexAttempt(Base):
# can be taken to the FileStore to grab the actual checkpoint value
checkpoint_pointer: Mapped[str | None] = mapped_column(String, nullable=True)
# NEW: Database-based coordination fields (replacing Redis fencing)
celery_task_id: Mapped[str | None] = mapped_column(String, nullable=True)
cancellation_requested: Mapped[bool] = mapped_column(Boolean, default=False)
# NEW: Batch coordination fields (replacing FileStore state)
total_batches: Mapped[int | None] = mapped_column(Integer, nullable=True)
completed_batches: Mapped[int] = mapped_column(Integer, default=0)
# TODO: unused, remove this column
total_failures_batch_level: Mapped[int] = mapped_column(Integer, default=0)
total_chunks: Mapped[int] = mapped_column(Integer, default=0)
# Progress tracking for stall detection
last_progress_time: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
last_batches_completed_count: Mapped[int] = mapped_column(Integer, default=0)
# NEW: Heartbeat tracking for worker liveness detection
heartbeat_counter: Mapped[int] = mapped_column(Integer, default=0)
last_heartbeat_value: Mapped[int] = mapped_column(Integer, default=0)
last_heartbeat_time: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
@@ -1733,6 +1755,13 @@ class IndexAttempt(Base):
"status",
desc("time_updated"),
),
# NEW: Index for coordination queries
Index(
"ix_index_attempt_active_coordination",
"connector_credential_pair_id",
"search_settings_id",
"status",
),
)
def __repr__(self) -> str:
@@ -1747,6 +1776,13 @@ class IndexAttempt(Base):
def is_finished(self) -> bool:
return self.status.is_terminal()
def is_coordination_complete(self) -> bool:
"""Check if all batches have been processed"""
return (
self.total_batches is not None
and self.completed_batches >= self.total_batches
)
class IndexAttemptError(Base):
__tablename__ = "index_attempt_errors"
@@ -3151,7 +3187,7 @@ class PublicExternalUserGroup(Base):
class UsageReport(Base):
"""This stores metadata about usage reports generated by admin including user who generated
them as well las the period they cover. The actual zip file of the report is stored as a lo
them as well as the period they cover. The actual zip file of the report is stored as a lo
using the FileRecord
"""

View File

@@ -46,7 +46,7 @@ def create_user_files(
# NOTE: At the moment, zip metadata is not used for user files.
# Should revisit to decide whether this should be a feature.
upload_response = upload_files(files, db_session)
upload_response = upload_files(files)
user_files = []
for file_path, file in zip(upload_response.file_paths, files):

View File

@@ -45,7 +45,7 @@ class IndexBatchParams:
Information necessary for efficiently indexing a batch of documents
"""
doc_id_to_previous_chunk_cnt: dict[str, int | None]
doc_id_to_previous_chunk_cnt: dict[str, int]
doc_id_to_new_chunk_cnt: dict[str, int]
tenant_id: str
large_chunks_enabled: bool

View File

@@ -1,8 +1,6 @@
from io import BytesIO
from typing import Tuple
from sqlalchemy.orm import Session
from onyx.configs.constants import FileOrigin
from onyx.connectors.models import ImageSection
from onyx.file_store.file_store import get_default_file_store
@@ -12,7 +10,6 @@ logger = setup_logger()
def store_image_and_create_section(
db_session: Session,
image_data: bytes,
file_id: str,
display_name: str,
@@ -24,7 +21,6 @@ def store_image_and_create_section(
Stores an image in FileStore and creates an ImageSection object without summarization.
Args:
db_session: Database session
image_data: Raw image bytes
file_id: Base identifier for the file
display_name: Human-readable name for the image
@@ -38,7 +34,7 @@ def store_image_and_create_section(
"""
# Storage logic
try:
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
file_id = file_store.save_file(
content=BytesIO(image_data),
display_name=display_name,

View File

@@ -0,0 +1,228 @@
import json
from abc import ABC
from abc import abstractmethod
from enum import Enum
from io import StringIO
from typing import List
from typing import Optional
from typing import TypeAlias
from pydantic import BaseModel
from onyx.configs.constants import FileOrigin
from onyx.connectors.models import DocExtractionContext
from onyx.connectors.models import DocIndexingContext
from onyx.connectors.models import Document
from onyx.file_store.file_store import FileStore
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
logger = setup_logger()
class DocumentBatchStorageStateType(str, Enum):
EXTRACTION = "extraction"
INDEXING = "indexing"
DocumentStorageState: TypeAlias = DocExtractionContext | DocIndexingContext
STATE_TYPE_TO_MODEL: dict[str, type[DocumentStorageState]] = {
DocumentBatchStorageStateType.EXTRACTION.value: DocExtractionContext,
DocumentBatchStorageStateType.INDEXING.value: DocIndexingContext,
}
class BatchStoragePathInfo(BaseModel):
cc_pair_id: int
index_attempt_id: int
batch_num: int
class DocumentBatchStorage(ABC):
"""Abstract base class for document batch storage implementations."""
def __init__(self, cc_pair_id: int, index_attempt_id: int):
self.cc_pair_id = cc_pair_id
self.index_attempt_id = index_attempt_id
self.base_path = f"{self._per_cc_pair_base_path()}/{index_attempt_id}"
@abstractmethod
def store_batch(self, batch_num: int, documents: List[Document]) -> None:
"""Store a batch of documents."""
@abstractmethod
def get_batch(self, batch_num: int) -> Optional[List[Document]]:
"""Retrieve a batch of documents."""
@abstractmethod
def delete_batch_by_name(self, batch_file_name: str) -> None:
"""Delete a specific batch."""
@abstractmethod
def delete_batch_by_num(self, batch_num: int) -> None:
"""Delete a specific batch."""
@abstractmethod
def cleanup_all_batches(self) -> None:
"""Clean up all batches and state for this index attempt."""
@abstractmethod
def get_all_batches_for_cc_pair(self) -> list[str]:
"""Get all IDs of batches stored in the file store."""
@abstractmethod
def update_old_batches_to_new_index_attempt(self, batch_names: list[str]) -> None:
"""Update all batches to the new index attempt."""
"""
This is used when we need to re-issue docprocessing tasks for a new index attempt.
We need to update the batch file names to the new index attempt ID.
"""
@abstractmethod
def extract_path_info(self, path: str) -> BatchStoragePathInfo | None:
"""Extract path info from a path."""
def _serialize_documents(self, documents: list[Document]) -> str:
"""Serialize documents to JSON string."""
# Use mode='json' to properly serialize datetime and other complex types
return json.dumps([doc.model_dump(mode="json") for doc in documents], indent=2)
def _deserialize_documents(self, data: str) -> list[Document]:
"""Deserialize documents from JSON string."""
doc_dicts = json.loads(data)
return [Document.model_validate(doc_dict) for doc_dict in doc_dicts]
def _per_cc_pair_base_path(self) -> str:
"""Get the base path for the cc pair."""
return f"iab/{self.cc_pair_id}"
class FileStoreDocumentBatchStorage(DocumentBatchStorage):
"""FileStore-based implementation of document batch storage."""
def __init__(self, cc_pair_id: int, index_attempt_id: int, file_store: FileStore):
super().__init__(cc_pair_id, index_attempt_id)
self.file_store = file_store
def _get_batch_file_name(self, batch_num: int) -> str:
"""Generate file name for a document batch."""
return f"{self.base_path}/{batch_num}.json"
def store_batch(self, batch_num: int, documents: list[Document]) -> None:
"""Store a batch of documents using FileStore."""
file_name = self._get_batch_file_name(batch_num)
try:
data = self._serialize_documents(documents)
content = StringIO(data)
self.file_store.save_file(
file_id=file_name,
content=content,
display_name=f"Document Batch {batch_num}",
file_origin=FileOrigin.OTHER,
file_type="application/json",
file_metadata={
"batch_num": batch_num,
"document_count": str(len(documents)),
},
)
logger.debug(
f"Stored batch {batch_num} with {len(documents)} documents to FileStore as {file_name}"
)
except Exception as e:
logger.error(f"Failed to store batch {batch_num}: {e}")
raise
def get_batch(self, batch_num: int) -> list[Document] | None:
"""Retrieve a batch of documents from FileStore."""
file_name = self._get_batch_file_name(batch_num)
try:
# Check if file exists
if not self.file_store.has_file(
file_id=file_name,
file_origin=FileOrigin.OTHER,
file_type="application/json",
):
logger.warning(
f"Batch {batch_num} not found in FileStore with name {file_name}"
)
return None
content_io = self.file_store.read_file(file_name)
data = content_io.read().decode("utf-8")
documents = self._deserialize_documents(data)
logger.debug(
f"Retrieved batch {batch_num} with {len(documents)} documents from FileStore"
)
return documents
except Exception as e:
logger.error(f"Failed to retrieve batch {batch_num}: {e}")
raise
def delete_batch_by_name(self, batch_file_name: str) -> None:
"""Delete a specific batch from FileStore."""
self.file_store.delete_file(batch_file_name)
logger.debug(f"Deleted batch {batch_file_name} from FileStore")
def delete_batch_by_num(self, batch_num: int) -> None:
"""Delete a specific batch from FileStore."""
batch_file_name = self._get_batch_file_name(batch_num)
self.delete_batch_by_name(batch_file_name)
logger.debug(f"Deleted batch num {batch_num} {batch_file_name} from FileStore")
def cleanup_all_batches(self) -> None:
"""Clean up all batches for this index attempt."""
for batch_file_name in self.get_all_batches_for_cc_pair():
self.delete_batch_by_name(batch_file_name)
def get_all_batches_for_cc_pair(self) -> list[str]:
"""Get all IDs of batches stored in the file store for the cc pair
this batch store was initialized with.
This includes any batches left over from a previous
indexing attempt that need to be processed.
"""
return [
file.file_id
for file in self.file_store.list_files_by_prefix(
self._per_cc_pair_base_path()
)
]
def update_old_batches_to_new_index_attempt(self, batch_names: list[str]) -> None:
"""Update all batches to the new index attempt."""
for batch_file_name in batch_names:
path_info = self.extract_path_info(batch_file_name)
if path_info is None:
continue
new_batch_file_name = self._get_batch_file_name(path_info.batch_num)
self.file_store.change_file_id(batch_file_name, new_batch_file_name)
def extract_path_info(self, path: str) -> BatchStoragePathInfo | None:
"""Extract path info from a path."""
path_spl = path.split("/")
# TODO: remove this in a few months, just for backwards compatibility
if len(path_spl) == 3:
path_spl = ["iab"] + path_spl
try:
_, cc_pair_id, index_attempt_id, batch_num = path_spl
return BatchStoragePathInfo(
cc_pair_id=int(cc_pair_id),
index_attempt_id=int(index_attempt_id),
batch_num=int(batch_num.split(".")[0]), # remove .json
)
except Exception as e:
logger.error(f"Failed to extract path info from {path}: {e}")
return None
def get_document_batch_storage(
cc_pair_id: int, index_attempt_id: int
) -> DocumentBatchStorage:
"""Factory function to get the configured document batch storage implementation."""
# The get_default_file_store will now correctly use S3BackedFileStore
# or other configured stores based on environment variables
file_store = get_default_file_store()
return FileStoreDocumentBatchStorage(cc_pair_id, index_attempt_id, file_store)

View File

@@ -22,10 +22,14 @@ from onyx.configs.app_configs import S3_FILE_STORE_BUCKET_NAME
from onyx.configs.app_configs import S3_FILE_STORE_PREFIX
from onyx.configs.app_configs import S3_VERIFY_SSL
from onyx.configs.constants import FileOrigin
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant_if_none
from onyx.db.file_record import delete_filerecord_by_file_id
from onyx.db.file_record import get_filerecord_by_file_id
from onyx.db.file_record import get_filerecord_by_file_id_optional
from onyx.db.file_record import get_filerecord_by_prefix
from onyx.db.file_record import upsert_filerecord
from onyx.db.models import FileRecord
from onyx.db.models import FileRecord as FileStoreModel
from onyx.file_store.s3_key_utils import generate_s3_key
from onyx.utils.file import FileWithMimeType
@@ -129,13 +133,29 @@ class FileStore(ABC):
Get the file + parse out the mime type.
"""
@abstractmethod
def change_file_id(self, old_file_id: str, new_file_id: str) -> None:
"""
Change the file ID of an existing file.
Parameters:
- old_file_id: Current file ID
- new_file_id: New file ID to assign
"""
raise NotImplementedError
@abstractmethod
def list_files_by_prefix(self, prefix: str) -> list[FileRecord]:
"""
List all file IDs that start with the given prefix.
"""
class S3BackedFileStore(FileStore):
"""Isn't necessarily S3, but is any S3-compatible storage (e.g. MinIO)"""
def __init__(
self,
db_session: Session,
bucket_name: str,
aws_access_key_id: str | None = None,
aws_secret_access_key: str | None = None,
@@ -144,7 +164,6 @@ class S3BackedFileStore(FileStore):
s3_prefix: str | None = None,
s3_verify_ssl: bool = True,
) -> None:
self.db_session = db_session
self._s3_client: S3Client | None = None
self._bucket_name = bucket_name
self._aws_access_key_id = aws_access_key_id
@@ -272,10 +291,12 @@ class S3BackedFileStore(FileStore):
file_id: str,
file_origin: FileOrigin,
file_type: str,
db_session: Session | None = None,
) -> bool:
file_record = get_filerecord_by_file_id_optional(
file_id=file_id, db_session=self.db_session
)
with get_session_with_current_tenant_if_none(db_session) as db_session:
file_record = get_filerecord_by_file_id_optional(
file_id=file_id, db_session=db_session
)
return (
file_record is not None
and file_record.file_origin == file_origin
@@ -290,6 +311,7 @@ class S3BackedFileStore(FileStore):
file_type: str,
file_metadata: dict[str, Any] | None = None,
file_id: str | None = None,
db_session: Session | None = None,
) -> str:
if file_id is None:
file_id = str(uuid.uuid4())
@@ -314,27 +336,33 @@ class S3BackedFileStore(FileStore):
ContentType=file_type,
)
# Save metadata to database
upsert_filerecord(
file_id=file_id,
display_name=display_name or file_id,
file_origin=file_origin,
file_type=file_type,
bucket_name=bucket_name,
object_key=s3_key,
db_session=self.db_session,
file_metadata=file_metadata,
)
self.db_session.commit()
with get_session_with_current_tenant_if_none(db_session) as db_session:
# Save metadata to database
upsert_filerecord(
file_id=file_id,
display_name=display_name or file_id,
file_origin=file_origin,
file_type=file_type,
bucket_name=bucket_name,
object_key=s3_key,
db_session=db_session,
file_metadata=file_metadata,
)
db_session.commit()
return file_id
def read_file(
self, file_id: str, mode: str | None = None, use_tempfile: bool = False
self,
file_id: str,
mode: str | None = None,
use_tempfile: bool = False,
db_session: Session | None = None,
) -> IO[bytes]:
file_record = get_filerecord_by_file_id(
file_id=file_id, db_session=self.db_session
)
with get_session_with_current_tenant_if_none(db_session) as db_session:
file_record = get_filerecord_by_file_id(
file_id=file_id, db_session=db_session
)
s3_client = self._get_s3_client()
try:
@@ -356,32 +384,107 @@ class S3BackedFileStore(FileStore):
else:
return BytesIO(file_content)
def read_file_record(self, file_id: str) -> FileStoreModel:
file_record = get_filerecord_by_file_id(
file_id=file_id, db_session=self.db_session
)
def read_file_record(
self, file_id: str, db_session: Session | None = None
) -> FileStoreModel:
with get_session_with_current_tenant_if_none(db_session) as db_session:
file_record = get_filerecord_by_file_id(
file_id=file_id, db_session=db_session
)
return file_record
def delete_file(self, file_id: str) -> None:
try:
file_record = get_filerecord_by_file_id(
file_id=file_id, db_session=self.db_session
)
def delete_file(self, file_id: str, db_session: Session | None = None) -> None:
with get_session_with_current_tenant_if_none(db_session) as db_session:
try:
# Delete from external storage
s3_client = self._get_s3_client()
s3_client.delete_object(
Bucket=file_record.bucket_name, Key=file_record.object_key
)
file_record = get_filerecord_by_file_id(
file_id=file_id, db_session=db_session
)
if not file_record.bucket_name:
logger.error(
f"File record {file_id} with key {file_record.object_key} "
"has no bucket name, cannot delete from filestore"
)
delete_filerecord_by_file_id(file_id=file_id, db_session=db_session)
db_session.commit()
return
# Delete metadata from database
delete_filerecord_by_file_id(file_id=file_id, db_session=self.db_session)
# Delete from external storage
s3_client = self._get_s3_client()
s3_client.delete_object(
Bucket=file_record.bucket_name, Key=file_record.object_key
)
self.db_session.commit()
# Delete metadata from database
delete_filerecord_by_file_id(file_id=file_id, db_session=db_session)
except Exception:
self.db_session.rollback()
raise
db_session.commit()
except Exception:
db_session.rollback()
raise
def change_file_id(
self, old_file_id: str, new_file_id: str, db_session: Session | None = None
) -> None:
with get_session_with_current_tenant_if_none(db_session) as db_session:
try:
# Get the existing file record
old_file_record = get_filerecord_by_file_id(
file_id=old_file_id, db_session=db_session
)
# Generate new S3 key for the new file ID
new_s3_key = self._get_s3_key(new_file_id)
# Copy S3 object to new key
s3_client = self._get_s3_client()
bucket_name = self._get_bucket_name()
copy_source = (
f"{old_file_record.bucket_name}/{old_file_record.object_key}"
)
s3_client.copy_object(
CopySource=copy_source,
Bucket=bucket_name,
Key=new_s3_key,
MetadataDirective="COPY",
)
# Create new file record with new file_id
# Cast file_metadata to the expected type
file_metadata = cast(
dict[Any, Any] | None, old_file_record.file_metadata
)
upsert_filerecord(
file_id=new_file_id,
display_name=old_file_record.display_name,
file_origin=old_file_record.file_origin,
file_type=old_file_record.file_type,
bucket_name=bucket_name,
object_key=new_s3_key,
db_session=db_session,
file_metadata=file_metadata,
)
# Delete old S3 object
s3_client.delete_object(
Bucket=old_file_record.bucket_name, Key=old_file_record.object_key
)
# Delete old file record
delete_filerecord_by_file_id(file_id=old_file_id, db_session=db_session)
db_session.commit()
except Exception as e:
db_session.rollback()
logger.exception(
f"Failed to change file ID from {old_file_id} to {new_file_id}: {e}"
)
raise
def get_file_with_mime_type(self, filename: str) -> FileWithMimeType | None:
mime_type: str = "application/octet-stream"
@@ -395,8 +498,18 @@ class S3BackedFileStore(FileStore):
except Exception:
return None
def list_files_by_prefix(self, prefix: str) -> list[FileRecord]:
"""
List all file IDs that start with the given prefix.
"""
with get_session_with_current_tenant() as db_session:
file_records = get_filerecord_by_prefix(
prefix=prefix, db_session=db_session
)
return file_records
def get_s3_file_store(db_session: Session) -> S3BackedFileStore:
def get_s3_file_store() -> S3BackedFileStore:
"""
Returns the S3 file store implementation.
"""
@@ -409,7 +522,6 @@ def get_s3_file_store(db_session: Session) -> S3BackedFileStore:
)
return S3BackedFileStore(
db_session=db_session,
bucket_name=bucket_name,
aws_access_key_id=S3_AWS_ACCESS_KEY_ID,
aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY,
@@ -420,7 +532,7 @@ def get_s3_file_store(db_session: Session) -> S3BackedFileStore:
)
def get_default_file_store(db_session: Session) -> FileStore:
def get_default_file_store() -> FileStore:
"""
Returns the configured file store implementation.
@@ -445,4 +557,4 @@ def get_default_file_store(db_session: Session) -> FileStore:
Other S3-compatible storage (Digital Ocean, Linode, etc.):
- Same as MinIO, but set appropriate S3_ENDPOINT_URL
"""
return get_s3_file_store(db_session)
return get_s3_file_store()

View File

@@ -8,7 +8,6 @@ import requests
from sqlalchemy.orm import Session
from onyx.configs.constants import FileOrigin
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import ChatMessage
from onyx.db.models import UserFile
from onyx.db.models import UserFolder
@@ -49,28 +48,23 @@ def store_user_file_plaintext(user_file_id: int, plaintext_content: str) -> bool
# Use a separate session to avoid committing the caller's transaction
try:
with get_session_with_current_tenant() as file_store_session:
file_store = get_default_file_store(file_store_session)
file_content = BytesIO(plaintext_content.encode("utf-8"))
file_store.save_file(
content=file_content,
display_name=f"Plaintext for user file {user_file_id}",
file_origin=FileOrigin.PLAINTEXT_CACHE,
file_type="text/plain",
file_id=plaintext_file_name,
)
return True
file_store = get_default_file_store()
file_content = BytesIO(plaintext_content.encode("utf-8"))
file_store.save_file(
content=file_content,
display_name=f"Plaintext for user file {user_file_id}",
file_origin=FileOrigin.PLAINTEXT_CACHE,
file_type="text/plain",
file_id=plaintext_file_name,
)
return True
except Exception as e:
logger.warning(f"Failed to store plaintext for user file {user_file_id}: {e}")
return False
def load_chat_file(
file_descriptor: FileDescriptor, db_session: Session
) -> InMemoryChatFile:
file_io = get_default_file_store(db_session).read_file(
file_descriptor["id"], mode="b"
)
def load_chat_file(file_descriptor: FileDescriptor) -> InMemoryChatFile:
file_io = get_default_file_store().read_file(file_descriptor["id"], mode="b")
return InMemoryChatFile(
file_id=file_descriptor["id"],
content=file_io.read(),
@@ -82,7 +76,6 @@ def load_chat_file(
def load_all_chat_files(
chat_messages: list[ChatMessage],
file_descriptors: list[FileDescriptor],
db_session: Session,
) -> list[InMemoryChatFile]:
file_descriptors_for_history: list[FileDescriptor] = []
for chat_message in chat_messages:
@@ -93,7 +86,7 @@ def load_all_chat_files(
list[InMemoryChatFile],
run_functions_tuples_in_parallel(
[
(load_chat_file, (file, db_session))
(load_chat_file, (file,))
for file in file_descriptors + file_descriptors_for_history
]
),
@@ -117,7 +110,7 @@ def load_user_file(file_id: int, db_session: Session) -> InMemoryChatFile:
raise ValueError(f"User file with id {file_id} not found")
# Get the file record to determine the appropriate chat file type
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
file_record = file_store.read_file_record(user_file.file_id)
# Determine appropriate chat file type based on the original file's MIME type
@@ -263,34 +256,29 @@ def get_user_files_as_user(
def save_file_from_url(url: str) -> str:
"""NOTE: using multiple sessions here, since this is often called
using multithreading. In practice, sharing a session has resulted in
weird errors."""
with get_session_with_current_tenant() as db_session:
response = requests.get(url)
response.raise_for_status()
response = requests.get(url)
response.raise_for_status()
file_io = BytesIO(response.content)
file_store = get_default_file_store(db_session)
file_id = file_store.save_file(
content=file_io,
display_name="GeneratedImage",
file_origin=FileOrigin.CHAT_IMAGE_GEN,
file_type="image/png;base64",
)
return file_id
file_io = BytesIO(response.content)
file_store = get_default_file_store()
file_id = file_store.save_file(
content=file_io,
display_name="GeneratedImage",
file_origin=FileOrigin.CHAT_IMAGE_GEN,
file_type="image/png;base64",
)
return file_id
def save_file_from_base64(base64_string: str) -> str:
with get_session_with_current_tenant() as db_session:
file_store = get_default_file_store(db_session)
file_id = file_store.save_file(
content=BytesIO(base64.b64decode(base64_string)),
display_name="GeneratedImage",
file_origin=FileOrigin.CHAT_IMAGE_GEN,
file_type=get_image_type(base64_string),
)
return file_id
file_store = get_default_file_store()
file_id = file_store.save_file(
content=BytesIO(base64.b64decode(base64_string)),
display_name="GeneratedImage",
file_origin=FileOrigin.CHAT_IMAGE_GEN,
file_type=get_image_type(base64_string),
)
return file_id
def save_file(

View File

@@ -4,6 +4,13 @@ from typing import Any
import httpx
def make_default_kwargs() -> dict[str, Any]:
return {
"http2": True,
"limits": httpx.Limits(),
}
class HttpxPool:
"""Class to manage a global httpx Client instance"""
@@ -11,10 +18,6 @@ class HttpxPool:
_lock: threading.Lock = threading.Lock()
# Default parameters for creation
DEFAULT_KWARGS = {
"http2": True,
"limits": lambda: httpx.Limits(),
}
def __init__(self) -> None:
pass
@@ -22,7 +25,7 @@ class HttpxPool:
@classmethod
def _init_client(cls, **kwargs: Any) -> httpx.Client:
"""Private helper method to create and return an httpx.Client."""
merged_kwargs = {**cls.DEFAULT_KWARGS, **kwargs}
merged_kwargs = {**(make_default_kwargs()), **kwargs}
return httpx.Client(**merged_kwargs)
@classmethod

View File

@@ -4,6 +4,7 @@ from abc import abstractmethod
from collections import defaultdict
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorStopSignal
from onyx.connectors.models import DocumentFailure
from onyx.db.models import SearchSettings
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
@@ -261,6 +262,11 @@ def embed_chunks_with_failure_handling(
),
[],
)
except ConnectorStopSignal as e:
logger.warning(
"Connector stop signal detected in embed_chunks_with_failure_handling"
)
raise e
except Exception:
logger.exception("Failed to embed chunk batch. Trying individual docs.")
# wait a couple seconds to let any rate limits or temporary issues resolve

View File

@@ -1,6 +1,5 @@
from collections import defaultdict
from collections.abc import Callable
from functools import partial
from typing import Protocol
from pydantic import BaseModel
@@ -23,6 +22,7 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
get_experts_stores_representations,
)
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorStopSignal
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import ImageSection
@@ -41,7 +41,6 @@ from onyx.db.document import update_docs_updated_at__no_commit
from onyx.db.document import upsert_document_by_connector_credential_pair
from onyx.db.document import upsert_documents
from onyx.db.document_set import fetch_document_sets_for_documents
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import Document as DBDocument
from onyx.db.models import IndexModelStatus
from onyx.db.search_settings import get_active_search_settings
@@ -62,7 +61,6 @@ from onyx.file_store.utils import store_user_file_plaintext
from onyx.indexing.chunker import Chunker
from onyx.indexing.embedder import embed_chunks_with_failure_handling
from onyx.indexing.embedder import IndexingEmbedder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
@@ -290,6 +288,10 @@ def index_doc_batch_with_handler(
enable_contextual_rag=enable_contextual_rag,
llm=llm,
)
except ConnectorStopSignal as e:
logger.warning("Connector stop signal detected in index_doc_batch_with_handler")
raise e
except Exception as e:
# don't log the batch directly, it's too much text
document_ids = [doc.id for doc in document_batch]
@@ -496,36 +498,33 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
# Try to get image summary
try:
with get_session_with_current_tenant() as db_session:
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
file_record = file_store.read_file_record(
file_record = file_store.read_file_record(
file_id=section.image_file_id
)
if not file_record:
logger.warning(
f"Image file {section.image_file_id} not found in FileStore"
)
processed_section.text = "[Image could not be processed]"
else:
# Get the image data
image_data_io = file_store.read_file(
file_id=section.image_file_id
)
if not file_record:
logger.warning(
f"Image file {section.image_file_id} not found in FileStore"
)
image_data = image_data_io.read()
summary = summarize_image_with_error_handling(
llm=llm,
image_data=image_data,
context_name=file_record.display_name or "Image",
)
processed_section.text = "[Image could not be processed]"
if summary:
processed_section.text = summary
else:
# Get the image data
image_data_io = file_store.read_file(
file_id=section.image_file_id
)
image_data = image_data_io.read()
summary = summarize_image_with_error_handling(
llm=llm,
image_data=image_data,
context_name=file_record.display_name or "Image",
)
if summary:
processed_section.text = summary
else:
processed_section.text = (
"[Image could not be summarized]"
)
processed_section.text = "[Image could not be summarized]"
except Exception as e:
logger.error(f"Error processing image section: {e}")
processed_section.text = "[Error processing image]"
@@ -832,7 +831,7 @@ def index_doc_batch(
)
)
doc_id_to_previous_chunk_cnt: dict[str, int | None] = {
doc_id_to_previous_chunk_cnt: dict[str, int] = {
document_id: chunk_count
for document_id, chunk_count in fetch_chunk_counts_for_documents(
document_ids=updatable_ids,
@@ -1029,7 +1028,7 @@ def index_doc_batch(
db_session.commit()
result = IndexingPipelineResult(
new_docs=len([r for r in insertion_records if r.already_existed is False]),
new_docs=len([r for r in insertion_records if not r.already_existed]),
total_docs=len(filtered_documents),
total_chunks=len(access_aware_chunks),
failures=vector_db_write_failures + embedding_failures,
@@ -1038,8 +1037,10 @@ def index_doc_batch(
return result
def build_indexing_pipeline(
def run_indexing_pipeline(
*,
document_batch: list[Document],
index_attempt_metadata: IndexAttemptMetadata,
embedder: IndexingEmbedder,
information_content_classification_model: InformationContentClassificationModel,
document_index: DocumentIndex,
@@ -1047,8 +1048,7 @@ def build_indexing_pipeline(
tenant_id: str,
chunker: Chunker | None = None,
ignore_time_skip: bool = False,
callback: IndexingHeartbeatInterface | None = None,
) -> IndexingPipelineProtocol:
) -> IndexingPipelineResult:
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
all_search_settings = get_active_search_settings(db_session)
if (
@@ -1078,15 +1078,15 @@ def build_indexing_pipeline(
enable_large_chunks=multipass_config.enable_large_chunks,
enable_contextual_rag=enable_contextual_rag,
# after every doc, update status in case there are a bunch of really long docs
callback=callback,
)
return partial(
index_doc_batch_with_handler,
return index_doc_batch_with_handler(
chunker=chunker,
embedder=embedder,
information_content_classification_model=information_content_classification_model,
document_index=document_index,
document_batch=document_batch,
index_attempt_metadata=index_attempt_metadata,
ignore_time_skip=ignore_time_skip,
db_session=db_session,
tenant_id=tenant_id,

View File

@@ -259,7 +259,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
setup_onyx(db_session, POSTGRES_DEFAULT_SCHEMA)
# set up the file store (e.g. create bucket if needed). On multi-tenant,
# this is done via IaC
get_default_file_store(db_session).initialize()
get_default_file_store().initialize()
else:
setup_multitenant_onyx()

View File

@@ -22,6 +22,7 @@ from onyx.configs.model_configs import (
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES,
)
from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from onyx.connectors.models import ConnectorStopSignal
from onyx.db.models import SearchSettings
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.natural_language_processing.exceptions import (
@@ -198,7 +199,9 @@ class EmbeddingModel:
) -> tuple[int, list[Embedding]]:
if self.callback:
if self.callback.should_stop():
raise RuntimeError("_batch_encode_texts detected stop signal")
raise ConnectorStopSignal(
"_batch_encode_texts detected stop signal"
)
embed_request = EmbedRequest(
model_name=self.model_name,

View File

@@ -1,8 +1,5 @@
import time
import redis
from onyx.db.models import SearchSettings
from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
@@ -12,68 +9,33 @@ from onyx.redis.redis_connector_stop import RedisConnectorStop
from onyx.redis.redis_pool import get_redis_client
# TODO: reduce dependence on redis
class RedisConnector:
"""Composes several classes to simplify interacting with a connector and its
associated background tasks / associated redis interactions."""
def __init__(self, tenant_id: str, id: int) -> None:
def __init__(self, tenant_id: str, cc_pair_id: int) -> None:
"""id: a connector credential pair id"""
self.tenant_id: str = tenant_id
self.id: int = id
self.cc_pair_id: int = cc_pair_id
self.redis: redis.Redis = get_redis_client(tenant_id=tenant_id)
self.stop = RedisConnectorStop(tenant_id, id, self.redis)
self.prune = RedisConnectorPrune(tenant_id, id, self.redis)
self.delete = RedisConnectorDelete(tenant_id, id, self.redis)
self.permissions = RedisConnectorPermissionSync(tenant_id, id, self.redis)
self.stop = RedisConnectorStop(tenant_id, cc_pair_id, self.redis)
self.prune = RedisConnectorPrune(tenant_id, cc_pair_id, self.redis)
self.delete = RedisConnectorDelete(tenant_id, cc_pair_id, self.redis)
self.permissions = RedisConnectorPermissionSync(
tenant_id, cc_pair_id, self.redis
)
self.external_group_sync = RedisConnectorExternalGroupSync(
tenant_id, id, self.redis
tenant_id, cc_pair_id, self.redis
)
def new_index(self, search_settings_id: int) -> RedisConnectorIndex:
return RedisConnectorIndex(
self.tenant_id, self.id, search_settings_id, self.redis
self.tenant_id, self.cc_pair_id, search_settings_id, self.redis
)
def wait_for_indexing_termination(
self,
search_settings_list: list[SearchSettings],
timeout: float = 15.0,
) -> bool:
"""
Returns True if all indexing for the given redis connector is finished within the given timeout.
Returns False if the timeout is exceeded
This check does not guarantee that current indexings being terminated
won't get restarted midflight
"""
finished = False
start = time.monotonic()
while True:
still_indexing = False
for search_settings in search_settings_list:
redis_connector_index = self.new_index(search_settings.id)
if redis_connector_index.fenced:
still_indexing = True
break
if not still_indexing:
finished = True
break
now = time.monotonic()
if now - start > timeout:
break
time.sleep(1)
continue
return finished
@staticmethod
def get_id_from_fence_key(key: str) -> str | None:
"""

View File

@@ -58,10 +58,7 @@ class RedisConnectorDelete:
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
return bool(self.redis.exists(self.fence_key))
@property
def payload(self) -> RedisConnectorDeletePayload | None:
@@ -93,10 +90,7 @@ class RedisConnectorDelete:
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
def active(self) -> bool:
if self.redis.exists(self.active_key):
return True
return False
return bool(self.redis.exists(self.active_key))
def _generate_task_id(self) -> str:
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"

View File

@@ -88,10 +88,7 @@ class RedisConnectorPermissionSync:
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
return bool(self.redis.exists(self.fence_key))
@property
def payload(self) -> RedisConnectorPermissionSyncPayload | None:
@@ -128,10 +125,7 @@ class RedisConnectorPermissionSync:
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
def active(self) -> bool:
if self.redis.exists(self.active_key):
return True
return False
return bool(self.redis.exists(self.active_key))
@property
def generator_complete(self) -> int | None:

View File

@@ -84,10 +84,7 @@ class RedisConnectorExternalGroupSync:
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
return bool(self.redis.exists(self.fence_key))
@property
def payload(self) -> RedisConnectorExternalGroupSyncPayload | None:
@@ -125,10 +122,7 @@ class RedisConnectorExternalGroupSync:
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
def active(self) -> bool:
if self.redis.exists(self.active_key):
return True
return False
return bool(self.redis.exists(self.active_key))
@property
def generator_complete(self) -> int | None:

View File

@@ -1,13 +1,10 @@
from datetime import datetime
from typing import Any
from typing import cast
from uuid import uuid4
import redis
from pydantic import BaseModel
from onyx.configs.constants import CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
from onyx.configs.constants import OnyxRedisConstants
class RedisConnectorIndexPayload(BaseModel):
@@ -31,7 +28,10 @@ class RedisConnectorIndex:
PREFIX + "_generator_complete"
) # connectorindexing_generator_complete
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
GENERATOR_LOCK_PREFIX = "da_lock:indexing:docfetching"
FILESTORE_LOCK_PREFIX = "da_lock:indexing:filestore"
DB_LOCK_PREFIX = "da_lock:indexing:db"
PER_WORKER_LOCK_PREFIX = "da_lock:indexing:per_worker"
TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate
TERMINATE_TTL = 600
@@ -53,130 +53,34 @@ class RedisConnectorIndex:
def __init__(
self,
tenant_id: str,
id: int,
cc_pair_id: int,
search_settings_id: int,
redis: redis.Redis,
) -> None:
self.tenant_id: str = tenant_id
self.id = id
self.cc_pair_id = cc_pair_id
self.search_settings_id = search_settings_id
self.redis = redis
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}/{search_settings_id}"
self.generator_progress_key = (
f"{self.GENERATOR_PROGRESS_PREFIX}_{id}/{search_settings_id}"
)
self.generator_complete_key = (
f"{self.GENERATOR_COMPLETE_PREFIX}_{id}/{search_settings_id}"
f"{self.GENERATOR_COMPLETE_PREFIX}_{cc_pair_id}/{search_settings_id}"
)
self.filestore_lock_key = (
f"{self.FILESTORE_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
)
self.generator_lock_key = (
f"{self.GENERATOR_LOCK_PREFIX}_{id}/{search_settings_id}"
f"{self.GENERATOR_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
)
self.terminate_key = f"{self.TERMINATE_PREFIX}_{id}/{search_settings_id}"
self.watchdog_key = f"{self.WATCHDOG_PREFIX}_{id}/{search_settings_id}"
self.active_key = f"{self.ACTIVE_PREFIX}_{id}/{search_settings_id}"
self.connector_active_key = (
f"{self.CONNECTOR_ACTIVE_PREFIX}_{id}/{search_settings_id}"
self.per_worker_lock_key = (
f"{self.PER_WORKER_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
)
self.db_lock_key = f"{self.DB_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
self.terminate_key = (
f"{self.TERMINATE_PREFIX}_{cc_pair_id}/{search_settings_id}"
)
@classmethod
def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str:
return f"{cls.FENCE_PREFIX}_{cc_pair_id}/{search_settings_id}"
def generate_generator_task_id(self) -> str:
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "connectorindexing+generator_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
return f"{self.GENERATOR_TASK_PREFIX}_{self.id}/{self.search_settings_id}_{uuid4()}"
@property
def fenced(self) -> bool:
return bool(self.redis.exists(self.fence_key))
@property
def payload(self) -> RedisConnectorIndexPayload | None:
# read related data and evaluate/print task progress
fence_bytes = cast(Any, self.redis.get(self.fence_key))
if fence_bytes is None:
return None
fence_str = fence_bytes.decode("utf-8")
return RedisConnectorIndexPayload.model_validate_json(cast(str, fence_str))
def set_fence(
self,
payload: RedisConnectorIndexPayload | None,
) -> None:
if not payload:
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
self.redis.delete(self.fence_key)
return
self.redis.set(self.fence_key, payload.model_dump_json())
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
def terminating(self, celery_task_id: str) -> bool:
return bool(self.redis.exists(f"{self.terminate_key}_{celery_task_id}"))
def set_terminate(self, celery_task_id: str) -> None:
"""This sets a signal. It does not block!"""
# We shouldn't need very long to terminate the spawned task.
# 10 minute TTL is good.
self.redis.set(
f"{self.terminate_key}_{celery_task_id}", 0, ex=self.TERMINATE_TTL
)
def set_watchdog(self, value: bool) -> None:
"""Signal the state of the watchdog."""
if not value:
self.redis.delete(self.watchdog_key)
return
self.redis.set(self.watchdog_key, 0, ex=self.WATCHDOG_TTL)
def watchdog_signaled(self) -> bool:
"""Check the state of the watchdog."""
return bool(self.redis.exists(self.watchdog_key))
def set_active(self) -> None:
"""This sets a signal to keep the indexing flow from getting cleaned up within
the expiration time.
The slack in timing is needed to avoid race conditions where simply checking
the celery queue and task status could result in race conditions."""
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
def active(self) -> bool:
return bool(self.redis.exists(self.active_key))
def set_connector_active(self) -> None:
"""This sets a signal to keep the indexing flow from getting cleaned up within
the expiration time.
The slack in timing is needed to avoid race conditions where simply checking
the celery queue and task status could result in race conditions."""
self.redis.set(self.connector_active_key, 0, ex=self.CONNECTOR_ACTIVE_TTL)
def connector_active(self) -> bool:
if self.redis.exists(self.connector_active_key):
return True
return False
def connector_active_ttl(self) -> int:
"""Refer to https://redis.io/docs/latest/commands/ttl/
-2 means the key does not exist
-1 means the key exists but has no associated expire
Otherwise, returns the actual TTL of the key
"""
ttl = cast(int, self.redis.ttl(self.connector_active_key))
return ttl
def generator_locked(self) -> bool:
return bool(self.redis.exists(self.generator_lock_key))
def lock_key_by_batch(self, batch_n: int) -> str:
return f"{self.per_worker_lock_key}/{batch_n}"
def set_generator_complete(self, payload: int | None) -> None:
if not payload:
@@ -186,21 +90,9 @@ class RedisConnectorIndex:
self.redis.set(self.generator_complete_key, payload)
def generator_clear(self) -> None:
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
def get_progress(self) -> int | None:
"""Returns None if the key doesn't exist. The"""
# TODO: move into fence?
bytes = self.redis.get(self.generator_progress_key)
if bytes is None:
return None
progress = int(cast(int, bytes))
return progress
def get_completion(self) -> int | None:
# TODO: move into fence?
bytes = self.redis.get(self.generator_complete_key)
if bytes is None:
return None
@@ -209,24 +101,22 @@ class RedisConnectorIndex:
return status
def reset(self) -> None:
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
self.redis.delete(self.connector_active_key)
self.redis.delete(self.active_key)
self.redis.delete(self.filestore_lock_key)
self.redis.delete(self.db_lock_key)
self.redis.delete(self.generator_lock_key)
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
self.redis.delete(self.fence_key)
@staticmethod
def reset_all(r: redis.Redis) -> None:
"""Deletes all redis values for all connectors"""
# leaving these temporarily for backwards compat, TODO: remove
for key in r.scan_iter(RedisConnectorIndex.CONNECTOR_ACTIVE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndex.ACTIVE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_LOCK_PREFIX + "*"):
for key in r.scan_iter(RedisConnectorIndex.FILESTORE_LOCK_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_COMPLETE_PREFIX + "*"):

View File

@@ -92,10 +92,7 @@ class RedisConnectorPrune:
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
return bool(self.redis.exists(self.fence_key))
@property
def payload(self) -> RedisConnectorPrunePayload | None:
@@ -130,10 +127,7 @@ class RedisConnectorPrune:
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
def active(self) -> bool:
if self.redis.exists(self.active_key):
return True
return False
return bool(self.redis.exists(self.active_key))
@property
def generator_complete(self) -> int | None:

View File

@@ -23,10 +23,7 @@ class RedisConnectorStop:
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
return bool(self.redis.exists(self.fence_key))
def set_fence(self, value: bool) -> None:
if not value:
@@ -37,10 +34,7 @@ class RedisConnectorStop:
@property
def timed_out(self) -> bool:
if self.redis.exists(self.timeout_key):
return False
return True
return not bool(self.redis.exists(self.timeout_key))
def set_timeout(self) -> None:
"""After calling this, call timed_out to determine if the timeout has been

View File

@@ -6,6 +6,7 @@ from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi.responses import JSONResponse
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
@@ -34,15 +35,15 @@ from onyx.db.document import get_documents_for_cc_pair
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingStatus
from onyx.db.index_attempt import count_index_attempt_errors_for_cc_pair
from onyx.db.index_attempt import count_index_attempts_for_connector
from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair
from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from onyx.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
from onyx.db.models import SearchSettings
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.db.models import IndexAttempt
from onyx.db.models import User
from onyx.db.search_settings import get_active_search_settings_list
from onyx.db.search_settings import get_current_search_settings
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_utils import get_deletion_attempt_snapshot
from onyx.redis.redis_pool import get_redis_client
@@ -139,11 +140,6 @@ def get_cc_pair_full_info(
only_finished=False,
)
search_settings = get_current_search_settings(db_session)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings.id)
return CCPairFullInfo.from_models(
cc_pair_model=cc_pair,
number_of_index_attempts=count_index_attempts_for_connector(
@@ -159,7 +155,9 @@ def get_cc_pair_full_info(
),
num_docs_indexed=documents_indexed,
is_editable_for_current_user=is_editable_for_current_user,
indexing=redis_connector_index.fenced,
indexing=bool(
latest_attempt and latest_attempt.status == IndexingStatus.IN_PROGRESS
),
)
@@ -195,31 +193,35 @@ def update_cc_pair_status(
if status_update_request.status == ConnectorCredentialPairStatus.PAUSED:
redis_connector.stop.set_fence(True)
search_settings_list: list[SearchSettings] = get_active_search_settings_list(
db_session
# Request cancellation for any active indexing attempts for this cc_pair
active_attempts = (
db_session.execute(
select(IndexAttempt).where(
IndexAttempt.connector_credential_pair_id == cc_pair_id,
IndexAttempt.status.in_(
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
),
)
)
.scalars()
.all()
)
while True:
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(search_settings.id)
if not redis_connector_index.fenced:
continue
index_payload = redis_connector_index.payload
if not index_payload:
continue
if not index_payload.celery_task_id:
continue
for attempt in active_attempts:
try:
IndexingCoordination.request_cancellation(db_session, attempt.id)
# Revoke the task to prevent it from running
client_app.control.revoke(index_payload.celery_task_id)
if attempt.celery_task_id:
client_app.control.revoke(attempt.celery_task_id)
logger.info(
f"Requested cancellation for active indexing attempt {attempt.id} "
f"due to connector pause: cc_pair={cc_pair_id}"
)
except Exception:
logger.exception(
f"Failed to request cancellation for indexing attempt {attempt.id}"
)
# If it is running, then signaling for termination will get the
# watchdog thread to kill the spawned task
redis_connector_index.set_terminate(index_payload.celery_task_id)
break
else:
redis_connector.stop.set_fence(False)

View File

@@ -106,7 +106,6 @@ from onyx.db.search_settings import get_secondary_search_settings
from onyx.file_processing.extract_file_text import convert_docx_to_txt
from onyx.file_store.file_store import get_default_file_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.redis.redis_connector import RedisConnector
from onyx.server.documents.models import AuthStatus
from onyx.server.documents.models import AuthUrl
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
@@ -421,7 +420,7 @@ def extract_zip_metadata(zf: zipfile.ZipFile) -> dict[str, Any]:
return zip_metadata
def upload_files(files: list[UploadFile], db_session: Session) -> FileUploadResponse:
def upload_files(files: list[UploadFile]) -> FileUploadResponse:
for file in files:
if not file.filename:
raise HTTPException(status_code=400, detail="File name cannot be empty")
@@ -434,7 +433,7 @@ def upload_files(files: list[UploadFile], db_session: Session) -> FileUploadResp
deduped_file_paths = []
zip_metadata = {}
try:
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
seen_zip = False
for file in files:
if file.content_type and file.content_type.startswith("application/zip"):
@@ -491,9 +490,8 @@ def upload_files(files: list[UploadFile], db_session: Session) -> FileUploadResp
def upload_files_api(
files: list[UploadFile],
_: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> FileUploadResponse:
return upload_files(files, db_session)
return upload_files(files)
@router.get("/admin/connector")
@@ -775,7 +773,7 @@ def get_connector_indexing_status(
if secondary_index
else get_current_search_settings
)
search_settings = get_search_settings(db_session)
get_search_settings(db_session)
for cc_pair in cc_pairs:
# TODO remove this to enable ingestion API
if cc_pair.name == "DefaultCCPair":
@@ -787,16 +785,13 @@ def get_connector_indexing_status(
# This may happen if background deletion is happening
continue
in_progress = False
if search_settings:
redis_connector = RedisConnector(tenant_id, cc_pair.id)
redis_connector_index = redis_connector.new_index(search_settings.id)
if redis_connector_index.fenced:
in_progress = True
latest_index_attempt = cc_pair_to_latest_index_attempt.get(
(connector.id, credential.id)
)
in_progress = bool(
latest_index_attempt
and latest_index_attempt.status == IndexingStatus.IN_PROGRESS
)
latest_finished_attempt = cc_pair_to_latest_finished_index_attempt.get(
(connector.id, credential.id)

View File

@@ -195,10 +195,9 @@ def undelete_persona(
@admin_router.post("/upload-image")
def upload_file(
file: UploadFile,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_user),
) -> dict[str, str]:
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
file_type = ChatFileType.IMAGE
file_id = file_store.save_file(
content=file.file,

View File

@@ -205,6 +205,6 @@ def create_deletion_attempt_for_connector_id(
if cc_pair.connector.source == DocumentSource.FILE:
connector = cc_pair.connector
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
for file_name in connector.connector_specific_config.get("file_locations", []):
file_store.delete_file(file_name)

View File

@@ -21,7 +21,7 @@ from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_secondary_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
@@ -113,16 +113,13 @@ def upsert_ingestion_doc(
information_content_classification_model = InformationContentClassificationModel()
indexing_pipeline = build_indexing_pipeline(
indexing_pipeline_result = run_indexing_pipeline(
embedder=index_embedding_model,
information_content_classification_model=information_content_classification_model,
document_index=curr_doc_index,
ignore_time_skip=True,
db_session=db_session,
tenant_id=tenant_id,
)
indexing_pipeline_result = indexing_pipeline(
document_batch=[document],
index_attempt_metadata=IndexAttemptMetadata(
connector_id=cc_pair.connector_id,
@@ -148,16 +145,13 @@ def upsert_ingestion_doc(
active_search_settings.secondary, None
)
sec_ind_pipeline = build_indexing_pipeline(
run_indexing_pipeline(
embedder=new_index_embedding_model,
information_content_classification_model=information_content_classification_model,
document_index=sec_doc_index,
ignore_time_skip=True,
db_session=db_session,
tenant_id=tenant_id,
)
sec_ind_pipeline(
document_batch=[document],
index_attempt_metadata=IndexAttemptMetadata(
connector_id=cc_pair.connector_id,

View File

@@ -720,7 +720,7 @@ def upload_files_for_chat(
detail="File size must be less than 20MB",
)
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
file_info: list[tuple[str, str | None, ChatFileType]] = []
for file in files:
@@ -823,10 +823,9 @@ def upload_files_for_chat(
@router.get("/file/{file_id:path}")
def fetch_chat_file(
file_id: str,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_user),
) -> Response:
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
file_record = file_store.read_file_record(file_id)
if not file_record:
raise HTTPException(status_code=404, detail="File not found")

View File

@@ -11,7 +11,6 @@ from onyx.configs.constants import CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAU
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import ONYX_EMAILABLE_LOGO_MAX_DIM
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.file_store.file_store import get_default_file_store
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.utils.file import FileWithMimeType
@@ -40,9 +39,8 @@ class OnyxRuntime:
onyx_file: FileWithMimeType | None = None
if db_filename:
with get_session_with_shared_schema() as db_session:
file_store = get_default_file_store(db_session)
onyx_file = file_store.get_file_with_mime_type(db_filename)
file_store = get_default_file_store()
onyx_file = file_store.get_file_with_mime_type(db_filename)
if not onyx_file:
onyx_file = OnyxStaticFileManager.get_static(static_filename)

View File

@@ -17,7 +17,6 @@ from requests import JSONDecodeError
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.constants import FileOrigin
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
@@ -205,27 +204,26 @@ class CustomTool(BaseTool):
def _save_and_get_file_references(
self, file_content: bytes | str, content_type: str
) -> List[str]:
with get_session_with_current_tenant() as db_session:
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
file_id = str(uuid.uuid4())
file_id = str(uuid.uuid4())
# Handle both binary and text content
if isinstance(file_content, str):
content = BytesIO(file_content.encode())
else:
content = BytesIO(file_content)
# Handle both binary and text content
if isinstance(file_content, str):
content = BytesIO(file_content.encode())
else:
content = BytesIO(file_content)
file_store.save_file(
file_id=file_id,
content=content,
display_name=file_id,
file_origin=FileOrigin.CHAT_UPLOAD,
file_type=content_type,
file_metadata={
"content_type": content_type,
},
)
file_store.save_file(
file_id=file_id,
content=content,
display_name=file_id,
file_origin=FileOrigin.CHAT_UPLOAD,
file_type=content_type,
file_metadata={
"content_type": content_type,
},
)
return [file_id]
@@ -328,22 +326,21 @@ class CustomTool(BaseTool):
# Load files from storage
files = []
with get_session_with_current_tenant() as db_session:
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
for file_id in response.tool_result.file_ids:
try:
file_io = file_store.read_file(file_id, mode="b")
files.append(
InMemoryChatFile(
file_id=file_id,
filename=file_id,
content=file_io.read(),
file_type=file_type,
)
for file_id in response.tool_result.file_ids:
try:
file_io = file_store.read_file(file_id, mode="b")
files.append(
InMemoryChatFile(
file_id=file_id,
filename=file_id,
content=file_io.read(),
file_type=file_type,
)
except Exception:
logger.exception(f"Failed to read file {file_id}")
)
except Exception:
logger.exception(f"Failed to read file {file_id}")
# Update prompt with file content
prompt_builder.update_user_prompt(

View File

@@ -207,6 +207,7 @@ def setup_logger(
name: str = __name__,
log_level: int = get_log_level_from_str(),
extra: MutableMapping[str, Any] | None = None,
propagate: bool = True,
) -> OnyxLoggingAdapter:
logger = logging.getLogger(name)
@@ -244,6 +245,12 @@ def setup_logger(
logger.notice = lambda msg, *args, **kwargs: logger.log(logging.getLevelName("NOTICE"), msg, *args, **kwargs) # type: ignore
# After handler configuration, disable propagation to avoid duplicate logs
# Prevent messages from propagating to the root logger which can cause
# duplicate log entries when the root logger is also configured with its
# own handler (e.g. by Uvicorn / Celery).
logger.propagate = propagate
return OnyxLoggingAdapter(logger, extra=extra)

View File

@@ -384,3 +384,24 @@ def parallel_yield(gens: list[Iterator[R]], max_workers: int = 10) -> Iterator[R
)
next_ind += 1
del future_to_index[future]
def parallel_yield_from_funcs(
funcs: list[Callable[..., R]],
max_workers: int = 10,
) -> Iterator[R]:
"""
Runs the list of functions with thread-level parallelism, yielding
results as available. The asynchronous nature of this yielding means
that stopping the returned iterator early DOES NOT GUARANTEE THAT NO
FURTHER ITEMS WERE PRODUCED by the input funcs. Only use this function
if you are consuming all elements from the functions OR it is acceptable
for some extra function code to run and not have the result(s) yielded.
"""
def func_wrapper(func: Callable[[], R]) -> Iterator[R]:
yield func()
yield from parallel_yield(
[func_wrapper(func) for func in funcs], max_workers=max_workers
)

View File

@@ -59,23 +59,23 @@ def run_jobs() -> None:
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation",
]
cmd_worker_indexing = [
cmd_worker_docprocessing = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.indexing",
"onyx.background.celery.versioned_apps.docprocessing",
"worker",
"--pool=threads",
"--concurrency=1",
"--concurrency=6",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=indexing@%n",
"--queues=connector_indexing",
"--hostname=docprocessing@%n",
"--queues=docprocessing",
]
cmd_worker_user_files_indexing = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.indexing",
"onyx.background.celery.versioned_apps.docfetching",
"worker",
"--pool=threads",
"--concurrency=1",
@@ -111,6 +111,19 @@ def run_jobs() -> None:
"--queues=kg_processing",
]
cmd_worker_docfetching = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.docfetching",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=docfetching@%n",
"--queues=connector_doc_fetching,user_files_indexing",
]
cmd_beat = [
"celery",
"-A",
@@ -132,8 +145,11 @@ def run_jobs() -> None:
cmd_worker_heavy, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
worker_indexing_process = subprocess.Popen(
cmd_worker_indexing, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
worker_docprocessing_process = subprocess.Popen(
cmd_worker_docprocessing,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
worker_user_files_indexing_process = subprocess.Popen(
@@ -157,6 +173,13 @@ def run_jobs() -> None:
text=True,
)
worker_docfetching_process = subprocess.Popen(
cmd_worker_docfetching,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
beat_process = subprocess.Popen(
cmd_beat, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
@@ -171,8 +194,8 @@ def run_jobs() -> None:
worker_heavy_thread = threading.Thread(
target=monitor_process, args=("HEAVY", worker_heavy_process)
)
worker_indexing_thread = threading.Thread(
target=monitor_process, args=("INDEX", worker_indexing_process)
worker_docprocessing_thread = threading.Thread(
target=monitor_process, args=("DOCPROCESSING", worker_docprocessing_process)
)
worker_user_files_indexing_thread = threading.Thread(
target=monitor_process,
@@ -184,24 +207,29 @@ def run_jobs() -> None:
worker_kg_processing_thread = threading.Thread(
target=monitor_process, args=("KG_PROCESSING", worker_kg_processing_process)
)
worker_docfetching_thread = threading.Thread(
target=monitor_process, args=("DOCFETCHING", worker_docfetching_process)
)
beat_thread = threading.Thread(target=monitor_process, args=("BEAT", beat_process))
worker_primary_thread.start()
worker_light_thread.start()
worker_heavy_thread.start()
worker_indexing_thread.start()
worker_docprocessing_thread.start()
worker_user_files_indexing_thread.start()
worker_monitoring_thread.start()
worker_kg_processing_thread.start()
worker_docfetching_thread.start()
beat_thread.start()
worker_primary_thread.join()
worker_light_thread.join()
worker_heavy_thread.join()
worker_indexing_thread.join()
worker_docprocessing_thread.join()
worker_user_files_indexing_thread.join()
worker_monitoring_thread.join()
worker_kg_processing_thread.join()
worker_docfetching_thread.join()
beat_thread.join()

View File

@@ -213,7 +213,7 @@ def _delete_connector(cc_pair_id: int, db_session: Session) -> None:
if file_names:
logger.notice("Deleting stored files!")
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
for file_name in file_names:
logger.notice(f"Deleting file {file_name}")
file_store.delete_file(file_name)

View File

@@ -65,12 +65,12 @@ autorestart=true
startsecs=10
stopasgroup=true
[program:celery_worker_indexing]
command=celery -A onyx.background.celery.versioned_apps.indexing worker
[program:celery_worker_docprocessing]
command=celery -A onyx.background.celery.versioned_apps.docprocessing worker
--loglevel=INFO
--hostname=indexing@%%n
-Q connector_indexing
stdout_logfile=/var/log/celery_worker_indexing.log
--hostname=docprocessing@%%n
-Q docprocessing
stdout_logfile=/var/log/celery_worker_docprocessing.log
stdout_logfile_maxbytes=16MB
redirect_stderr=true
autorestart=true
@@ -78,7 +78,7 @@ startsecs=10
stopasgroup=true
[program:celery_worker_user_files_indexing]
command=celery -A onyx.background.celery.versioned_apps.indexing worker
command=celery -A onyx.background.celery.versioned_apps.docfetching worker
--loglevel=INFO
--hostname=user_files_indexing@%%n
-Q user_files_indexing
@@ -89,6 +89,18 @@ autorestart=true
startsecs=10
stopasgroup=true
[program:celery_worker_docfetching]
command=celery -A onyx.background.celery.versioned_apps.docfetching worker
--loglevel=INFO
--hostname=docfetching@%%n
-Q connector_doc_fetching
stdout_logfile=/var/log/celery_worker_docfetching.log
stdout_logfile_maxbytes=16MB
redirect_stderr=true
autorestart=true
startsecs=10
stopasgroup=true
[program:celery_worker_monitoring]
command=celery -A onyx.background.celery.versioned_apps.monitoring worker
--loglevel=INFO
@@ -161,6 +173,7 @@ command=tail -qF
/var/log/celery_worker_heavy.log
/var/log/celery_worker_indexing.log
/var/log/celery_worker_user_files_indexing.log
/var/log/celery_worker_docfetching.log
/var/log/celery_worker_monitoring.log
/var/log/slack_bot.log
/var/log/supervisord_watchdog_celery_beat.log

View File

@@ -30,14 +30,12 @@ def mock_filestore_record() -> MagicMock:
@patch("onyx.connectors.file.connector.get_default_file_store")
@patch("onyx.connectors.file.connector.get_session_with_current_tenant")
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key", return_value=None
)
def test_single_text_file_with_metadata(
mock_get_unstructured_api_key: MagicMock,
mock_get_session: MagicMock,
mock_get_filestore: MagicMock,
mock_db_session: MagicMock,
mock_file_store: MagicMock,
mock_filestore_record: MagicMock,
@@ -48,13 +46,18 @@ def test_single_text_file_with_metadata(
"doc_updated_at": "2001-01-01T00:00:00Z"}\n'
b"Test answer is 12345"
)
mock_get_filestore = MagicMock()
mock_get_filestore.return_value = mock_file_store
mock_file_store.read_file_record.return_value = mock_filestore_record
mock_get_session.return_value.__enter__.return_value = mock_db_session
mock_file_store.read_file.return_value = file_content
connector = LocalFileConnector(file_locations=["test.txt"], zip_metadata={})
batches = list(connector.load_from_state())
with patch(
"onyx.connectors.file.connector.get_default_file_store",
return_value=mock_file_store,
):
connector = LocalFileConnector(file_locations=["test.txt"], zip_metadata={})
batches = list(connector.load_from_state())
assert len(batches) == 1
docs = batches[0]
@@ -69,26 +72,22 @@ def test_single_text_file_with_metadata(
assert doc.doc_updated_at == datetime(2001, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
@patch("onyx.connectors.file.connector.get_default_file_store")
@patch("onyx.connectors.file.connector.get_session_with_current_tenant")
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key", return_value=None
)
def test_two_text_files_with_zip_metadata(
mock_get_unstructured_api_key: MagicMock,
mock_get_session: MagicMock,
mock_get_filestore: MagicMock,
mock_db_session: MagicMock,
mock_file_store: MagicMock,
) -> None:
file1_content = io.BytesIO(b"File 1 content")
file2_content = io.BytesIO(b"File 2 content")
mock_get_filestore = MagicMock()
mock_get_filestore.return_value = mock_file_store
mock_file_store.read_file_record.side_effect = [
MagicMock(file_id=str(uuid4()), display_name="file1.txt"),
MagicMock(file_id=str(uuid4()), display_name="file2.txt"),
]
mock_get_session.return_value.__enter__.return_value = mock_db_session
mock_file_store.read_file.side_effect = [file1_content, file2_content]
zip_metadata = {
"file1.txt": {
@@ -109,10 +108,14 @@ def test_two_text_files_with_zip_metadata(
},
}
connector = LocalFileConnector(
file_locations=["file1.txt", "file2.txt"], zip_metadata=zip_metadata
)
batches = list(connector.load_from_state())
with patch(
"onyx.connectors.file.connector.get_default_file_store",
return_value=mock_file_store,
):
connector = LocalFileConnector(
file_locations=["file1.txt", "file2.txt"], zip_metadata=zip_metadata
)
batches = list(connector.load_from_state())
assert len(batches) == 1
docs = batches[0]

View File

@@ -70,20 +70,19 @@ def _get_all_backend_configs() -> List[BackendConfig]:
if S3_ENDPOINT_URL:
minio_access_key = "minioadmin"
minio_secret_key = "minioadmin"
if minio_access_key and minio_secret_key:
configs.append(
{
"endpoint_url": S3_ENDPOINT_URL,
"access_key": minio_access_key,
"secret_key": minio_secret_key,
"region": "us-east-1",
"verify_ssl": False,
"backend_name": "MinIO",
}
)
configs.append(
{
"endpoint_url": S3_ENDPOINT_URL,
"access_key": minio_access_key,
"secret_key": minio_secret_key,
"region": "us-east-1",
"verify_ssl": False,
"backend_name": "MinIO",
}
)
# AWS S3 configuration (if credentials are available)
if S3_AWS_ACCESS_KEY_ID and S3_AWS_SECRET_ACCESS_KEY:
elif S3_AWS_ACCESS_KEY_ID and S3_AWS_SECRET_ACCESS_KEY:
configs.append(
{
"endpoint_url": None,
@@ -116,7 +115,6 @@ def file_store(
# Create S3BackedFileStore with backend-specific configuration
store = S3BackedFileStore(
db_session=db_session,
bucket_name=TEST_BUCKET_NAME,
aws_access_key_id=backend_config["access_key"],
aws_secret_access_key=backend_config["secret_key"],
@@ -827,7 +825,6 @@ class TestS3BackedFileStore:
# Create a new database session for each worker to avoid conflicts
with get_session_with_current_tenant() as worker_session:
worker_file_store = S3BackedFileStore(
db_session=worker_session,
bucket_name=current_bucket_name,
aws_access_key_id=current_access_key,
aws_secret_access_key=current_secret_key,
@@ -849,6 +846,7 @@ class TestS3BackedFileStore:
display_name=f"Worker {worker_id} File",
file_origin=file_origin,
file_type=file_type,
db_session=worker_session,
)
results.append((file_name, content))
return True
@@ -885,3 +883,94 @@ class TestS3BackedFileStore:
read_content_io = file_store.read_file(file_id)
actual_content: str = read_content_io.read().decode("utf-8")
assert actual_content == expected_content
def test_list_files_by_prefix(self, file_store: S3BackedFileStore) -> None:
"""Test listing files by prefix returns only correctly prefixed files"""
test_prefix = "documents-batch-"
# Files that should be returned (start with the prefix)
prefixed_files: List[str] = [
f"{test_prefix}001.txt",
f"{test_prefix}002.json",
f"{test_prefix}abc.pdf",
f"{test_prefix}xyz-final.docx",
]
# Files that should NOT be returned (don't start with prefix, even if they contain it)
non_prefixed_files: List[str] = [
f"other-{test_prefix}001.txt", # Contains prefix but doesn't start with it
f"backup-{test_prefix}data.txt", # Contains prefix but doesn't start with it
f"{uuid.uuid4()}.txt", # Random file without prefix
"reports-001.pdf", # Different prefix
f"my-{test_prefix[:-1]}.txt", # Similar but not exact prefix
]
all_files = prefixed_files + non_prefixed_files
saved_file_ids: List[str] = []
# Save all test files
for file_name in all_files:
content = f"Content for {file_name}"
content_io = BytesIO(content.encode("utf-8"))
returned_file_id = file_store.save_file(
content=content_io,
display_name=f"Display: {file_name}",
file_origin=FileOrigin.OTHER,
file_type="text/plain",
file_id=file_name,
)
saved_file_ids.append(returned_file_id)
# Verify file was saved
assert returned_file_id == file_name
# Test the list_files_by_prefix functionality
prefix_results = file_store.list_files_by_prefix(test_prefix)
# Extract file IDs from results
returned_file_ids = [record.file_id for record in prefix_results]
# Verify correct number of files returned
assert len(returned_file_ids) == len(prefixed_files), (
f"Expected {len(prefixed_files)} files with prefix '{test_prefix}', "
f"but got {len(returned_file_ids)}: {returned_file_ids}"
)
# Verify all prefixed files are returned
for expected_file_id in prefixed_files:
assert expected_file_id in returned_file_ids, (
f"File '{expected_file_id}' should be in results but was not found. "
f"Returned files: {returned_file_ids}"
)
# Verify no non-prefixed files are returned
for unexpected_file_id in non_prefixed_files:
assert unexpected_file_id not in returned_file_ids, (
f"File '{unexpected_file_id}' should NOT be in results but was found. "
f"Returned files: {returned_file_ids}"
)
# Verify the returned records have correct properties
for record in prefix_results:
assert record.file_id.startswith(test_prefix)
assert record.display_name == f"Display: {record.file_id}"
assert record.file_origin == FileOrigin.OTHER
assert record.file_type == "text/plain"
assert record.bucket_name == file_store._get_bucket_name()
# Test with empty prefix (should return all files we created)
all_results = file_store.list_files_by_prefix("")
all_returned_ids = [record.file_id for record in all_results]
# Should include all our test files
for file_id in saved_file_ids:
assert (
file_id in all_returned_ids
), f"File '{file_id}' should be in results for empty prefix"
# Test with non-existent prefix
nonexistent_results = file_store.list_files_by_prefix("nonexistent-prefix-")
assert (
len(nonexistent_results) == 0
), "Should return empty list for non-existent prefix"

View File

@@ -6,7 +6,7 @@ API_SERVER_PROTOCOL = os.getenv("API_SERVER_PROTOCOL") or "http"
API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "127.0.0.1"
API_SERVER_PORT = os.getenv("API_SERVER_PORT") or "8080"
API_SERVER_URL = f"{API_SERVER_PROTOCOL}://{API_SERVER_HOST}:{API_SERVER_PORT}"
MAX_DELAY = 60
MAX_DELAY = 300
GENERAL_HEADERS = {"Content-Type": "application/json"}

View File

@@ -203,7 +203,9 @@ class IndexAttemptManager:
)
if index_attempt.status and index_attempt.status.is_terminal():
print(f"IndexAttempt {index_attempt_id} completed")
print(
f"IndexAttempt {index_attempt_id} completed with status {index_attempt.status}"
)
return
elapsed = time.monotonic() - start
@@ -216,6 +218,7 @@ class IndexAttemptManager:
f"Waiting for IndexAttempt {index_attempt_id} to complete. "
f"elapsed={elapsed:.2f} timeout={timeout}"
)
time.sleep(5)
@staticmethod
def get_index_attempt_errors_for_cc_pair(

View File

@@ -22,6 +22,7 @@ from onyx.db.swap_index import check_and_perform_index_swap
from onyx.document_index.document_index_utils import get_multipass_config
from onyx.document_index.vespa.index import DOCUMENT_ID_ENDPOINT
from onyx.document_index.vespa.index import VespaIndex
from onyx.file_store.file_store import get_default_file_store
from onyx.indexing.models import IndexingSetting
from onyx.setup import setup_postgres
from onyx.setup import setup_vespa
@@ -398,6 +399,13 @@ def reset_vespa_multitenant() -> None:
time.sleep(5)
def reset_file_store() -> None:
"""Reset the FileStore."""
filestore = get_default_file_store()
for file_record in filestore.list_files_by_prefix(""):
filestore.delete_file(file_record.file_id)
def reset_all() -> None:
if os.environ.get("SKIP_RESET", "").lower() == "true":
logger.info("Skipping reset.")
@@ -407,6 +415,8 @@ def reset_all() -> None:
reset_postgres()
logger.info("Resetting Vespa...")
reset_vespa()
logger.info("Resetting FileStore...")
reset_file_store()
def reset_all_multitenant() -> None:

View File

@@ -60,7 +60,7 @@ def test_overlapping_connector_creation(reset: None) -> None:
)
CCPairManager.wait_for_indexing_completion(
cc_pair_1, now, timeout=120, user_performing_action=admin_user
cc_pair_1, now, timeout=300, user_performing_action=admin_user
)
now = datetime.now(timezone.utc)
@@ -74,7 +74,7 @@ def test_overlapping_connector_creation(reset: None) -> None:
)
CCPairManager.wait_for_indexing_completion(
cc_pair_2, now, timeout=120, user_performing_action=admin_user
cc_pair_2, now, timeout=300, user_performing_action=admin_user
)
info_1 = CCPairManager.get_single(cc_pair_1.id, user_performing_action=admin_user)

View File

@@ -90,7 +90,7 @@ def test_image_indexing(
CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=datetime.now(timezone.utc),
timeout=180,
timeout=300,
user_performing_action=admin_user,
)

View File

@@ -343,8 +343,6 @@ def test_mock_connector_checkpoint_recovery(
"""Test that checkpointing works correctly when an unhandled exception occurs
and that subsequent runs pick up from the last successful checkpoint."""
# Create test documents
# Create 100 docs for first batch, this is needed to get past the
# `_NUM_DOCS_INDEXED_TO_BE_VALID_CHECKPOINT` logic in `get_latest_valid_checkpoint`.
docs_batch_1 = [create_test_document() for _ in range(100)]
doc2 = create_test_document()
doc3 = create_test_document()
@@ -421,10 +419,13 @@ def test_mock_connector_checkpoint_recovery(
db_session=db_session,
vespa_client=vespa_client,
)
assert len(documents) == 101 # 100 docs from first batch + doc2
document_ids = {doc.id for doc in documents}
assert doc2.id in document_ids
assert all(doc.id in document_ids for doc in docs_batch_1)
# This is no longer guaranteed because docfetching and docprocessing are decoupled!
# Some batches may not be processed when docfetching fails, but they should still stick around
# in the filestore and be ready for the next run.
# assert len(documents) == 101 # 100 docs from first batch + doc2
# document_ids = {doc.id for doc in documents}
# assert doc2.id in document_ids
# assert all(doc.id in document_ids for doc in docs_batch_1)
# Get the checkpoints that were sent to the mock server
response = mock_server_client.get("/get-checkpoints")

View File

@@ -3,7 +3,7 @@ import uuid
import httpx
from onyx.background.celery.tasks.indexing.utils import (
from onyx.background.celery.tasks.docprocessing.utils import (
NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE,
)
from onyx.configs.constants import DocumentSource

View File

@@ -136,7 +136,7 @@ def test_web_pruning(reset: None, vespa_client: vespa_fixture) -> None:
)
CCPairManager.wait_for_indexing_completion(
cc_pair_1, now, timeout=120, user_performing_action=admin_user
cc_pair_1, now, timeout=300, user_performing_action=admin_user
)
selected_cc_pair = CCPairManager.get_indexing_status_by_id(

View File

@@ -77,18 +77,15 @@ def sample_file_io(sample_content: bytes) -> BytesIO:
class TestExternalStorageFileStore:
"""Test external storage file store functionality (S3-compatible)"""
def test_get_default_file_store_s3(self, db_session: Session) -> None:
def test_get_default_file_store_s3(self) -> None:
"""Test that external storage file store is returned"""
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
assert isinstance(file_store, S3BackedFileStore)
def test_s3_client_initialization_with_credentials(
self, db_session: Session
) -> None:
def test_s3_client_initialization_with_credentials(self) -> None:
"""Test S3 client initialization with explicit credentials"""
with patch("boto3.client") as mock_boto3:
file_store = S3BackedFileStore(
db_session,
bucket_name="test-bucket",
aws_access_key_id="test-key",
aws_secret_access_key="test-secret",
@@ -110,7 +107,6 @@ class TestExternalStorageFileStore:
"""Test S3 client initialization with IAM role (no explicit credentials)"""
with patch("boto3.client") as mock_boto3:
file_store = S3BackedFileStore(
db_session,
bucket_name="test-bucket",
aws_access_key_id=None,
aws_secret_access_key=None,
@@ -129,16 +125,16 @@ class TestExternalStorageFileStore:
assert "aws_access_key_id" not in call_kwargs
assert "aws_secret_access_key" not in call_kwargs
def test_s3_bucket_name_configuration(self, db_session: Session) -> None:
def test_s3_bucket_name_configuration(self) -> None:
"""Test S3 bucket name configuration"""
with patch(
"onyx.file_store.file_store.S3_FILE_STORE_BUCKET_NAME", "my-test-bucket"
):
file_store = S3BackedFileStore(db_session, bucket_name="my-test-bucket")
file_store = S3BackedFileStore(bucket_name="my-test-bucket")
bucket_name: str = file_store._get_bucket_name()
assert bucket_name == "my-test-bucket"
def test_s3_key_generation_default_prefix(self, db_session: Session) -> None:
def test_s3_key_generation_default_prefix(self) -> None:
"""Test S3 key generation with default prefix"""
with (
patch("onyx.file_store.file_store.S3_FILE_STORE_PREFIX", "onyx-files"),
@@ -147,11 +143,11 @@ class TestExternalStorageFileStore:
return_value="test-tenant",
),
):
file_store = S3BackedFileStore(db_session, bucket_name="test-bucket")
file_store = S3BackedFileStore(bucket_name="test-bucket")
s3_key: str = file_store._get_s3_key("test-file.txt")
assert s3_key == "onyx-files/test-tenant/test-file.txt"
def test_s3_key_generation_custom_prefix(self, db_session: Session) -> None:
def test_s3_key_generation_custom_prefix(self) -> None:
"""Test S3 key generation with custom prefix"""
with (
patch("onyx.file_store.file_store.S3_FILE_STORE_PREFIX", "custom-prefix"),
@@ -161,17 +157,15 @@ class TestExternalStorageFileStore:
),
):
file_store = S3BackedFileStore(
db_session, bucket_name="test-bucket", s3_prefix="custom-prefix"
bucket_name="test-bucket", s3_prefix="custom-prefix"
)
s3_key: str = file_store._get_s3_key("test-file.txt")
assert s3_key == "custom-prefix/test-tenant/test-file.txt"
def test_s3_key_generation_with_different_tenant_ids(
self, db_session: Session
) -> None:
def test_s3_key_generation_with_different_tenant_ids(self) -> None:
"""Test S3 key generation with different tenant IDs"""
with patch("onyx.file_store.file_store.S3_FILE_STORE_PREFIX", "onyx-files"):
file_store = S3BackedFileStore(db_session, bucket_name="test-bucket")
file_store = S3BackedFileStore(bucket_name="test-bucket")
# Test with tenant ID "tenant-1"
with patch(
@@ -224,9 +218,7 @@ class TestExternalStorageFileStore:
with patch("onyx.db.file_record.upsert_filerecord") as mock_upsert:
mock_upsert.return_value = Mock()
file_store = S3BackedFileStore(
mock_db_session, bucket_name="test-bucket"
)
file_store = S3BackedFileStore(bucket_name="test-bucket")
# This should not raise an exception
file_store.save_file(
@@ -235,6 +227,7 @@ class TestExternalStorageFileStore:
display_name="Test File",
file_origin=FileOrigin.OTHER,
file_type="text/plain",
db_session=mock_db_session,
)
# Verify S3 client was called correctly
@@ -244,14 +237,13 @@ class TestExternalStorageFileStore:
assert call_args[1]["Key"] == "onyx-files/public/test-file.txt"
assert call_args[1]["ContentType"] == "text/plain"
def test_minio_client_initialization(self, db_session: Session) -> None:
def test_minio_client_initialization(self) -> None:
"""Test S3 client initialization with MinIO endpoint"""
with (
patch("boto3.client") as mock_boto3,
patch("urllib3.disable_warnings"),
):
file_store = S3BackedFileStore(
db_session,
bucket_name="test-bucket",
aws_access_key_id="minioadmin",
aws_secret_access_key="minioadmin",
@@ -277,11 +269,10 @@ class TestExternalStorageFileStore:
assert config.signature_version == "s3v4"
assert config.s3["addressing_style"] == "path"
def test_minio_ssl_verification_enabled(self, db_session: Session) -> None:
def test_minio_ssl_verification_enabled(self) -> None:
"""Test MinIO with SSL verification enabled"""
with patch("boto3.client") as mock_boto3:
file_store = S3BackedFileStore(
db_session,
bucket_name="test-bucket",
aws_access_key_id="test-key",
aws_secret_access_key="test-secret",
@@ -295,11 +286,10 @@ class TestExternalStorageFileStore:
assert "verify" not in call_kwargs or call_kwargs.get("verify") is not False
assert call_kwargs["endpoint_url"] == "https://minio.example.com"
def test_aws_s3_without_endpoint_url(self, db_session: Session) -> None:
def test_aws_s3_without_endpoint_url(self) -> None:
"""Test that regular AWS S3 doesn't include endpoint URL or custom config"""
with patch("boto3.client") as mock_boto3:
file_store = S3BackedFileStore(
db_session,
bucket_name="test-bucket",
aws_access_key_id="test-key",
aws_secret_access_key="test-secret",
@@ -321,8 +311,8 @@ class TestExternalStorageFileStore:
class TestFileStoreInterface:
"""Test the general file store interface"""
def test_file_store_always_external_storage(self, db_session: Session) -> None:
def test_file_store_always_external_storage(self) -> None:
"""Test that external storage file store is always returned"""
# File store should always be S3BackedFileStore regardless of environment
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
assert isinstance(file_store, S3BackedFileStore)

Some files were not shown because too many files have changed in this diff Show More