forked from github/onyx
feat: connector indexing decoupling (#4893)
* WIP * renamed and moved tasks (WIP) * minio migration * bug fixes and finally add document batch storage * WIP: can suceed but status is error * WIP * import fixes * working v1 of decoupled * catastrophe handling * refactor * remove unused db session in prep for new approach * renaming and docstrings (untested) * renames * WIP with no more indexing fences * robustness improvements * clean up rebase * migration and salesforce rate limits * minor tweaks * test fix * connector pausing behavior * correct checkpoint resumption logic * cleanups in docfetching * add heartbeat file * update template jsonc * deployment fixes * fix vespa httpx pool * error handling * cosmetic fixes * dumb * logging improvements and non checkpointed connector fixes * didnt save * misc fixes * fix import * fix deletion of old files * add in attempt prefix * fix attempt prefix * tiny log improvement * minor changes * fixed resumption behavior * passing int tests * fix unit test * fixed unit tests * trying timeout bump to see if int tests pass * trying timeout bump to see if int tests pass * fix autodiscovery * helm chart fixes * helm and logging
This commit is contained in:
99
.vscode/launch.template.jsonc
vendored
99
.vscode/launch.template.jsonc
vendored
@@ -46,7 +46,8 @@
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery indexing",
|
||||
"Celery docfetching",
|
||||
"Celery docprocessing",
|
||||
"Celery user files indexing",
|
||||
"Celery beat",
|
||||
"Celery monitoring"
|
||||
@@ -226,35 +227,66 @@
|
||||
"consoleTitle": "Celery heavy Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery indexing",
|
||||
"name": "Celery docfetching",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"ENABLE_MULTIPASS_INDEXING": "false",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.indexing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=indexing@%n",
|
||||
"-Q",
|
||||
"connector_indexing"
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.docfetching",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=docfetching@%n",
|
||||
"-Q",
|
||||
"connector_doc_fetching,user_files_indexing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery indexing Console"
|
||||
},
|
||||
"consoleTitle": "Celery docfetching Console",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Celery docprocessing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"ENABLE_MULTIPASS_INDEXING": "false",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.docprocessing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=6",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=docprocessing@%n",
|
||||
"-Q",
|
||||
"docprocessing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery docprocessing Console",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Celery monitoring",
|
||||
"type": "debugpy",
|
||||
@@ -303,35 +335,6 @@
|
||||
},
|
||||
"consoleTitle": "Celery beat Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery user files indexing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.indexing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=user_files_indexing@%n",
|
||||
"-Q",
|
||||
"user_files_indexing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery user files indexing Console"
|
||||
},
|
||||
{
|
||||
"name": "Pytest",
|
||||
"consoleName": "Pytest",
|
||||
|
||||
@@ -96,7 +96,7 @@ def get_google_drive_documents_from_database() -> list[dict]:
|
||||
result = bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT d.id, cc.id as cc_pair_id
|
||||
SELECT d.id
|
||||
FROM document d
|
||||
JOIN document_by_connector_credential_pair dcc ON d.id = dcc.id
|
||||
JOIN connector_credential_pair cc ON dcc.connector_id = cc.connector_id
|
||||
@@ -109,7 +109,7 @@ def get_google_drive_documents_from_database() -> list[dict]:
|
||||
|
||||
documents = []
|
||||
for row in result:
|
||||
documents.append({"document_id": row.id, "cc_pair_id": row.cc_pair_id})
|
||||
documents.append({"document_id": row.id})
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
"""add_indexing_coordination
|
||||
|
||||
Revision ID: 2f95e36923e6
|
||||
Revises: 0816326d83aa
|
||||
Create Date: 2025-07-10 16:17:57.762182
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2f95e36923e6"
|
||||
down_revision = "0816326d83aa"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add database-based coordination fields (replacing Redis fencing)
|
||||
op.add_column(
|
||||
"index_attempt", sa.Column("celery_task_id", sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column(
|
||||
"cancellation_requested",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
|
||||
# Add batch coordination fields (replacing FileStore state)
|
||||
op.add_column(
|
||||
"index_attempt", sa.Column("total_batches", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column(
|
||||
"completed_batches", sa.Integer(), nullable=False, server_default="0"
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column(
|
||||
"total_failures_batch_level",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default="0",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("total_chunks", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
|
||||
# Progress tracking for stall detection
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("last_progress_time", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column(
|
||||
"last_batches_completed_count",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default="0",
|
||||
),
|
||||
)
|
||||
|
||||
# Heartbeat tracking for worker liveness detection
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column(
|
||||
"heartbeat_counter", sa.Integer(), nullable=False, server_default="0"
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column(
|
||||
"last_heartbeat_value", sa.Integer(), nullable=False, server_default="0"
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("last_heartbeat_time", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
|
||||
# Add index for coordination queries
|
||||
op.create_index(
|
||||
"ix_index_attempt_active_coordination",
|
||||
"index_attempt",
|
||||
["connector_credential_pair_id", "search_settings_id", "status"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the new index
|
||||
op.drop_index("ix_index_attempt_active_coordination", table_name="index_attempt")
|
||||
|
||||
# Remove the new columns
|
||||
op.drop_column("index_attempt", "last_batches_completed_count")
|
||||
op.drop_column("index_attempt", "last_progress_time")
|
||||
op.drop_column("index_attempt", "last_heartbeat_time")
|
||||
op.drop_column("index_attempt", "last_heartbeat_value")
|
||||
op.drop_column("index_attempt", "heartbeat_counter")
|
||||
op.drop_column("index_attempt", "total_chunks")
|
||||
op.drop_column("index_attempt", "total_failures_batch_level")
|
||||
op.drop_column("index_attempt", "completed_batches")
|
||||
op.drop_column("index_attempt", "total_batches")
|
||||
op.drop_column("index_attempt", "cancellation_requested")
|
||||
op.drop_column("index_attempt", "celery_task_id")
|
||||
@@ -159,7 +159,7 @@ def _migrate_files_to_postgres() -> None:
|
||||
|
||||
# only create external store if we have files to migrate. This line
|
||||
# makes it so we need to have S3/MinIO configured to run this migration.
|
||||
external_store = get_s3_file_store(db_session=session)
|
||||
external_store = get_s3_file_store()
|
||||
|
||||
for i, file_id in enumerate(files_to_migrate, 1):
|
||||
print(f"Migrating file {i}/{total_files}: {file_id}")
|
||||
@@ -219,7 +219,7 @@ def _migrate_files_to_external_storage() -> None:
|
||||
# Get database session
|
||||
bind = op.get_bind()
|
||||
session = Session(bind=bind)
|
||||
external_store = get_s3_file_store(db_session=session)
|
||||
external_store = get_s3_file_store()
|
||||
|
||||
# Find all files currently stored in PostgreSQL (lobj_oid is not null)
|
||||
result = session.execute(
|
||||
|
||||
@@ -91,7 +91,7 @@ def export_query_history_task(
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
stream.seek(0)
|
||||
get_default_file_store(db_session).save_file(
|
||||
get_default_file_store().save_file(
|
||||
content=stream,
|
||||
display_name=report_name,
|
||||
file_origin=FileOrigin.QUERY_HISTORY_CSV,
|
||||
|
||||
@@ -422,7 +422,7 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
+ f"_{redis_connector.cc_pair_id}",
|
||||
timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
@@ -383,7 +383,7 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
+ f"_{redis_connector.cc_pair_id}",
|
||||
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
|
||||
@@ -114,7 +114,6 @@ def get_all_usage_reports(db_session: Session) -> list[UsageReportMetadata]:
|
||||
|
||||
|
||||
def get_usage_report_data(
|
||||
db_session: Session,
|
||||
report_display_name: str,
|
||||
) -> IO:
|
||||
"""
|
||||
@@ -128,7 +127,7 @@ def get_usage_report_data(
|
||||
Returns:
|
||||
The usage report data.
|
||||
"""
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
# usage report may be very large, so don't load it all into memory
|
||||
return file_store.read_file(
|
||||
file_id=report_display_name, mode="b", use_tempfile=True
|
||||
|
||||
@@ -134,15 +134,14 @@ def ee_fetch_settings() -> EnterpriseSettings:
|
||||
def put_logo(
|
||||
file: UploadFile,
|
||||
is_logotype: bool = False,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
upload_logo(file=file, db_session=db_session, is_logotype=is_logotype)
|
||||
upload_logo(file=file, is_logotype=is_logotype)
|
||||
|
||||
|
||||
def fetch_logo_helper(db_session: Session) -> Response:
|
||||
try:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
onyx_file = file_store.get_file_with_mime_type(get_logo_filename())
|
||||
if not onyx_file:
|
||||
raise ValueError("get_onyx_file returned None!")
|
||||
@@ -158,7 +157,7 @@ def fetch_logo_helper(db_session: Session) -> Response:
|
||||
|
||||
def fetch_logotype_helper(db_session: Session) -> Response:
|
||||
try:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
onyx_file = file_store.get_file_with_mime_type(get_logotype_filename())
|
||||
if not onyx_file:
|
||||
raise ValueError("get_onyx_file returned None!")
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import IO
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload
|
||||
from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
|
||||
@@ -99,9 +98,7 @@ def guess_file_type(filename: str) -> str:
|
||||
return "application/octet-stream"
|
||||
|
||||
|
||||
def upload_logo(
|
||||
db_session: Session, file: UploadFile | str, is_logotype: bool = False
|
||||
) -> bool:
|
||||
def upload_logo(file: UploadFile | str, is_logotype: bool = False) -> bool:
|
||||
content: IO[Any]
|
||||
|
||||
if isinstance(file, str):
|
||||
@@ -129,7 +126,7 @@ def upload_logo(
|
||||
display_name = file.filename
|
||||
file_type = file.content_type or "image/jpeg"
|
||||
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
file_store.save_file(
|
||||
content=content,
|
||||
display_name=display_name,
|
||||
|
||||
@@ -358,7 +358,7 @@ def get_query_history_export_status(
|
||||
# If task is None, then it's possible that the task has already finished processing.
|
||||
# Therefore, we should then check if the export file has already been stored inside of the file-store.
|
||||
# If that *also* doesn't exist, then we can return a 404.
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
|
||||
report_name = construct_query_history_report_name(request_id)
|
||||
has_file = file_store.has_file(
|
||||
@@ -385,7 +385,7 @@ def download_query_history_csv(
|
||||
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
|
||||
|
||||
report_name = construct_query_history_report_name(request_id)
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
has_file = file_store.has_file(
|
||||
file_id=report_name,
|
||||
file_origin=FileOrigin.QUERY_HISTORY_CSV,
|
||||
|
||||
@@ -53,7 +53,7 @@ def read_usage_report(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
try:
|
||||
file = get_usage_report_data(db_session, report_name)
|
||||
file = get_usage_report_data(report_name)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@@ -112,7 +112,7 @@ def create_new_usage_report(
|
||||
period: tuple[datetime, datetime] | None,
|
||||
) -> UsageReportMetadata:
|
||||
report_id = str(uuid.uuid4())
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
|
||||
messages_file_id = generate_chat_messages_report(
|
||||
db_session, file_store, report_id, period
|
||||
|
||||
@@ -200,10 +200,10 @@ def _seed_enterprise_settings(seed_config: SeedConfiguration) -> None:
|
||||
store_ee_settings(final_enterprise_settings)
|
||||
|
||||
|
||||
def _seed_logo(db_session: Session, logo_path: str | None) -> None:
|
||||
def _seed_logo(logo_path: str | None) -> None:
|
||||
if logo_path:
|
||||
logger.notice("Uploading logo")
|
||||
upload_logo(db_session=db_session, file=logo_path)
|
||||
upload_logo(file=logo_path)
|
||||
|
||||
|
||||
def _seed_analytics_script(seed_config: SeedConfiguration) -> None:
|
||||
@@ -245,7 +245,7 @@ def seed_db() -> None:
|
||||
if seed_config.custom_tools is not None:
|
||||
_seed_custom_tools(db_session, seed_config.custom_tools)
|
||||
|
||||
_seed_logo(db_session, seed_config.seeded_logo_path)
|
||||
_seed_logo(seed_config.seeded_logo_path)
|
||||
_seed_enterprise_settings(seed_config)
|
||||
_seed_analytics_script(seed_config)
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import ColoredFormatter
|
||||
from onyx.utils.logger import LoggerContextVars
|
||||
from onyx.utils.logger import PlainFormatter
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -93,7 +94,13 @@ def on_task_prerun(
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
**other_kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
# Reset any per-task logging context so that prefixes (e.g. pruning_ctx)
|
||||
# from a previous task executed in the same worker process do not leak
|
||||
# into the next task's log messages. This fixes incorrect [CC Pair:/Index Attempt]
|
||||
# prefixes observed when a pruning task finishes and an indexing task
|
||||
# runs in the same process.
|
||||
|
||||
LoggerContextVars.reset()
|
||||
|
||||
|
||||
def on_task_postrun(
|
||||
@@ -474,7 +481,8 @@ class TenantContextFilter(logging.Filter):
|
||||
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id:
|
||||
tenant_id = tenant_id.split(TENANT_ID_PREFIX)[-1][:5]
|
||||
# Match the 8 character tenant abbreviation used in OnyxLoggingAdapter
|
||||
tenant_id = tenant_id.split(TENANT_ID_PREFIX)[-1][:8]
|
||||
record.name = f"[t:{tenant_id}]"
|
||||
else:
|
||||
record.name = ""
|
||||
|
||||
102
backend/onyx/background/celery/apps/docfetching.py
Normal file
102
backend/onyx/background/celery/apps/docfetching.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.docfetching")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME)
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
base_bootsteps = app_base.get_bootsteps()
|
||||
for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
]
|
||||
)
|
||||
@@ -12,7 +12,7 @@ from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -21,7 +21,7 @@ from shared_configs.configs import MULTI_TENANT
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.indexing")
|
||||
celery_app.config_from_object("onyx.background.celery.configs.docprocessing")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME)
|
||||
|
||||
# rkuo: Transient errors keep happening in the indexing watchdog threads.
|
||||
# "SSL connection has been closed unexpectedly"
|
||||
@@ -108,6 +108,6 @@ for bootstep in base_bootsteps:
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.indexing",
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
]
|
||||
)
|
||||
@@ -116,6 +116,6 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.user_file_folder_sync",
|
||||
"onyx.background.celery.tasks.indexing",
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.exceptions import WorkerShutdown
|
||||
from celery.result import AsyncResult
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
@@ -18,9 +19,6 @@ from redis.lock import Lock as RedisLock
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_utils import celery_is_worker_primary
|
||||
from onyx.background.celery.tasks.indexing.utils import (
|
||||
get_unfenced_index_attempt_ids,
|
||||
)
|
||||
from onyx.background.celery.tasks.vespa.document_sync import reset_document_sync
|
||||
from onyx.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
@@ -30,6 +28,7 @@ from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import mark_attempt_canceled
|
||||
from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
|
||||
@@ -168,24 +167,50 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
RedisConnectorExternalGroupSync.reset_all(r)
|
||||
|
||||
# mark orphaned index attempts as failed
|
||||
# This uses database coordination instead of Redis fencing
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
|
||||
for attempt_id in unfenced_attempt_ids:
|
||||
# Get potentially orphaned attempts (those with active status and task IDs)
|
||||
potentially_orphaned_ids = IndexingCoordination.get_orphaned_index_attempt_ids(
|
||||
db_session
|
||||
)
|
||||
|
||||
for attempt_id in potentially_orphaned_ids:
|
||||
attempt = get_index_attempt(db_session, attempt_id)
|
||||
if not attempt:
|
||||
|
||||
# handle case where not started or docfetching is done but indexing is not
|
||||
if (
|
||||
not attempt
|
||||
or not attempt.celery_task_id
|
||||
or attempt.total_batches is not None
|
||||
):
|
||||
continue
|
||||
|
||||
failure_reason = (
|
||||
f"Canceling leftover index attempt found on startup: "
|
||||
f"index_attempt={attempt.id} "
|
||||
f"cc_pair={attempt.connector_credential_pair_id} "
|
||||
f"search_settings={attempt.search_settings_id}"
|
||||
)
|
||||
logger.warning(failure_reason)
|
||||
logger.exception(
|
||||
f"Marking attempt {attempt.id} as canceled due to validation error 2"
|
||||
)
|
||||
mark_attempt_canceled(attempt.id, db_session, failure_reason)
|
||||
# Check if the Celery task actually exists
|
||||
try:
|
||||
|
||||
result: AsyncResult = AsyncResult(attempt.celery_task_id)
|
||||
|
||||
# If the task is not in PENDING state, it exists in Celery
|
||||
if result.state != "PENDING":
|
||||
continue
|
||||
|
||||
# Task is orphaned - mark as failed
|
||||
failure_reason = (
|
||||
f"Orphaned index attempt found on startup - Celery task not found: "
|
||||
f"index_attempt={attempt.id} "
|
||||
f"cc_pair={attempt.connector_credential_pair_id} "
|
||||
f"search_settings={attempt.search_settings_id} "
|
||||
f"celery_task_id={attempt.celery_task_id}"
|
||||
)
|
||||
logger.warning(failure_reason)
|
||||
mark_attempt_canceled(attempt.id, db_session, failure_reason)
|
||||
|
||||
except Exception:
|
||||
# If we can't check the task status, be conservative and continue
|
||||
logger.warning(
|
||||
f"Could not verify Celery task status on startup for attempt {attempt.id}, "
|
||||
f"task_id={attempt.celery_task_id}"
|
||||
)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
@@ -292,7 +317,7 @@ for bootstep in base_bootsteps:
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.indexing",
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
"onyx.background.celery.tasks.periodic",
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.shared",
|
||||
|
||||
@@ -26,7 +26,7 @@ def celery_get_unacked_length(r: Redis) -> int:
|
||||
def celery_get_unacked_task_ids(queue: str, r: Redis) -> set[str]:
|
||||
"""Gets the set of task id's matching the given queue in the unacked hash.
|
||||
|
||||
Unacked entries belonging to the indexing queue are "prefetched", so this gives
|
||||
Unacked entries belonging to the indexing queues are "prefetched", so this gives
|
||||
us crucial visibility as to what tasks are in that state.
|
||||
"""
|
||||
tasks: set[str] = set()
|
||||
|
||||
22
backend/onyx/background/celery/configs/docfetching.py
Normal file
22
backend/onyx/background/celery/configs/docfetching.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
from onyx.configs.app_configs import CELERY_WORKER_DOCFETCHING_CONCURRENCY
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
# Docfetching worker configuration
|
||||
worker_concurrency = CELERY_WORKER_DOCFETCHING_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
@@ -1,5 +1,5 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
from onyx.configs.app_configs import CELERY_WORKER_INDEXING_CONCURRENCY
|
||||
from onyx.configs.app_configs import CELERY_WORKER_DOCPROCESSING_CONCURRENCY
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
@@ -24,6 +24,6 @@ task_acks_late = shared_config.task_acks_late
|
||||
# which means a duplicate run might change the task state unexpectedly
|
||||
# task_track_started = True
|
||||
|
||||
worker_concurrency = CELERY_WORKER_INDEXING_CONCURRENCY
|
||||
worker_concurrency = CELERY_WORKER_DOCPROCESSING_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
@@ -40,9 +40,11 @@ from onyx.db.document import get_document_ids_for_connector_credential_pair
|
||||
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.index_attempt import delete_index_attempts
|
||||
from onyx.db.index_attempt import get_recent_attempts_for_cc_pair
|
||||
from onyx.db.search_settings import get_all_search_settings
|
||||
from onyx.db.sync_record import cleanup_sync_records
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
@@ -69,13 +71,21 @@ def revoke_tasks_blocking_deletion(
|
||||
) -> None:
|
||||
search_settings_list = get_all_search_settings(db_session)
|
||||
for search_settings in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
try:
|
||||
index_payload = redis_connector_index.payload
|
||||
if index_payload and index_payload.celery_task_id:
|
||||
app.control.revoke(index_payload.celery_task_id)
|
||||
recent_index_attempts = get_recent_attempts_for_cc_pair(
|
||||
cc_pair_id=redis_connector.cc_pair_id,
|
||||
search_settings_id=search_settings.id,
|
||||
limit=1,
|
||||
db_session=db_session,
|
||||
)
|
||||
if (
|
||||
recent_index_attempts
|
||||
and recent_index_attempts[0].status == IndexingStatus.IN_PROGRESS
|
||||
and recent_index_attempts[0].celery_task_id
|
||||
):
|
||||
app.control.revoke(recent_index_attempts[0].celery_task_id)
|
||||
task_logger.info(
|
||||
f"Revoked indexing task {index_payload.celery_task_id}."
|
||||
f"Revoked indexing task {recent_index_attempts[0].celery_task_id}."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while revoking indexing task")
|
||||
@@ -281,8 +291,16 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
# do not proceed if connector indexing or connector pruning are running
|
||||
search_settings_list = get_all_search_settings(db_session)
|
||||
for search_settings in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
if redis_connector_index.fenced:
|
||||
recent_index_attempts = get_recent_attempts_for_cc_pair(
|
||||
cc_pair_id=cc_pair_id,
|
||||
search_settings_id=search_settings.id,
|
||||
limit=1,
|
||||
db_session=db_session,
|
||||
)
|
||||
if (
|
||||
recent_index_attempts
|
||||
and recent_index_attempts[0].status == IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
raise TaskDependencyError(
|
||||
"Connector deletion - Delayed (indexing in progress): "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
|
||||
675
backend/onyx/background/celery/tasks/docfetching/tasks.py
Normal file
675
backend/onyx/background/celery/tasks/docfetching/tasks.py
Normal file
@@ -0,0 +1,675 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from http import HTTPStatus
|
||||
from time import sleep
|
||||
|
||||
import sentry_sdk
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
from onyx.background.celery.tasks.docprocessing.heartbeat import start_heartbeat
|
||||
from onyx.background.celery.tasks.docprocessing.heartbeat import stop_heartbeat
|
||||
from onyx.background.celery.tasks.docprocessing.tasks import ConnectorIndexingLogBuilder
|
||||
from onyx.background.celery.tasks.docprocessing.utils import IndexingCallback
|
||||
from onyx.background.celery.tasks.models import DocProcessingContext
|
||||
from onyx.background.celery.tasks.models import IndexingWatchdogTerminalStatus
|
||||
from onyx.background.celery.tasks.models import SimpleJobResult
|
||||
from onyx.background.indexing.job_client import SimpleJob
|
||||
from onyx.background.indexing.job_client import SimpleJobClient
|
||||
from onyx.background.indexing.job_client import SimpleJobException
|
||||
from onyx.background.indexing.run_docfetching import run_indexing_entrypoint
|
||||
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import mark_attempt_canceled
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _verify_indexing_attempt(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
) -> None:
|
||||
"""
|
||||
Verify that the indexing attempt exists and is in the correct state.
|
||||
"""
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
|
||||
if not attempt:
|
||||
raise SimpleJobException(
|
||||
f"docfetching_task - IndexAttempt not found: attempt_id={index_attempt_id}",
|
||||
code=IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND.code,
|
||||
)
|
||||
|
||||
if attempt.connector_credential_pair_id != cc_pair_id:
|
||||
raise SimpleJobException(
|
||||
f"docfetching_task - CC pair mismatch: "
|
||||
f"expected={cc_pair_id} actual={attempt.connector_credential_pair_id}",
|
||||
code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code,
|
||||
)
|
||||
|
||||
if attempt.search_settings_id != search_settings_id:
|
||||
raise SimpleJobException(
|
||||
f"docfetching_task - Search settings mismatch: "
|
||||
f"expected={search_settings_id} actual={attempt.search_settings_id}",
|
||||
code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code,
|
||||
)
|
||||
|
||||
if attempt.status not in [
|
||||
IndexingStatus.NOT_STARTED,
|
||||
IndexingStatus.IN_PROGRESS,
|
||||
]:
|
||||
raise SimpleJobException(
|
||||
f"docfetching_task - Invalid attempt status: "
|
||||
f"attempt_id={index_attempt_id} status={attempt.status}",
|
||||
code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code,
|
||||
)
|
||||
|
||||
# Check for cancellation
|
||||
if IndexingCoordination.check_cancellation_requested(
|
||||
db_session, index_attempt_id
|
||||
):
|
||||
raise SimpleJobException(
|
||||
f"docfetching_task - Cancellation requested: attempt_id={index_attempt_id}",
|
||||
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL.code,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"docfetching_task - IndexAttempt verified: "
|
||||
f"attempt_id={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
|
||||
def docfetching_task(
|
||||
app: Celery,
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
is_ee: bool,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
This function is run in a SimpleJob as a new process. It is responsible for validating
|
||||
some stuff, but basically it just calls run_indexing_entrypoint.
|
||||
|
||||
NOTE: if an exception is raised out of this task, the primary worker will detect
|
||||
that the task transitioned to a "READY" state but the generator_complete_key doesn't exist.
|
||||
This will cause the primary worker to abort the indexing attempt and clean up.
|
||||
"""
|
||||
|
||||
# Start heartbeat for this indexing attempt
|
||||
heartbeat_thread, stop_event = start_heartbeat(index_attempt_id)
|
||||
try:
|
||||
_docfetching_task(
|
||||
app, index_attempt_id, cc_pair_id, search_settings_id, is_ee, tenant_id
|
||||
)
|
||||
finally:
|
||||
stop_heartbeat(heartbeat_thread, stop_event) # Stop heartbeat before exiting
|
||||
|
||||
|
||||
def _docfetching_task(
|
||||
app: Celery,
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
is_ee: bool,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
# Since connector_indexing_proxy_task spawns a new process using this function as
|
||||
# the entrypoint, we init Sentry here.
|
||||
if SENTRY_DSN:
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
traces_sample_rate=0.1,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
|
||||
|
||||
logger.info(
|
||||
f"Indexing spawned task starting: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
# TODO: remove all fences, cause all signals to be set in postgres
|
||||
if redis_connector.delete.fenced:
|
||||
raise SimpleJobException(
|
||||
f"Indexing will not start because connector deletion is in progress: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={redis_connector.delete.fence_key}",
|
||||
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION.code,
|
||||
)
|
||||
|
||||
if redis_connector.stop.fenced:
|
||||
raise SimpleJobException(
|
||||
f"Indexing will not start because a connector stop signal was detected: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={redis_connector.stop.fence_key}",
|
||||
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL.code,
|
||||
)
|
||||
|
||||
# Verify the indexing attempt exists and is valid
|
||||
# This replaces the Redis fence payload waiting
|
||||
_verify_indexing_attempt(index_attempt_id, cc_pair_id, search_settings_id)
|
||||
|
||||
# We still need a basic Redis lock to prevent duplicate task execution
|
||||
# but this is much simpler than the full fencing mechanism
|
||||
r = get_redis_client()
|
||||
# set thread_local=False since we don't control what thread the indexing/pruning
|
||||
# might run our callback with
|
||||
lock: RedisLock = r.lock(
|
||||
redis_connector_index.generator_lock_key,
|
||||
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
logger.warning(
|
||||
f"Docfetching task already running, exiting...: "
|
||||
f"index_attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
raise SimpleJobException(
|
||||
f"Docfetching task already running, exiting...: "
|
||||
f"index_attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}",
|
||||
code=IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING.code,
|
||||
)
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if not attempt:
|
||||
raise SimpleJobException(
|
||||
f"Index attempt not found: index_attempt={index_attempt_id}",
|
||||
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
|
||||
)
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
if not cc_pair:
|
||||
raise SimpleJobException(
|
||||
f"cc_pair not found: cc_pair={cc_pair_id}",
|
||||
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
|
||||
)
|
||||
|
||||
# define a callback class
|
||||
callback = IndexingCallback(
|
||||
os.getppid(),
|
||||
redis_connector,
|
||||
lock,
|
||||
r,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Indexing spawned task running entrypoint: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
# This is where the heavy/real work happens
|
||||
run_indexing_entrypoint(
|
||||
app,
|
||||
index_attempt_id,
|
||||
tenant_id,
|
||||
cc_pair_id,
|
||||
is_ee,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
except ConnectorValidationError:
|
||||
raise SimpleJobException(
|
||||
f"Indexing task failed: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}",
|
||||
code=IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR.code,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Indexing spawned task failed: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
# special bulletproofing ... truncate long exception messages
|
||||
# for exception types that require more args, this will fail
|
||||
# thus the try/except
|
||||
try:
|
||||
sanitized_e = type(e)(str(e)[:1024])
|
||||
sanitized_e.__traceback__ = e.__traceback__
|
||||
raise sanitized_e
|
||||
except Exception:
|
||||
raise e
|
||||
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
logger.info(
|
||||
f"Indexing spawned task finished: attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
os._exit(0) # ensure process exits cleanly
|
||||
|
||||
|
||||
def process_job_result(
|
||||
job: SimpleJob,
|
||||
connector_source: str | None,
|
||||
redis_connector_index: RedisConnectorIndex,
|
||||
log_builder: ConnectorIndexingLogBuilder,
|
||||
) -> SimpleJobResult:
|
||||
result = SimpleJobResult()
|
||||
result.connector_source = connector_source
|
||||
|
||||
if job.process:
|
||||
result.exit_code = job.process.exitcode
|
||||
|
||||
if job.status != "error":
|
||||
result.status = IndexingWatchdogTerminalStatus.SUCCEEDED
|
||||
return result
|
||||
|
||||
ignore_exitcode = False
|
||||
|
||||
# In EKS, there is an edge case where successful tasks return exit
|
||||
# code 1 in the cloud due to the set_spawn_method not sticking.
|
||||
# We've since worked around this, but the following is a safe way to
|
||||
# work around this issue. Basically, we ignore the job error state
|
||||
# if the completion signal is OK.
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int:
|
||||
status_enum = HTTPStatus(status_int)
|
||||
if status_enum == HTTPStatus.OK:
|
||||
ignore_exitcode = True
|
||||
|
||||
if ignore_exitcode:
|
||||
result.status = IndexingWatchdogTerminalStatus.SUCCEEDED
|
||||
task_logger.warning(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - spawned task has non-zero exit code "
|
||||
"but completion signal is OK. Continuing...",
|
||||
exit_code=str(result.exit_code),
|
||||
)
|
||||
)
|
||||
else:
|
||||
if result.exit_code is not None:
|
||||
result.status = IndexingWatchdogTerminalStatus.from_code(result.exit_code)
|
||||
|
||||
result.exception_str = job.exception()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
|
||||
bind=True,
|
||||
acks_late=False,
|
||||
track_started=True,
|
||||
)
|
||||
def docfetching_proxy_task(
|
||||
self: Task,
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
This task is the entrypoint for the full indexing pipeline, which is composed of two tasks:
|
||||
docfetching and docprocessing.
|
||||
This task is spawned by "try_creating_indexing_task" which is called in the "check_for_indexing" task.
|
||||
|
||||
This task spawns a new process for a new scheduled index attempt. That
|
||||
new process (which runs the docfetching_task function) does the following:
|
||||
|
||||
1) determines parameters of the indexing attempt (which connector indexing function to run,
|
||||
start and end time, from prev checkpoint or not), then run that connector. Specifically,
|
||||
connectors are responsible for reading data from an outside source and converting it to Onyx documents.
|
||||
At the moment these two steps (reading external data and converting to an Onyx document)
|
||||
are not parallelized in most connectors; that's a subject for future work.
|
||||
|
||||
Each document batch produced by step 1 is stored in the file store, and a docprocessing task is spawned
|
||||
to process it. docprocessing involves the steps listed below.
|
||||
|
||||
2) upserts documents to postgres (index_doc_batch_prepare)
|
||||
3) chunks each document (optionally adds context for contextual rag)
|
||||
4) embeds chunks (embed_chunks_with_failure_handling) via a call to the model server
|
||||
5) write chunks to vespa (write_chunks_to_vector_db_with_backoff)
|
||||
6) update document and indexing metadata in postgres
|
||||
7) pulls all document IDs from the source and compares those IDs to locally stored documents and deletes
|
||||
all locally stored IDs missing from the most recently pulled document ID list
|
||||
|
||||
Some important notes:
|
||||
Invariants:
|
||||
- docfetching proxy tasks are spawned by check_for_indexing. The proxy then runs the docfetching_task wrapped in a watchdog.
|
||||
The watchdog is responsible for monitoring the docfetching_task and marking the index attempt as failed
|
||||
if it is not making progress.
|
||||
- All docprocessing tasks are spawned by a docfetching task.
|
||||
- all docfetching tasks, docprocessing tasks, and document batches in the file store are
|
||||
associated with a specific index attempt.
|
||||
- the index attempt status is the source of truth for what is currently happening with the index attempt.
|
||||
It is coupled with the creation/running of docfetching and docprocessing tasks as much as possible.
|
||||
|
||||
How we deal with failures/ partial indexing:
|
||||
- non-checkpointed connectors/ new runs in general => delete the old document batches from the file store and do the new run
|
||||
- checkpointed connectors + resuming from checkpoint => reissue the old document batches and do a new run
|
||||
|
||||
Misc:
|
||||
- most inter-process communication is handled in postgres, some is still in redis and we're trying to remove it
|
||||
- Heartbeat spawned in docfetching and docprocessing is how check_for_indexing monitors liveliness
|
||||
- progress based liveliness check: if nothing is done in 3-6 hours, mark the attempt as failed
|
||||
- TODO: task level timeouts (i.e. a connector stuck in an infinite loop)
|
||||
|
||||
|
||||
Comments below are from the old version and some may no longer be valid.
|
||||
TODO(rkuo): refactor this so that there is a single return path where we canonically
|
||||
log the result of running this function.
|
||||
|
||||
Some more Richard notes:
|
||||
celery out of process task execution strategy is pool=prefork, but it uses fork,
|
||||
and forking is inherently unstable.
|
||||
|
||||
To work around this, we use pool=threads and proxy our work to a spawned task.
|
||||
|
||||
acks_late must be set to False. Otherwise, celery's visibility timeout will
|
||||
cause any task that runs longer than the timeout to be redispatched by the broker.
|
||||
There appears to be no good workaround for this, so we need to handle redispatching
|
||||
manually.
|
||||
NOTE: we try/except all db access in this function because as a watchdog, this function
|
||||
needs to be extremely stable.
|
||||
"""
|
||||
# TODO: remove dependence on Redis
|
||||
start = time.monotonic()
|
||||
|
||||
result = SimpleJobResult()
|
||||
|
||||
ctx = DocProcessingContext(
|
||||
tenant_id=tenant_id,
|
||||
cc_pair_id=cc_pair_id,
|
||||
search_settings_id=search_settings_id,
|
||||
index_attempt_id=index_attempt_id,
|
||||
)
|
||||
|
||||
log_builder = ConnectorIndexingLogBuilder(ctx)
|
||||
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - starting",
|
||||
mp_start_method=str(multiprocessing.get_start_method()),
|
||||
)
|
||||
)
|
||||
|
||||
if not self.request.id:
|
||||
task_logger.error("self.request.id is None!")
|
||||
|
||||
client = SimpleJobClient()
|
||||
task_logger.info(f"submitting docfetching_task with tenant_id={tenant_id}")
|
||||
|
||||
job = client.submit(
|
||||
docfetching_task,
|
||||
self.app,
|
||||
index_attempt_id,
|
||||
cc_pair_id,
|
||||
search_settings_id,
|
||||
global_version.is_ee_version(),
|
||||
tenant_id,
|
||||
)
|
||||
|
||||
if not job or not job.process:
|
||||
result.status = IndexingWatchdogTerminalStatus.SPAWN_FAILED
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - finished",
|
||||
status=str(result.status.value),
|
||||
exit_code=str(result.exit_code),
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Ensure the process has moved out of the starting state
|
||||
num_waits = 0
|
||||
while True:
|
||||
if num_waits > 15:
|
||||
result.status = IndexingWatchdogTerminalStatus.SPAWN_NOT_ALIVE
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - finished",
|
||||
status=str(result.status.value),
|
||||
exit_code=str(result.exit_code),
|
||||
)
|
||||
)
|
||||
job.release()
|
||||
return
|
||||
|
||||
if job.process.is_alive() or job.process.exitcode is not None:
|
||||
break
|
||||
|
||||
sleep(1)
|
||||
num_waits += 1
|
||||
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - spawn succeeded",
|
||||
pid=str(job.process.pid),
|
||||
)
|
||||
)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
# Track the last time memory info was emitted
|
||||
last_memory_emit_time = 0.0
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session,
|
||||
index_attempt_id=index_attempt_id,
|
||||
eager_load_cc_pair=True,
|
||||
)
|
||||
if not index_attempt:
|
||||
raise RuntimeError("Index attempt not found")
|
||||
|
||||
result.connector_source = (
|
||||
index_attempt.connector_credential_pair.connector.source.value
|
||||
)
|
||||
|
||||
while True:
|
||||
sleep(5)
|
||||
|
||||
time.monotonic()
|
||||
|
||||
# if the job is done, clean up and break
|
||||
if job.done():
|
||||
try:
|
||||
result = process_job_result(
|
||||
job, result.connector_source, redis_connector_index, log_builder
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - spawned task exceptioned"
|
||||
)
|
||||
)
|
||||
finally:
|
||||
job.release()
|
||||
break
|
||||
|
||||
# log the memory usage for tracking down memory leaks / connector-specific memory issues
|
||||
pid = job.process.pid
|
||||
if pid is not None:
|
||||
# Only emit memory info once per minute (60 seconds)
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_memory_emit_time >= 60.0:
|
||||
emit_process_memory(
|
||||
pid,
|
||||
"indexing_worker",
|
||||
{
|
||||
"cc_pair_id": cc_pair_id,
|
||||
"search_settings_id": search_settings_id,
|
||||
"index_attempt_id": index_attempt_id,
|
||||
},
|
||||
)
|
||||
last_memory_emit_time = current_time
|
||||
|
||||
# if the spawned task is still running, restart the check once again
|
||||
# if the index attempt is not in a finished status
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - transient exception looking up index attempt"
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
result.status = IndexingWatchdogTerminalStatus.WATCHDOG_EXCEPTIONED
|
||||
if isinstance(e, ConnectorValidationError):
|
||||
# No need to expose full stack trace for validation errors
|
||||
result.exception_str = str(e)
|
||||
else:
|
||||
result.exception_str = traceback.format_exc()
|
||||
|
||||
# handle exit and reporting
|
||||
elapsed = time.monotonic() - start
|
||||
if result.exception_str is not None:
|
||||
# print with exception
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
failure_reason = (
|
||||
f"Spawned task exceptioned: exit_code={result.exit_code}"
|
||||
)
|
||||
mark_attempt_failed(
|
||||
ctx.index_attempt_id,
|
||||
db_session,
|
||||
failure_reason=failure_reason,
|
||||
full_exception_trace=result.exception_str,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - transient exception marking index attempt as failed"
|
||||
)
|
||||
)
|
||||
|
||||
normalized_exception_str = "None"
|
||||
if result.exception_str:
|
||||
normalized_exception_str = result.exception_str.replace(
|
||||
"\n", "\\n"
|
||||
).replace('"', '\\"')
|
||||
|
||||
task_logger.warning(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - finished",
|
||||
source=result.connector_source,
|
||||
status=result.status.value,
|
||||
exit_code=str(result.exit_code),
|
||||
exception=f'"{normalized_exception_str}"',
|
||||
elapsed=f"{elapsed:.2f}s",
|
||||
)
|
||||
)
|
||||
raise RuntimeError(f"Exception encountered: traceback={result.exception_str}")
|
||||
|
||||
# print without exception
|
||||
if result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
logger.exception(
|
||||
f"Marking attempt {index_attempt_id} as canceled due to termination signal"
|
||||
)
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
"Connector termination signal detected",
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - transient exception marking index attempt as canceled"
|
||||
)
|
||||
)
|
||||
|
||||
job.cancel()
|
||||
elif result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_ACTIVITY_TIMEOUT:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
"Indexing watchdog - activity timeout exceeded: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"timeout={CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - transient exception marking index attempt as failed"
|
||||
)
|
||||
)
|
||||
job.cancel()
|
||||
else:
|
||||
pass
|
||||
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - finished",
|
||||
source=result.connector_source,
|
||||
status=str(result.status.value),
|
||||
exit_code=str(result.exit_code),
|
||||
elapsed=f"{elapsed:.2f}s",
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,36 @@
|
||||
import threading
|
||||
|
||||
from sqlalchemy import update
|
||||
|
||||
from onyx.configs.constants import INDEXING_WORKER_HEARTBEAT_INTERVAL
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import IndexAttempt
|
||||
|
||||
|
||||
def start_heartbeat(index_attempt_id: int) -> tuple[threading.Thread, threading.Event]:
|
||||
"""Start a heartbeat thread for the given index attempt"""
|
||||
stop_event = threading.Event()
|
||||
|
||||
def heartbeat_loop() -> None:
|
||||
while not stop_event.wait(INDEXING_WORKER_HEARTBEAT_INTERVAL):
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_session.execute(
|
||||
update(IndexAttempt)
|
||||
.where(IndexAttempt.id == index_attempt_id)
|
||||
.values(heartbeat_counter=IndexAttempt.heartbeat_counter + 1)
|
||||
)
|
||||
db_session.commit()
|
||||
except Exception:
|
||||
# Silently continue if heartbeat fails
|
||||
pass
|
||||
|
||||
thread = threading.Thread(target=heartbeat_loop, daemon=True)
|
||||
thread.start()
|
||||
return thread, stop_event
|
||||
|
||||
|
||||
def stop_heartbeat(thread: threading.Thread, stop_event: threading.Event) -> None:
|
||||
"""Stop the heartbeat thread"""
|
||||
stop_event.set()
|
||||
thread.join(timeout=5) # Wait up to 5 seconds for clean shutdown
|
||||
1283
backend/onyx/background/celery/tasks/docprocessing/tasks.py
Normal file
1283
backend/onyx/background/celery/tasks/docprocessing/tasks.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,10 +1,8 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from redis.exceptions import LockError
|
||||
@@ -12,8 +10,6 @@ from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
@@ -21,27 +17,19 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.time_utils import get_db_current_time
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.index_attempt import create_index_attempt
|
||||
from onyx.db.index_attempt import delete_index_attempt
|
||||
from onyx.db.index_attempt import get_all_index_attempts_by_status
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from onyx.db.index_attempt import get_recent_attempts_for_cc_pair
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndexPayload
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -50,54 +38,6 @@ logger = setup_logger()
|
||||
NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE = 5
|
||||
|
||||
|
||||
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
|
||||
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
|
||||
want to clean them up.
|
||||
|
||||
Unfenced = attempt not in terminal state and fence does not exist.
|
||||
"""
|
||||
unfenced_attempts: list[int] = []
|
||||
|
||||
# inner/outer/inner double check pattern to avoid race conditions when checking for
|
||||
# bad state
|
||||
# inner = index_attempt in non terminal state
|
||||
# outer = r.fence_key down
|
||||
|
||||
# check the db for index attempts in a non terminal state
|
||||
attempts: list[IndexAttempt] = []
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
|
||||
)
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
|
||||
)
|
||||
|
||||
for attempt in attempts:
|
||||
fence_key = RedisConnectorIndex.fence_key_with_ids(
|
||||
attempt.connector_credential_pair_id, attempt.search_settings_id
|
||||
)
|
||||
|
||||
# if the fence is down / doesn't exist, possible error but not confirmed
|
||||
if r.exists(fence_key):
|
||||
continue
|
||||
|
||||
# Between the time the attempts are first looked up and the time we see the fence down,
|
||||
# the attempt may have completed and taken down the fence normally.
|
||||
|
||||
# We need to double check that the index attempt is still in a non terminal state
|
||||
# and matches the original state, which confirms we are really in a bad state.
|
||||
attempt_2 = get_index_attempt(db_session, attempt.id)
|
||||
if not attempt_2:
|
||||
continue
|
||||
|
||||
if attempt.status != attempt_2.status:
|
||||
continue
|
||||
|
||||
unfenced_attempts.append(attempt.id)
|
||||
|
||||
return unfenced_attempts
|
||||
|
||||
|
||||
class IndexingCallbackBase(IndexingHeartbeatInterface):
|
||||
PARENT_CHECK_INTERVAL = 60
|
||||
|
||||
@@ -123,10 +63,9 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
|
||||
self.last_parent_check = time.monotonic()
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_connector.stop.fenced:
|
||||
return True
|
||||
|
||||
return False
|
||||
# Check if the associated indexing attempt has been cancelled
|
||||
# TODO: Pass index_attempt_id to the callback and check cancellation using the db
|
||||
return bool(self.redis_connector.stop.fenced)
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
"""Amount isn't used yet."""
|
||||
@@ -178,179 +117,16 @@ class IndexingCallback(IndexingCallbackBase):
|
||||
redis_connector: RedisConnector,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
redis_connector_index: RedisConnectorIndex,
|
||||
):
|
||||
super().__init__(parent_pid, redis_connector, redis_lock, redis_client)
|
||||
|
||||
self.redis_connector_index: RedisConnectorIndex = redis_connector_index
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
self.redis_connector_index.set_active()
|
||||
self.redis_connector_index.set_connector_active()
|
||||
super().progress(tag, amount)
|
||||
self.redis_client.incrby(
|
||||
self.redis_connector_index.generator_progress_key, amount
|
||||
)
|
||||
|
||||
|
||||
def validate_indexing_fence(
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||
This can happen if the indexing worker hard crashes or is terminated.
|
||||
Being in this bad state means the fence will never clear without help, so this function
|
||||
gives the help.
|
||||
|
||||
How this works:
|
||||
1. This function renews the active signal with a 5 minute TTL under the following conditions
|
||||
1.2. When the task is seen in the redis queue
|
||||
1.3. When the task is seen in the reserved / prefetched list
|
||||
|
||||
2. Externally, the active signal is renewed when:
|
||||
2.1. The fence is created
|
||||
2.2. The indexing watchdog checks the spawned task.
|
||||
|
||||
3. The TTL allows us to get through the transitions on fence startup
|
||||
and when the task starts executing.
|
||||
|
||||
More TTL clarification: it is seemingly impossible to exactly query Celery for
|
||||
whether a task is in the queue or currently executing.
|
||||
1. An unknown task id is always returned as state PENDING.
|
||||
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
|
||||
and the time it actually starts on the worker.
|
||||
"""
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if composite_id is None:
|
||||
task_logger.warning(
|
||||
f"validate_indexing_fence - could not parse composite_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
# parse out metadata and initialize the helper class with it
|
||||
parts = composite_id.split("/")
|
||||
if len(parts) != 2:
|
||||
return
|
||||
|
||||
cc_pair_id = int(parts[0])
|
||||
search_settings_id = int(parts[1])
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
# check to see if the fence/payload exists
|
||||
if not redis_connector_index.fenced:
|
||||
return
|
||||
|
||||
payload = redis_connector_index.payload
|
||||
if not payload:
|
||||
return
|
||||
|
||||
# OK, there's actually something for us to validate
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
# the fence is just barely set up.
|
||||
if redis_connector_index.active():
|
||||
return
|
||||
|
||||
# it would be odd to get here as there isn't that much that can go wrong during
|
||||
# initial fence setup, but it's still worth making sure we can recover
|
||||
logger.info(
|
||||
f"validate_indexing_fence - "
|
||||
f"Resetting fence in basic state without any activity: fence={fence_key}"
|
||||
)
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
|
||||
found = celery_find_task(
|
||||
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
if found:
|
||||
# the celery task exists in the redis queue
|
||||
redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
if payload.celery_task_id in reserved_tasks:
|
||||
# the celery task was prefetched and is reserved within the indexing worker
|
||||
redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
# we may want to enable this check if using the active task list somehow isn't good enough
|
||||
# if redis_connector_index.generator_locked():
|
||||
# logger.info(f"{payload.celery_task_id} is currently executing.")
|
||||
|
||||
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
|
||||
# but they still might be there due to gaps in our ability to check states during transitions
|
||||
# Checking the active signal safeguards us against these transition periods
|
||||
# (which has a duration that allows us to bridge those gaps)
|
||||
if redis_connector_index.active():
|
||||
return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
logger.warning(
|
||||
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: "
|
||||
f"index_attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
if payload.index_attempt_id:
|
||||
try:
|
||||
mark_attempt_failed(
|
||||
payload.index_attempt_id,
|
||||
db_session,
|
||||
"validate_indexing_fence - Canceling index attempt due to missing celery tasks: "
|
||||
f"index_attempt={payload.index_attempt_id}",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"validate_indexing_fence - Exception while marking index attempt as failed: "
|
||||
f"index_attempt={payload.index_attempt_id}",
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
|
||||
|
||||
def validate_indexing_fences(
|
||||
tenant_id: str,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
"""Validates all indexing fences for this tenant ... aka makes sure
|
||||
indexing tasks sent to celery are still in flight.
|
||||
"""
|
||||
reserved_indexing_tasks = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
|
||||
# Use replica for this because the worst thing that happens
|
||||
# is that we don't run the validation on this pass
|
||||
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if not key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
|
||||
continue
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
validate_indexing_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
reserved_indexing_tasks,
|
||||
r_celery,
|
||||
db_session,
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
return
|
||||
# NOTE: The validate_indexing_fence and validate_indexing_fences functions have been removed
|
||||
# as they are no longer needed with database-based coordination. The new validation is
|
||||
# handled by validate_active_indexing_attempts in the main indexing tasks module.
|
||||
|
||||
|
||||
def is_in_repeated_error_state(
|
||||
@@ -414,10 +190,12 @@ def should_index(
|
||||
)
|
||||
|
||||
# uncomment for debugging
|
||||
# task_logger.info(f"_should_index: "
|
||||
# f"cc_pair={cc_pair.id} "
|
||||
# f"connector={cc_pair.connector_id} "
|
||||
# f"refresh_freq={connector.refresh_freq}")
|
||||
task_logger.info(
|
||||
f"_should_index: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"connector={cc_pair.connector_id} "
|
||||
f"refresh_freq={connector.refresh_freq}"
|
||||
)
|
||||
|
||||
# don't kick off indexing for `NOT_APPLICABLE` sources
|
||||
if connector.source == DocumentSource.NOT_APPLICABLE:
|
||||
@@ -517,7 +295,7 @@ def should_index(
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_indexing_task(
|
||||
def try_creating_docfetching_task(
|
||||
celery_app: Celery,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
search_settings: SearchSettings,
|
||||
@@ -531,10 +309,11 @@ def try_creating_indexing_task(
|
||||
|
||||
Does not check for scheduling related conditions as this function
|
||||
is used to trigger indexing immediately.
|
||||
|
||||
Now uses database-based coordination instead of Redis fencing.
|
||||
"""
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
index_attempt_id: int | None = None
|
||||
|
||||
# we need to serialize any attempt to trigger indexing since it can be triggered
|
||||
# either via celery beat or manually (API call)
|
||||
@@ -547,61 +326,42 @@ def try_creating_indexing_task(
|
||||
if not acquired:
|
||||
return None
|
||||
|
||||
redis_connector_index: RedisConnectorIndex
|
||||
index_attempt_id = None
|
||||
try:
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair.id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
|
||||
# skip if already indexing
|
||||
if redis_connector_index.fenced:
|
||||
return None
|
||||
|
||||
# skip indexing if the cc_pair is deleting
|
||||
if redis_connector.delete.fenced:
|
||||
return None
|
||||
|
||||
# Basic status checks
|
||||
db_session.refresh(cc_pair)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
# add a long running generator task to the queue
|
||||
redis_connector_index.generator_clear()
|
||||
# Generate custom task ID for tracking
|
||||
custom_task_id = f"docfetching_{cc_pair.id}_{search_settings.id}_{uuid4()}"
|
||||
|
||||
# set a basic fence to start
|
||||
payload = RedisConnectorIndexPayload(
|
||||
index_attempt_id=None,
|
||||
started=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
celery_task_id=None,
|
||||
)
|
||||
|
||||
redis_connector_index.set_active()
|
||||
redis_connector_index.set_fence(payload)
|
||||
|
||||
# create the index attempt for tracking purposes
|
||||
# code elsewhere checks for index attempts without an associated redis key
|
||||
# and cleans them up
|
||||
# therefore we must create the attempt and the task after the fence goes up
|
||||
index_attempt_id = create_index_attempt(
|
||||
cc_pair.id,
|
||||
search_settings.id,
|
||||
from_beginning=reindex,
|
||||
# Try to create a new index attempt using database coordination
|
||||
# This replaces the Redis fencing mechanism
|
||||
index_attempt_id = IndexingCoordination.try_create_index_attempt(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
search_settings_id=search_settings.id,
|
||||
celery_task_id=custom_task_id,
|
||||
from_beginning=reindex,
|
||||
)
|
||||
|
||||
custom_task_id = redis_connector_index.generate_generator_task_id()
|
||||
if index_attempt_id is None:
|
||||
# Another indexing attempt is already running
|
||||
return None
|
||||
|
||||
# Determine which queue to use based on whether this is a user file
|
||||
# TODO: at the moment the indexing pipeline is
|
||||
# shared between user files and connectors
|
||||
queue = (
|
||||
OnyxCeleryQueues.USER_FILES_INDEXING
|
||||
if cc_pair.is_user_file
|
||||
else OnyxCeleryQueues.CONNECTOR_INDEXING
|
||||
else OnyxCeleryQueues.CONNECTOR_DOC_FETCHING
|
||||
)
|
||||
|
||||
# when the task is sent, we have yet to finish setting up the fence
|
||||
# therefore, the task must contain code that blocks until the fence is ready
|
||||
# Send the task to Celery
|
||||
result = celery_app.send_task(
|
||||
OnyxCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
|
||||
OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
|
||||
kwargs=dict(
|
||||
index_attempt_id=index_attempt_id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
@@ -613,14 +373,18 @@ def try_creating_indexing_task(
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
if not result:
|
||||
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
|
||||
raise RuntimeError("send_task for connector_doc_fetching_task failed.")
|
||||
|
||||
# now fill out the fence with the rest of the data
|
||||
redis_connector_index.set_active()
|
||||
task_logger.info(
|
||||
f"Created docfetching task: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings.id} "
|
||||
f"attempt_id={index_attempt_id} "
|
||||
f"celery_task_id={custom_task_id}"
|
||||
)
|
||||
|
||||
return index_attempt_id
|
||||
|
||||
payload.index_attempt_id = index_attempt_id
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector_index.set_fence(payload)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"try_creating_indexing_task - Unexpected exception: "
|
||||
@@ -628,9 +392,10 @@ def try_creating_indexing_task(
|
||||
f"search_settings={search_settings.id}"
|
||||
)
|
||||
|
||||
# Clean up on failure
|
||||
if index_attempt_id is not None:
|
||||
delete_index_attempt(db_session, index_attempt_id)
|
||||
redis_connector_index.set_fence(None)
|
||||
mark_attempt_failed(index_attempt_id, db_session)
|
||||
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
File diff suppressed because it is too large
Load Diff
110
backend/onyx/background/celery/tasks/models.py
Normal file
110
backend/onyx/background/celery/tasks/models.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DocProcessingContext(BaseModel):
|
||||
tenant_id: str
|
||||
cc_pair_id: int
|
||||
search_settings_id: int
|
||||
index_attempt_id: int
|
||||
|
||||
|
||||
class IndexingWatchdogTerminalStatus(str, Enum):
|
||||
"""The different statuses the watchdog can finish with.
|
||||
|
||||
TODO: create broader success/failure/abort categories
|
||||
"""
|
||||
|
||||
UNDEFINED = "undefined"
|
||||
|
||||
SUCCEEDED = "succeeded"
|
||||
|
||||
SPAWN_FAILED = "spawn_failed" # connector spawn failed
|
||||
SPAWN_NOT_ALIVE = (
|
||||
"spawn_not_alive" # spawn succeeded but process did not come alive
|
||||
)
|
||||
|
||||
BLOCKED_BY_DELETION = "blocked_by_deletion"
|
||||
BLOCKED_BY_STOP_SIGNAL = "blocked_by_stop_signal"
|
||||
FENCE_NOT_FOUND = "fence_not_found" # fence does not exist
|
||||
FENCE_READINESS_TIMEOUT = (
|
||||
"fence_readiness_timeout" # fence exists but wasn't ready within the timeout
|
||||
)
|
||||
FENCE_MISMATCH = "fence_mismatch" # task and fence metadata mismatch
|
||||
TASK_ALREADY_RUNNING = "task_already_running" # task appears to be running already
|
||||
INDEX_ATTEMPT_MISMATCH = (
|
||||
"index_attempt_mismatch" # expected index attempt metadata not found in db
|
||||
)
|
||||
|
||||
CONNECTOR_VALIDATION_ERROR = (
|
||||
"connector_validation_error" # the connector validation failed
|
||||
)
|
||||
CONNECTOR_EXCEPTIONED = "connector_exceptioned" # the connector itself exceptioned
|
||||
WATCHDOG_EXCEPTIONED = "watchdog_exceptioned" # the watchdog exceptioned
|
||||
|
||||
# the watchdog received a termination signal
|
||||
TERMINATED_BY_SIGNAL = "terminated_by_signal"
|
||||
|
||||
# the watchdog terminated the task due to no activity
|
||||
TERMINATED_BY_ACTIVITY_TIMEOUT = "terminated_by_activity_timeout"
|
||||
|
||||
# NOTE: this may actually be the same as SIGKILL, but parsed differently by python
|
||||
# consolidate once we know more
|
||||
OUT_OF_MEMORY = "out_of_memory"
|
||||
|
||||
PROCESS_SIGNAL_SIGKILL = "process_signal_sigkill"
|
||||
|
||||
@property
|
||||
def code(self) -> int:
|
||||
_ENUM_TO_CODE: dict[IndexingWatchdogTerminalStatus, int] = {
|
||||
IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL: -9,
|
||||
IndexingWatchdogTerminalStatus.OUT_OF_MEMORY: 137,
|
||||
IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR: 247,
|
||||
IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION: 248,
|
||||
IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL: 249,
|
||||
IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND: 250,
|
||||
IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT: 251,
|
||||
IndexingWatchdogTerminalStatus.FENCE_MISMATCH: 252,
|
||||
IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING: 253,
|
||||
IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH: 254,
|
||||
IndexingWatchdogTerminalStatus.CONNECTOR_EXCEPTIONED: 255,
|
||||
}
|
||||
|
||||
return _ENUM_TO_CODE[self]
|
||||
|
||||
@classmethod
|
||||
def from_code(cls, code: int) -> "IndexingWatchdogTerminalStatus":
|
||||
_CODE_TO_ENUM: dict[int, IndexingWatchdogTerminalStatus] = {
|
||||
-9: IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL,
|
||||
137: IndexingWatchdogTerminalStatus.OUT_OF_MEMORY,
|
||||
247: IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR,
|
||||
248: IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION,
|
||||
249: IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL,
|
||||
250: IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND,
|
||||
251: IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT,
|
||||
252: IndexingWatchdogTerminalStatus.FENCE_MISMATCH,
|
||||
253: IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING,
|
||||
254: IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH,
|
||||
255: IndexingWatchdogTerminalStatus.CONNECTOR_EXCEPTIONED,
|
||||
}
|
||||
|
||||
if code in _CODE_TO_ENUM:
|
||||
return _CODE_TO_ENUM[code]
|
||||
|
||||
return IndexingWatchdogTerminalStatus.UNDEFINED
|
||||
|
||||
|
||||
class SimpleJobResult:
|
||||
"""The data we want to have when the watchdog finishes"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.status = IndexingWatchdogTerminalStatus.UNDEFINED
|
||||
self.connector_source = None
|
||||
self.exit_code = None
|
||||
self.exception_str = None
|
||||
|
||||
status: IndexingWatchdogTerminalStatus
|
||||
connector_source: str | None
|
||||
exit_code: int | None
|
||||
exception_str: str | None
|
||||
@@ -147,7 +147,7 @@ def _collect_queue_metrics(redis_celery: Redis) -> list[Metric]:
|
||||
metrics = []
|
||||
queue_mappings = {
|
||||
"celery_queue_length": "celery",
|
||||
"indexing_queue_length": "indexing",
|
||||
"docprocessing_queue_length": "docprocessing",
|
||||
"sync_queue_length": "sync",
|
||||
"deletion_queue_length": "deletion",
|
||||
"pruning_queue_length": "pruning",
|
||||
@@ -882,7 +882,13 @@ def monitor_celery_queues_helper(
|
||||
|
||||
r_celery = task.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length("celery", r_celery)
|
||||
n_indexing = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery)
|
||||
n_docfetching = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
|
||||
)
|
||||
n_docprocessing = celery_get_queue_length(OnyxCeleryQueues.DOCPROCESSING, r_celery)
|
||||
n_user_files_indexing = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILES_INDEXING, r_celery
|
||||
)
|
||||
n_sync = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
|
||||
n_deletion = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
|
||||
n_pruning = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery)
|
||||
@@ -896,14 +902,20 @@ def monitor_celery_queues_helper(
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
|
||||
)
|
||||
|
||||
n_indexing_prefetched = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
n_docfetching_prefetched = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
|
||||
)
|
||||
n_docprocessing_prefetched = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.DOCPROCESSING, r_celery
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Queue lengths: celery={n_celery} "
|
||||
f"indexing={n_indexing} "
|
||||
f"indexing_prefetched={len(n_indexing_prefetched)} "
|
||||
f"docfetching={n_docfetching} "
|
||||
f"docfetching_prefetched={len(n_docfetching_prefetched)} "
|
||||
f"docprocessing={n_docprocessing} "
|
||||
f"docprocessing_prefetched={len(n_docprocessing_prefetched)} "
|
||||
f"user_files_indexing={n_user_files_indexing} "
|
||||
f"sync={n_sync} "
|
||||
f"deletion={n_deletion} "
|
||||
f"pruning={n_pruning} "
|
||||
|
||||
@@ -22,7 +22,7 @@ from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
from onyx.background.celery.tasks.indexing.utils import IndexingCallbackBase
|
||||
from onyx.background.celery.tasks.docprocessing.utils import IndexingCallbackBase
|
||||
from onyx.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
@@ -464,7 +464,7 @@ def connector_pruning_generator_task(
|
||||
# set thread_local=False since we don't control what thread the indexing/pruning
|
||||
# might run our callback with
|
||||
lock: RedisLock = r.lock(
|
||||
OnyxRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.id}",
|
||||
OnyxRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.cc_pair_id}",
|
||||
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
@@ -10,7 +10,7 @@ set_is_ee_based_on_env_variable()
|
||||
|
||||
|
||||
def get_app() -> Celery:
|
||||
from onyx.background.celery.apps.indexing import celery_app
|
||||
from onyx.background.celery.apps.docfetching import celery_app
|
||||
|
||||
return celery_app
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
"""Factory stub for running celery worker / celery beat.
|
||||
This code is different from the primary/beat stubs because there is no EE version to
|
||||
fetch. Port over the code in those files if we add an EE version of this worker."""
|
||||
|
||||
from celery import Celery
|
||||
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
|
||||
def get_app() -> Celery:
|
||||
from onyx.background.celery.apps.docprocessing import celery_app
|
||||
|
||||
return celery_app
|
||||
|
||||
|
||||
app = get_app()
|
||||
@@ -33,7 +33,7 @@ def save_checkpoint(
|
||||
"""Save a checkpoint for a given index attempt to the file store"""
|
||||
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
|
||||
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
file_store.save_file(
|
||||
content=BytesIO(checkpoint.model_dump_json().encode()),
|
||||
display_name=checkpoint_pointer,
|
||||
@@ -52,11 +52,11 @@ def save_checkpoint(
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
db_session: Session, index_attempt_id: int, connector: BaseConnector
|
||||
index_attempt_id: int, connector: BaseConnector
|
||||
) -> ConnectorCheckpoint:
|
||||
"""Load a checkpoint for a given index attempt from the file store"""
|
||||
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
checkpoint_io = file_store.read_file(checkpoint_pointer, mode="rb")
|
||||
checkpoint_data = checkpoint_io.read().decode("utf-8")
|
||||
if isinstance(connector, CheckpointedConnector):
|
||||
@@ -71,7 +71,7 @@ def get_latest_valid_checkpoint(
|
||||
window_start: datetime,
|
||||
window_end: datetime,
|
||||
connector: BaseConnector,
|
||||
) -> ConnectorCheckpoint:
|
||||
) -> tuple[ConnectorCheckpoint, bool]:
|
||||
"""Get the latest valid checkpoint for a given connector credential pair"""
|
||||
checkpoint_candidates = get_recent_completed_attempts_for_cc_pair(
|
||||
cc_pair_id=cc_pair_id,
|
||||
@@ -83,7 +83,7 @@ def get_latest_valid_checkpoint(
|
||||
# don't keep using checkpoints if we've had a bunch of failed attempts in a row
|
||||
# where we make no progress. Only do this if we have had at least
|
||||
# _NUM_RECENT_ATTEMPTS_TO_CONSIDER completed attempts.
|
||||
if len(checkpoint_candidates) == _NUM_RECENT_ATTEMPTS_TO_CONSIDER:
|
||||
if len(checkpoint_candidates) >= _NUM_RECENT_ATTEMPTS_TO_CONSIDER:
|
||||
had_any_progress = False
|
||||
for candidate in checkpoint_candidates:
|
||||
if (
|
||||
@@ -99,7 +99,7 @@ def get_latest_valid_checkpoint(
|
||||
f"found for cc_pair={cc_pair_id}. Ignoring checkpoint to let the run start "
|
||||
"from scratch."
|
||||
)
|
||||
return connector.build_dummy_checkpoint()
|
||||
return connector.build_dummy_checkpoint(), False
|
||||
|
||||
# filter out any candidates that don't meet the criteria
|
||||
checkpoint_candidates = [
|
||||
@@ -140,11 +140,10 @@ def get_latest_valid_checkpoint(
|
||||
logger.info(
|
||||
f"No valid checkpoint found for cc_pair={cc_pair_id}. Starting from scratch."
|
||||
)
|
||||
return checkpoint
|
||||
return checkpoint, False
|
||||
|
||||
try:
|
||||
previous_checkpoint = load_checkpoint(
|
||||
db_session=db_session,
|
||||
index_attempt_id=latest_valid_checkpoint_candidate.id,
|
||||
connector=connector,
|
||||
)
|
||||
@@ -153,14 +152,14 @@ def get_latest_valid_checkpoint(
|
||||
f"Failed to load checkpoint from previous failed attempt with ID "
|
||||
f"{latest_valid_checkpoint_candidate.id}. Falling back to default checkpoint."
|
||||
)
|
||||
return checkpoint
|
||||
return checkpoint, False
|
||||
|
||||
logger.info(
|
||||
f"Using checkpoint from previous failed attempt with ID "
|
||||
f"{latest_valid_checkpoint_candidate.id}. Previous checkpoint: "
|
||||
f"{previous_checkpoint}"
|
||||
)
|
||||
return previous_checkpoint
|
||||
return previous_checkpoint, True
|
||||
|
||||
|
||||
def get_index_attempts_with_old_checkpoints(
|
||||
@@ -201,7 +200,7 @@ def cleanup_checkpoint(db_session: Session, index_attempt_id: int) -> None:
|
||||
if not index_attempt.checkpoint_pointer:
|
||||
return None
|
||||
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
file_store.delete_file(index_attempt.checkpoint_pointer)
|
||||
|
||||
index_attempt.checkpoint_pointer = None
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
@@ -5,7 +6,7 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
from pydantic import BaseModel
|
||||
from celery import Celery
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.access import source_should_fetch_permissions_during_indexing
|
||||
@@ -18,18 +19,25 @@ from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
|
||||
from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE
|
||||
from onyx.configs.app_configs import MAX_FILE_SIZE_BYTES
|
||||
from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.connectors.connector_runner import ConnectorRunner
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorStopSignal
|
||||
from onyx.connectors.models import DocExtractionContext
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.connector import mark_cc_pair_as_permissions_synced
|
||||
from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_last_successful_attempt_poll_range_end
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair
|
||||
@@ -49,13 +57,16 @@ from onyx.db.index_attempt import mark_attempt_partially_succeeded
|
||||
from onyx.db.index_attempt import mark_attempt_succeeded
|
||||
from onyx.db.index_attempt import transition_attempt_to_in_progress
|
||||
from onyx.db.index_attempt import update_docs_indexed
|
||||
from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexAttemptError
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.file_store.document_batch_storage import DocumentBatchStorage
|
||||
from onyx.file_store.document_batch_storage import get_document_batch_storage
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
)
|
||||
@@ -68,7 +79,7 @@ from onyx.utils.telemetry import RecordType
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
logger = setup_logger(propagate=False)
|
||||
|
||||
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
|
||||
|
||||
@@ -146,6 +157,10 @@ def _get_connector_runner(
|
||||
def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
|
||||
cleaned_batch = []
|
||||
for doc in doc_batch:
|
||||
if sys.getsizeof(doc) > MAX_FILE_SIZE_BYTES:
|
||||
logger.warning(
|
||||
f"doc {doc.id} too large, Document size: {sys.getsizeof(doc)}"
|
||||
)
|
||||
cleaned_doc = doc.model_copy()
|
||||
|
||||
# Postgres cannot handle NUL characters in text fields
|
||||
@@ -180,25 +195,11 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
|
||||
return cleaned_batch
|
||||
|
||||
|
||||
class ConnectorStopSignal(Exception):
|
||||
"""A custom exception used to signal a stop in processing."""
|
||||
|
||||
|
||||
class RunIndexingContext(BaseModel):
|
||||
index_name: str
|
||||
cc_pair_id: int
|
||||
connector_id: int
|
||||
credential_id: int
|
||||
source: DocumentSource
|
||||
earliest_index_time: float
|
||||
from_beginning: bool
|
||||
is_primary: bool
|
||||
should_fetch_permissions_during_indexing: bool
|
||||
search_settings_status: IndexModelStatus
|
||||
|
||||
|
||||
def _check_connector_and_attempt_status(
|
||||
db_session_temp: Session, ctx: RunIndexingContext, index_attempt_id: int
|
||||
db_session_temp: Session,
|
||||
cc_pair_id: int,
|
||||
search_settings_status: IndexModelStatus,
|
||||
index_attempt_id: int,
|
||||
) -> None:
|
||||
"""
|
||||
Checks the status of the connector credential pair and index attempt.
|
||||
@@ -206,27 +207,34 @@ def _check_connector_and_attempt_status(
|
||||
"""
|
||||
cc_pair_loop = get_connector_credential_pair_from_id(
|
||||
db_session_temp,
|
||||
ctx.cc_pair_id,
|
||||
cc_pair_id,
|
||||
)
|
||||
if not cc_pair_loop:
|
||||
raise RuntimeError(f"CC pair {ctx.cc_pair_id} not found in DB.")
|
||||
raise RuntimeError(f"CC pair {cc_pair_id} not found in DB.")
|
||||
|
||||
if (
|
||||
cc_pair_loop.status == ConnectorCredentialPairStatus.PAUSED
|
||||
and ctx.search_settings_status != IndexModelStatus.FUTURE
|
||||
and search_settings_status != IndexModelStatus.FUTURE
|
||||
) or cc_pair_loop.status == ConnectorCredentialPairStatus.DELETING:
|
||||
raise RuntimeError("Connector was disabled mid run")
|
||||
raise ConnectorStopSignal(f"Connector {cc_pair_loop.status.value.lower()}")
|
||||
|
||||
index_attempt_loop = get_index_attempt(db_session_temp, index_attempt_id)
|
||||
if not index_attempt_loop:
|
||||
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
|
||||
|
||||
if index_attempt_loop.status == IndexingStatus.CANCELED:
|
||||
raise ConnectorStopSignal(f"Index attempt {index_attempt_id} was canceled")
|
||||
|
||||
if index_attempt_loop.status != IndexingStatus.IN_PROGRESS:
|
||||
raise RuntimeError(
|
||||
f"Index Attempt was canceled, status is {index_attempt_loop.status}"
|
||||
f"Index Attempt is not running, status is {index_attempt_loop.status}"
|
||||
)
|
||||
|
||||
if index_attempt_loop.celery_task_id is None:
|
||||
raise RuntimeError(f"Index attempt {index_attempt_id} has no celery task id")
|
||||
|
||||
|
||||
# TODO: delete from here if ends up unused
|
||||
def _check_failure_threshold(
|
||||
total_failures: int,
|
||||
document_count: int,
|
||||
@@ -257,6 +265,9 @@ def _check_failure_threshold(
|
||||
)
|
||||
|
||||
|
||||
# NOTE: this is the old run_indexing function that the new decoupled approach
|
||||
# is based on. Leaving this for comparison purposes, but if you see this comment
|
||||
# has been here for >1 month, please delete this function.
|
||||
def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt_id: int,
|
||||
@@ -271,7 +282,12 @@ def _run_indexing(
|
||||
start_time = time.monotonic() # jsut used for logging
|
||||
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
index_attempt_start = get_index_attempt(db_session_temp, index_attempt_id)
|
||||
index_attempt_start = get_index_attempt(
|
||||
db_session_temp,
|
||||
index_attempt_id,
|
||||
eager_load_cc_pair=True,
|
||||
eager_load_search_settings=True,
|
||||
)
|
||||
if not index_attempt_start:
|
||||
raise ValueError(
|
||||
f"Index attempt {index_attempt_id} does not exist in DB. This should not be possible."
|
||||
@@ -292,7 +308,7 @@ def _run_indexing(
|
||||
index_attempt_start.connector_credential_pair.last_successful_index_time
|
||||
is not None
|
||||
)
|
||||
ctx = RunIndexingContext(
|
||||
ctx = DocExtractionContext(
|
||||
index_name=index_attempt_start.search_settings.index_name,
|
||||
cc_pair_id=index_attempt_start.connector_credential_pair.id,
|
||||
connector_id=db_connector.id,
|
||||
@@ -317,6 +333,7 @@ def _run_indexing(
|
||||
and (from_beginning or not has_successful_attempt)
|
||||
),
|
||||
search_settings_status=index_attempt_start.search_settings.status,
|
||||
doc_extraction_complete_batch_num=None,
|
||||
)
|
||||
|
||||
last_successful_index_poll_range_end = (
|
||||
@@ -384,19 +401,6 @@ def _run_indexing(
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
embedder=embedding_model,
|
||||
information_content_classification_model=information_content_classification_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=(
|
||||
ctx.from_beginning
|
||||
or (ctx.search_settings_status == IndexModelStatus.FUTURE)
|
||||
),
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
# Initialize memory tracer. NOTE: won't actually do anything if
|
||||
# `INDEXING_TRACER_INTERVAL` is 0.
|
||||
memory_tracer = MemoryTracer(interval=INDEXING_TRACER_INTERVAL)
|
||||
@@ -416,7 +420,9 @@ def _run_indexing(
|
||||
index_attempt: IndexAttempt | None = None
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
|
||||
index_attempt = get_index_attempt(
|
||||
db_session_temp, index_attempt_id, eager_load_cc_pair=True
|
||||
)
|
||||
if not index_attempt:
|
||||
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
|
||||
|
||||
@@ -439,7 +445,7 @@ def _run_indexing(
|
||||
):
|
||||
checkpoint = connector_runner.connector.build_dummy_checkpoint()
|
||||
else:
|
||||
checkpoint = get_latest_valid_checkpoint(
|
||||
checkpoint, _ = get_latest_valid_checkpoint(
|
||||
db_session=db_session_temp,
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
search_settings_id=index_attempt.search_settings_id,
|
||||
@@ -496,7 +502,10 @@ def _run_indexing(
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
# will exception if the connector/index attempt is marked as paused/failed
|
||||
_check_connector_and_attempt_status(
|
||||
db_session_temp, ctx, index_attempt_id
|
||||
db_session_temp,
|
||||
ctx.cc_pair_id,
|
||||
ctx.search_settings_status,
|
||||
index_attempt_id,
|
||||
)
|
||||
|
||||
# save record of any failures at the connector level
|
||||
@@ -554,7 +563,16 @@ def _run_indexing(
|
||||
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
|
||||
|
||||
# real work happens here!
|
||||
index_pipeline_result = indexing_pipeline(
|
||||
index_pipeline_result = run_indexing_pipeline(
|
||||
embedder=embedding_model,
|
||||
information_content_classification_model=information_content_classification_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=(
|
||||
ctx.from_beginning
|
||||
or (ctx.search_settings_status == IndexModelStatus.FUTURE)
|
||||
),
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
document_batch=doc_batch_cleaned,
|
||||
index_attempt_metadata=index_attempt_md,
|
||||
)
|
||||
@@ -815,6 +833,7 @@ def _run_indexing(
|
||||
|
||||
|
||||
def run_indexing_entrypoint(
|
||||
app: Celery,
|
||||
index_attempt_id: int,
|
||||
tenant_id: str,
|
||||
connector_credential_pair_id: int,
|
||||
@@ -832,7 +851,6 @@ def run_indexing_entrypoint(
|
||||
index_attempt_id, connector_credential_pair_id
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# TODO: remove long running session entirely
|
||||
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
|
||||
|
||||
tenant_str = ""
|
||||
@@ -846,18 +864,516 @@ def run_indexing_entrypoint(
|
||||
credential_id = attempt.connector_credential_pair.credential_id
|
||||
|
||||
logger.info(
|
||||
f"Indexing starting{tenant_str}: "
|
||||
f"Docfetching starting{tenant_str}: "
|
||||
f"connector='{connector_name}' "
|
||||
f"config='{connector_config}' "
|
||||
f"credentials='{credential_id}'"
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
|
||||
connector_document_extraction(
|
||||
app,
|
||||
index_attempt_id,
|
||||
attempt.connector_credential_pair_id,
|
||||
attempt.search_settings_id,
|
||||
tenant_id,
|
||||
callback,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Indexing finished{tenant_str}: "
|
||||
f"Docfetching finished{tenant_str}: "
|
||||
f"connector='{connector_name}' "
|
||||
f"config='{connector_config}' "
|
||||
f"credentials='{credential_id}'"
|
||||
)
|
||||
|
||||
|
||||
def connector_document_extraction(
|
||||
app: Celery,
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
tenant_id: str,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> None:
|
||||
"""Extract documents from connector and queue them for indexing pipeline processing.
|
||||
|
||||
This is the first part of the split indexing process that runs the connector
|
||||
and extracts documents, storing them in the filestore for later processing.
|
||||
"""
|
||||
|
||||
start_time = time.monotonic()
|
||||
|
||||
logger.info(
|
||||
f"Document extraction starting: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
|
||||
# Get batch storage (transition to IN_PROGRESS is handled by run_indexing_entrypoint)
|
||||
batch_storage = get_document_batch_storage(cc_pair_id, index_attempt_id)
|
||||
|
||||
# Initialize memory tracer. NOTE: won't actually do anything if
|
||||
# `INDEXING_TRACER_INTERVAL` is 0.
|
||||
memory_tracer = MemoryTracer(interval=INDEXING_TRACER_INTERVAL)
|
||||
memory_tracer.start()
|
||||
|
||||
index_attempt = None
|
||||
last_batch_num = 0 # used to continue from checkpointing
|
||||
# comes from _run_indexing
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session,
|
||||
index_attempt_id,
|
||||
eager_load_cc_pair=True,
|
||||
eager_load_search_settings=True,
|
||||
)
|
||||
if not index_attempt:
|
||||
raise RuntimeError(f"Index attempt {index_attempt_id} not found")
|
||||
|
||||
if index_attempt.search_settings is None:
|
||||
raise ValueError("Search settings must be set for indexing")
|
||||
|
||||
# Clear the indexing trigger if it was set, to prevent duplicate indexing attempts
|
||||
if index_attempt.connector_credential_pair.indexing_trigger is not None:
|
||||
logger.info(
|
||||
"Clearing indexing trigger: "
|
||||
f"cc_pair={index_attempt.connector_credential_pair.id} "
|
||||
f"trigger={index_attempt.connector_credential_pair.indexing_trigger}"
|
||||
)
|
||||
mark_ccpair_with_indexing_trigger(
|
||||
index_attempt.connector_credential_pair.id, None, db_session
|
||||
)
|
||||
|
||||
db_connector = index_attempt.connector_credential_pair.connector
|
||||
db_credential = index_attempt.connector_credential_pair.credential
|
||||
is_primary = index_attempt.search_settings.status == IndexModelStatus.PRESENT
|
||||
from_beginning = index_attempt.from_beginning
|
||||
has_successful_attempt = (
|
||||
index_attempt.connector_credential_pair.last_successful_index_time
|
||||
is not None
|
||||
)
|
||||
|
||||
earliest_index_time = (
|
||||
db_connector.indexing_start.timestamp()
|
||||
if db_connector.indexing_start
|
||||
else 0
|
||||
)
|
||||
should_fetch_permissions_during_indexing = (
|
||||
index_attempt.connector_credential_pair.access_type == AccessType.SYNC
|
||||
and source_should_fetch_permissions_during_indexing(db_connector.source)
|
||||
and is_primary
|
||||
# if we've already successfully indexed, let the doc_sync job
|
||||
# take care of doc-level permissions
|
||||
and (from_beginning or not has_successful_attempt)
|
||||
)
|
||||
|
||||
# Set up time windows for polling
|
||||
last_successful_index_poll_range_end = (
|
||||
earliest_index_time
|
||||
if from_beginning
|
||||
else get_last_successful_attempt_poll_range_end(
|
||||
cc_pair_id=cc_pair_id,
|
||||
earliest_index=earliest_index_time,
|
||||
search_settings=index_attempt.search_settings,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
|
||||
if last_successful_index_poll_range_end > POLL_CONNECTOR_OFFSET:
|
||||
window_start = datetime.fromtimestamp(
|
||||
last_successful_index_poll_range_end, tz=timezone.utc
|
||||
) - timedelta(minutes=POLL_CONNECTOR_OFFSET)
|
||||
else:
|
||||
# don't go into "negative" time if we've never indexed before
|
||||
window_start = datetime.fromtimestamp(0, tz=timezone.utc)
|
||||
|
||||
most_recent_attempt = next(
|
||||
iter(
|
||||
get_recent_completed_attempts_for_cc_pair(
|
||||
cc_pair_id=cc_pair_id,
|
||||
search_settings_id=index_attempt.search_settings_id,
|
||||
db_session=db_session,
|
||||
limit=1,
|
||||
)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
# if the last attempt failed, try and use the same window. This is necessary
|
||||
# to ensure correctness with checkpointing. If we don't do this, things like
|
||||
# new slack channels could be missed (since existing slack channels are
|
||||
# cached as part of the checkpoint).
|
||||
if (
|
||||
most_recent_attempt
|
||||
and most_recent_attempt.poll_range_end
|
||||
and (
|
||||
most_recent_attempt.status == IndexingStatus.FAILED
|
||||
or most_recent_attempt.status == IndexingStatus.CANCELED
|
||||
)
|
||||
):
|
||||
window_end = most_recent_attempt.poll_range_end
|
||||
else:
|
||||
window_end = datetime.now(tz=timezone.utc)
|
||||
|
||||
# set time range in db
|
||||
index_attempt.poll_range_start = window_start
|
||||
index_attempt.poll_range_end = window_end
|
||||
db_session.commit()
|
||||
|
||||
# TODO: maybe memory tracer here
|
||||
|
||||
# Set up connector runner
|
||||
connector_runner = _get_connector_runner(
|
||||
db_session=db_session,
|
||||
attempt=index_attempt,
|
||||
batch_size=INDEX_BATCH_SIZE,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
include_permissions=should_fetch_permissions_during_indexing,
|
||||
)
|
||||
|
||||
# don't use a checkpoint if we're explicitly indexing from
|
||||
# the beginning in order to avoid weird interactions between
|
||||
# checkpointing / failure handling
|
||||
# OR
|
||||
# if the last attempt was successful
|
||||
if index_attempt.from_beginning or (
|
||||
most_recent_attempt and most_recent_attempt.status.is_successful()
|
||||
):
|
||||
logger.info(
|
||||
f"Cleaning up all old batches for index attempt {index_attempt_id} before starting new run"
|
||||
)
|
||||
batch_storage.cleanup_all_batches()
|
||||
checkpoint = connector_runner.connector.build_dummy_checkpoint()
|
||||
else:
|
||||
logger.info(
|
||||
f"Getting latest valid checkpoint for index attempt {index_attempt_id}"
|
||||
)
|
||||
checkpoint, resuming_from_checkpoint = get_latest_valid_checkpoint(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
search_settings_id=index_attempt.search_settings_id,
|
||||
window_start=window_start,
|
||||
window_end=window_end,
|
||||
connector=connector_runner.connector,
|
||||
)
|
||||
|
||||
if (
|
||||
isinstance(connector_runner.connector, CheckpointedConnector)
|
||||
and resuming_from_checkpoint
|
||||
):
|
||||
reissued_batch_count, completed_batches = reissue_old_batches(
|
||||
batch_storage,
|
||||
index_attempt_id,
|
||||
cc_pair_id,
|
||||
tenant_id,
|
||||
app,
|
||||
most_recent_attempt,
|
||||
)
|
||||
last_batch_num = reissued_batch_count + completed_batches
|
||||
index_attempt.completed_batches = completed_batches
|
||||
db_session.commit()
|
||||
else:
|
||||
logger.info(
|
||||
f"Cleaning up all batches for index attempt {index_attempt_id} before starting new run"
|
||||
)
|
||||
# for non-checkpointed connectors, throw out batches from previous unsuccessful attempts
|
||||
# because we'll be getting those documents again anyways.
|
||||
batch_storage.cleanup_all_batches()
|
||||
|
||||
# Save initial checkpoint
|
||||
save_checkpoint(
|
||||
db_session=db_session,
|
||||
index_attempt_id=index_attempt_id,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
try:
|
||||
batch_num = last_batch_num # starts at 0 if no last batch
|
||||
total_doc_batches_queued = 0
|
||||
total_failures = 0
|
||||
document_count = 0
|
||||
|
||||
# Main extraction loop
|
||||
while checkpoint.has_more:
|
||||
logger.info(
|
||||
f"Running '{db_connector.source.value}' connector with checkpoint: {checkpoint}"
|
||||
)
|
||||
for document_batch, failure, next_checkpoint in connector_runner.run(
|
||||
checkpoint
|
||||
):
|
||||
# Check if connector is disabled mid run and stop if so unless it's the secondary
|
||||
# index being built. We want to populate it even for paused connectors
|
||||
# Often paused connectors are sources that aren't updated frequently but the
|
||||
# contents still need to be initially pulled.
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise ConnectorStopSignal("Connector stop signal detected")
|
||||
|
||||
# NOTE: this progress callback runs on every loop. We've seen cases
|
||||
# where we loop many times with no new documents and eventually time
|
||||
# out, so only doing the callback after indexing isn't sufficient.
|
||||
# TODO: change to doc extraction if it doesnt break things
|
||||
callback.progress("_run_indexing", 0)
|
||||
|
||||
# will exception if the connector/index attempt is marked as paused/failed
|
||||
with get_session_with_current_tenant() as db_session_tmp:
|
||||
_check_connector_and_attempt_status(
|
||||
db_session_tmp,
|
||||
cc_pair_id,
|
||||
index_attempt.search_settings.status,
|
||||
index_attempt_id,
|
||||
)
|
||||
|
||||
# save record of any failures at the connector level
|
||||
if failure is not None:
|
||||
total_failures += 1
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
create_index_attempt_error(
|
||||
index_attempt_id,
|
||||
cc_pair_id,
|
||||
failure,
|
||||
db_session,
|
||||
)
|
||||
_check_failure_threshold(
|
||||
total_failures, document_count, batch_num, failure
|
||||
)
|
||||
|
||||
# Save checkpoint if provided
|
||||
if next_checkpoint:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
# below is all document processing task, so if no batch we can just continue
|
||||
if document_batch is None:
|
||||
continue
|
||||
|
||||
# Clean documents and create batch
|
||||
doc_batch_cleaned = strip_null_characters(document_batch)
|
||||
batch_description = []
|
||||
|
||||
for doc in doc_batch_cleaned:
|
||||
batch_description.append(doc.to_short_descriptor())
|
||||
|
||||
doc_size = 0
|
||||
for section in doc.sections:
|
||||
if (
|
||||
isinstance(section, TextSection)
|
||||
and section.text is not None
|
||||
):
|
||||
doc_size += len(section.text)
|
||||
|
||||
if doc_size > INDEXING_SIZE_WARNING_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Document size: doc='{doc.to_short_descriptor()}' "
|
||||
f"size={doc_size} "
|
||||
f"threshold={INDEXING_SIZE_WARNING_THRESHOLD}"
|
||||
)
|
||||
|
||||
logger.debug(f"Indexing batch of documents: {batch_description}")
|
||||
memory_tracer.increment_and_maybe_trace()
|
||||
|
||||
# Store documents in storage
|
||||
batch_storage.store_batch(batch_num, doc_batch_cleaned)
|
||||
|
||||
# Create processing task data
|
||||
processing_batch_data = {
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"cc_pair_id": cc_pair_id,
|
||||
"tenant_id": tenant_id,
|
||||
"batch_num": batch_num, # 0-indexed
|
||||
}
|
||||
|
||||
# Queue document processing task
|
||||
app.send_task(
|
||||
OnyxCeleryTask.DOCPROCESSING_TASK,
|
||||
kwargs=processing_batch_data,
|
||||
queue=OnyxCeleryQueues.DOCPROCESSING,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
total_doc_batches_queued += 1
|
||||
|
||||
logger.info(
|
||||
f"Queued document processing batch: "
|
||||
f"batch_num={batch_num} "
|
||||
f"docs={len(doc_batch_cleaned)} "
|
||||
f"attempt={index_attempt_id}"
|
||||
)
|
||||
|
||||
# Check checkpoint size periodically
|
||||
CHECKPOINT_SIZE_CHECK_INTERVAL = 100
|
||||
if batch_num % CHECKPOINT_SIZE_CHECK_INTERVAL == 0:
|
||||
check_checkpoint_size(checkpoint)
|
||||
|
||||
# Save latest checkpoint
|
||||
# NOTE: checkpointing is used to track which batches have
|
||||
# been sent to the filestore, NOT which batches have been fully indexed
|
||||
# as it used to be.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
save_checkpoint(
|
||||
db_session=db_session,
|
||||
index_attempt_id=index_attempt_id,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
elapsed_time = time.monotonic() - start_time
|
||||
|
||||
logger.info(
|
||||
f"Document extraction completed: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"batches_queued={total_doc_batches_queued} "
|
||||
f"elapsed={elapsed_time:.2f}s"
|
||||
)
|
||||
|
||||
# Set total batches in database to signal extraction completion.
|
||||
# Used by check_for_indexing to determine if the index attempt is complete.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
IndexingCoordination.set_total_batches(
|
||||
db_session=db_session,
|
||||
index_attempt_id=index_attempt_id,
|
||||
total_batches=batch_num,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Document extraction failed: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"error={str(e)}"
|
||||
)
|
||||
|
||||
# Do NOT clean up batches on failure; future runs will use those batches
|
||||
# while docfetching will continue from the saved checkpoint if one exists
|
||||
|
||||
if isinstance(e, ConnectorValidationError):
|
||||
# On validation errors during indexing, we want to cancel the indexing attempt
|
||||
# and mark the CCPair as invalid. This prevents the connector from being
|
||||
# used in the future until the credentials are updated.
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
logger.exception(
|
||||
f"Marking attempt {index_attempt_id} as canceled due to validation error."
|
||||
)
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
reason=f"{CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX}{str(e)}",
|
||||
)
|
||||
|
||||
if is_primary:
|
||||
if not index_attempt:
|
||||
# should always be set by now
|
||||
raise RuntimeError("Should never happen.")
|
||||
|
||||
VALIDATION_ERROR_THRESHOLD = 5
|
||||
|
||||
recent_index_attempts = get_recent_completed_attempts_for_cc_pair(
|
||||
cc_pair_id=cc_pair_id,
|
||||
search_settings_id=index_attempt.search_settings_id,
|
||||
limit=VALIDATION_ERROR_THRESHOLD,
|
||||
db_session=db_session_temp,
|
||||
)
|
||||
num_validation_errors = len(
|
||||
[
|
||||
index_attempt
|
||||
for index_attempt in recent_index_attempts
|
||||
if index_attempt.error_msg
|
||||
and index_attempt.error_msg.startswith(
|
||||
CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if num_validation_errors >= VALIDATION_ERROR_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Connector {db_connector.id} has {num_validation_errors} consecutive validation"
|
||||
f" errors. Marking the CC Pair as invalid."
|
||||
)
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
status=ConnectorCredentialPairStatus.INVALID,
|
||||
)
|
||||
raise e
|
||||
elif isinstance(e, ConnectorStopSignal):
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
logger.exception(
|
||||
f"Marking attempt {index_attempt_id} as canceled due to stop signal."
|
||||
)
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
reason=str(e),
|
||||
)
|
||||
|
||||
else:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
# don't overwrite attempts that are already failed/canceled for another reason
|
||||
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
|
||||
if index_attempt and index_attempt.status in [
|
||||
IndexingStatus.CANCELED,
|
||||
IndexingStatus.FAILED,
|
||||
]:
|
||||
logger.info(
|
||||
f"Attempt {index_attempt_id} is already failed/canceled, skipping marking as failed."
|
||||
)
|
||||
raise e
|
||||
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
failure_reason=str(e),
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
)
|
||||
|
||||
raise e
|
||||
|
||||
finally:
|
||||
memory_tracer.stop()
|
||||
|
||||
|
||||
def reissue_old_batches(
|
||||
batch_storage: DocumentBatchStorage,
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
tenant_id: str,
|
||||
app: Celery,
|
||||
most_recent_attempt: IndexAttempt | None,
|
||||
) -> tuple[int, int]:
|
||||
# When loading from a checkpoint, we need to start new docprocessing tasks
|
||||
# tied to the new index attempt for any batches left over in the file store
|
||||
old_batches = batch_storage.get_all_batches_for_cc_pair()
|
||||
batch_storage.update_old_batches_to_new_index_attempt(old_batches)
|
||||
for batch_id in old_batches:
|
||||
logger.info(
|
||||
f"Re-issuing docprocessing task for batch {batch_id} for index attempt {index_attempt_id}"
|
||||
)
|
||||
path_info = batch_storage.extract_path_info(batch_id)
|
||||
if path_info is None:
|
||||
continue
|
||||
if path_info.cc_pair_id != cc_pair_id:
|
||||
raise RuntimeError(f"Batch {batch_id} is not for cc pair {cc_pair_id}")
|
||||
|
||||
app.send_task(
|
||||
OnyxCeleryTask.DOCPROCESSING_TASK,
|
||||
kwargs={
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"cc_pair_id": cc_pair_id,
|
||||
"tenant_id": tenant_id,
|
||||
"batch_num": path_info.batch_num, # use same batch num as previously
|
||||
},
|
||||
queue=OnyxCeleryQueues.DOCPROCESSING,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
recent_batches = most_recent_attempt.completed_batches if most_recent_attempt else 0
|
||||
# resume from the batch num of the last attempt. This should be one more
|
||||
# than the last batch created by docfetching regardless of whether the batch
|
||||
# is still in the filestore waiting for processing or not.
|
||||
last_batch_num = len(old_batches) + recent_batches
|
||||
logger.info(
|
||||
f"Starting from batch {last_batch_num} due to "
|
||||
f"re-issued batches: {old_batches}, completed batches: {recent_batches}"
|
||||
)
|
||||
return len(old_batches), recent_batches
|
||||
@@ -725,9 +725,7 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
# load all files needed for this chat chain in memory
|
||||
files = load_all_chat_files(
|
||||
history_msgs, new_msg_req.file_descriptors, db_session
|
||||
)
|
||||
files = load_all_chat_files(history_msgs, new_msg_req.file_descriptors)
|
||||
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
|
||||
latest_query_files = [file for file in files if file.file_id in req_file_ids]
|
||||
user_file_ids = new_msg_req.user_file_ids or []
|
||||
|
||||
@@ -311,18 +311,33 @@ except ValueError:
|
||||
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT
|
||||
)
|
||||
|
||||
CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT = 3
|
||||
CELERY_WORKER_DOCPROCESSING_CONCURRENCY_DEFAULT = 6
|
||||
try:
|
||||
env_value = os.environ.get("CELERY_WORKER_INDEXING_CONCURRENCY")
|
||||
env_value = os.environ.get("CELERY_WORKER_DOCPROCESSING_CONCURRENCY")
|
||||
if not env_value:
|
||||
env_value = os.environ.get("NUM_INDEXING_WORKERS")
|
||||
|
||||
if not env_value:
|
||||
env_value = str(CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT)
|
||||
CELERY_WORKER_INDEXING_CONCURRENCY = int(env_value)
|
||||
env_value = str(CELERY_WORKER_DOCPROCESSING_CONCURRENCY_DEFAULT)
|
||||
CELERY_WORKER_DOCPROCESSING_CONCURRENCY = int(env_value)
|
||||
except ValueError:
|
||||
CELERY_WORKER_INDEXING_CONCURRENCY = CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT
|
||||
CELERY_WORKER_DOCPROCESSING_CONCURRENCY = (
|
||||
CELERY_WORKER_DOCPROCESSING_CONCURRENCY_DEFAULT
|
||||
)
|
||||
|
||||
CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT = 1
|
||||
try:
|
||||
env_value = os.environ.get("CELERY_WORKER_DOCFETCHING_CONCURRENCY")
|
||||
if not env_value:
|
||||
env_value = os.environ.get("NUM_DOCFETCHING_WORKERS")
|
||||
|
||||
if not env_value:
|
||||
env_value = str(CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT)
|
||||
CELERY_WORKER_DOCFETCHING_CONCURRENCY = int(env_value)
|
||||
except ValueError:
|
||||
CELERY_WORKER_DOCFETCHING_CONCURRENCY = (
|
||||
CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT
|
||||
)
|
||||
|
||||
CELERY_WORKER_KG_PROCESSING_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_KG_PROCESSING_CONCURRENCY") or 4
|
||||
|
||||
@@ -65,7 +65,8 @@ POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
|
||||
POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
|
||||
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
|
||||
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_APP_NAME = "celery_worker_indexing"
|
||||
POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME = "celery_worker_docprocessing"
|
||||
POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME = "celery_worker_docfetching"
|
||||
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
|
||||
POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME = "celery_worker_kg_processing"
|
||||
@@ -121,6 +122,8 @@ CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT = 3 * 60 * 60 # 3 hours (in seconds)
|
||||
# hard termination should always fire first if the connector is hung
|
||||
CELERY_INDEXING_LOCK_TIMEOUT = CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT + 900
|
||||
|
||||
# Heartbeat interval for indexing worker liveness detection
|
||||
INDEXING_WORKER_HEARTBEAT_INTERVAL = 30 # seconds
|
||||
|
||||
# how long a task should wait for associated fence to be ready
|
||||
CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT = 5 * 60 # 5 min
|
||||
@@ -331,9 +334,12 @@ class OnyxCeleryQueues:
|
||||
CSV_GENERATION = "csv_generation"
|
||||
|
||||
# Indexing queue
|
||||
CONNECTOR_INDEXING = "connector_indexing"
|
||||
USER_FILES_INDEXING = "user_files_indexing"
|
||||
|
||||
# Document processing pipeline queue
|
||||
DOCPROCESSING = "docprocessing"
|
||||
CONNECTOR_DOC_FETCHING = "connector_doc_fetching"
|
||||
|
||||
# Monitoring queue
|
||||
MONITORING = "monitoring"
|
||||
|
||||
@@ -464,7 +470,11 @@ class OnyxCeleryTask:
|
||||
CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK = (
|
||||
"connector_external_group_sync_generator_task"
|
||||
)
|
||||
CONNECTOR_INDEXING_PROXY_TASK = "connector_indexing_proxy_task"
|
||||
|
||||
# New split indexing tasks
|
||||
CONNECTOR_DOC_FETCHING_TASK = "connector_doc_fetching_task"
|
||||
DOCPROCESSING_TASK = "docprocessing_task"
|
||||
|
||||
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
|
||||
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
|
||||
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"
|
||||
|
||||
@@ -34,7 +34,6 @@ from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.file_processing.extract_file_text import extract_text_and_images
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.extract_file_text import is_accepted_file_ext
|
||||
@@ -281,30 +280,28 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
|
||||
# TODO: Refactor to avoid direct DB access in connector
|
||||
# This will require broader refactoring across the codebase
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
image_section, _ = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=downloaded_file,
|
||||
file_id=f"{self.bucket_type}_{self.bucket_name}_{key.replace('/', '_')}",
|
||||
display_name=file_name,
|
||||
link=link,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
image_section, _ = store_image_and_create_section(
|
||||
image_data=downloaded_file,
|
||||
file_id=f"{self.bucket_type}_{self.bucket_name}_{key.replace('/', '_')}",
|
||||
display_name=file_name,
|
||||
link=link,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
|
||||
sections=[image_section],
|
||||
source=DocumentSource(self.bucket_type.value),
|
||||
semantic_identifier=file_name,
|
||||
doc_updated_at=last_modified,
|
||||
metadata={},
|
||||
)
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
|
||||
sections=[image_section],
|
||||
source=DocumentSource(self.bucket_type.value),
|
||||
semantic_identifier=file_name,
|
||||
doc_updated_at=last_modified,
|
||||
metadata={},
|
||||
)
|
||||
)
|
||||
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
except Exception:
|
||||
logger.exception(f"Error processing image {key}")
|
||||
continue
|
||||
|
||||
@@ -23,7 +23,6 @@ from onyx.configs.app_configs import (
|
||||
)
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import is_accepted_file_ext
|
||||
from onyx.file_processing.extract_file_text import OnyxExtensionType
|
||||
@@ -224,19 +223,17 @@ def _process_image_attachment(
|
||||
"""Process an image attachment by saving it without generating a summary."""
|
||||
try:
|
||||
# Use the standardized image storage and section creation
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
section, file_name = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=raw_bytes,
|
||||
file_id=Path(attachment["id"]).name,
|
||||
display_name=attachment["title"],
|
||||
media_type=media_type,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
logger.info(f"Stored image attachment with file name: {file_name}")
|
||||
section, file_name = store_image_and_create_section(
|
||||
image_data=raw_bytes,
|
||||
file_id=Path(attachment["id"]).name,
|
||||
display_name=attachment["title"],
|
||||
media_type=media_type,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
logger.info(f"Stored image attachment with file name: {file_name}")
|
||||
|
||||
# Return empty text but include the file_name for later processing
|
||||
return AttachmentProcessingResult(text="", file_name=file_name, error=None)
|
||||
# Return empty text but include the file_name for later processing
|
||||
return AttachmentProcessingResult(text="", file_name=file_name, error=None)
|
||||
except Exception as e:
|
||||
msg = f"Image storage failed for {attachment['title']}: {e}"
|
||||
logger.error(msg, exc_info=e)
|
||||
|
||||
@@ -5,8 +5,6 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import IO
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FileOrigin
|
||||
@@ -18,7 +16,6 @@ from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.file_processing.extract_file_text import extract_text_and_images
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.extract_file_text import is_accepted_file_ext
|
||||
@@ -32,7 +29,6 @@ logger = setup_logger()
|
||||
|
||||
def _create_image_section(
|
||||
image_data: bytes,
|
||||
db_session: Session,
|
||||
parent_file_name: str,
|
||||
display_name: str,
|
||||
link: str | None = None,
|
||||
@@ -58,7 +54,6 @@ def _create_image_section(
|
||||
# Store the image and create a section
|
||||
try:
|
||||
section, stored_file_name = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=image_data,
|
||||
file_id=file_id,
|
||||
display_name=display_name,
|
||||
@@ -77,7 +72,6 @@ def _process_file(
|
||||
file: IO[Any],
|
||||
metadata: dict[str, Any] | None,
|
||||
pdf_pass: str | None,
|
||||
db_session: Session,
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Process a file and return a list of Documents.
|
||||
@@ -125,7 +119,6 @@ def _process_file(
|
||||
try:
|
||||
section, _ = _create_image_section(
|
||||
image_data=image_data,
|
||||
db_session=db_session,
|
||||
parent_file_name=file_id,
|
||||
display_name=title,
|
||||
)
|
||||
@@ -196,7 +189,6 @@ def _process_file(
|
||||
try:
|
||||
image_section, stored_file_name = _create_image_section(
|
||||
image_data=img_data,
|
||||
db_session=db_session,
|
||||
parent_file_name=file_id,
|
||||
display_name=f"{title} - image {idx}",
|
||||
idx=idx,
|
||||
@@ -260,37 +252,33 @@ class LocalFileConnector(LoadConnector):
|
||||
"""
|
||||
documents: list[Document] = []
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for file_id in self.file_locations:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_record = file_store.read_file_record(file_id=file_id)
|
||||
if not file_record:
|
||||
# typically an unsupported extension
|
||||
logger.warning(
|
||||
f"No file record found for '{file_id}' in PG; skipping."
|
||||
)
|
||||
continue
|
||||
for file_id in self.file_locations:
|
||||
file_store = get_default_file_store()
|
||||
file_record = file_store.read_file_record(file_id=file_id)
|
||||
if not file_record:
|
||||
# typically an unsupported extension
|
||||
logger.warning(f"No file record found for '{file_id}' in PG; skipping.")
|
||||
continue
|
||||
|
||||
metadata = self._get_file_metadata(file_id)
|
||||
file_io = file_store.read_file(file_id=file_id, mode="b")
|
||||
new_docs = _process_file(
|
||||
file_id=file_id,
|
||||
file_name=file_record.display_name,
|
||||
file=file_io,
|
||||
metadata=metadata,
|
||||
pdf_pass=self.pdf_pass,
|
||||
db_session=db_session,
|
||||
)
|
||||
documents.extend(new_docs)
|
||||
metadata = self._get_file_metadata(file_id)
|
||||
file_io = file_store.read_file(file_id=file_id, mode="b")
|
||||
new_docs = _process_file(
|
||||
file_id=file_id,
|
||||
file_name=file_record.display_name,
|
||||
file=file_io,
|
||||
metadata=metadata,
|
||||
pdf_pass=self.pdf_pass,
|
||||
)
|
||||
documents.extend(new_docs)
|
||||
|
||||
if len(documents) >= self.batch_size:
|
||||
yield documents
|
||||
|
||||
documents = []
|
||||
|
||||
if documents:
|
||||
if len(documents) >= self.batch_size:
|
||||
yield documents
|
||||
|
||||
documents = []
|
||||
|
||||
if documents:
|
||||
yield documents
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = LocalFileConnector(
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
@@ -1374,3 +1378,139 @@ class GoogleDriveConnector(
|
||||
@override
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> GoogleDriveCheckpoint:
|
||||
return GoogleDriveCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
|
||||
def get_credentials_from_env(email: str, oauth: bool) -> dict:
|
||||
if oauth:
|
||||
raw_credential_string = os.environ["GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR"]
|
||||
else:
|
||||
raw_credential_string = os.environ["GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR"]
|
||||
|
||||
refried_credential_string = json.dumps(json.loads(raw_credential_string))
|
||||
|
||||
# This is the Oauth token
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
|
||||
# This is the service account key
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
|
||||
# The email saved for both auth types
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method"
|
||||
cred_key = (
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY
|
||||
if oauth
|
||||
else DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
|
||||
)
|
||||
return {
|
||||
cred_key: refried_credential_string,
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email,
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD: "uploaded",
|
||||
}
|
||||
|
||||
|
||||
class CheckpointOutputWrapper:
|
||||
"""
|
||||
Wraps a CheckpointOutput generator to give things back in a more digestible format.
|
||||
The connector format is easier for the connector implementor (e.g. it enforces exactly
|
||||
one new checkpoint is returned AND that the checkpoint is at the end), thus the different
|
||||
formats.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.next_checkpoint: GoogleDriveCheckpoint | None = None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
checkpoint_connector_generator: CheckpointOutput[GoogleDriveCheckpoint],
|
||||
) -> Generator[
|
||||
tuple[Document | None, ConnectorFailure | None, GoogleDriveCheckpoint | None],
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
# grabs the final return value and stores it in the `next_checkpoint` variable
|
||||
def _inner_wrapper(
|
||||
checkpoint_connector_generator: CheckpointOutput[GoogleDriveCheckpoint],
|
||||
) -> CheckpointOutput[GoogleDriveCheckpoint]:
|
||||
self.next_checkpoint = yield from checkpoint_connector_generator
|
||||
return self.next_checkpoint # not used
|
||||
|
||||
for document_or_failure in _inner_wrapper(checkpoint_connector_generator):
|
||||
if isinstance(document_or_failure, Document):
|
||||
yield document_or_failure, None, None
|
||||
elif isinstance(document_or_failure, ConnectorFailure):
|
||||
yield None, document_or_failure, None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid document_or_failure type: {type(document_or_failure)}"
|
||||
)
|
||||
|
||||
if self.next_checkpoint is None:
|
||||
raise RuntimeError(
|
||||
"Checkpoint is None. This should never happen - the connector should always return a checkpoint."
|
||||
)
|
||||
|
||||
yield None, None, self.next_checkpoint
|
||||
|
||||
|
||||
def yield_all_docs_from_checkpoint_connector(
|
||||
connector: GoogleDriveConnector,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
) -> Iterator[Document | ConnectorFailure]:
|
||||
num_iterations = 0
|
||||
|
||||
checkpoint = connector.build_dummy_checkpoint()
|
||||
while checkpoint.has_more:
|
||||
doc_batch_generator = CheckpointOutputWrapper()(
|
||||
connector.load_from_checkpoint(start, end, checkpoint)
|
||||
)
|
||||
for document, failure, next_checkpoint in doc_batch_generator:
|
||||
if failure is not None:
|
||||
yield failure
|
||||
if document is not None:
|
||||
yield document
|
||||
if next_checkpoint is not None:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
num_iterations += 1
|
||||
if num_iterations > 100_000:
|
||||
raise RuntimeError("Too many iterations. Infinite loop?")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
creds = get_credentials_from_env(
|
||||
os.environ["GOOGLE_DRIVE_PRIMARY_ADMIN_EMAIL"], False
|
||||
)
|
||||
connector = GoogleDriveConnector(
|
||||
include_shared_drives=True,
|
||||
shared_drive_urls=None,
|
||||
include_my_drives=True,
|
||||
my_drive_emails=None,
|
||||
shared_folder_urls=None,
|
||||
include_files_shared_with_me=True,
|
||||
specific_user_emails=None,
|
||||
)
|
||||
connector.load_credentials(creds)
|
||||
max_fsize = 0
|
||||
biggest_fsize = 0
|
||||
num_errors = 0
|
||||
start_time = time.time()
|
||||
with open("stats.txt", "w") as f:
|
||||
for num, doc_or_failure in enumerate(
|
||||
yield_all_docs_from_checkpoint_connector(connector, 0, time.time())
|
||||
):
|
||||
if num % 200 == 0:
|
||||
f.write(f"Processed {num} files\n")
|
||||
f.write(f"Max file size: {max_fsize/1000_000:.2f} MB\n")
|
||||
f.write(f"Time so far: {time.time() - start_time:.2f} seconds\n")
|
||||
f.write(f"Docs per minute: {num/(time.time() - start_time)*60:.2f}\n")
|
||||
biggest_fsize = max(biggest_fsize, max_fsize)
|
||||
max_fsize = 0
|
||||
if isinstance(doc_or_failure, Document):
|
||||
max_fsize = max(max_fsize, sys.getsizeof(doc_or_failure))
|
||||
elif isinstance(doc_or_failure, ConnectorFailure):
|
||||
num_errors += 1
|
||||
print(f"Num errors: {num_errors}")
|
||||
print(f"Biggest file size: {biggest_fsize/1000_000:.2f} MB")
|
||||
print(f"Time taken: {time.time() - start_time:.2f} seconds")
|
||||
|
||||
@@ -29,7 +29,6 @@ from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.file_processing.extract_file_text import ALL_ACCEPTED_FILE_EXTENSIONS
|
||||
from onyx.file_processing.extract_file_text import docx_to_text_and_images
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
@@ -143,17 +142,15 @@ def _download_and_extract_sections_basic(
|
||||
# Store images for later processing
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=response_call(),
|
||||
file_id=file_id,
|
||||
display_name=file_name,
|
||||
media_type=mime_type,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
link=link,
|
||||
)
|
||||
sections.append(section)
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
image_data=response_call(),
|
||||
file_id=file_id,
|
||||
display_name=file_name,
|
||||
media_type=mime_type,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
link=link,
|
||||
)
|
||||
sections.append(section)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process image {file_name}: {e}")
|
||||
return sections
|
||||
@@ -216,16 +213,14 @@ def _download_and_extract_sections_basic(
|
||||
|
||||
# Process embedded images in the PDF
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for idx, (img_data, img_name) in enumerate(images):
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=img_data,
|
||||
file_id=f"{file_id}_img_{idx}",
|
||||
display_name=img_name or f"{file_name} - image {idx}",
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
pdf_sections.append(section)
|
||||
for idx, (img_data, img_name) in enumerate(images):
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
image_data=img_data,
|
||||
file_id=f"{file_id}_img_{idx}",
|
||||
display_name=img_name or f"{file_name} - image {idx}",
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
pdf_sections.append(section)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process PDF images in {file_name}: {e}")
|
||||
return pdf_sections
|
||||
|
||||
@@ -12,7 +12,6 @@ from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.file_processing.extract_file_text import load_files_from_zip
|
||||
from onyx.file_processing.extract_file_text import read_text_file
|
||||
from onyx.file_processing.html_utils import web_html_cleanup
|
||||
@@ -68,10 +67,7 @@ class GoogleSitesConnector(LoadConnector):
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
documents: list[Document] = []
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
file_content_io = get_default_file_store(db_session).read_file(
|
||||
self.zip_path, mode="b"
|
||||
)
|
||||
file_content_io = get_default_file_store().read_file(self.zip_path, mode="b")
|
||||
|
||||
# load the HTML files
|
||||
files = load_files_from_zip(file_content_io)
|
||||
|
||||
@@ -11,6 +11,7 @@ from onyx.access.models import ExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import INDEX_SEPARATOR
|
||||
from onyx.configs.constants import RETURN_SEPARATOR
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.utils.text_processing import make_url_compatible
|
||||
|
||||
|
||||
@@ -363,6 +364,10 @@ class ConnectorFailure(BaseModel):
|
||||
return values
|
||||
|
||||
|
||||
class ConnectorStopSignal(Exception):
|
||||
"""A custom exception used to signal a stop in processing."""
|
||||
|
||||
|
||||
class OnyxMetadata(BaseModel):
|
||||
# Note that doc_id cannot be overriden here as it may cause issues
|
||||
# with the display functionalities in the UI. Ask @chris if clarification is needed.
|
||||
@@ -373,3 +378,24 @@ class OnyxMetadata(BaseModel):
|
||||
secondary_owners: list[BasicExpertInfo] | None = None
|
||||
doc_updated_at: datetime | None = None
|
||||
title: str | None = None
|
||||
|
||||
|
||||
class DocExtractionContext(BaseModel):
|
||||
index_name: str
|
||||
cc_pair_id: int
|
||||
connector_id: int
|
||||
credential_id: int
|
||||
source: DocumentSource
|
||||
earliest_index_time: float
|
||||
from_beginning: bool
|
||||
is_primary: bool
|
||||
should_fetch_permissions_during_indexing: bool
|
||||
search_settings_status: IndexModelStatus
|
||||
doc_extraction_complete_batch_num: int | None
|
||||
|
||||
|
||||
class DocIndexingContext(BaseModel):
|
||||
batches_done: int
|
||||
total_failures: int
|
||||
net_doc_change: int
|
||||
total_chunks: int
|
||||
|
||||
@@ -267,7 +267,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
|
||||
result = ""
|
||||
for prop_name, prop in properties.items():
|
||||
if not prop:
|
||||
if not prop or not isinstance(prop, dict):
|
||||
continue
|
||||
|
||||
try:
|
||||
|
||||
@@ -992,7 +992,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
doc_metadata_list: list[SlimDocument] = []
|
||||
for parent_object_type in self.parent_object_list:
|
||||
query = f"SELECT Id FROM {parent_object_type}"
|
||||
query_result = self.sf_client.query_all(query)
|
||||
query_result = self.sf_client.safe_query_all(query)
|
||||
doc_metadata_list.extend(
|
||||
SlimDocument(
|
||||
id=f"{ID_PREFIX}{instance_dict.get('Id', '')}",
|
||||
|
||||
@@ -1,18 +1,31 @@
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from simple_salesforce import Salesforce
|
||||
from simple_salesforce import SFType
|
||||
from simple_salesforce.exceptions import SalesforceRefusedRequest
|
||||
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from onyx.connectors.salesforce.blacklist import SALESFORCE_BLACKLISTED_OBJECTS
|
||||
from onyx.connectors.salesforce.blacklist import SALESFORCE_BLACKLISTED_PREFIXES
|
||||
from onyx.connectors.salesforce.blacklist import SALESFORCE_BLACKLISTED_SUFFIXES
|
||||
from onyx.connectors.salesforce.salesforce_calls import get_object_by_id_query
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_salesforce_rate_limit_error(exception: Exception) -> bool:
|
||||
"""Check if an exception is a Salesforce rate limit error."""
|
||||
return isinstance(
|
||||
exception, SalesforceRefusedRequest
|
||||
) and "REQUEST_LIMIT_EXCEEDED" in str(exception)
|
||||
|
||||
|
||||
class OnyxSalesforce(Salesforce):
|
||||
SOQL_MAX_SUBQUERIES = 20
|
||||
|
||||
@@ -52,6 +65,48 @@ class OnyxSalesforce(Salesforce):
|
||||
|
||||
return False
|
||||
|
||||
@retry_builder(
|
||||
tries=5,
|
||||
delay=20,
|
||||
backoff=1.5,
|
||||
max_delay=60,
|
||||
exceptions=(SalesforceRefusedRequest,),
|
||||
)
|
||||
@rate_limit_builder(max_calls=50, period=60)
|
||||
def safe_query(self, query: str, **kwargs: Any) -> dict[str, Any]:
|
||||
"""Wrapper around the original query method with retry logic and rate limiting."""
|
||||
try:
|
||||
return super().query(query, **kwargs)
|
||||
except SalesforceRefusedRequest as e:
|
||||
if is_salesforce_rate_limit_error(e):
|
||||
logger.warning(
|
||||
f"Salesforce rate limit exceeded for query: {query[:100]}..."
|
||||
)
|
||||
# Add additional delay for rate limit errors
|
||||
time.sleep(5)
|
||||
raise
|
||||
|
||||
@retry_builder(
|
||||
tries=5,
|
||||
delay=20,
|
||||
backoff=1.5,
|
||||
max_delay=60,
|
||||
exceptions=(SalesforceRefusedRequest,),
|
||||
)
|
||||
@rate_limit_builder(max_calls=50, period=60)
|
||||
def safe_query_all(self, query: str, **kwargs: Any) -> dict[str, Any]:
|
||||
"""Wrapper around the original query_all method with retry logic and rate limiting."""
|
||||
try:
|
||||
return super().query_all(query, **kwargs)
|
||||
except SalesforceRefusedRequest as e:
|
||||
if is_salesforce_rate_limit_error(e):
|
||||
logger.warning(
|
||||
f"Salesforce rate limit exceeded for query_all: {query[:100]}..."
|
||||
)
|
||||
# Add additional delay for rate limit errors
|
||||
time.sleep(5)
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _make_child_objects_by_id_query(
|
||||
object_id: str,
|
||||
@@ -99,7 +154,7 @@ class OnyxSalesforce(Salesforce):
|
||||
|
||||
queryable_fields = type_to_queryable_fields[object_type]
|
||||
query = get_object_by_id_query(object_id, object_type, queryable_fields)
|
||||
result = self.query(query)
|
||||
result = self.safe_query(query)
|
||||
if not result:
|
||||
return None
|
||||
|
||||
@@ -151,7 +206,7 @@ class OnyxSalesforce(Salesforce):
|
||||
)
|
||||
|
||||
try:
|
||||
result = self.query(query)
|
||||
result = self.safe_query(query)
|
||||
except Exception:
|
||||
logger.exception(f"Query failed: {query=}")
|
||||
else:
|
||||
@@ -189,10 +244,25 @@ class OnyxSalesforce(Salesforce):
|
||||
|
||||
return child_records
|
||||
|
||||
@retry_builder(
|
||||
tries=3,
|
||||
delay=1,
|
||||
backoff=2,
|
||||
exceptions=(SalesforceRefusedRequest,),
|
||||
)
|
||||
def describe_type(self, name: str) -> Any:
|
||||
sf_object = SFType(name, self.session_id, self.sf_instance)
|
||||
result = sf_object.describe()
|
||||
return result
|
||||
try:
|
||||
result = sf_object.describe()
|
||||
return result
|
||||
except SalesforceRefusedRequest as e:
|
||||
if is_salesforce_rate_limit_error(e):
|
||||
logger.warning(
|
||||
f"Salesforce rate limit exceeded for describe_type: {name}"
|
||||
)
|
||||
# Add additional delay for rate limit errors
|
||||
time.sleep(3)
|
||||
raise
|
||||
|
||||
def get_queryable_fields_by_type(self, name: str) -> list[str]:
|
||||
object_description = self.describe_type(name)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import gc
|
||||
import os
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
|
||||
@@ -7,13 +8,25 @@ from pytz import UTC
|
||||
from simple_salesforce import Salesforce
|
||||
from simple_salesforce.bulk2 import SFBulk2Handler
|
||||
from simple_salesforce.bulk2 import SFBulk2Type
|
||||
from simple_salesforce.exceptions import SalesforceRefusedRequest
|
||||
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_salesforce_rate_limit_error(exception: Exception) -> bool:
|
||||
"""Check if an exception is a Salesforce rate limit error."""
|
||||
return isinstance(
|
||||
exception, SalesforceRefusedRequest
|
||||
) and "REQUEST_LIMIT_EXCEEDED" in str(exception)
|
||||
|
||||
|
||||
def _build_last_modified_time_filter_for_salesforce(
|
||||
start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
|
||||
) -> str:
|
||||
@@ -71,6 +84,14 @@ def get_object_by_id_query(
|
||||
return query
|
||||
|
||||
|
||||
@retry_builder(
|
||||
tries=5,
|
||||
delay=2,
|
||||
backoff=2,
|
||||
max_delay=60,
|
||||
exceptions=(SalesforceRefusedRequest,),
|
||||
)
|
||||
@rate_limit_builder(max_calls=50, period=60)
|
||||
def _object_type_has_api_data(
|
||||
sf_client: Salesforce, sf_type: str, time_filter: str
|
||||
) -> bool:
|
||||
@@ -82,6 +103,15 @@ def _object_type_has_api_data(
|
||||
result = sf_client.query(query)
|
||||
if result["totalSize"] == 0:
|
||||
return False
|
||||
except SalesforceRefusedRequest as e:
|
||||
if is_salesforce_rate_limit_error(e):
|
||||
logger.warning(
|
||||
f"Salesforce rate limit exceeded for object type check: {sf_type}"
|
||||
)
|
||||
# Add additional delay for rate limit errors
|
||||
time.sleep(3)
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
if "OPERATION_TOO_LARGE" not in str(e):
|
||||
logger.warning(f"Object type {sf_type} doesn't support query: {e}")
|
||||
|
||||
@@ -25,7 +25,6 @@ from onyx.context.search.models import MAX_METRICS_CONTENT
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.models import RerankMetricsContainer
|
||||
from onyx.context.search.models import SearchQuery
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.document_index.document_index_utils import (
|
||||
translate_boost_count_to_multiplier,
|
||||
)
|
||||
@@ -70,18 +69,16 @@ def update_image_sections_with_query(
|
||||
logger.debug(
|
||||
f"Processing image chunk with ID: {chunk.unique_id}, image: {chunk.image_file_id}"
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
file_record = get_default_file_store(db_session).read_file(
|
||||
cast(str, chunk.image_file_id), mode="b"
|
||||
)
|
||||
if not file_record:
|
||||
logger.error(f"Image file not found: {chunk.image_file_id}")
|
||||
raise Exception("File not found")
|
||||
file_content = file_record.read()
|
||||
image_base64 = base64.b64encode(file_content).decode()
|
||||
logger.debug(
|
||||
f"Successfully loaded image data for {chunk.image_file_id}"
|
||||
)
|
||||
|
||||
file_record = get_default_file_store().read_file(
|
||||
cast(str, chunk.image_file_id), mode="b"
|
||||
)
|
||||
if not file_record:
|
||||
logger.error(f"Image file not found: {chunk.image_file_id}")
|
||||
raise Exception("File not found")
|
||||
file_content = file_record.read()
|
||||
image_base64 = base64.b64encode(file_content).decode()
|
||||
logger.debug(f"Successfully loaded image data for {chunk.image_file_id}")
|
||||
|
||||
messages: list[BaseMessage] = [
|
||||
SystemMessage(content=IMAGE_ANALYSIS_SYSTEM_PROMPT),
|
||||
|
||||
@@ -229,7 +229,7 @@ def delete_messages_and_files_from_chat_session(
|
||||
delete_tool_call_for_message_id(message_id=id, db_session=db_session)
|
||||
delete_search_doc_message_relationship(message_id=id, db_session=db_session)
|
||||
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
for file_info in files or []:
|
||||
file_store.delete_file(file_id=file_info.get("id"))
|
||||
|
||||
|
||||
@@ -264,6 +264,7 @@ def get_connector_credential_pair_from_id_for_user(
|
||||
def get_connector_credential_pair_from_id(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
eager_load_connector: bool = False,
|
||||
eager_load_credential: bool = False,
|
||||
) -> ConnectorCredentialPair | None:
|
||||
stmt = select(ConnectorCredentialPair).distinct()
|
||||
@@ -271,6 +272,8 @@ def get_connector_credential_pair_from_id(
|
||||
|
||||
if eager_load_credential:
|
||||
stmt = stmt.options(joinedload(ConnectorCredentialPair.credential))
|
||||
if eager_load_connector:
|
||||
stmt = stmt.options(joinedload(ConnectorCredentialPair.connector))
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@@ -849,7 +849,9 @@ def fetch_chunk_counts_for_documents(
|
||||
# Create a dictionary of document_id to chunk_count
|
||||
chunk_counts = {str(row.id): row.chunk_count or 0 for row in results}
|
||||
|
||||
# Return a list of tuples, using 0 for documents not found in the database
|
||||
# Return a list of tuples, preserving `None` for documents not found or with
|
||||
# an unknown chunk count. Callers should handle the `None` case and fall
|
||||
# back to an existence check against the vector DB if necessary.
|
||||
return [(doc_id, chunk_counts.get(doc_id, 0)) for doc_id in document_ids]
|
||||
|
||||
|
||||
|
||||
@@ -305,6 +305,18 @@ def get_session_with_current_tenant() -> Generator[Session, None, None]:
|
||||
yield session
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session_with_current_tenant_if_none(
|
||||
session: Session | None,
|
||||
) -> Generator[Session, None, None]:
|
||||
if session is None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as session:
|
||||
yield session
|
||||
else:
|
||||
yield session
|
||||
|
||||
|
||||
# Used in multi tenant mode when need to refer to the shared `public` schema
|
||||
@contextmanager
|
||||
def get_session_with_shared_schema() -> Generator[Session, None, None]:
|
||||
|
||||
@@ -43,6 +43,17 @@ def get_filerecord_by_file_id(
|
||||
return filestore
|
||||
|
||||
|
||||
def get_filerecord_by_prefix(
|
||||
prefix: str,
|
||||
db_session: Session,
|
||||
) -> list[FileRecord]:
|
||||
if not prefix:
|
||||
return db_session.query(FileRecord).all()
|
||||
return (
|
||||
db_session.query(FileRecord).filter(FileRecord.file_id.like(f"{prefix}%")).all()
|
||||
)
|
||||
|
||||
|
||||
def delete_filerecord_by_file_id(
|
||||
file_id: str,
|
||||
db_session: Session,
|
||||
|
||||
@@ -28,6 +28,8 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
|
||||
# from sqlalchemy.sql.selectable import Select
|
||||
|
||||
# Comment out unused imports that cause mypy errors
|
||||
# from onyx.auth.models import UserRole
|
||||
# from onyx.configs.constants import MAX_LAST_VALID_CHECKPOINT_AGE_SECONDS
|
||||
@@ -95,23 +97,52 @@ def get_recent_attempts_for_cc_pair(
|
||||
|
||||
|
||||
def get_index_attempt(
|
||||
db_session: Session, index_attempt_id: int
|
||||
db_session: Session,
|
||||
index_attempt_id: int,
|
||||
eager_load_cc_pair: bool = False,
|
||||
eager_load_search_settings: bool = False,
|
||||
) -> IndexAttempt | None:
|
||||
stmt = select(IndexAttempt).where(IndexAttempt.id == index_attempt_id)
|
||||
if eager_load_cc_pair:
|
||||
stmt = stmt.options(
|
||||
joinedload(IndexAttempt.connector_credential_pair).joinedload(
|
||||
ConnectorCredentialPair.connector
|
||||
)
|
||||
)
|
||||
stmt = stmt.options(
|
||||
joinedload(IndexAttempt.connector_credential_pair).joinedload(
|
||||
ConnectorCredentialPair.credential
|
||||
)
|
||||
)
|
||||
if eager_load_search_settings:
|
||||
stmt = stmt.options(joinedload(IndexAttempt.search_settings))
|
||||
return db_session.scalars(stmt).first()
|
||||
|
||||
|
||||
def count_error_rows_for_index_attempt(
|
||||
index_attempt_id: int,
|
||||
db_session: Session,
|
||||
) -> int:
|
||||
return (
|
||||
db_session.query(IndexAttemptError)
|
||||
.filter(IndexAttemptError.index_attempt_id == index_attempt_id)
|
||||
.count()
|
||||
)
|
||||
|
||||
|
||||
def create_index_attempt(
|
||||
connector_credential_pair_id: int,
|
||||
search_settings_id: int,
|
||||
db_session: Session,
|
||||
from_beginning: bool = False,
|
||||
celery_task_id: str | None = None,
|
||||
) -> int:
|
||||
new_attempt = IndexAttempt(
|
||||
connector_credential_pair_id=connector_credential_pair_id,
|
||||
search_settings_id=search_settings_id,
|
||||
from_beginning=from_beginning,
|
||||
status=IndexingStatus.NOT_STARTED,
|
||||
celery_task_id=celery_task_id,
|
||||
)
|
||||
db_session.add(new_attempt)
|
||||
db_session.commit()
|
||||
@@ -247,7 +278,7 @@ def mark_attempt_in_progress(
|
||||
def mark_attempt_succeeded(
|
||||
index_attempt_id: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
) -> IndexAttempt:
|
||||
try:
|
||||
attempt = db_session.execute(
|
||||
select(IndexAttempt)
|
||||
@@ -256,6 +287,7 @@ def mark_attempt_succeeded(
|
||||
).scalar_one()
|
||||
|
||||
attempt.status = IndexingStatus.SUCCESS
|
||||
attempt.celery_task_id = None
|
||||
db_session.commit()
|
||||
|
||||
# Add telemetry for index attempt status change
|
||||
@@ -267,6 +299,7 @@ def mark_attempt_succeeded(
|
||||
"cc_pair_id": attempt.connector_credential_pair_id,
|
||||
},
|
||||
)
|
||||
return attempt
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
raise
|
||||
@@ -275,7 +308,7 @@ def mark_attempt_succeeded(
|
||||
def mark_attempt_partially_succeeded(
|
||||
index_attempt_id: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
) -> IndexAttempt:
|
||||
try:
|
||||
attempt = db_session.execute(
|
||||
select(IndexAttempt)
|
||||
@@ -284,6 +317,7 @@ def mark_attempt_partially_succeeded(
|
||||
).scalar_one()
|
||||
|
||||
attempt.status = IndexingStatus.COMPLETED_WITH_ERRORS
|
||||
attempt.celery_task_id = None
|
||||
db_session.commit()
|
||||
|
||||
# Add telemetry for index attempt status change
|
||||
@@ -295,6 +329,7 @@ def mark_attempt_partially_succeeded(
|
||||
"cc_pair_id": attempt.connector_credential_pair_id,
|
||||
},
|
||||
)
|
||||
return attempt
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
raise
|
||||
@@ -350,6 +385,7 @@ def mark_attempt_failed(
|
||||
attempt.status = IndexingStatus.FAILED
|
||||
attempt.error_msg = failure_reason
|
||||
attempt.full_exception_trace = full_exception_trace
|
||||
attempt.celery_task_id = None
|
||||
db_session.commit()
|
||||
|
||||
# Add telemetry for index attempt status change
|
||||
@@ -373,16 +409,22 @@ def update_docs_indexed(
|
||||
new_docs_indexed: int,
|
||||
docs_removed_from_index: int,
|
||||
) -> None:
|
||||
"""Updates the docs_indexed and new_docs_indexed fields of an index attempt.
|
||||
Adds the given values to the current values in the db"""
|
||||
try:
|
||||
attempt = db_session.execute(
|
||||
select(IndexAttempt)
|
||||
.where(IndexAttempt.id == index_attempt_id)
|
||||
.with_for_update()
|
||||
.with_for_update() # Locks the row when we try to update
|
||||
).scalar_one()
|
||||
|
||||
attempt.total_docs_indexed = total_docs_indexed
|
||||
attempt.new_docs_indexed = new_docs_indexed
|
||||
attempt.docs_removed_from_index = docs_removed_from_index
|
||||
attempt.total_docs_indexed = (
|
||||
attempt.total_docs_indexed or 0
|
||||
) + total_docs_indexed
|
||||
attempt.new_docs_indexed = (attempt.new_docs_indexed or 0) + new_docs_indexed
|
||||
attempt.docs_removed_from_index = (
|
||||
attempt.docs_removed_from_index or 0
|
||||
) + docs_removed_from_index
|
||||
db_session.commit()
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
|
||||
307
backend/onyx/db/indexing_coordination.py
Normal file
307
backend/onyx/db/indexing_coordination.py
Normal file
@@ -0,0 +1,307 @@
|
||||
"""Database-based indexing coordination to replace Redis fencing."""
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.time_utils import get_db_current_time
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.index_attempt import count_error_rows_for_index_attempt
|
||||
from onyx.db.index_attempt import create_index_attempt
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
INDEXING_PROGRESS_TIMEOUT_HOURS = 6
|
||||
|
||||
|
||||
class CoordinationStatus(BaseModel):
|
||||
"""Status of an indexing attempt's coordination."""
|
||||
|
||||
found: bool
|
||||
total_batches: int | None
|
||||
completed_batches: int
|
||||
total_failures: int
|
||||
total_docs: int
|
||||
total_chunks: int
|
||||
status: IndexingStatus | None = None
|
||||
cancellation_requested: bool = False
|
||||
|
||||
|
||||
class IndexingCoordination:
|
||||
"""Database-based coordination for indexing tasks, replacing Redis fencing."""
|
||||
|
||||
@staticmethod
|
||||
def try_create_index_attempt(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
celery_task_id: str,
|
||||
from_beginning: bool = False,
|
||||
) -> int | None:
|
||||
"""
|
||||
Try to create a new index attempt for the given CC pair and search settings.
|
||||
Returns the index_attempt_id if successful, None if another attempt is already running.
|
||||
|
||||
This replaces the Redis fencing mechanism by using database constraints
|
||||
and transactions to prevent duplicate attempts.
|
||||
"""
|
||||
try:
|
||||
# Check for existing active attempts (this is the "fence" check)
|
||||
existing_attempt = db_session.execute(
|
||||
select(IndexAttempt)
|
||||
.where(
|
||||
IndexAttempt.connector_credential_pair_id == cc_pair_id,
|
||||
IndexAttempt.search_settings_id == search_settings_id,
|
||||
IndexAttempt.status.in_(
|
||||
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
|
||||
),
|
||||
)
|
||||
.with_for_update(nowait=True)
|
||||
).first()
|
||||
|
||||
if existing_attempt:
|
||||
logger.info(
|
||||
f"Indexing already in progress: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"existing_attempt={existing_attempt[0].id}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Create new index attempt (this is setting the "fence")
|
||||
attempt_id = create_index_attempt(
|
||||
connector_credential_pair_id=cc_pair_id,
|
||||
search_settings_id=search_settings_id,
|
||||
from_beginning=from_beginning,
|
||||
db_session=db_session,
|
||||
celery_task_id=celery_task_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created Index Attempt: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"attempt_id={attempt_id} "
|
||||
f"celery_task_id={celery_task_id}"
|
||||
)
|
||||
|
||||
return attempt_id
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.info(
|
||||
f"Failed to create index attempt (likely race condition): "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"error={str(e)}"
|
||||
)
|
||||
db_session.rollback()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def check_cancellation_requested(
|
||||
db_session: Session,
|
||||
index_attempt_id: int,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if cancellation has been requested for this indexing attempt.
|
||||
This replaces Redis termination signals.
|
||||
"""
|
||||
attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
return attempt.cancellation_requested if attempt else False
|
||||
|
||||
@staticmethod
|
||||
def request_cancellation(
|
||||
db_session: Session,
|
||||
index_attempt_id: int,
|
||||
) -> None:
|
||||
"""
|
||||
Request cancellation of an indexing attempt.
|
||||
This replaces Redis termination signals.
|
||||
"""
|
||||
attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if attempt:
|
||||
attempt.cancellation_requested = True
|
||||
db_session.commit()
|
||||
|
||||
logger.info(f"Requested cancellation for attempt {index_attempt_id}")
|
||||
|
||||
@staticmethod
|
||||
def set_total_batches(
|
||||
db_session: Session,
|
||||
index_attempt_id: int,
|
||||
total_batches: int,
|
||||
) -> None:
|
||||
"""
|
||||
Set the total number of batches for this indexing attempt.
|
||||
Called by docfetching when extraction is complete.
|
||||
"""
|
||||
attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if attempt:
|
||||
attempt.total_batches = total_batches
|
||||
db_session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Set total batches: attempt={index_attempt_id} total={total_batches}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_batch_completion_and_docs(
|
||||
db_session: Session,
|
||||
index_attempt_id: int,
|
||||
total_docs_indexed: int,
|
||||
new_docs_indexed: int,
|
||||
total_chunks: int,
|
||||
) -> tuple[int, int | None]:
|
||||
"""
|
||||
Update batch completion and document counts atomically.
|
||||
Returns (completed_batches, total_batches).
|
||||
This extends the existing update_docs_indexed pattern.
|
||||
"""
|
||||
try:
|
||||
attempt = db_session.execute(
|
||||
select(IndexAttempt)
|
||||
.where(IndexAttempt.id == index_attempt_id)
|
||||
.with_for_update() # Same pattern as existing update_docs_indexed
|
||||
).scalar_one()
|
||||
|
||||
# Existing document count updates
|
||||
attempt.total_docs_indexed = (
|
||||
attempt.total_docs_indexed or 0
|
||||
) + total_docs_indexed
|
||||
attempt.new_docs_indexed = (
|
||||
attempt.new_docs_indexed or 0
|
||||
) + new_docs_indexed
|
||||
|
||||
# New coordination updates
|
||||
attempt.completed_batches = (attempt.completed_batches or 0) + 1
|
||||
attempt.total_chunks = (attempt.total_chunks or 0) + total_chunks
|
||||
|
||||
db_session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Updated batch completion: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"completed={attempt.completed_batches} "
|
||||
f"total={attempt.total_batches} "
|
||||
f"docs={total_docs_indexed} "
|
||||
)
|
||||
|
||||
return attempt.completed_batches, attempt.total_batches
|
||||
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
logger.exception(
|
||||
f"Failed to update batch completion for attempt {index_attempt_id}"
|
||||
)
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_coordination_status(
|
||||
db_session: Session,
|
||||
index_attempt_id: int,
|
||||
) -> CoordinationStatus:
|
||||
"""
|
||||
Get the current coordination status for an indexing attempt.
|
||||
This replaces reading FileStore state files.
|
||||
"""
|
||||
attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if not attempt:
|
||||
return CoordinationStatus(
|
||||
found=False,
|
||||
total_batches=None,
|
||||
completed_batches=0,
|
||||
total_failures=0,
|
||||
total_docs=0,
|
||||
total_chunks=0,
|
||||
status=None,
|
||||
cancellation_requested=False,
|
||||
)
|
||||
|
||||
return CoordinationStatus(
|
||||
found=True,
|
||||
total_batches=attempt.total_batches,
|
||||
completed_batches=attempt.completed_batches,
|
||||
total_failures=count_error_rows_for_index_attempt(
|
||||
index_attempt_id, db_session
|
||||
),
|
||||
total_docs=attempt.total_docs_indexed or 0,
|
||||
total_chunks=attempt.total_chunks,
|
||||
status=attempt.status,
|
||||
cancellation_requested=attempt.cancellation_requested,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_orphaned_index_attempt_ids(db_session: Session) -> list[int]:
|
||||
"""
|
||||
Gets a list of potentially orphaned index attempts.
|
||||
These are attempts in non-terminal state that have task IDs but may have died.
|
||||
|
||||
This replaces the old get_unfenced_index_attempt_ids function.
|
||||
The actual orphan detection requires checking with Celery, which should be
|
||||
done by the caller.
|
||||
"""
|
||||
# Find attempts that are active and have task IDs
|
||||
# The caller needs to check each one with Celery to confirm orphaned status
|
||||
active_attempts = (
|
||||
db_session.execute(
|
||||
select(IndexAttempt).where(
|
||||
IndexAttempt.status.in_(
|
||||
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
|
||||
),
|
||||
IndexAttempt.celery_task_id.isnot(None),
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
return [attempt.id for attempt in active_attempts]
|
||||
|
||||
@staticmethod
|
||||
def update_progress_tracking(
|
||||
db_session: Session,
|
||||
index_attempt_id: int,
|
||||
current_batches_completed: int,
|
||||
timeout_hours: int = INDEXING_PROGRESS_TIMEOUT_HOURS,
|
||||
) -> bool:
|
||||
"""
|
||||
Update progress tracking for stall detection.
|
||||
Returns True if sufficient progress was made, False if stalled.
|
||||
"""
|
||||
|
||||
attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if not attempt:
|
||||
logger.error(f"Index attempt {index_attempt_id} not found in database")
|
||||
return False
|
||||
|
||||
current_time = get_db_current_time(db_session)
|
||||
|
||||
# No progress - check if this is the first time tracking
|
||||
if attempt.last_progress_time is None:
|
||||
# First time tracking - initialize
|
||||
attempt.last_progress_time = current_time
|
||||
attempt.last_batches_completed_count = current_batches_completed
|
||||
db_session.commit()
|
||||
return True
|
||||
|
||||
time_elapsed = (current_time - attempt.last_progress_time).total_seconds()
|
||||
# only actually write to db every timeout_hours/2
|
||||
# this ensure thats at most timeout_hours will pass with no activity
|
||||
if time_elapsed < timeout_hours * 1800:
|
||||
return True
|
||||
|
||||
# Check if progress has been made
|
||||
if current_batches_completed <= attempt.last_batches_completed_count:
|
||||
# if between timeout_hours/2 and timeout_hours has passed
|
||||
# without an update, we consider the attempt stalled
|
||||
return False
|
||||
|
||||
# Progress made - update tracking
|
||||
attempt.last_progress_time = current_time
|
||||
attempt.last_batches_completed_count = current_batches_completed
|
||||
db_session.commit()
|
||||
return True
|
||||
@@ -1612,9 +1612,7 @@ class SearchSettings(Base):
|
||||
|
||||
@property
|
||||
def final_embedding_dim(self) -> int:
|
||||
if self.reduced_dimension:
|
||||
return self.reduced_dimension
|
||||
return self.model_dim
|
||||
return self.reduced_dimension or self.model_dim
|
||||
|
||||
@staticmethod
|
||||
def can_use_large_chunks(
|
||||
@@ -1635,7 +1633,7 @@ class SearchSettings(Base):
|
||||
|
||||
class IndexAttempt(Base):
|
||||
"""
|
||||
Represents an attempt to index a group of 1 or more documents from a
|
||||
Represents an attempt to index a group of 0 or more documents from a
|
||||
source. For example, a single pull from Google Drive, a single event from
|
||||
slack event API, or a single website crawl.
|
||||
"""
|
||||
@@ -1683,6 +1681,30 @@ class IndexAttempt(Base):
|
||||
# can be taken to the FileStore to grab the actual checkpoint value
|
||||
checkpoint_pointer: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# NEW: Database-based coordination fields (replacing Redis fencing)
|
||||
celery_task_id: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
cancellation_requested: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
# NEW: Batch coordination fields (replacing FileStore state)
|
||||
total_batches: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
completed_batches: Mapped[int] = mapped_column(Integer, default=0)
|
||||
# TODO: unused, remove this column
|
||||
total_failures_batch_level: Mapped[int] = mapped_column(Integer, default=0)
|
||||
total_chunks: Mapped[int] = mapped_column(Integer, default=0)
|
||||
|
||||
# Progress tracking for stall detection
|
||||
last_progress_time: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
last_batches_completed_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
|
||||
# NEW: Heartbeat tracking for worker liveness detection
|
||||
heartbeat_counter: Mapped[int] = mapped_column(Integer, default=0)
|
||||
last_heartbeat_value: Mapped[int] = mapped_column(Integer, default=0)
|
||||
last_heartbeat_time: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
@@ -1733,6 +1755,13 @@ class IndexAttempt(Base):
|
||||
"status",
|
||||
desc("time_updated"),
|
||||
),
|
||||
# NEW: Index for coordination queries
|
||||
Index(
|
||||
"ix_index_attempt_active_coordination",
|
||||
"connector_credential_pair_id",
|
||||
"search_settings_id",
|
||||
"status",
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -1747,6 +1776,13 @@ class IndexAttempt(Base):
|
||||
def is_finished(self) -> bool:
|
||||
return self.status.is_terminal()
|
||||
|
||||
def is_coordination_complete(self) -> bool:
|
||||
"""Check if all batches have been processed"""
|
||||
return (
|
||||
self.total_batches is not None
|
||||
and self.completed_batches >= self.total_batches
|
||||
)
|
||||
|
||||
|
||||
class IndexAttemptError(Base):
|
||||
__tablename__ = "index_attempt_errors"
|
||||
@@ -3151,7 +3187,7 @@ class PublicExternalUserGroup(Base):
|
||||
|
||||
class UsageReport(Base):
|
||||
"""This stores metadata about usage reports generated by admin including user who generated
|
||||
them as well las the period they cover. The actual zip file of the report is stored as a lo
|
||||
them as well as the period they cover. The actual zip file of the report is stored as a lo
|
||||
using the FileRecord
|
||||
"""
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ def create_user_files(
|
||||
|
||||
# NOTE: At the moment, zip metadata is not used for user files.
|
||||
# Should revisit to decide whether this should be a feature.
|
||||
upload_response = upload_files(files, db_session)
|
||||
upload_response = upload_files(files)
|
||||
user_files = []
|
||||
|
||||
for file_path, file in zip(upload_response.file_paths, files):
|
||||
|
||||
@@ -45,7 +45,7 @@ class IndexBatchParams:
|
||||
Information necessary for efficiently indexing a batch of documents
|
||||
"""
|
||||
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int | None]
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int]
|
||||
doc_id_to_new_chunk_cnt: dict[str, int]
|
||||
tenant_id: str
|
||||
large_chunks_enabled: bool
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from io import BytesIO
|
||||
from typing import Tuple
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
@@ -12,7 +10,6 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def store_image_and_create_section(
|
||||
db_session: Session,
|
||||
image_data: bytes,
|
||||
file_id: str,
|
||||
display_name: str,
|
||||
@@ -24,7 +21,6 @@ def store_image_and_create_section(
|
||||
Stores an image in FileStore and creates an ImageSection object without summarization.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
image_data: Raw image bytes
|
||||
file_id: Base identifier for the file
|
||||
display_name: Human-readable name for the image
|
||||
@@ -38,7 +34,7 @@ def store_image_and_create_section(
|
||||
"""
|
||||
# Storage logic
|
||||
try:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
file_id = file_store.save_file(
|
||||
content=BytesIO(image_data),
|
||||
display_name=display_name,
|
||||
|
||||
228
backend/onyx/file_store/document_batch_storage.py
Normal file
228
backend/onyx/file_store/document_batch_storage.py
Normal file
@@ -0,0 +1,228 @@
|
||||
import json
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from io import StringIO
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import TypeAlias
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.connectors.models import DocExtractionContext
|
||||
from onyx.connectors.models import DocIndexingContext
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.file_store.file_store import FileStore
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class DocumentBatchStorageStateType(str, Enum):
|
||||
EXTRACTION = "extraction"
|
||||
INDEXING = "indexing"
|
||||
|
||||
|
||||
DocumentStorageState: TypeAlias = DocExtractionContext | DocIndexingContext
|
||||
|
||||
STATE_TYPE_TO_MODEL: dict[str, type[DocumentStorageState]] = {
|
||||
DocumentBatchStorageStateType.EXTRACTION.value: DocExtractionContext,
|
||||
DocumentBatchStorageStateType.INDEXING.value: DocIndexingContext,
|
||||
}
|
||||
|
||||
|
||||
class BatchStoragePathInfo(BaseModel):
|
||||
cc_pair_id: int
|
||||
index_attempt_id: int
|
||||
batch_num: int
|
||||
|
||||
|
||||
class DocumentBatchStorage(ABC):
|
||||
"""Abstract base class for document batch storage implementations."""
|
||||
|
||||
def __init__(self, cc_pair_id: int, index_attempt_id: int):
|
||||
self.cc_pair_id = cc_pair_id
|
||||
self.index_attempt_id = index_attempt_id
|
||||
self.base_path = f"{self._per_cc_pair_base_path()}/{index_attempt_id}"
|
||||
|
||||
@abstractmethod
|
||||
def store_batch(self, batch_num: int, documents: List[Document]) -> None:
|
||||
"""Store a batch of documents."""
|
||||
|
||||
@abstractmethod
|
||||
def get_batch(self, batch_num: int) -> Optional[List[Document]]:
|
||||
"""Retrieve a batch of documents."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_batch_by_name(self, batch_file_name: str) -> None:
|
||||
"""Delete a specific batch."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_batch_by_num(self, batch_num: int) -> None:
|
||||
"""Delete a specific batch."""
|
||||
|
||||
@abstractmethod
|
||||
def cleanup_all_batches(self) -> None:
|
||||
"""Clean up all batches and state for this index attempt."""
|
||||
|
||||
@abstractmethod
|
||||
def get_all_batches_for_cc_pair(self) -> list[str]:
|
||||
"""Get all IDs of batches stored in the file store."""
|
||||
|
||||
@abstractmethod
|
||||
def update_old_batches_to_new_index_attempt(self, batch_names: list[str]) -> None:
|
||||
"""Update all batches to the new index attempt."""
|
||||
"""
|
||||
This is used when we need to re-issue docprocessing tasks for a new index attempt.
|
||||
We need to update the batch file names to the new index attempt ID.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def extract_path_info(self, path: str) -> BatchStoragePathInfo | None:
|
||||
"""Extract path info from a path."""
|
||||
|
||||
def _serialize_documents(self, documents: list[Document]) -> str:
|
||||
"""Serialize documents to JSON string."""
|
||||
# Use mode='json' to properly serialize datetime and other complex types
|
||||
return json.dumps([doc.model_dump(mode="json") for doc in documents], indent=2)
|
||||
|
||||
def _deserialize_documents(self, data: str) -> list[Document]:
|
||||
"""Deserialize documents from JSON string."""
|
||||
doc_dicts = json.loads(data)
|
||||
return [Document.model_validate(doc_dict) for doc_dict in doc_dicts]
|
||||
|
||||
def _per_cc_pair_base_path(self) -> str:
|
||||
"""Get the base path for the cc pair."""
|
||||
return f"iab/{self.cc_pair_id}"
|
||||
|
||||
|
||||
class FileStoreDocumentBatchStorage(DocumentBatchStorage):
|
||||
"""FileStore-based implementation of document batch storage."""
|
||||
|
||||
def __init__(self, cc_pair_id: int, index_attempt_id: int, file_store: FileStore):
|
||||
super().__init__(cc_pair_id, index_attempt_id)
|
||||
self.file_store = file_store
|
||||
|
||||
def _get_batch_file_name(self, batch_num: int) -> str:
|
||||
"""Generate file name for a document batch."""
|
||||
return f"{self.base_path}/{batch_num}.json"
|
||||
|
||||
def store_batch(self, batch_num: int, documents: list[Document]) -> None:
|
||||
"""Store a batch of documents using FileStore."""
|
||||
file_name = self._get_batch_file_name(batch_num)
|
||||
try:
|
||||
data = self._serialize_documents(documents)
|
||||
content = StringIO(data)
|
||||
|
||||
self.file_store.save_file(
|
||||
file_id=file_name,
|
||||
content=content,
|
||||
display_name=f"Document Batch {batch_num}",
|
||||
file_origin=FileOrigin.OTHER,
|
||||
file_type="application/json",
|
||||
file_metadata={
|
||||
"batch_num": batch_num,
|
||||
"document_count": str(len(documents)),
|
||||
},
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Stored batch {batch_num} with {len(documents)} documents to FileStore as {file_name}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store batch {batch_num}: {e}")
|
||||
raise
|
||||
|
||||
def get_batch(self, batch_num: int) -> list[Document] | None:
|
||||
"""Retrieve a batch of documents from FileStore."""
|
||||
file_name = self._get_batch_file_name(batch_num)
|
||||
try:
|
||||
# Check if file exists
|
||||
if not self.file_store.has_file(
|
||||
file_id=file_name,
|
||||
file_origin=FileOrigin.OTHER,
|
||||
file_type="application/json",
|
||||
):
|
||||
logger.warning(
|
||||
f"Batch {batch_num} not found in FileStore with name {file_name}"
|
||||
)
|
||||
return None
|
||||
|
||||
content_io = self.file_store.read_file(file_name)
|
||||
data = content_io.read().decode("utf-8")
|
||||
|
||||
documents = self._deserialize_documents(data)
|
||||
logger.debug(
|
||||
f"Retrieved batch {batch_num} with {len(documents)} documents from FileStore"
|
||||
)
|
||||
return documents
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve batch {batch_num}: {e}")
|
||||
raise
|
||||
|
||||
def delete_batch_by_name(self, batch_file_name: str) -> None:
|
||||
"""Delete a specific batch from FileStore."""
|
||||
self.file_store.delete_file(batch_file_name)
|
||||
logger.debug(f"Deleted batch {batch_file_name} from FileStore")
|
||||
|
||||
def delete_batch_by_num(self, batch_num: int) -> None:
|
||||
"""Delete a specific batch from FileStore."""
|
||||
batch_file_name = self._get_batch_file_name(batch_num)
|
||||
self.delete_batch_by_name(batch_file_name)
|
||||
logger.debug(f"Deleted batch num {batch_num} {batch_file_name} from FileStore")
|
||||
|
||||
def cleanup_all_batches(self) -> None:
|
||||
"""Clean up all batches for this index attempt."""
|
||||
for batch_file_name in self.get_all_batches_for_cc_pair():
|
||||
self.delete_batch_by_name(batch_file_name)
|
||||
|
||||
def get_all_batches_for_cc_pair(self) -> list[str]:
|
||||
"""Get all IDs of batches stored in the file store for the cc pair
|
||||
this batch store was initialized with.
|
||||
This includes any batches left over from a previous
|
||||
indexing attempt that need to be processed.
|
||||
"""
|
||||
return [
|
||||
file.file_id
|
||||
for file in self.file_store.list_files_by_prefix(
|
||||
self._per_cc_pair_base_path()
|
||||
)
|
||||
]
|
||||
|
||||
def update_old_batches_to_new_index_attempt(self, batch_names: list[str]) -> None:
|
||||
"""Update all batches to the new index attempt."""
|
||||
for batch_file_name in batch_names:
|
||||
path_info = self.extract_path_info(batch_file_name)
|
||||
if path_info is None:
|
||||
continue
|
||||
new_batch_file_name = self._get_batch_file_name(path_info.batch_num)
|
||||
self.file_store.change_file_id(batch_file_name, new_batch_file_name)
|
||||
|
||||
def extract_path_info(self, path: str) -> BatchStoragePathInfo | None:
|
||||
"""Extract path info from a path."""
|
||||
path_spl = path.split("/")
|
||||
# TODO: remove this in a few months, just for backwards compatibility
|
||||
if len(path_spl) == 3:
|
||||
path_spl = ["iab"] + path_spl
|
||||
try:
|
||||
_, cc_pair_id, index_attempt_id, batch_num = path_spl
|
||||
return BatchStoragePathInfo(
|
||||
cc_pair_id=int(cc_pair_id),
|
||||
index_attempt_id=int(index_attempt_id),
|
||||
batch_num=int(batch_num.split(".")[0]), # remove .json
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to extract path info from {path}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_document_batch_storage(
|
||||
cc_pair_id: int, index_attempt_id: int
|
||||
) -> DocumentBatchStorage:
|
||||
"""Factory function to get the configured document batch storage implementation."""
|
||||
# The get_default_file_store will now correctly use S3BackedFileStore
|
||||
# or other configured stores based on environment variables
|
||||
file_store = get_default_file_store()
|
||||
return FileStoreDocumentBatchStorage(cc_pair_id, index_attempt_id, file_store)
|
||||
@@ -22,10 +22,14 @@ from onyx.configs.app_configs import S3_FILE_STORE_BUCKET_NAME
|
||||
from onyx.configs.app_configs import S3_FILE_STORE_PREFIX
|
||||
from onyx.configs.app_configs import S3_VERIFY_SSL
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant_if_none
|
||||
from onyx.db.file_record import delete_filerecord_by_file_id
|
||||
from onyx.db.file_record import get_filerecord_by_file_id
|
||||
from onyx.db.file_record import get_filerecord_by_file_id_optional
|
||||
from onyx.db.file_record import get_filerecord_by_prefix
|
||||
from onyx.db.file_record import upsert_filerecord
|
||||
from onyx.db.models import FileRecord
|
||||
from onyx.db.models import FileRecord as FileStoreModel
|
||||
from onyx.file_store.s3_key_utils import generate_s3_key
|
||||
from onyx.utils.file import FileWithMimeType
|
||||
@@ -129,13 +133,29 @@ class FileStore(ABC):
|
||||
Get the file + parse out the mime type.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def change_file_id(self, old_file_id: str, new_file_id: str) -> None:
|
||||
"""
|
||||
Change the file ID of an existing file.
|
||||
|
||||
Parameters:
|
||||
- old_file_id: Current file ID
|
||||
- new_file_id: New file ID to assign
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def list_files_by_prefix(self, prefix: str) -> list[FileRecord]:
|
||||
"""
|
||||
List all file IDs that start with the given prefix.
|
||||
"""
|
||||
|
||||
|
||||
class S3BackedFileStore(FileStore):
|
||||
"""Isn't necessarily S3, but is any S3-compatible storage (e.g. MinIO)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_session: Session,
|
||||
bucket_name: str,
|
||||
aws_access_key_id: str | None = None,
|
||||
aws_secret_access_key: str | None = None,
|
||||
@@ -144,7 +164,6 @@ class S3BackedFileStore(FileStore):
|
||||
s3_prefix: str | None = None,
|
||||
s3_verify_ssl: bool = True,
|
||||
) -> None:
|
||||
self.db_session = db_session
|
||||
self._s3_client: S3Client | None = None
|
||||
self._bucket_name = bucket_name
|
||||
self._aws_access_key_id = aws_access_key_id
|
||||
@@ -272,10 +291,12 @@ class S3BackedFileStore(FileStore):
|
||||
file_id: str,
|
||||
file_origin: FileOrigin,
|
||||
file_type: str,
|
||||
db_session: Session | None = None,
|
||||
) -> bool:
|
||||
file_record = get_filerecord_by_file_id_optional(
|
||||
file_id=file_id, db_session=self.db_session
|
||||
)
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
file_record = get_filerecord_by_file_id_optional(
|
||||
file_id=file_id, db_session=db_session
|
||||
)
|
||||
return (
|
||||
file_record is not None
|
||||
and file_record.file_origin == file_origin
|
||||
@@ -290,6 +311,7 @@ class S3BackedFileStore(FileStore):
|
||||
file_type: str,
|
||||
file_metadata: dict[str, Any] | None = None,
|
||||
file_id: str | None = None,
|
||||
db_session: Session | None = None,
|
||||
) -> str:
|
||||
if file_id is None:
|
||||
file_id = str(uuid.uuid4())
|
||||
@@ -314,27 +336,33 @@ class S3BackedFileStore(FileStore):
|
||||
ContentType=file_type,
|
||||
)
|
||||
|
||||
# Save metadata to database
|
||||
upsert_filerecord(
|
||||
file_id=file_id,
|
||||
display_name=display_name or file_id,
|
||||
file_origin=file_origin,
|
||||
file_type=file_type,
|
||||
bucket_name=bucket_name,
|
||||
object_key=s3_key,
|
||||
db_session=self.db_session,
|
||||
file_metadata=file_metadata,
|
||||
)
|
||||
self.db_session.commit()
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
# Save metadata to database
|
||||
upsert_filerecord(
|
||||
file_id=file_id,
|
||||
display_name=display_name or file_id,
|
||||
file_origin=file_origin,
|
||||
file_type=file_type,
|
||||
bucket_name=bucket_name,
|
||||
object_key=s3_key,
|
||||
db_session=db_session,
|
||||
file_metadata=file_metadata,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
return file_id
|
||||
|
||||
def read_file(
|
||||
self, file_id: str, mode: str | None = None, use_tempfile: bool = False
|
||||
self,
|
||||
file_id: str,
|
||||
mode: str | None = None,
|
||||
use_tempfile: bool = False,
|
||||
db_session: Session | None = None,
|
||||
) -> IO[bytes]:
|
||||
file_record = get_filerecord_by_file_id(
|
||||
file_id=file_id, db_session=self.db_session
|
||||
)
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
file_record = get_filerecord_by_file_id(
|
||||
file_id=file_id, db_session=db_session
|
||||
)
|
||||
|
||||
s3_client = self._get_s3_client()
|
||||
try:
|
||||
@@ -356,32 +384,107 @@ class S3BackedFileStore(FileStore):
|
||||
else:
|
||||
return BytesIO(file_content)
|
||||
|
||||
def read_file_record(self, file_id: str) -> FileStoreModel:
|
||||
file_record = get_filerecord_by_file_id(
|
||||
file_id=file_id, db_session=self.db_session
|
||||
)
|
||||
def read_file_record(
|
||||
self, file_id: str, db_session: Session | None = None
|
||||
) -> FileStoreModel:
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
file_record = get_filerecord_by_file_id(
|
||||
file_id=file_id, db_session=db_session
|
||||
)
|
||||
return file_record
|
||||
|
||||
def delete_file(self, file_id: str) -> None:
|
||||
try:
|
||||
file_record = get_filerecord_by_file_id(
|
||||
file_id=file_id, db_session=self.db_session
|
||||
)
|
||||
def delete_file(self, file_id: str, db_session: Session | None = None) -> None:
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
try:
|
||||
|
||||
# Delete from external storage
|
||||
s3_client = self._get_s3_client()
|
||||
s3_client.delete_object(
|
||||
Bucket=file_record.bucket_name, Key=file_record.object_key
|
||||
)
|
||||
file_record = get_filerecord_by_file_id(
|
||||
file_id=file_id, db_session=db_session
|
||||
)
|
||||
if not file_record.bucket_name:
|
||||
logger.error(
|
||||
f"File record {file_id} with key {file_record.object_key} "
|
||||
"has no bucket name, cannot delete from filestore"
|
||||
)
|
||||
delete_filerecord_by_file_id(file_id=file_id, db_session=db_session)
|
||||
db_session.commit()
|
||||
return
|
||||
|
||||
# Delete metadata from database
|
||||
delete_filerecord_by_file_id(file_id=file_id, db_session=self.db_session)
|
||||
# Delete from external storage
|
||||
s3_client = self._get_s3_client()
|
||||
s3_client.delete_object(
|
||||
Bucket=file_record.bucket_name, Key=file_record.object_key
|
||||
)
|
||||
|
||||
self.db_session.commit()
|
||||
# Delete metadata from database
|
||||
delete_filerecord_by_file_id(file_id=file_id, db_session=db_session)
|
||||
|
||||
except Exception:
|
||||
self.db_session.rollback()
|
||||
raise
|
||||
db_session.commit()
|
||||
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
raise
|
||||
|
||||
def change_file_id(
|
||||
self, old_file_id: str, new_file_id: str, db_session: Session | None = None
|
||||
) -> None:
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
try:
|
||||
# Get the existing file record
|
||||
old_file_record = get_filerecord_by_file_id(
|
||||
file_id=old_file_id, db_session=db_session
|
||||
)
|
||||
|
||||
# Generate new S3 key for the new file ID
|
||||
new_s3_key = self._get_s3_key(new_file_id)
|
||||
|
||||
# Copy S3 object to new key
|
||||
s3_client = self._get_s3_client()
|
||||
bucket_name = self._get_bucket_name()
|
||||
|
||||
copy_source = (
|
||||
f"{old_file_record.bucket_name}/{old_file_record.object_key}"
|
||||
)
|
||||
|
||||
s3_client.copy_object(
|
||||
CopySource=copy_source,
|
||||
Bucket=bucket_name,
|
||||
Key=new_s3_key,
|
||||
MetadataDirective="COPY",
|
||||
)
|
||||
|
||||
# Create new file record with new file_id
|
||||
# Cast file_metadata to the expected type
|
||||
file_metadata = cast(
|
||||
dict[Any, Any] | None, old_file_record.file_metadata
|
||||
)
|
||||
|
||||
upsert_filerecord(
|
||||
file_id=new_file_id,
|
||||
display_name=old_file_record.display_name,
|
||||
file_origin=old_file_record.file_origin,
|
||||
file_type=old_file_record.file_type,
|
||||
bucket_name=bucket_name,
|
||||
object_key=new_s3_key,
|
||||
db_session=db_session,
|
||||
file_metadata=file_metadata,
|
||||
)
|
||||
|
||||
# Delete old S3 object
|
||||
s3_client.delete_object(
|
||||
Bucket=old_file_record.bucket_name, Key=old_file_record.object_key
|
||||
)
|
||||
|
||||
# Delete old file record
|
||||
delete_filerecord_by_file_id(file_id=old_file_id, db_session=db_session)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
logger.exception(
|
||||
f"Failed to change file ID from {old_file_id} to {new_file_id}: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
def get_file_with_mime_type(self, filename: str) -> FileWithMimeType | None:
|
||||
mime_type: str = "application/octet-stream"
|
||||
@@ -395,8 +498,18 @@ class S3BackedFileStore(FileStore):
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def list_files_by_prefix(self, prefix: str) -> list[FileRecord]:
|
||||
"""
|
||||
List all file IDs that start with the given prefix.
|
||||
"""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
file_records = get_filerecord_by_prefix(
|
||||
prefix=prefix, db_session=db_session
|
||||
)
|
||||
return file_records
|
||||
|
||||
def get_s3_file_store(db_session: Session) -> S3BackedFileStore:
|
||||
|
||||
def get_s3_file_store() -> S3BackedFileStore:
|
||||
"""
|
||||
Returns the S3 file store implementation.
|
||||
"""
|
||||
@@ -409,7 +522,6 @@ def get_s3_file_store(db_session: Session) -> S3BackedFileStore:
|
||||
)
|
||||
|
||||
return S3BackedFileStore(
|
||||
db_session=db_session,
|
||||
bucket_name=bucket_name,
|
||||
aws_access_key_id=S3_AWS_ACCESS_KEY_ID,
|
||||
aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY,
|
||||
@@ -420,7 +532,7 @@ def get_s3_file_store(db_session: Session) -> S3BackedFileStore:
|
||||
)
|
||||
|
||||
|
||||
def get_default_file_store(db_session: Session) -> FileStore:
|
||||
def get_default_file_store() -> FileStore:
|
||||
"""
|
||||
Returns the configured file store implementation.
|
||||
|
||||
@@ -445,4 +557,4 @@ def get_default_file_store(db_session: Session) -> FileStore:
|
||||
Other S3-compatible storage (Digital Ocean, Linode, etc.):
|
||||
- Same as MinIO, but set appropriate S3_ENDPOINT_URL
|
||||
"""
|
||||
return get_s3_file_store(db_session)
|
||||
return get_s3_file_store()
|
||||
|
||||
@@ -8,7 +8,6 @@ import requests
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.models import UserFolder
|
||||
@@ -49,28 +48,23 @@ def store_user_file_plaintext(user_file_id: int, plaintext_content: str) -> bool
|
||||
|
||||
# Use a separate session to avoid committing the caller's transaction
|
||||
try:
|
||||
with get_session_with_current_tenant() as file_store_session:
|
||||
file_store = get_default_file_store(file_store_session)
|
||||
file_content = BytesIO(plaintext_content.encode("utf-8"))
|
||||
file_store.save_file(
|
||||
content=file_content,
|
||||
display_name=f"Plaintext for user file {user_file_id}",
|
||||
file_origin=FileOrigin.PLAINTEXT_CACHE,
|
||||
file_type="text/plain",
|
||||
file_id=plaintext_file_name,
|
||||
)
|
||||
return True
|
||||
file_store = get_default_file_store()
|
||||
file_content = BytesIO(plaintext_content.encode("utf-8"))
|
||||
file_store.save_file(
|
||||
content=file_content,
|
||||
display_name=f"Plaintext for user file {user_file_id}",
|
||||
file_origin=FileOrigin.PLAINTEXT_CACHE,
|
||||
file_type="text/plain",
|
||||
file_id=plaintext_file_name,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store plaintext for user file {user_file_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def load_chat_file(
|
||||
file_descriptor: FileDescriptor, db_session: Session
|
||||
) -> InMemoryChatFile:
|
||||
file_io = get_default_file_store(db_session).read_file(
|
||||
file_descriptor["id"], mode="b"
|
||||
)
|
||||
def load_chat_file(file_descriptor: FileDescriptor) -> InMemoryChatFile:
|
||||
file_io = get_default_file_store().read_file(file_descriptor["id"], mode="b")
|
||||
return InMemoryChatFile(
|
||||
file_id=file_descriptor["id"],
|
||||
content=file_io.read(),
|
||||
@@ -82,7 +76,6 @@ def load_chat_file(
|
||||
def load_all_chat_files(
|
||||
chat_messages: list[ChatMessage],
|
||||
file_descriptors: list[FileDescriptor],
|
||||
db_session: Session,
|
||||
) -> list[InMemoryChatFile]:
|
||||
file_descriptors_for_history: list[FileDescriptor] = []
|
||||
for chat_message in chat_messages:
|
||||
@@ -93,7 +86,7 @@ def load_all_chat_files(
|
||||
list[InMemoryChatFile],
|
||||
run_functions_tuples_in_parallel(
|
||||
[
|
||||
(load_chat_file, (file, db_session))
|
||||
(load_chat_file, (file,))
|
||||
for file in file_descriptors + file_descriptors_for_history
|
||||
]
|
||||
),
|
||||
@@ -117,7 +110,7 @@ def load_user_file(file_id: int, db_session: Session) -> InMemoryChatFile:
|
||||
raise ValueError(f"User file with id {file_id} not found")
|
||||
|
||||
# Get the file record to determine the appropriate chat file type
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
file_record = file_store.read_file_record(user_file.file_id)
|
||||
|
||||
# Determine appropriate chat file type based on the original file's MIME type
|
||||
@@ -263,34 +256,29 @@ def get_user_files_as_user(
|
||||
|
||||
|
||||
def save_file_from_url(url: str) -> str:
|
||||
"""NOTE: using multiple sessions here, since this is often called
|
||||
using multithreading. In practice, sharing a session has resulted in
|
||||
weird errors."""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
file_io = BytesIO(response.content)
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_id = file_store.save_file(
|
||||
content=file_io,
|
||||
display_name="GeneratedImage",
|
||||
file_origin=FileOrigin.CHAT_IMAGE_GEN,
|
||||
file_type="image/png;base64",
|
||||
)
|
||||
return file_id
|
||||
file_io = BytesIO(response.content)
|
||||
file_store = get_default_file_store()
|
||||
file_id = file_store.save_file(
|
||||
content=file_io,
|
||||
display_name="GeneratedImage",
|
||||
file_origin=FileOrigin.CHAT_IMAGE_GEN,
|
||||
file_type="image/png;base64",
|
||||
)
|
||||
return file_id
|
||||
|
||||
|
||||
def save_file_from_base64(base64_string: str) -> str:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_id = file_store.save_file(
|
||||
content=BytesIO(base64.b64decode(base64_string)),
|
||||
display_name="GeneratedImage",
|
||||
file_origin=FileOrigin.CHAT_IMAGE_GEN,
|
||||
file_type=get_image_type(base64_string),
|
||||
)
|
||||
return file_id
|
||||
file_store = get_default_file_store()
|
||||
file_id = file_store.save_file(
|
||||
content=BytesIO(base64.b64decode(base64_string)),
|
||||
display_name="GeneratedImage",
|
||||
file_origin=FileOrigin.CHAT_IMAGE_GEN,
|
||||
file_type=get_image_type(base64_string),
|
||||
)
|
||||
return file_id
|
||||
|
||||
|
||||
def save_file(
|
||||
|
||||
@@ -4,6 +4,13 @@ from typing import Any
|
||||
import httpx
|
||||
|
||||
|
||||
def make_default_kwargs() -> dict[str, Any]:
|
||||
return {
|
||||
"http2": True,
|
||||
"limits": httpx.Limits(),
|
||||
}
|
||||
|
||||
|
||||
class HttpxPool:
|
||||
"""Class to manage a global httpx Client instance"""
|
||||
|
||||
@@ -11,10 +18,6 @@ class HttpxPool:
|
||||
_lock: threading.Lock = threading.Lock()
|
||||
|
||||
# Default parameters for creation
|
||||
DEFAULT_KWARGS = {
|
||||
"http2": True,
|
||||
"limits": lambda: httpx.Limits(),
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
@@ -22,7 +25,7 @@ class HttpxPool:
|
||||
@classmethod
|
||||
def _init_client(cls, **kwargs: Any) -> httpx.Client:
|
||||
"""Private helper method to create and return an httpx.Client."""
|
||||
merged_kwargs = {**cls.DEFAULT_KWARGS, **kwargs}
|
||||
merged_kwargs = {**(make_default_kwargs()), **kwargs}
|
||||
return httpx.Client(**merged_kwargs)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -4,6 +4,7 @@ from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorStopSignal
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
@@ -261,6 +262,11 @@ def embed_chunks_with_failure_handling(
|
||||
),
|
||||
[],
|
||||
)
|
||||
except ConnectorStopSignal as e:
|
||||
logger.warning(
|
||||
"Connector stop signal detected in embed_chunks_with_failure_handling"
|
||||
)
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("Failed to embed chunk batch. Trying individual docs.")
|
||||
# wait a couple seconds to let any rate limits or temporary issues resolve
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from typing import Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -23,6 +22,7 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_experts_stores_representations,
|
||||
)
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorStopSignal
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import ImageSection
|
||||
@@ -41,7 +41,6 @@ from onyx.db.document import update_docs_updated_at__no_commit
|
||||
from onyx.db.document import upsert_document_by_connector_credential_pair
|
||||
from onyx.db.document import upsert_documents
|
||||
from onyx.db.document_set import fetch_document_sets_for_documents
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import Document as DBDocument
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
@@ -62,7 +61,6 @@ from onyx.file_store.utils import store_user_file_plaintext
|
||||
from onyx.indexing.chunker import Chunker
|
||||
from onyx.indexing.embedder import embed_chunks_with_failure_handling
|
||||
from onyx.indexing.embedder import IndexingEmbedder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
@@ -290,6 +288,10 @@ def index_doc_batch_with_handler(
|
||||
enable_contextual_rag=enable_contextual_rag,
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
except ConnectorStopSignal as e:
|
||||
logger.warning("Connector stop signal detected in index_doc_batch_with_handler")
|
||||
raise e
|
||||
except Exception as e:
|
||||
# don't log the batch directly, it's too much text
|
||||
document_ids = [doc.id for doc in document_batch]
|
||||
@@ -496,36 +498,33 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
|
||||
|
||||
# Try to get image summary
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
|
||||
file_record = file_store.read_file_record(
|
||||
file_record = file_store.read_file_record(
|
||||
file_id=section.image_file_id
|
||||
)
|
||||
if not file_record:
|
||||
logger.warning(
|
||||
f"Image file {section.image_file_id} not found in FileStore"
|
||||
)
|
||||
|
||||
processed_section.text = "[Image could not be processed]"
|
||||
else:
|
||||
# Get the image data
|
||||
image_data_io = file_store.read_file(
|
||||
file_id=section.image_file_id
|
||||
)
|
||||
if not file_record:
|
||||
logger.warning(
|
||||
f"Image file {section.image_file_id} not found in FileStore"
|
||||
)
|
||||
image_data = image_data_io.read()
|
||||
summary = summarize_image_with_error_handling(
|
||||
llm=llm,
|
||||
image_data=image_data,
|
||||
context_name=file_record.display_name or "Image",
|
||||
)
|
||||
|
||||
processed_section.text = "[Image could not be processed]"
|
||||
if summary:
|
||||
processed_section.text = summary
|
||||
else:
|
||||
# Get the image data
|
||||
image_data_io = file_store.read_file(
|
||||
file_id=section.image_file_id
|
||||
)
|
||||
image_data = image_data_io.read()
|
||||
summary = summarize_image_with_error_handling(
|
||||
llm=llm,
|
||||
image_data=image_data,
|
||||
context_name=file_record.display_name or "Image",
|
||||
)
|
||||
|
||||
if summary:
|
||||
processed_section.text = summary
|
||||
else:
|
||||
processed_section.text = (
|
||||
"[Image could not be summarized]"
|
||||
)
|
||||
processed_section.text = "[Image could not be summarized]"
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image section: {e}")
|
||||
processed_section.text = "[Error processing image]"
|
||||
@@ -832,7 +831,7 @@ def index_doc_batch(
|
||||
)
|
||||
)
|
||||
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int | None] = {
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int] = {
|
||||
document_id: chunk_count
|
||||
for document_id, chunk_count in fetch_chunk_counts_for_documents(
|
||||
document_ids=updatable_ids,
|
||||
@@ -1029,7 +1028,7 @@ def index_doc_batch(
|
||||
db_session.commit()
|
||||
|
||||
result = IndexingPipelineResult(
|
||||
new_docs=len([r for r in insertion_records if r.already_existed is False]),
|
||||
new_docs=len([r for r in insertion_records if not r.already_existed]),
|
||||
total_docs=len(filtered_documents),
|
||||
total_chunks=len(access_aware_chunks),
|
||||
failures=vector_db_write_failures + embedding_failures,
|
||||
@@ -1038,8 +1037,10 @@ def index_doc_batch(
|
||||
return result
|
||||
|
||||
|
||||
def build_indexing_pipeline(
|
||||
def run_indexing_pipeline(
|
||||
*,
|
||||
document_batch: list[Document],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
embedder: IndexingEmbedder,
|
||||
information_content_classification_model: InformationContentClassificationModel,
|
||||
document_index: DocumentIndex,
|
||||
@@ -1047,8 +1048,7 @@ def build_indexing_pipeline(
|
||||
tenant_id: str,
|
||||
chunker: Chunker | None = None,
|
||||
ignore_time_skip: bool = False,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> IndexingPipelineProtocol:
|
||||
) -> IndexingPipelineResult:
|
||||
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
|
||||
all_search_settings = get_active_search_settings(db_session)
|
||||
if (
|
||||
@@ -1078,15 +1078,15 @@ def build_indexing_pipeline(
|
||||
enable_large_chunks=multipass_config.enable_large_chunks,
|
||||
enable_contextual_rag=enable_contextual_rag,
|
||||
# after every doc, update status in case there are a bunch of really long docs
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
return partial(
|
||||
index_doc_batch_with_handler,
|
||||
return index_doc_batch_with_handler(
|
||||
chunker=chunker,
|
||||
embedder=embedder,
|
||||
information_content_classification_model=information_content_classification_model,
|
||||
document_index=document_index,
|
||||
document_batch=document_batch,
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
ignore_time_skip=ignore_time_skip,
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@@ -259,7 +259,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
setup_onyx(db_session, POSTGRES_DEFAULT_SCHEMA)
|
||||
# set up the file store (e.g. create bucket if needed). On multi-tenant,
|
||||
# this is done via IaC
|
||||
get_default_file_store(db_session).initialize()
|
||||
get_default_file_store().initialize()
|
||||
else:
|
||||
setup_multitenant_onyx()
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ from onyx.configs.model_configs import (
|
||||
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES,
|
||||
)
|
||||
from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from onyx.connectors.models import ConnectorStopSignal
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.natural_language_processing.exceptions import (
|
||||
@@ -198,7 +199,9 @@ class EmbeddingModel:
|
||||
) -> tuple[int, list[Embedding]]:
|
||||
if self.callback:
|
||||
if self.callback.should_stop():
|
||||
raise RuntimeError("_batch_encode_texts detected stop signal")
|
||||
raise ConnectorStopSignal(
|
||||
"_batch_encode_texts detected stop signal"
|
||||
)
|
||||
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
import time
|
||||
|
||||
import redis
|
||||
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
|
||||
@@ -12,68 +9,33 @@ from onyx.redis.redis_connector_stop import RedisConnectorStop
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
|
||||
|
||||
# TODO: reduce dependence on redis
|
||||
class RedisConnector:
|
||||
"""Composes several classes to simplify interacting with a connector and its
|
||||
associated background tasks / associated redis interactions."""
|
||||
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
def __init__(self, tenant_id: str, cc_pair_id: int) -> None:
|
||||
"""id: a connector credential pair id"""
|
||||
|
||||
self.tenant_id: str = tenant_id
|
||||
self.id: int = id
|
||||
self.cc_pair_id: int = cc_pair_id
|
||||
self.redis: redis.Redis = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
self.stop = RedisConnectorStop(tenant_id, id, self.redis)
|
||||
self.prune = RedisConnectorPrune(tenant_id, id, self.redis)
|
||||
self.delete = RedisConnectorDelete(tenant_id, id, self.redis)
|
||||
self.permissions = RedisConnectorPermissionSync(tenant_id, id, self.redis)
|
||||
self.stop = RedisConnectorStop(tenant_id, cc_pair_id, self.redis)
|
||||
self.prune = RedisConnectorPrune(tenant_id, cc_pair_id, self.redis)
|
||||
self.delete = RedisConnectorDelete(tenant_id, cc_pair_id, self.redis)
|
||||
self.permissions = RedisConnectorPermissionSync(
|
||||
tenant_id, cc_pair_id, self.redis
|
||||
)
|
||||
self.external_group_sync = RedisConnectorExternalGroupSync(
|
||||
tenant_id, id, self.redis
|
||||
tenant_id, cc_pair_id, self.redis
|
||||
)
|
||||
|
||||
def new_index(self, search_settings_id: int) -> RedisConnectorIndex:
|
||||
return RedisConnectorIndex(
|
||||
self.tenant_id, self.id, search_settings_id, self.redis
|
||||
self.tenant_id, self.cc_pair_id, search_settings_id, self.redis
|
||||
)
|
||||
|
||||
def wait_for_indexing_termination(
|
||||
self,
|
||||
search_settings_list: list[SearchSettings],
|
||||
timeout: float = 15.0,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if all indexing for the given redis connector is finished within the given timeout.
|
||||
Returns False if the timeout is exceeded
|
||||
|
||||
This check does not guarantee that current indexings being terminated
|
||||
won't get restarted midflight
|
||||
"""
|
||||
|
||||
finished = False
|
||||
|
||||
start = time.monotonic()
|
||||
|
||||
while True:
|
||||
still_indexing = False
|
||||
for search_settings in search_settings_list:
|
||||
redis_connector_index = self.new_index(search_settings.id)
|
||||
if redis_connector_index.fenced:
|
||||
still_indexing = True
|
||||
break
|
||||
|
||||
if not still_indexing:
|
||||
finished = True
|
||||
break
|
||||
|
||||
now = time.monotonic()
|
||||
if now - start > timeout:
|
||||
break
|
||||
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
return finished
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_fence_key(key: str) -> str | None:
|
||||
"""
|
||||
|
||||
@@ -58,10 +58,7 @@ class RedisConnectorDelete:
|
||||
|
||||
@property
|
||||
def fenced(self) -> bool:
|
||||
if self.redis.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(self.redis.exists(self.fence_key))
|
||||
|
||||
@property
|
||||
def payload(self) -> RedisConnectorDeletePayload | None:
|
||||
@@ -93,10 +90,7 @@ class RedisConnectorDelete:
|
||||
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
|
||||
|
||||
def active(self) -> bool:
|
||||
if self.redis.exists(self.active_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(self.redis.exists(self.active_key))
|
||||
|
||||
def _generate_task_id(self) -> str:
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
|
||||
@@ -88,10 +88,7 @@ class RedisConnectorPermissionSync:
|
||||
|
||||
@property
|
||||
def fenced(self) -> bool:
|
||||
if self.redis.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(self.redis.exists(self.fence_key))
|
||||
|
||||
@property
|
||||
def payload(self) -> RedisConnectorPermissionSyncPayload | None:
|
||||
@@ -128,10 +125,7 @@ class RedisConnectorPermissionSync:
|
||||
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
|
||||
|
||||
def active(self) -> bool:
|
||||
if self.redis.exists(self.active_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(self.redis.exists(self.active_key))
|
||||
|
||||
@property
|
||||
def generator_complete(self) -> int | None:
|
||||
|
||||
@@ -84,10 +84,7 @@ class RedisConnectorExternalGroupSync:
|
||||
|
||||
@property
|
||||
def fenced(self) -> bool:
|
||||
if self.redis.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(self.redis.exists(self.fence_key))
|
||||
|
||||
@property
|
||||
def payload(self) -> RedisConnectorExternalGroupSyncPayload | None:
|
||||
@@ -125,10 +122,7 @@ class RedisConnectorExternalGroupSync:
|
||||
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
|
||||
|
||||
def active(self) -> bool:
|
||||
if self.redis.exists(self.active_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(self.redis.exists(self.active_key))
|
||||
|
||||
@property
|
||||
def generator_complete(self) -> int | None:
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
|
||||
|
||||
class RedisConnectorIndexPayload(BaseModel):
|
||||
@@ -31,7 +28,10 @@ class RedisConnectorIndex:
|
||||
PREFIX + "_generator_complete"
|
||||
) # connectorindexing_generator_complete
|
||||
|
||||
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
|
||||
GENERATOR_LOCK_PREFIX = "da_lock:indexing:docfetching"
|
||||
FILESTORE_LOCK_PREFIX = "da_lock:indexing:filestore"
|
||||
DB_LOCK_PREFIX = "da_lock:indexing:db"
|
||||
PER_WORKER_LOCK_PREFIX = "da_lock:indexing:per_worker"
|
||||
|
||||
TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate
|
||||
TERMINATE_TTL = 600
|
||||
@@ -53,130 +53,34 @@ class RedisConnectorIndex:
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
redis: redis.Redis,
|
||||
) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
self.id = id
|
||||
self.cc_pair_id = cc_pair_id
|
||||
self.search_settings_id = search_settings_id
|
||||
self.redis = redis
|
||||
|
||||
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}/{search_settings_id}"
|
||||
self.generator_progress_key = (
|
||||
f"{self.GENERATOR_PROGRESS_PREFIX}_{id}/{search_settings_id}"
|
||||
)
|
||||
self.generator_complete_key = (
|
||||
f"{self.GENERATOR_COMPLETE_PREFIX}_{id}/{search_settings_id}"
|
||||
f"{self.GENERATOR_COMPLETE_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
)
|
||||
self.filestore_lock_key = (
|
||||
f"{self.FILESTORE_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
)
|
||||
self.generator_lock_key = (
|
||||
f"{self.GENERATOR_LOCK_PREFIX}_{id}/{search_settings_id}"
|
||||
f"{self.GENERATOR_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
)
|
||||
self.terminate_key = f"{self.TERMINATE_PREFIX}_{id}/{search_settings_id}"
|
||||
self.watchdog_key = f"{self.WATCHDOG_PREFIX}_{id}/{search_settings_id}"
|
||||
|
||||
self.active_key = f"{self.ACTIVE_PREFIX}_{id}/{search_settings_id}"
|
||||
self.connector_active_key = (
|
||||
f"{self.CONNECTOR_ACTIVE_PREFIX}_{id}/{search_settings_id}"
|
||||
self.per_worker_lock_key = (
|
||||
f"{self.PER_WORKER_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
)
|
||||
self.db_lock_key = f"{self.DB_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
self.terminate_key = (
|
||||
f"{self.TERMINATE_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str:
|
||||
return f"{cls.FENCE_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
|
||||
def generate_generator_task_id(self) -> str:
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "connectorindexing+generator_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
|
||||
return f"{self.GENERATOR_TASK_PREFIX}_{self.id}/{self.search_settings_id}_{uuid4()}"
|
||||
|
||||
@property
|
||||
def fenced(self) -> bool:
|
||||
return bool(self.redis.exists(self.fence_key))
|
||||
|
||||
@property
|
||||
def payload(self) -> RedisConnectorIndexPayload | None:
|
||||
# read related data and evaluate/print task progress
|
||||
fence_bytes = cast(Any, self.redis.get(self.fence_key))
|
||||
if fence_bytes is None:
|
||||
return None
|
||||
|
||||
fence_str = fence_bytes.decode("utf-8")
|
||||
return RedisConnectorIndexPayload.model_validate_json(cast(str, fence_str))
|
||||
|
||||
def set_fence(
|
||||
self,
|
||||
payload: RedisConnectorIndexPayload | None,
|
||||
) -> None:
|
||||
if not payload:
|
||||
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
self.redis.delete(self.fence_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.fence_key, payload.model_dump_json())
|
||||
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
|
||||
def terminating(self, celery_task_id: str) -> bool:
|
||||
return bool(self.redis.exists(f"{self.terminate_key}_{celery_task_id}"))
|
||||
|
||||
def set_terminate(self, celery_task_id: str) -> None:
|
||||
"""This sets a signal. It does not block!"""
|
||||
# We shouldn't need very long to terminate the spawned task.
|
||||
# 10 minute TTL is good.
|
||||
self.redis.set(
|
||||
f"{self.terminate_key}_{celery_task_id}", 0, ex=self.TERMINATE_TTL
|
||||
)
|
||||
|
||||
def set_watchdog(self, value: bool) -> None:
|
||||
"""Signal the state of the watchdog."""
|
||||
if not value:
|
||||
self.redis.delete(self.watchdog_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.watchdog_key, 0, ex=self.WATCHDOG_TTL)
|
||||
|
||||
def watchdog_signaled(self) -> bool:
|
||||
"""Check the state of the watchdog."""
|
||||
return bool(self.redis.exists(self.watchdog_key))
|
||||
|
||||
def set_active(self) -> None:
|
||||
"""This sets a signal to keep the indexing flow from getting cleaned up within
|
||||
the expiration time.
|
||||
|
||||
The slack in timing is needed to avoid race conditions where simply checking
|
||||
the celery queue and task status could result in race conditions."""
|
||||
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
|
||||
|
||||
def active(self) -> bool:
|
||||
return bool(self.redis.exists(self.active_key))
|
||||
|
||||
def set_connector_active(self) -> None:
|
||||
"""This sets a signal to keep the indexing flow from getting cleaned up within
|
||||
the expiration time.
|
||||
|
||||
The slack in timing is needed to avoid race conditions where simply checking
|
||||
the celery queue and task status could result in race conditions."""
|
||||
self.redis.set(self.connector_active_key, 0, ex=self.CONNECTOR_ACTIVE_TTL)
|
||||
|
||||
def connector_active(self) -> bool:
|
||||
if self.redis.exists(self.connector_active_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def connector_active_ttl(self) -> int:
|
||||
"""Refer to https://redis.io/docs/latest/commands/ttl/
|
||||
|
||||
-2 means the key does not exist
|
||||
-1 means the key exists but has no associated expire
|
||||
Otherwise, returns the actual TTL of the key
|
||||
"""
|
||||
ttl = cast(int, self.redis.ttl(self.connector_active_key))
|
||||
return ttl
|
||||
|
||||
def generator_locked(self) -> bool:
|
||||
return bool(self.redis.exists(self.generator_lock_key))
|
||||
def lock_key_by_batch(self, batch_n: int) -> str:
|
||||
return f"{self.per_worker_lock_key}/{batch_n}"
|
||||
|
||||
def set_generator_complete(self, payload: int | None) -> None:
|
||||
if not payload:
|
||||
@@ -186,21 +90,9 @@ class RedisConnectorIndex:
|
||||
self.redis.set(self.generator_complete_key, payload)
|
||||
|
||||
def generator_clear(self) -> None:
|
||||
self.redis.delete(self.generator_progress_key)
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
|
||||
def get_progress(self) -> int | None:
|
||||
"""Returns None if the key doesn't exist. The"""
|
||||
# TODO: move into fence?
|
||||
bytes = self.redis.get(self.generator_progress_key)
|
||||
if bytes is None:
|
||||
return None
|
||||
|
||||
progress = int(cast(int, bytes))
|
||||
return progress
|
||||
|
||||
def get_completion(self) -> int | None:
|
||||
# TODO: move into fence?
|
||||
bytes = self.redis.get(self.generator_complete_key)
|
||||
if bytes is None:
|
||||
return None
|
||||
@@ -209,24 +101,22 @@ class RedisConnectorIndex:
|
||||
return status
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
self.redis.delete(self.connector_active_key)
|
||||
self.redis.delete(self.active_key)
|
||||
self.redis.delete(self.filestore_lock_key)
|
||||
self.redis.delete(self.db_lock_key)
|
||||
self.redis.delete(self.generator_lock_key)
|
||||
self.redis.delete(self.generator_progress_key)
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
self.redis.delete(self.fence_key)
|
||||
|
||||
@staticmethod
|
||||
def reset_all(r: redis.Redis) -> None:
|
||||
"""Deletes all redis values for all connectors"""
|
||||
# leaving these temporarily for backwards compat, TODO: remove
|
||||
for key in r.scan_iter(RedisConnectorIndex.CONNECTOR_ACTIVE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndex.ACTIVE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_LOCK_PREFIX + "*"):
|
||||
for key in r.scan_iter(RedisConnectorIndex.FILESTORE_LOCK_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_COMPLETE_PREFIX + "*"):
|
||||
|
||||
@@ -92,10 +92,7 @@ class RedisConnectorPrune:
|
||||
|
||||
@property
|
||||
def fenced(self) -> bool:
|
||||
if self.redis.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(self.redis.exists(self.fence_key))
|
||||
|
||||
@property
|
||||
def payload(self) -> RedisConnectorPrunePayload | None:
|
||||
@@ -130,10 +127,7 @@ class RedisConnectorPrune:
|
||||
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
|
||||
|
||||
def active(self) -> bool:
|
||||
if self.redis.exists(self.active_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(self.redis.exists(self.active_key))
|
||||
|
||||
@property
|
||||
def generator_complete(self) -> int | None:
|
||||
|
||||
@@ -23,10 +23,7 @@ class RedisConnectorStop:
|
||||
|
||||
@property
|
||||
def fenced(self) -> bool:
|
||||
if self.redis.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(self.redis.exists(self.fence_key))
|
||||
|
||||
def set_fence(self, value: bool) -> None:
|
||||
if not value:
|
||||
@@ -37,10 +34,7 @@ class RedisConnectorStop:
|
||||
|
||||
@property
|
||||
def timed_out(self) -> bool:
|
||||
if self.redis.exists(self.timeout_key):
|
||||
return False
|
||||
|
||||
return True
|
||||
return not bool(self.redis.exists(self.timeout_key))
|
||||
|
||||
def set_timeout(self) -> None:
|
||||
"""After calling this, call timed_out to determine if the timeout has been
|
||||
|
||||
@@ -6,6 +6,7 @@ from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -34,15 +35,15 @@ from onyx.db.document import get_documents_for_cc_pair
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.index_attempt import count_index_attempt_errors_for_cc_pair
|
||||
from onyx.db.index_attempt import count_index_attempts_for_connector
|
||||
from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair
|
||||
from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
|
||||
from onyx.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import User
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_utils import get_deletion_attempt_snapshot
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -139,11 +140,6 @@ def get_cc_pair_full_info(
|
||||
only_finished=False,
|
||||
)
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
|
||||
return CCPairFullInfo.from_models(
|
||||
cc_pair_model=cc_pair,
|
||||
number_of_index_attempts=count_index_attempts_for_connector(
|
||||
@@ -159,7 +155,9 @@ def get_cc_pair_full_info(
|
||||
),
|
||||
num_docs_indexed=documents_indexed,
|
||||
is_editable_for_current_user=is_editable_for_current_user,
|
||||
indexing=redis_connector_index.fenced,
|
||||
indexing=bool(
|
||||
latest_attempt and latest_attempt.status == IndexingStatus.IN_PROGRESS
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -195,31 +193,35 @@ def update_cc_pair_status(
|
||||
if status_update_request.status == ConnectorCredentialPairStatus.PAUSED:
|
||||
redis_connector.stop.set_fence(True)
|
||||
|
||||
search_settings_list: list[SearchSettings] = get_active_search_settings_list(
|
||||
db_session
|
||||
# Request cancellation for any active indexing attempts for this cc_pair
|
||||
active_attempts = (
|
||||
db_session.execute(
|
||||
select(IndexAttempt).where(
|
||||
IndexAttempt.connector_credential_pair_id == cc_pair_id,
|
||||
IndexAttempt.status.in_(
|
||||
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
|
||||
),
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
while True:
|
||||
for search_settings in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
if not redis_connector_index.fenced:
|
||||
continue
|
||||
|
||||
index_payload = redis_connector_index.payload
|
||||
if not index_payload:
|
||||
continue
|
||||
|
||||
if not index_payload.celery_task_id:
|
||||
continue
|
||||
|
||||
for attempt in active_attempts:
|
||||
try:
|
||||
IndexingCoordination.request_cancellation(db_session, attempt.id)
|
||||
# Revoke the task to prevent it from running
|
||||
client_app.control.revoke(index_payload.celery_task_id)
|
||||
if attempt.celery_task_id:
|
||||
client_app.control.revoke(attempt.celery_task_id)
|
||||
logger.info(
|
||||
f"Requested cancellation for active indexing attempt {attempt.id} "
|
||||
f"due to connector pause: cc_pair={cc_pair_id}"
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to request cancellation for indexing attempt {attempt.id}"
|
||||
)
|
||||
|
||||
# If it is running, then signaling for termination will get the
|
||||
# watchdog thread to kill the spawned task
|
||||
redis_connector_index.set_terminate(index_payload.celery_task_id)
|
||||
|
||||
break
|
||||
else:
|
||||
redis_connector.stop.set_fence(False)
|
||||
|
||||
|
||||
@@ -106,7 +106,6 @@ from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.file_processing.extract_file_text import convert_docx_to_txt
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.server.documents.models import AuthStatus
|
||||
from onyx.server.documents.models import AuthUrl
|
||||
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
@@ -421,7 +420,7 @@ def extract_zip_metadata(zf: zipfile.ZipFile) -> dict[str, Any]:
|
||||
return zip_metadata
|
||||
|
||||
|
||||
def upload_files(files: list[UploadFile], db_session: Session) -> FileUploadResponse:
|
||||
def upload_files(files: list[UploadFile]) -> FileUploadResponse:
|
||||
for file in files:
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="File name cannot be empty")
|
||||
@@ -434,7 +433,7 @@ def upload_files(files: list[UploadFile], db_session: Session) -> FileUploadResp
|
||||
deduped_file_paths = []
|
||||
zip_metadata = {}
|
||||
try:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
seen_zip = False
|
||||
for file in files:
|
||||
if file.content_type and file.content_type.startswith("application/zip"):
|
||||
@@ -491,9 +490,8 @@ def upload_files(files: list[UploadFile], db_session: Session) -> FileUploadResp
|
||||
def upload_files_api(
|
||||
files: list[UploadFile],
|
||||
_: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FileUploadResponse:
|
||||
return upload_files(files, db_session)
|
||||
return upload_files(files)
|
||||
|
||||
|
||||
@router.get("/admin/connector")
|
||||
@@ -775,7 +773,7 @@ def get_connector_indexing_status(
|
||||
if secondary_index
|
||||
else get_current_search_settings
|
||||
)
|
||||
search_settings = get_search_settings(db_session)
|
||||
get_search_settings(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
# TODO remove this to enable ingestion API
|
||||
if cc_pair.name == "DefaultCCPair":
|
||||
@@ -787,16 +785,13 @@ def get_connector_indexing_status(
|
||||
# This may happen if background deletion is happening
|
||||
continue
|
||||
|
||||
in_progress = False
|
||||
if search_settings:
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair.id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
if redis_connector_index.fenced:
|
||||
in_progress = True
|
||||
|
||||
latest_index_attempt = cc_pair_to_latest_index_attempt.get(
|
||||
(connector.id, credential.id)
|
||||
)
|
||||
in_progress = bool(
|
||||
latest_index_attempt
|
||||
and latest_index_attempt.status == IndexingStatus.IN_PROGRESS
|
||||
)
|
||||
|
||||
latest_finished_attempt = cc_pair_to_latest_finished_index_attempt.get(
|
||||
(connector.id, credential.id)
|
||||
|
||||
@@ -195,10 +195,9 @@ def undelete_persona(
|
||||
@admin_router.post("/upload-image")
|
||||
def upload_file(
|
||||
file: UploadFile,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_user),
|
||||
) -> dict[str, str]:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
file_type = ChatFileType.IMAGE
|
||||
file_id = file_store.save_file(
|
||||
content=file.file,
|
||||
|
||||
@@ -205,6 +205,6 @@ def create_deletion_attempt_for_connector_id(
|
||||
|
||||
if cc_pair.connector.source == DocumentSource.FILE:
|
||||
connector = cc_pair.connector
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
for file_name in connector.connector_specific_config.get("file_locations", []):
|
||||
file_store.delete_file(file_name)
|
||||
|
||||
@@ -21,7 +21,7 @@ from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
)
|
||||
@@ -113,16 +113,13 @@ def upsert_ingestion_doc(
|
||||
|
||||
information_content_classification_model = InformationContentClassificationModel()
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
indexing_pipeline_result = run_indexing_pipeline(
|
||||
embedder=index_embedding_model,
|
||||
information_content_classification_model=information_content_classification_model,
|
||||
document_index=curr_doc_index,
|
||||
ignore_time_skip=True,
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
indexing_pipeline_result = indexing_pipeline(
|
||||
document_batch=[document],
|
||||
index_attempt_metadata=IndexAttemptMetadata(
|
||||
connector_id=cc_pair.connector_id,
|
||||
@@ -148,16 +145,13 @@ def upsert_ingestion_doc(
|
||||
active_search_settings.secondary, None
|
||||
)
|
||||
|
||||
sec_ind_pipeline = build_indexing_pipeline(
|
||||
run_indexing_pipeline(
|
||||
embedder=new_index_embedding_model,
|
||||
information_content_classification_model=information_content_classification_model,
|
||||
document_index=sec_doc_index,
|
||||
ignore_time_skip=True,
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
sec_ind_pipeline(
|
||||
document_batch=[document],
|
||||
index_attempt_metadata=IndexAttemptMetadata(
|
||||
connector_id=cc_pair.connector_id,
|
||||
|
||||
@@ -720,7 +720,7 @@ def upload_files_for_chat(
|
||||
detail="File size must be less than 20MB",
|
||||
)
|
||||
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
|
||||
file_info: list[tuple[str, str | None, ChatFileType]] = []
|
||||
for file in files:
|
||||
@@ -823,10 +823,9 @@ def upload_files_for_chat(
|
||||
@router.get("/file/{file_id:path}")
|
||||
def fetch_chat_file(
|
||||
file_id: str,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_user),
|
||||
) -> Response:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
file_record = file_store.read_file_record(file_id)
|
||||
if not file_record:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
@@ -11,7 +11,6 @@ from onyx.configs.constants import CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAU
|
||||
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import ONYX_EMAILABLE_LOGO_MAX_DIM
|
||||
from onyx.db.engine.sql_engine import get_session_with_shared_schema
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.utils.file import FileWithMimeType
|
||||
@@ -40,9 +39,8 @@ class OnyxRuntime:
|
||||
onyx_file: FileWithMimeType | None = None
|
||||
|
||||
if db_filename:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
file_store = get_default_file_store(db_session)
|
||||
onyx_file = file_store.get_file_with_mime_type(db_filename)
|
||||
file_store = get_default_file_store()
|
||||
onyx_file = file_store.get_file_with_mime_type(db_filename)
|
||||
|
||||
if not onyx_file:
|
||||
onyx_file = OnyxStaticFileManager.get_static(static_filename)
|
||||
|
||||
@@ -17,7 +17,6 @@ from requests import JSONDecodeError
|
||||
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
@@ -205,27 +204,26 @@ class CustomTool(BaseTool):
|
||||
def _save_and_get_file_references(
|
||||
self, file_content: bytes | str, content_type: str
|
||||
) -> List[str]:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
|
||||
file_id = str(uuid.uuid4())
|
||||
file_id = str(uuid.uuid4())
|
||||
|
||||
# Handle both binary and text content
|
||||
if isinstance(file_content, str):
|
||||
content = BytesIO(file_content.encode())
|
||||
else:
|
||||
content = BytesIO(file_content)
|
||||
# Handle both binary and text content
|
||||
if isinstance(file_content, str):
|
||||
content = BytesIO(file_content.encode())
|
||||
else:
|
||||
content = BytesIO(file_content)
|
||||
|
||||
file_store.save_file(
|
||||
file_id=file_id,
|
||||
content=content,
|
||||
display_name=file_id,
|
||||
file_origin=FileOrigin.CHAT_UPLOAD,
|
||||
file_type=content_type,
|
||||
file_metadata={
|
||||
"content_type": content_type,
|
||||
},
|
||||
)
|
||||
file_store.save_file(
|
||||
file_id=file_id,
|
||||
content=content,
|
||||
display_name=file_id,
|
||||
file_origin=FileOrigin.CHAT_UPLOAD,
|
||||
file_type=content_type,
|
||||
file_metadata={
|
||||
"content_type": content_type,
|
||||
},
|
||||
)
|
||||
|
||||
return [file_id]
|
||||
|
||||
@@ -328,22 +326,21 @@ class CustomTool(BaseTool):
|
||||
|
||||
# Load files from storage
|
||||
files = []
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
|
||||
for file_id in response.tool_result.file_ids:
|
||||
try:
|
||||
file_io = file_store.read_file(file_id, mode="b")
|
||||
files.append(
|
||||
InMemoryChatFile(
|
||||
file_id=file_id,
|
||||
filename=file_id,
|
||||
content=file_io.read(),
|
||||
file_type=file_type,
|
||||
)
|
||||
for file_id in response.tool_result.file_ids:
|
||||
try:
|
||||
file_io = file_store.read_file(file_id, mode="b")
|
||||
files.append(
|
||||
InMemoryChatFile(
|
||||
file_id=file_id,
|
||||
filename=file_id,
|
||||
content=file_io.read(),
|
||||
file_type=file_type,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to read file {file_id}")
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to read file {file_id}")
|
||||
|
||||
# Update prompt with file content
|
||||
prompt_builder.update_user_prompt(
|
||||
|
||||
@@ -207,6 +207,7 @@ def setup_logger(
|
||||
name: str = __name__,
|
||||
log_level: int = get_log_level_from_str(),
|
||||
extra: MutableMapping[str, Any] | None = None,
|
||||
propagate: bool = True,
|
||||
) -> OnyxLoggingAdapter:
|
||||
logger = logging.getLogger(name)
|
||||
|
||||
@@ -244,6 +245,12 @@ def setup_logger(
|
||||
|
||||
logger.notice = lambda msg, *args, **kwargs: logger.log(logging.getLevelName("NOTICE"), msg, *args, **kwargs) # type: ignore
|
||||
|
||||
# After handler configuration, disable propagation to avoid duplicate logs
|
||||
# Prevent messages from propagating to the root logger which can cause
|
||||
# duplicate log entries when the root logger is also configured with its
|
||||
# own handler (e.g. by Uvicorn / Celery).
|
||||
logger.propagate = propagate
|
||||
|
||||
return OnyxLoggingAdapter(logger, extra=extra)
|
||||
|
||||
|
||||
|
||||
@@ -384,3 +384,24 @@ def parallel_yield(gens: list[Iterator[R]], max_workers: int = 10) -> Iterator[R
|
||||
)
|
||||
next_ind += 1
|
||||
del future_to_index[future]
|
||||
|
||||
|
||||
def parallel_yield_from_funcs(
|
||||
funcs: list[Callable[..., R]],
|
||||
max_workers: int = 10,
|
||||
) -> Iterator[R]:
|
||||
"""
|
||||
Runs the list of functions with thread-level parallelism, yielding
|
||||
results as available. The asynchronous nature of this yielding means
|
||||
that stopping the returned iterator early DOES NOT GUARANTEE THAT NO
|
||||
FURTHER ITEMS WERE PRODUCED by the input funcs. Only use this function
|
||||
if you are consuming all elements from the functions OR it is acceptable
|
||||
for some extra function code to run and not have the result(s) yielded.
|
||||
"""
|
||||
|
||||
def func_wrapper(func: Callable[[], R]) -> Iterator[R]:
|
||||
yield func()
|
||||
|
||||
yield from parallel_yield(
|
||||
[func_wrapper(func) for func in funcs], max_workers=max_workers
|
||||
)
|
||||
|
||||
@@ -59,23 +59,23 @@ def run_jobs() -> None:
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation",
|
||||
]
|
||||
|
||||
cmd_worker_indexing = [
|
||||
cmd_worker_docprocessing = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.indexing",
|
||||
"onyx.background.celery.versioned_apps.docprocessing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--concurrency=6",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=indexing@%n",
|
||||
"--queues=connector_indexing",
|
||||
"--hostname=docprocessing@%n",
|
||||
"--queues=docprocessing",
|
||||
]
|
||||
|
||||
cmd_worker_user_files_indexing = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.indexing",
|
||||
"onyx.background.celery.versioned_apps.docfetching",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
@@ -111,6 +111,19 @@ def run_jobs() -> None:
|
||||
"--queues=kg_processing",
|
||||
]
|
||||
|
||||
cmd_worker_docfetching = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.docfetching",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=docfetching@%n",
|
||||
"--queues=connector_doc_fetching,user_files_indexing",
|
||||
]
|
||||
|
||||
cmd_beat = [
|
||||
"celery",
|
||||
"-A",
|
||||
@@ -132,8 +145,11 @@ def run_jobs() -> None:
|
||||
cmd_worker_heavy, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
|
||||
worker_indexing_process = subprocess.Popen(
|
||||
cmd_worker_indexing, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
worker_docprocessing_process = subprocess.Popen(
|
||||
cmd_worker_docprocessing,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
|
||||
worker_user_files_indexing_process = subprocess.Popen(
|
||||
@@ -157,6 +173,13 @@ def run_jobs() -> None:
|
||||
text=True,
|
||||
)
|
||||
|
||||
worker_docfetching_process = subprocess.Popen(
|
||||
cmd_worker_docfetching,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
|
||||
beat_process = subprocess.Popen(
|
||||
cmd_beat, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
@@ -171,8 +194,8 @@ def run_jobs() -> None:
|
||||
worker_heavy_thread = threading.Thread(
|
||||
target=monitor_process, args=("HEAVY", worker_heavy_process)
|
||||
)
|
||||
worker_indexing_thread = threading.Thread(
|
||||
target=monitor_process, args=("INDEX", worker_indexing_process)
|
||||
worker_docprocessing_thread = threading.Thread(
|
||||
target=monitor_process, args=("DOCPROCESSING", worker_docprocessing_process)
|
||||
)
|
||||
worker_user_files_indexing_thread = threading.Thread(
|
||||
target=monitor_process,
|
||||
@@ -184,24 +207,29 @@ def run_jobs() -> None:
|
||||
worker_kg_processing_thread = threading.Thread(
|
||||
target=monitor_process, args=("KG_PROCESSING", worker_kg_processing_process)
|
||||
)
|
||||
worker_docfetching_thread = threading.Thread(
|
||||
target=monitor_process, args=("DOCFETCHING", worker_docfetching_process)
|
||||
)
|
||||
beat_thread = threading.Thread(target=monitor_process, args=("BEAT", beat_process))
|
||||
|
||||
worker_primary_thread.start()
|
||||
worker_light_thread.start()
|
||||
worker_heavy_thread.start()
|
||||
worker_indexing_thread.start()
|
||||
worker_docprocessing_thread.start()
|
||||
worker_user_files_indexing_thread.start()
|
||||
worker_monitoring_thread.start()
|
||||
worker_kg_processing_thread.start()
|
||||
worker_docfetching_thread.start()
|
||||
beat_thread.start()
|
||||
|
||||
worker_primary_thread.join()
|
||||
worker_light_thread.join()
|
||||
worker_heavy_thread.join()
|
||||
worker_indexing_thread.join()
|
||||
worker_docprocessing_thread.join()
|
||||
worker_user_files_indexing_thread.join()
|
||||
worker_monitoring_thread.join()
|
||||
worker_kg_processing_thread.join()
|
||||
worker_docfetching_thread.join()
|
||||
beat_thread.join()
|
||||
|
||||
|
||||
|
||||
@@ -213,7 +213,7 @@ def _delete_connector(cc_pair_id: int, db_session: Session) -> None:
|
||||
|
||||
if file_names:
|
||||
logger.notice("Deleting stored files!")
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
for file_name in file_names:
|
||||
logger.notice(f"Deleting file {file_name}")
|
||||
file_store.delete_file(file_name)
|
||||
|
||||
@@ -65,12 +65,12 @@ autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
|
||||
[program:celery_worker_indexing]
|
||||
command=celery -A onyx.background.celery.versioned_apps.indexing worker
|
||||
[program:celery_worker_docprocessing]
|
||||
command=celery -A onyx.background.celery.versioned_apps.docprocessing worker
|
||||
--loglevel=INFO
|
||||
--hostname=indexing@%%n
|
||||
-Q connector_indexing
|
||||
stdout_logfile=/var/log/celery_worker_indexing.log
|
||||
--hostname=docprocessing@%%n
|
||||
-Q docprocessing
|
||||
stdout_logfile=/var/log/celery_worker_docprocessing.log
|
||||
stdout_logfile_maxbytes=16MB
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
@@ -78,7 +78,7 @@ startsecs=10
|
||||
stopasgroup=true
|
||||
|
||||
[program:celery_worker_user_files_indexing]
|
||||
command=celery -A onyx.background.celery.versioned_apps.indexing worker
|
||||
command=celery -A onyx.background.celery.versioned_apps.docfetching worker
|
||||
--loglevel=INFO
|
||||
--hostname=user_files_indexing@%%n
|
||||
-Q user_files_indexing
|
||||
@@ -89,6 +89,18 @@ autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
|
||||
[program:celery_worker_docfetching]
|
||||
command=celery -A onyx.background.celery.versioned_apps.docfetching worker
|
||||
--loglevel=INFO
|
||||
--hostname=docfetching@%%n
|
||||
-Q connector_doc_fetching
|
||||
stdout_logfile=/var/log/celery_worker_docfetching.log
|
||||
stdout_logfile_maxbytes=16MB
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
|
||||
[program:celery_worker_monitoring]
|
||||
command=celery -A onyx.background.celery.versioned_apps.monitoring worker
|
||||
--loglevel=INFO
|
||||
@@ -161,6 +173,7 @@ command=tail -qF
|
||||
/var/log/celery_worker_heavy.log
|
||||
/var/log/celery_worker_indexing.log
|
||||
/var/log/celery_worker_user_files_indexing.log
|
||||
/var/log/celery_worker_docfetching.log
|
||||
/var/log/celery_worker_monitoring.log
|
||||
/var/log/slack_bot.log
|
||||
/var/log/supervisord_watchdog_celery_beat.log
|
||||
|
||||
@@ -30,14 +30,12 @@ def mock_filestore_record() -> MagicMock:
|
||||
|
||||
|
||||
@patch("onyx.connectors.file.connector.get_default_file_store")
|
||||
@patch("onyx.connectors.file.connector.get_session_with_current_tenant")
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key", return_value=None
|
||||
)
|
||||
def test_single_text_file_with_metadata(
|
||||
mock_get_unstructured_api_key: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
mock_get_filestore: MagicMock,
|
||||
mock_db_session: MagicMock,
|
||||
mock_file_store: MagicMock,
|
||||
mock_filestore_record: MagicMock,
|
||||
@@ -48,13 +46,18 @@ def test_single_text_file_with_metadata(
|
||||
"doc_updated_at": "2001-01-01T00:00:00Z"}\n'
|
||||
b"Test answer is 12345"
|
||||
)
|
||||
mock_get_filestore = MagicMock()
|
||||
mock_get_filestore.return_value = mock_file_store
|
||||
mock_file_store.read_file_record.return_value = mock_filestore_record
|
||||
mock_get_session.return_value.__enter__.return_value = mock_db_session
|
||||
mock_file_store.read_file.return_value = file_content
|
||||
|
||||
connector = LocalFileConnector(file_locations=["test.txt"], zip_metadata={})
|
||||
batches = list(connector.load_from_state())
|
||||
with patch(
|
||||
"onyx.connectors.file.connector.get_default_file_store",
|
||||
return_value=mock_file_store,
|
||||
):
|
||||
connector = LocalFileConnector(file_locations=["test.txt"], zip_metadata={})
|
||||
batches = list(connector.load_from_state())
|
||||
|
||||
assert len(batches) == 1
|
||||
docs = batches[0]
|
||||
@@ -69,26 +72,22 @@ def test_single_text_file_with_metadata(
|
||||
assert doc.doc_updated_at == datetime(2001, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
@patch("onyx.connectors.file.connector.get_default_file_store")
|
||||
@patch("onyx.connectors.file.connector.get_session_with_current_tenant")
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key", return_value=None
|
||||
)
|
||||
def test_two_text_files_with_zip_metadata(
|
||||
mock_get_unstructured_api_key: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
mock_get_filestore: MagicMock,
|
||||
mock_db_session: MagicMock,
|
||||
mock_file_store: MagicMock,
|
||||
) -> None:
|
||||
file1_content = io.BytesIO(b"File 1 content")
|
||||
file2_content = io.BytesIO(b"File 2 content")
|
||||
mock_get_filestore = MagicMock()
|
||||
mock_get_filestore.return_value = mock_file_store
|
||||
mock_file_store.read_file_record.side_effect = [
|
||||
MagicMock(file_id=str(uuid4()), display_name="file1.txt"),
|
||||
MagicMock(file_id=str(uuid4()), display_name="file2.txt"),
|
||||
]
|
||||
mock_get_session.return_value.__enter__.return_value = mock_db_session
|
||||
mock_file_store.read_file.side_effect = [file1_content, file2_content]
|
||||
zip_metadata = {
|
||||
"file1.txt": {
|
||||
@@ -109,10 +108,14 @@ def test_two_text_files_with_zip_metadata(
|
||||
},
|
||||
}
|
||||
|
||||
connector = LocalFileConnector(
|
||||
file_locations=["file1.txt", "file2.txt"], zip_metadata=zip_metadata
|
||||
)
|
||||
batches = list(connector.load_from_state())
|
||||
with patch(
|
||||
"onyx.connectors.file.connector.get_default_file_store",
|
||||
return_value=mock_file_store,
|
||||
):
|
||||
connector = LocalFileConnector(
|
||||
file_locations=["file1.txt", "file2.txt"], zip_metadata=zip_metadata
|
||||
)
|
||||
batches = list(connector.load_from_state())
|
||||
|
||||
assert len(batches) == 1
|
||||
docs = batches[0]
|
||||
|
||||
@@ -70,20 +70,19 @@ def _get_all_backend_configs() -> List[BackendConfig]:
|
||||
if S3_ENDPOINT_URL:
|
||||
minio_access_key = "minioadmin"
|
||||
minio_secret_key = "minioadmin"
|
||||
if minio_access_key and minio_secret_key:
|
||||
configs.append(
|
||||
{
|
||||
"endpoint_url": S3_ENDPOINT_URL,
|
||||
"access_key": minio_access_key,
|
||||
"secret_key": minio_secret_key,
|
||||
"region": "us-east-1",
|
||||
"verify_ssl": False,
|
||||
"backend_name": "MinIO",
|
||||
}
|
||||
)
|
||||
configs.append(
|
||||
{
|
||||
"endpoint_url": S3_ENDPOINT_URL,
|
||||
"access_key": minio_access_key,
|
||||
"secret_key": minio_secret_key,
|
||||
"region": "us-east-1",
|
||||
"verify_ssl": False,
|
||||
"backend_name": "MinIO",
|
||||
}
|
||||
)
|
||||
|
||||
# AWS S3 configuration (if credentials are available)
|
||||
if S3_AWS_ACCESS_KEY_ID and S3_AWS_SECRET_ACCESS_KEY:
|
||||
elif S3_AWS_ACCESS_KEY_ID and S3_AWS_SECRET_ACCESS_KEY:
|
||||
configs.append(
|
||||
{
|
||||
"endpoint_url": None,
|
||||
@@ -116,7 +115,6 @@ def file_store(
|
||||
|
||||
# Create S3BackedFileStore with backend-specific configuration
|
||||
store = S3BackedFileStore(
|
||||
db_session=db_session,
|
||||
bucket_name=TEST_BUCKET_NAME,
|
||||
aws_access_key_id=backend_config["access_key"],
|
||||
aws_secret_access_key=backend_config["secret_key"],
|
||||
@@ -827,7 +825,6 @@ class TestS3BackedFileStore:
|
||||
# Create a new database session for each worker to avoid conflicts
|
||||
with get_session_with_current_tenant() as worker_session:
|
||||
worker_file_store = S3BackedFileStore(
|
||||
db_session=worker_session,
|
||||
bucket_name=current_bucket_name,
|
||||
aws_access_key_id=current_access_key,
|
||||
aws_secret_access_key=current_secret_key,
|
||||
@@ -849,6 +846,7 @@ class TestS3BackedFileStore:
|
||||
display_name=f"Worker {worker_id} File",
|
||||
file_origin=file_origin,
|
||||
file_type=file_type,
|
||||
db_session=worker_session,
|
||||
)
|
||||
results.append((file_name, content))
|
||||
return True
|
||||
@@ -885,3 +883,94 @@ class TestS3BackedFileStore:
|
||||
read_content_io = file_store.read_file(file_id)
|
||||
actual_content: str = read_content_io.read().decode("utf-8")
|
||||
assert actual_content == expected_content
|
||||
|
||||
def test_list_files_by_prefix(self, file_store: S3BackedFileStore) -> None:
|
||||
"""Test listing files by prefix returns only correctly prefixed files"""
|
||||
test_prefix = "documents-batch-"
|
||||
|
||||
# Files that should be returned (start with the prefix)
|
||||
prefixed_files: List[str] = [
|
||||
f"{test_prefix}001.txt",
|
||||
f"{test_prefix}002.json",
|
||||
f"{test_prefix}abc.pdf",
|
||||
f"{test_prefix}xyz-final.docx",
|
||||
]
|
||||
|
||||
# Files that should NOT be returned (don't start with prefix, even if they contain it)
|
||||
non_prefixed_files: List[str] = [
|
||||
f"other-{test_prefix}001.txt", # Contains prefix but doesn't start with it
|
||||
f"backup-{test_prefix}data.txt", # Contains prefix but doesn't start with it
|
||||
f"{uuid.uuid4()}.txt", # Random file without prefix
|
||||
"reports-001.pdf", # Different prefix
|
||||
f"my-{test_prefix[:-1]}.txt", # Similar but not exact prefix
|
||||
]
|
||||
|
||||
all_files = prefixed_files + non_prefixed_files
|
||||
saved_file_ids: List[str] = []
|
||||
|
||||
# Save all test files
|
||||
for file_name in all_files:
|
||||
content = f"Content for {file_name}"
|
||||
content_io = BytesIO(content.encode("utf-8"))
|
||||
|
||||
returned_file_id = file_store.save_file(
|
||||
content=content_io,
|
||||
display_name=f"Display: {file_name}",
|
||||
file_origin=FileOrigin.OTHER,
|
||||
file_type="text/plain",
|
||||
file_id=file_name,
|
||||
)
|
||||
saved_file_ids.append(returned_file_id)
|
||||
|
||||
# Verify file was saved
|
||||
assert returned_file_id == file_name
|
||||
|
||||
# Test the list_files_by_prefix functionality
|
||||
prefix_results = file_store.list_files_by_prefix(test_prefix)
|
||||
|
||||
# Extract file IDs from results
|
||||
returned_file_ids = [record.file_id for record in prefix_results]
|
||||
|
||||
# Verify correct number of files returned
|
||||
assert len(returned_file_ids) == len(prefixed_files), (
|
||||
f"Expected {len(prefixed_files)} files with prefix '{test_prefix}', "
|
||||
f"but got {len(returned_file_ids)}: {returned_file_ids}"
|
||||
)
|
||||
|
||||
# Verify all prefixed files are returned
|
||||
for expected_file_id in prefixed_files:
|
||||
assert expected_file_id in returned_file_ids, (
|
||||
f"File '{expected_file_id}' should be in results but was not found. "
|
||||
f"Returned files: {returned_file_ids}"
|
||||
)
|
||||
|
||||
# Verify no non-prefixed files are returned
|
||||
for unexpected_file_id in non_prefixed_files:
|
||||
assert unexpected_file_id not in returned_file_ids, (
|
||||
f"File '{unexpected_file_id}' should NOT be in results but was found. "
|
||||
f"Returned files: {returned_file_ids}"
|
||||
)
|
||||
|
||||
# Verify the returned records have correct properties
|
||||
for record in prefix_results:
|
||||
assert record.file_id.startswith(test_prefix)
|
||||
assert record.display_name == f"Display: {record.file_id}"
|
||||
assert record.file_origin == FileOrigin.OTHER
|
||||
assert record.file_type == "text/plain"
|
||||
assert record.bucket_name == file_store._get_bucket_name()
|
||||
|
||||
# Test with empty prefix (should return all files we created)
|
||||
all_results = file_store.list_files_by_prefix("")
|
||||
all_returned_ids = [record.file_id for record in all_results]
|
||||
|
||||
# Should include all our test files
|
||||
for file_id in saved_file_ids:
|
||||
assert (
|
||||
file_id in all_returned_ids
|
||||
), f"File '{file_id}' should be in results for empty prefix"
|
||||
|
||||
# Test with non-existent prefix
|
||||
nonexistent_results = file_store.list_files_by_prefix("nonexistent-prefix-")
|
||||
assert (
|
||||
len(nonexistent_results) == 0
|
||||
), "Should return empty list for non-existent prefix"
|
||||
|
||||
@@ -6,7 +6,7 @@ API_SERVER_PROTOCOL = os.getenv("API_SERVER_PROTOCOL") or "http"
|
||||
API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "127.0.0.1"
|
||||
API_SERVER_PORT = os.getenv("API_SERVER_PORT") or "8080"
|
||||
API_SERVER_URL = f"{API_SERVER_PROTOCOL}://{API_SERVER_HOST}:{API_SERVER_PORT}"
|
||||
MAX_DELAY = 60
|
||||
MAX_DELAY = 300
|
||||
|
||||
GENERAL_HEADERS = {"Content-Type": "application/json"}
|
||||
|
||||
|
||||
@@ -203,7 +203,9 @@ class IndexAttemptManager:
|
||||
)
|
||||
|
||||
if index_attempt.status and index_attempt.status.is_terminal():
|
||||
print(f"IndexAttempt {index_attempt_id} completed")
|
||||
print(
|
||||
f"IndexAttempt {index_attempt_id} completed with status {index_attempt.status}"
|
||||
)
|
||||
return
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
@@ -216,6 +218,7 @@ class IndexAttemptManager:
|
||||
f"Waiting for IndexAttempt {index_attempt_id} to complete. "
|
||||
f"elapsed={elapsed:.2f} timeout={timeout}"
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
@staticmethod
|
||||
def get_index_attempt_errors_for_cc_pair(
|
||||
|
||||
@@ -22,6 +22,7 @@ from onyx.db.swap_index import check_and_perform_index_swap
|
||||
from onyx.document_index.document_index_utils import get_multipass_config
|
||||
from onyx.document_index.vespa.index import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from onyx.setup import setup_postgres
|
||||
from onyx.setup import setup_vespa
|
||||
@@ -398,6 +399,13 @@ def reset_vespa_multitenant() -> None:
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
def reset_file_store() -> None:
|
||||
"""Reset the FileStore."""
|
||||
filestore = get_default_file_store()
|
||||
for file_record in filestore.list_files_by_prefix(""):
|
||||
filestore.delete_file(file_record.file_id)
|
||||
|
||||
|
||||
def reset_all() -> None:
|
||||
if os.environ.get("SKIP_RESET", "").lower() == "true":
|
||||
logger.info("Skipping reset.")
|
||||
@@ -407,6 +415,8 @@ def reset_all() -> None:
|
||||
reset_postgres()
|
||||
logger.info("Resetting Vespa...")
|
||||
reset_vespa()
|
||||
logger.info("Resetting FileStore...")
|
||||
reset_file_store()
|
||||
|
||||
|
||||
def reset_all_multitenant() -> None:
|
||||
|
||||
@@ -60,7 +60,7 @@ def test_overlapping_connector_creation(reset: None) -> None:
|
||||
)
|
||||
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair_1, now, timeout=120, user_performing_action=admin_user
|
||||
cc_pair_1, now, timeout=300, user_performing_action=admin_user
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
@@ -74,7 +74,7 @@ def test_overlapping_connector_creation(reset: None) -> None:
|
||||
)
|
||||
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair_2, now, timeout=120, user_performing_action=admin_user
|
||||
cc_pair_2, now, timeout=300, user_performing_action=admin_user
|
||||
)
|
||||
|
||||
info_1 = CCPairManager.get_single(cc_pair_1.id, user_performing_action=admin_user)
|
||||
|
||||
@@ -90,7 +90,7 @@ def test_image_indexing(
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=cc_pair,
|
||||
after=datetime.now(timezone.utc),
|
||||
timeout=180,
|
||||
timeout=300,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
|
||||
@@ -343,8 +343,6 @@ def test_mock_connector_checkpoint_recovery(
|
||||
"""Test that checkpointing works correctly when an unhandled exception occurs
|
||||
and that subsequent runs pick up from the last successful checkpoint."""
|
||||
# Create test documents
|
||||
# Create 100 docs for first batch, this is needed to get past the
|
||||
# `_NUM_DOCS_INDEXED_TO_BE_VALID_CHECKPOINT` logic in `get_latest_valid_checkpoint`.
|
||||
docs_batch_1 = [create_test_document() for _ in range(100)]
|
||||
doc2 = create_test_document()
|
||||
doc3 = create_test_document()
|
||||
@@ -421,10 +419,13 @@ def test_mock_connector_checkpoint_recovery(
|
||||
db_session=db_session,
|
||||
vespa_client=vespa_client,
|
||||
)
|
||||
assert len(documents) == 101 # 100 docs from first batch + doc2
|
||||
document_ids = {doc.id for doc in documents}
|
||||
assert doc2.id in document_ids
|
||||
assert all(doc.id in document_ids for doc in docs_batch_1)
|
||||
# This is no longer guaranteed because docfetching and docprocessing are decoupled!
|
||||
# Some batches may not be processed when docfetching fails, but they should still stick around
|
||||
# in the filestore and be ready for the next run.
|
||||
# assert len(documents) == 101 # 100 docs from first batch + doc2
|
||||
# document_ids = {doc.id for doc in documents}
|
||||
# assert doc2.id in document_ids
|
||||
# assert all(doc.id in document_ids for doc in docs_batch_1)
|
||||
|
||||
# Get the checkpoints that were sent to the mock server
|
||||
response = mock_server_client.get("/get-checkpoints")
|
||||
|
||||
@@ -3,7 +3,7 @@ import uuid
|
||||
|
||||
import httpx
|
||||
|
||||
from onyx.background.celery.tasks.indexing.utils import (
|
||||
from onyx.background.celery.tasks.docprocessing.utils import (
|
||||
NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE,
|
||||
)
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
@@ -136,7 +136,7 @@ def test_web_pruning(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
)
|
||||
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair_1, now, timeout=120, user_performing_action=admin_user
|
||||
cc_pair_1, now, timeout=300, user_performing_action=admin_user
|
||||
)
|
||||
|
||||
selected_cc_pair = CCPairManager.get_indexing_status_by_id(
|
||||
|
||||
@@ -77,18 +77,15 @@ def sample_file_io(sample_content: bytes) -> BytesIO:
|
||||
class TestExternalStorageFileStore:
|
||||
"""Test external storage file store functionality (S3-compatible)"""
|
||||
|
||||
def test_get_default_file_store_s3(self, db_session: Session) -> None:
|
||||
def test_get_default_file_store_s3(self) -> None:
|
||||
"""Test that external storage file store is returned"""
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
assert isinstance(file_store, S3BackedFileStore)
|
||||
|
||||
def test_s3_client_initialization_with_credentials(
|
||||
self, db_session: Session
|
||||
) -> None:
|
||||
def test_s3_client_initialization_with_credentials(self) -> None:
|
||||
"""Test S3 client initialization with explicit credentials"""
|
||||
with patch("boto3.client") as mock_boto3:
|
||||
file_store = S3BackedFileStore(
|
||||
db_session,
|
||||
bucket_name="test-bucket",
|
||||
aws_access_key_id="test-key",
|
||||
aws_secret_access_key="test-secret",
|
||||
@@ -110,7 +107,6 @@ class TestExternalStorageFileStore:
|
||||
"""Test S3 client initialization with IAM role (no explicit credentials)"""
|
||||
with patch("boto3.client") as mock_boto3:
|
||||
file_store = S3BackedFileStore(
|
||||
db_session,
|
||||
bucket_name="test-bucket",
|
||||
aws_access_key_id=None,
|
||||
aws_secret_access_key=None,
|
||||
@@ -129,16 +125,16 @@ class TestExternalStorageFileStore:
|
||||
assert "aws_access_key_id" not in call_kwargs
|
||||
assert "aws_secret_access_key" not in call_kwargs
|
||||
|
||||
def test_s3_bucket_name_configuration(self, db_session: Session) -> None:
|
||||
def test_s3_bucket_name_configuration(self) -> None:
|
||||
"""Test S3 bucket name configuration"""
|
||||
with patch(
|
||||
"onyx.file_store.file_store.S3_FILE_STORE_BUCKET_NAME", "my-test-bucket"
|
||||
):
|
||||
file_store = S3BackedFileStore(db_session, bucket_name="my-test-bucket")
|
||||
file_store = S3BackedFileStore(bucket_name="my-test-bucket")
|
||||
bucket_name: str = file_store._get_bucket_name()
|
||||
assert bucket_name == "my-test-bucket"
|
||||
|
||||
def test_s3_key_generation_default_prefix(self, db_session: Session) -> None:
|
||||
def test_s3_key_generation_default_prefix(self) -> None:
|
||||
"""Test S3 key generation with default prefix"""
|
||||
with (
|
||||
patch("onyx.file_store.file_store.S3_FILE_STORE_PREFIX", "onyx-files"),
|
||||
@@ -147,11 +143,11 @@ class TestExternalStorageFileStore:
|
||||
return_value="test-tenant",
|
||||
),
|
||||
):
|
||||
file_store = S3BackedFileStore(db_session, bucket_name="test-bucket")
|
||||
file_store = S3BackedFileStore(bucket_name="test-bucket")
|
||||
s3_key: str = file_store._get_s3_key("test-file.txt")
|
||||
assert s3_key == "onyx-files/test-tenant/test-file.txt"
|
||||
|
||||
def test_s3_key_generation_custom_prefix(self, db_session: Session) -> None:
|
||||
def test_s3_key_generation_custom_prefix(self) -> None:
|
||||
"""Test S3 key generation with custom prefix"""
|
||||
with (
|
||||
patch("onyx.file_store.file_store.S3_FILE_STORE_PREFIX", "custom-prefix"),
|
||||
@@ -161,17 +157,15 @@ class TestExternalStorageFileStore:
|
||||
),
|
||||
):
|
||||
file_store = S3BackedFileStore(
|
||||
db_session, bucket_name="test-bucket", s3_prefix="custom-prefix"
|
||||
bucket_name="test-bucket", s3_prefix="custom-prefix"
|
||||
)
|
||||
s3_key: str = file_store._get_s3_key("test-file.txt")
|
||||
assert s3_key == "custom-prefix/test-tenant/test-file.txt"
|
||||
|
||||
def test_s3_key_generation_with_different_tenant_ids(
|
||||
self, db_session: Session
|
||||
) -> None:
|
||||
def test_s3_key_generation_with_different_tenant_ids(self) -> None:
|
||||
"""Test S3 key generation with different tenant IDs"""
|
||||
with patch("onyx.file_store.file_store.S3_FILE_STORE_PREFIX", "onyx-files"):
|
||||
file_store = S3BackedFileStore(db_session, bucket_name="test-bucket")
|
||||
file_store = S3BackedFileStore(bucket_name="test-bucket")
|
||||
|
||||
# Test with tenant ID "tenant-1"
|
||||
with patch(
|
||||
@@ -224,9 +218,7 @@ class TestExternalStorageFileStore:
|
||||
with patch("onyx.db.file_record.upsert_filerecord") as mock_upsert:
|
||||
mock_upsert.return_value = Mock()
|
||||
|
||||
file_store = S3BackedFileStore(
|
||||
mock_db_session, bucket_name="test-bucket"
|
||||
)
|
||||
file_store = S3BackedFileStore(bucket_name="test-bucket")
|
||||
|
||||
# This should not raise an exception
|
||||
file_store.save_file(
|
||||
@@ -235,6 +227,7 @@ class TestExternalStorageFileStore:
|
||||
display_name="Test File",
|
||||
file_origin=FileOrigin.OTHER,
|
||||
file_type="text/plain",
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
# Verify S3 client was called correctly
|
||||
@@ -244,14 +237,13 @@ class TestExternalStorageFileStore:
|
||||
assert call_args[1]["Key"] == "onyx-files/public/test-file.txt"
|
||||
assert call_args[1]["ContentType"] == "text/plain"
|
||||
|
||||
def test_minio_client_initialization(self, db_session: Session) -> None:
|
||||
def test_minio_client_initialization(self) -> None:
|
||||
"""Test S3 client initialization with MinIO endpoint"""
|
||||
with (
|
||||
patch("boto3.client") as mock_boto3,
|
||||
patch("urllib3.disable_warnings"),
|
||||
):
|
||||
file_store = S3BackedFileStore(
|
||||
db_session,
|
||||
bucket_name="test-bucket",
|
||||
aws_access_key_id="minioadmin",
|
||||
aws_secret_access_key="minioadmin",
|
||||
@@ -277,11 +269,10 @@ class TestExternalStorageFileStore:
|
||||
assert config.signature_version == "s3v4"
|
||||
assert config.s3["addressing_style"] == "path"
|
||||
|
||||
def test_minio_ssl_verification_enabled(self, db_session: Session) -> None:
|
||||
def test_minio_ssl_verification_enabled(self) -> None:
|
||||
"""Test MinIO with SSL verification enabled"""
|
||||
with patch("boto3.client") as mock_boto3:
|
||||
file_store = S3BackedFileStore(
|
||||
db_session,
|
||||
bucket_name="test-bucket",
|
||||
aws_access_key_id="test-key",
|
||||
aws_secret_access_key="test-secret",
|
||||
@@ -295,11 +286,10 @@ class TestExternalStorageFileStore:
|
||||
assert "verify" not in call_kwargs or call_kwargs.get("verify") is not False
|
||||
assert call_kwargs["endpoint_url"] == "https://minio.example.com"
|
||||
|
||||
def test_aws_s3_without_endpoint_url(self, db_session: Session) -> None:
|
||||
def test_aws_s3_without_endpoint_url(self) -> None:
|
||||
"""Test that regular AWS S3 doesn't include endpoint URL or custom config"""
|
||||
with patch("boto3.client") as mock_boto3:
|
||||
file_store = S3BackedFileStore(
|
||||
db_session,
|
||||
bucket_name="test-bucket",
|
||||
aws_access_key_id="test-key",
|
||||
aws_secret_access_key="test-secret",
|
||||
@@ -321,8 +311,8 @@ class TestExternalStorageFileStore:
|
||||
class TestFileStoreInterface:
|
||||
"""Test the general file store interface"""
|
||||
|
||||
def test_file_store_always_external_storage(self, db_session: Session) -> None:
|
||||
def test_file_store_always_external_storage(self) -> None:
|
||||
"""Test that external storage file store is always returned"""
|
||||
# File store should always be S3BackedFileStore regardless of environment
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
assert isinstance(file_store, S3BackedFileStore)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user