Compare commits

...

13 Commits

Author SHA1 Message Date
Evan Lohn
81ffcc00db fix migration 2025-07-06 14:46:30 -07:00
Evan Lohn
cc3ee12bff renaming and docstrings (untested) 2025-07-06 13:53:44 -07:00
Evan Lohn
70a01680ff remove unused db session in prep for new approach 2025-07-06 13:53:44 -07:00
Evan Lohn
8da82dbfaf refactor 2025-07-06 13:53:42 -07:00
Evan Lohn
56e749dcee catastrophe handling 2025-07-06 13:47:55 -07:00
Evan Lohn
f2ca7a8769 working v1 of decoupled 2025-07-06 13:47:55 -07:00
Evan Lohn
c3de6a8e49 import fixes 2025-07-06 13:47:55 -07:00
Evan Lohn
c8cd85b284 WIP 2025-07-06 13:47:53 -07:00
Evan Lohn
a899254766 WIP: can suceed but status is error 2025-07-06 13:47:01 -07:00
Evan Lohn
d33adfa91a bug fixes and finally add document batch storage 2025-07-06 13:47:01 -07:00
Evan Lohn
b9df82c5a1 minio migration 2025-07-06 13:47:01 -07:00
Evan Lohn
b8f652109f renamed and moved tasks (WIP) 2025-07-06 13:47:01 -07:00
Evan Lohn
3aeaff7cda WIP 2025-07-06 13:47:01 -07:00
71 changed files with 3187 additions and 1260 deletions

View File

@@ -9,7 +9,7 @@ Create Date: 2025-06-20 14:44:54.241159
from alembic import op
import sqlalchemy as sa
from urllib.parse import urlparse, urlunparse
from httpx import HTTPStatusError
from onyx.document_index.factory import get_default_document_index
from onyx.db.search_settings import SearchSettings
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
@@ -175,6 +175,7 @@ def update_document_id_in_database(old_doc_id: str, new_doc_id: str) -> None:
)
# Update search_doc table (stores search results for chat replay)
# This is critical for agent functionality
bind.execute(
sa.text(
"UPDATE search_doc SET document_id = :new_id WHERE document_id = :old_id"
@@ -396,6 +397,39 @@ def delete_document_from_db(current_doc_id: str, index_name: str) -> None:
try:
bind = op.get_bind()
# Delete from agent-related tables first (order matters due to foreign keys)
# Delete from agent__sub_query__search_doc first since it references search_doc
bind.execute(
sa.text(
"""
DELETE FROM agent__sub_query__search_doc
WHERE search_doc_id IN (
SELECT id FROM search_doc WHERE document_id = :doc_id
)
"""
),
{"doc_id": current_doc_id},
)
# Delete from chat_message__search_doc
bind.execute(
sa.text(
"""
DELETE FROM chat_message__search_doc
WHERE search_doc_id IN (
SELECT id FROM search_doc WHERE document_id = :doc_id
)
"""
),
{"doc_id": current_doc_id},
)
# Now we can safely delete from search_doc
bind.execute(
sa.text("DELETE FROM search_doc WHERE document_id = :doc_id"),
{"doc_id": current_doc_id},
)
# Delete from document_by_connector_credential_pair
bind.execute(
sa.text(
@@ -405,11 +439,6 @@ def delete_document_from_db(current_doc_id: str, index_name: str) -> None:
)
# Delete from other tables that reference this document
bind.execute(
sa.text("DELETE FROM search_doc WHERE document_id = :doc_id"),
{"doc_id": current_doc_id},
)
bind.execute(
sa.text(
"DELETE FROM document_retrieval_feedback WHERE document_id = :doc_id"
@@ -531,7 +560,6 @@ def upgrade() -> None:
updated_count += 1
except Exception as e:
print(f"Failed to update document {current_doc_id}: {e}")
from httpx import HTTPStatusError
if isinstance(e, HTTPStatusError):
print(f"HTTPStatusError: {e}")

View File

@@ -159,7 +159,7 @@ def _migrate_files_to_postgres() -> None:
# only create external store if we have files to migrate. This line
# makes it so we need to have S3/MinIO configured to run this migration.
external_store = get_s3_file_store(db_session=session)
external_store = get_s3_file_store()
for i, file_id in enumerate(files_to_migrate, 1):
print(f"Migrating file {i}/{total_files}: {file_id}")
@@ -219,7 +219,7 @@ def _migrate_files_to_external_storage() -> None:
# Get database session
bind = op.get_bind()
session = Session(bind=bind)
external_store = get_s3_file_store(db_session=session)
external_store = get_s3_file_store()
# Find all files currently stored in PostgreSQL (lobj_oid is not null)
result = session.execute(

View File

@@ -91,7 +91,7 @@ def export_query_history_task(
with get_session_with_current_tenant() as db_session:
try:
stream.seek(0)
get_default_file_store(db_session).save_file(
get_default_file_store().save_file(
content=stream,
display_name=report_name,
file_origin=FileOrigin.QUERY_HISTORY_CSV,

View File

@@ -422,7 +422,7 @@ def connector_permission_sync_generator_task(
lock: RedisLock = r.lock(
OnyxRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
+ f"_{redis_connector.cc_pair_id}",
timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT,
thread_local=False,
)

View File

@@ -383,7 +383,7 @@ def connector_external_group_sync_generator_task(
lock: RedisLock = r.lock(
OnyxRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
+ f"_{redis_connector.cc_pair_id}",
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
)

View File

@@ -114,7 +114,6 @@ def get_all_usage_reports(db_session: Session) -> list[UsageReportMetadata]:
def get_usage_report_data(
db_session: Session,
report_display_name: str,
) -> IO:
"""
@@ -128,7 +127,7 @@ def get_usage_report_data(
Returns:
The usage report data.
"""
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
# usage report may be very large, so don't load it all into memory
return file_store.read_file(
file_id=report_display_name, mode="b", use_tempfile=True

View File

@@ -134,15 +134,14 @@ def ee_fetch_settings() -> EnterpriseSettings:
def put_logo(
file: UploadFile,
is_logotype: bool = False,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_admin_user),
) -> None:
upload_logo(file=file, db_session=db_session, is_logotype=is_logotype)
upload_logo(file=file, is_logotype=is_logotype)
def fetch_logo_helper(db_session: Session) -> Response:
try:
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
onyx_file = file_store.get_file_with_mime_type(get_logo_filename())
if not onyx_file:
raise ValueError("get_onyx_file returned None!")
@@ -158,7 +157,7 @@ def fetch_logo_helper(db_session: Session) -> Response:
def fetch_logotype_helper(db_session: Session) -> Response:
try:
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
onyx_file = file_store.get_file_with_mime_type(get_logotype_filename())
if not onyx_file:
raise ValueError("get_onyx_file returned None!")

View File

@@ -6,7 +6,6 @@ from typing import IO
from fastapi import HTTPException
from fastapi import UploadFile
from sqlalchemy.orm import Session
from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload
from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
@@ -99,9 +98,7 @@ def guess_file_type(filename: str) -> str:
return "application/octet-stream"
def upload_logo(
db_session: Session, file: UploadFile | str, is_logotype: bool = False
) -> bool:
def upload_logo(file: UploadFile | str, is_logotype: bool = False) -> bool:
content: IO[Any]
if isinstance(file, str):
@@ -129,7 +126,7 @@ def upload_logo(
display_name = file.filename
file_type = file.content_type or "image/jpeg"
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
file_store.save_file(
content=content,
display_name=display_name,

View File

@@ -358,7 +358,7 @@ def get_query_history_export_status(
# If task is None, then it's possible that the task has already finished processing.
# Therefore, we should then check if the export file has already been stored inside of the file-store.
# If that *also* doesn't exist, then we can return a 404.
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
report_name = construct_query_history_report_name(request_id)
has_file = file_store.has_file(
@@ -385,7 +385,7 @@ def download_query_history_csv(
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
report_name = construct_query_history_report_name(request_id)
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
has_file = file_store.has_file(
file_id=report_name,
file_origin=FileOrigin.QUERY_HISTORY_CSV,

View File

@@ -53,7 +53,7 @@ def read_usage_report(
db_session: Session = Depends(get_session),
) -> Response:
try:
file = get_usage_report_data(db_session, report_name)
file = get_usage_report_data(report_name)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -112,7 +112,7 @@ def create_new_usage_report(
period: tuple[datetime, datetime] | None,
) -> UsageReportMetadata:
report_id = str(uuid.uuid4())
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
messages_file_id = generate_chat_messages_report(
db_session, file_store, report_id, period

View File

@@ -200,10 +200,10 @@ def _seed_enterprise_settings(seed_config: SeedConfiguration) -> None:
store_ee_settings(final_enterprise_settings)
def _seed_logo(db_session: Session, logo_path: str | None) -> None:
def _seed_logo(logo_path: str | None) -> None:
if logo_path:
logger.notice("Uploading logo")
upload_logo(db_session=db_session, file=logo_path)
upload_logo(file=logo_path)
def _seed_analytics_script(seed_config: SeedConfiguration) -> None:
@@ -245,7 +245,7 @@ def seed_db() -> None:
if seed_config.custom_tools is not None:
_seed_custom_tools(db_session, seed_config.custom_tools)
_seed_logo(db_session, seed_config.seeded_logo_path)
_seed_logo(seed_config.seeded_logo_path)
_seed_enterprise_settings(seed_config)
_seed_analytics_script(seed_config)

View File

@@ -0,0 +1,102 @@
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.docfetching")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
@celeryd_init.connect
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME)
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
app_base.on_secondary_worker_init(sender, **kwargs)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_ready(sender, **kwargs)
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
base_bootsteps = app_base.get_bootsteps()
for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.docfetching",
]
)

View File

@@ -0,0 +1,23 @@
import onyx.background.celery.configs.base as shared_config
from onyx.configs.app_configs import CELERY_WORKER_DOCFETCHING_CONCURRENCY
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
# Document processing worker configuration
# Similar to indexing worker since it does similar CPU/memory intensive work
worker_concurrency = CELERY_WORKER_DOCFETCHING_CONCURRENCY
worker_pool = "threads"
worker_prefetch_multiplier = 1

View File

@@ -0,0 +1,713 @@
import multiprocessing
import os
import time
import traceback
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from time import sleep
import sentry_sdk
from celery import Celery
from celery import shared_task
from celery import Task
from redis.lock import Lock as RedisLock
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.memory_monitoring import emit_process_memory
from onyx.background.celery.tasks.indexing.tasks import ConnectorIndexingLogBuilder
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
from onyx.background.celery.tasks.models import ConnectorIndexingContext
from onyx.background.celery.tasks.models import IndexingWatchdogTerminalStatus
from onyx.background.celery.tasks.models import SimpleJobResult
from onyx.background.indexing.job_client import SimpleJob
from onyx.background.indexing.job_client import SimpleJobClient
from onyx.background.indexing.job_client import SimpleJobException
from onyx.background.indexing.run_indexing import run_indexing_entrypoint
from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import mark_attempt_canceled
from onyx.db.index_attempt import mark_attempt_failed
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_index import RedisConnectorIndexPayload
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
def _wait_for_fence_payload(
redis_connector: RedisConnector,
redis_connector_index: RedisConnectorIndex,
index_attempt_id: int,
) -> RedisConnectorIndexPayload:
# this wait is needed to avoid a race condition where
# the primary worker sends the task and it is immediately executed
# before the primary worker can finalize the fence
start = time.monotonic()
while True:
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
raise SimpleJobException(
f"docfetching_task - timed out waiting for fence to be ready: "
f"fence={redis_connector.permissions.fence_key}",
code=IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT.code,
)
if not redis_connector_index.fenced: # The fence must exist
raise SimpleJobException(
f"docfetching_task - fence not found: fence={redis_connector_index.fence_key}",
code=IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND.code,
)
payload = redis_connector_index.payload # The payload must exist
if not payload:
raise SimpleJobException(
"docfetching_task: payload invalid or not found",
code=IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND.code,
)
if payload.index_attempt_id is None or payload.celery_task_id is None:
logger.info(
f"docfetching_task - Waiting for fence: fence={redis_connector_index.fence_key}"
)
sleep(1)
continue
if payload.index_attempt_id != index_attempt_id:
raise SimpleJobException(
f"docfetching_task - id mismatch. Task may be left over from previous run.: "
f"task_index_attempt={index_attempt_id} "
f"payload_index_attempt={payload.index_attempt_id}",
code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code,
)
logger.info(
f"docfetching_task - Fence found, continuing...: fence={redis_connector_index.fence_key}"
)
return payload
def docfetching_task(
app: Celery,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
is_ee: bool,
tenant_id: str,
) -> None:
"""
This function is run in a SimpleJob as a new process. It is responsible for validating
some stuff, but basically it just calls run_indexing_entrypoint.
NOTE: if an exception is raised out of this task, the primary worker will detect
that the task transitioned to a "READY" state but the generator_complete_key doesn't exist.
This will cause the primary worker to abort the indexing attempt and clean up.
"""
# Since connector_indexing_proxy_task spawns a new process using this function as
# the entrypoint, we init Sentry here.
if SENTRY_DSN:
sentry_sdk.init(
dsn=SENTRY_DSN,
traces_sample_rate=0.1,
)
logger.info("Sentry initialized")
else:
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
logger.info(
f"Indexing spawned task starting: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
# 20 is the documented default for httpx max_keepalive_connections
if MANAGED_VESPA:
httpx_init_vespa_pool(
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
)
else:
httpx_init_vespa_pool(20)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
if redis_connector.delete.fenced:
raise SimpleJobException(
f"Indexing will not start because connector deletion is in progress: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.delete.fence_key}",
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION.code,
)
if redis_connector.stop.fenced:
raise SimpleJobException(
f"Indexing will not start because a connector stop signal was detected: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.stop.fence_key}",
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL.code,
)
payload = _wait_for_fence_payload(
redis_connector, redis_connector_index, index_attempt_id
)
r = get_redis_client()
# set thread_local=False since we don't control what thread the indexing/pruning
# might run our callback with
lock: RedisLock = r.lock(
redis_connector_index.generator_lock_key,
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
thread_local=False,
)
acquired = lock.acquire(blocking=False)
if not acquired:
logger.warning(
f"Indexing task already running, exiting...: "
f"index_attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
raise SimpleJobException(
f"Indexing task already running, exiting...: "
f"index_attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}",
code=IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING.code,
)
payload.started = datetime.now(timezone.utc)
redis_connector_index.set_fence(payload)
try:
with get_session_with_current_tenant() as db_session:
attempt = get_index_attempt(db_session, index_attempt_id)
if not attempt:
raise SimpleJobException(
f"Index attempt not found: index_attempt={index_attempt_id}",
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
)
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
raise SimpleJobException(
f"cc_pair not found: cc_pair={cc_pair_id}",
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
)
# TODO: pretty sure it's impossible for these to happen, should probably delete
if not cc_pair.connector:
raise SimpleJobException(
f"Connector not found: cc_pair={cc_pair_id} connector={cc_pair.connector_id}",
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
)
if not cc_pair.credential:
raise SimpleJobException(
f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}",
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
)
# define a callback class
callback = IndexingCallback(
os.getppid(),
redis_connector,
lock,
r,
redis_connector_index,
)
logger.info(
f"Indexing spawned task running entrypoint: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
# This is where the heavy/real work happens
run_indexing_entrypoint(
app,
index_attempt_id,
tenant_id,
cc_pair_id,
is_ee,
callback=callback,
)
# get back the total number of indexed docs and return it
redis_connector_index.get_progress()
redis_connector_index.set_generator_complete(HTTPStatus.OK.value)
except ConnectorValidationError:
raise SimpleJobException(
f"Indexing task failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}",
code=IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR.code,
)
except Exception as e:
logger.exception(
f"Indexing spawned task failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
# special bulletproofing ... truncate long exception messages
# for exception types that require more args, this will fail
# thus the try/except
try:
sanitized_e = type(e)(str(e)[:1024])
sanitized_e.__traceback__ = e.__traceback__
raise sanitized_e
except Exception:
raise e
finally:
if lock.owned():
lock.release()
logger.info(
f"Indexing spawned task finished: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
os._exit(0) # ensure process exits cleanly
def process_job_result(
job: SimpleJob,
connector_source: str | None,
redis_connector_index: RedisConnectorIndex,
log_builder: ConnectorIndexingLogBuilder,
) -> SimpleJobResult:
result = SimpleJobResult()
result.connector_source = connector_source
if job.process:
result.exit_code = job.process.exitcode
if job.status != "error":
result.status = IndexingWatchdogTerminalStatus.SUCCEEDED
return result
ignore_exitcode = False
# In EKS, there is an edge case where successful tasks return exit
# code 1 in the cloud due to the set_spawn_method not sticking.
# We've since worked around this, but the following is a safe way to
# work around this issue. Basically, we ignore the job error state
# if the completion signal is OK.
status_int = redis_connector_index.get_completion()
if status_int:
status_enum = HTTPStatus(status_int)
if status_enum == HTTPStatus.OK:
ignore_exitcode = True
if ignore_exitcode:
result.status = IndexingWatchdogTerminalStatus.SUCCEEDED
task_logger.warning(
log_builder.build(
"Indexing watchdog - spawned task has non-zero exit code "
"but completion signal is OK. Continuing...",
exit_code=str(result.exit_code),
)
)
else:
if result.exit_code is not None:
result.status = IndexingWatchdogTerminalStatus.from_code(result.exit_code)
result.exception_str = job.exception()
return result
@shared_task(
name=OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
bind=True,
acks_late=False,
track_started=True,
)
def docfetching_proxy_task(
self: Task,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str,
) -> None:
"""
This task is the entrypoint for the full indexing pipeline, which is composed of two tasks:
docfetching and docprocessing.
This task is spawned by "try_creating_indexing_task" which is called in the "check_for_indexing" task.
This task spawns a new process for a new scheduled index attempt. That
new process (which runs the docfetching_task function) does the following:
TODO: clen up
determines parameters of the indexing attempt (which connector indexing function to run,
start and end time, from prev checkpoint or not), then run that connector. Specifically,
connectors are responsible for reading data from an outside source and converting it to Onyx documents.
At the moment these two steps (reading external data and converting to an Onyx document)
are not parallelized in most connectors; that's a subject for future work.
upserts documents to postgres (index_doc_batch_prepare)
chunks each document (optionally adds context for contextual rag)
embeds chunks (embed_chunks_with_failure_handling) via a call to the model server
write chunks to vespa (write_chunks_to_vector_db_with_backoff)
update document and indexing metadata in postgres
pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list
TODO(rkuo): refactor this so that there is a single return path where we canonically
log the result of running this function.
Some more Richard notes:
celery out of process task execution strategy is pool=prefork, but it uses fork,
and forking is inherently unstable.
To work around this, we use pool=threads and proxy our work to a spawned task.
acks_late must be set to False. Otherwise, celery's visibility timeout will
cause any task that runs longer than the timeout to be redispatched by the broker.
There appears to be no good workaround for this, so we need to handle redispatching
manually.
NOTE: we try/except all db access in this function because as a watchdog, this function
needs to be extremely stable.
"""
start = time.monotonic()
result = SimpleJobResult()
ctx = ConnectorIndexingContext(
tenant_id=tenant_id,
cc_pair_id=cc_pair_id,
search_settings_id=search_settings_id,
index_attempt_id=index_attempt_id,
)
log_builder = ConnectorIndexingLogBuilder(ctx)
task_logger.info(
log_builder.build(
"Indexing watchdog - starting",
mp_start_method=str(multiprocessing.get_start_method()),
)
)
if not self.request.id:
task_logger.error("self.request.id is None!")
client = SimpleJobClient()
task_logger.info(f"submitting docfetching_task with tenant_id={tenant_id}")
job = client.submit(
docfetching_task,
self.app,
index_attempt_id,
cc_pair_id,
search_settings_id,
global_version.is_ee_version(),
tenant_id,
)
if not job or not job.process:
result.status = IndexingWatchdogTerminalStatus.SPAWN_FAILED
task_logger.info(
log_builder.build(
"Indexing watchdog - finished",
status=str(result.status.value),
exit_code=str(result.exit_code),
)
)
return
# Ensure the process has moved out of the starting state
num_waits = 0
while True:
if num_waits > 15:
result.status = IndexingWatchdogTerminalStatus.SPAWN_NOT_ALIVE
task_logger.info(
log_builder.build(
"Indexing watchdog - finished",
status=str(result.status.value),
exit_code=str(result.exit_code),
)
)
job.release()
return
if job.process.is_alive() or job.process.exitcode is not None:
break
sleep(1)
num_waits += 1
task_logger.info(
log_builder.build(
"Indexing watchdog - spawn succeeded",
pid=str(job.process.pid),
)
)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
# Track the last time memory info was emitted
last_memory_emit_time = 0.0
# track the last ttl and the time it was observed
last_activity_ttl_observed: float = time.monotonic()
last_activity_ttl: int = 0
try:
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(
db_session=db_session,
index_attempt_id=index_attempt_id,
eager_load_cc_pair=True,
)
if not index_attempt:
raise RuntimeError("Index attempt not found")
result.connector_source = (
index_attempt.connector_credential_pair.connector.source.value
)
redis_connector_index.set_active() # renew active signal
# prime the connector active signal (renewed inside the connector)
redis_connector_index.set_connector_active()
while True:
sleep(5)
now = time.monotonic()
# renew watchdog signal (this has a shorter timeout than set_active)
redis_connector_index.set_watchdog(True)
# renew active signal
redis_connector_index.set_active()
# if the job is done, clean up and break
if job.done():
try:
result = process_job_result(
job, result.connector_source, redis_connector_index, log_builder
)
except Exception:
task_logger.exception(
log_builder.build(
"Indexing watchdog - spawned task exceptioned"
)
)
finally:
job.release()
break
# log the memory usage for tracking down memory leaks / connector-specific memory issues
pid = job.process.pid
if pid is not None:
# Only emit memory info once per minute (60 seconds)
current_time = time.monotonic()
if current_time - last_memory_emit_time >= 60.0:
emit_process_memory(
pid,
"indexing_worker",
{
"cc_pair_id": cc_pair_id,
"search_settings_id": search_settings_id,
"index_attempt_id": index_attempt_id,
},
)
last_memory_emit_time = current_time
# if a termination signal is detected, break (exit point will clean up)
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
log_builder.build("Indexing watchdog - termination signal detected")
)
result.status = IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL
break
# if activity timeout is detected, break (exit point will clean up)
ttl = redis_connector_index.connector_active_ttl()
if ttl < 0:
# verify expectations around ttl
last_observed = last_activity_ttl_observed - now
if now > last_activity_ttl_observed + last_activity_ttl:
task_logger.warning(
log_builder.build(
"Indexing watchdog - activity timeout exceeded",
last_observed=f"{last_observed:.2f}s",
last_ttl=f"{last_activity_ttl}",
timeout=f"{CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
)
)
result.status = (
IndexingWatchdogTerminalStatus.TERMINATED_BY_ACTIVITY_TIMEOUT
)
break
else:
task_logger.warning(
log_builder.build(
"Indexing watchdog - activity timeout expired unexpectedly, "
"waiting for last observed TTL before exiting",
last_observed=f"{last_observed:.2f}s",
last_ttl=f"{last_activity_ttl}",
timeout=f"{CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
)
)
else:
last_activity_ttl_observed = now
last_activity_ttl = ttl
# if the spawned task is still running, restart the check once again
# if the index attempt is not in a finished status
try:
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
except Exception:
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception looking up index attempt"
)
)
continue
except Exception as e:
result.status = IndexingWatchdogTerminalStatus.WATCHDOG_EXCEPTIONED
if isinstance(e, ConnectorValidationError):
# No need to expose full stack trace for validation errors
result.exception_str = str(e)
else:
result.exception_str = traceback.format_exc()
# handle exit and reporting
elapsed = time.monotonic() - start
if result.exception_str is not None:
# print with exception
try:
with get_session_with_current_tenant() as db_session:
failure_reason = (
f"Spawned task exceptioned: exit_code={result.exit_code}"
)
mark_attempt_failed(
ctx.index_attempt_id,
db_session,
failure_reason=failure_reason,
full_exception_trace=result.exception_str,
)
except Exception:
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as failed"
)
)
normalized_exception_str = "None"
if result.exception_str:
normalized_exception_str = result.exception_str.replace(
"\n", "\\n"
).replace('"', '\\"')
task_logger.warning(
log_builder.build(
"Indexing watchdog - finished",
source=result.connector_source,
status=result.status.value,
exit_code=str(result.exit_code),
exception=f'"{normalized_exception_str}"',
elapsed=f"{elapsed:.2f}s",
)
)
redis_connector_index.set_watchdog(False)
raise RuntimeError(f"Exception encountered: traceback={result.exception_str}")
# print without exception
if result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL:
try:
with get_session_with_current_tenant() as db_session:
logger.exception(
f"Marking attempt {index_attempt_id} as canceled due to termination signal"
)
mark_attempt_canceled(
index_attempt_id,
db_session,
"Connector termination signal detected",
)
except Exception:
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as canceled"
)
)
job.cancel()
elif result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_ACTIVITY_TIMEOUT:
try:
with get_session_with_current_tenant() as db_session:
mark_attempt_failed(
index_attempt_id,
db_session,
"Indexing watchdog - activity timeout exceeded: "
f"attempt={index_attempt_id} "
f"timeout={CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
)
except Exception:
logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as failed"
)
)
job.cancel()
else:
pass
task_logger.info(
log_builder.build(
"Indexing watchdog - finished",
source=result.connector_source,
status=str(result.status.value),
exit_code=str(result.exit_code),
elapsed=f"{elapsed:.2f}s",
)
)
redis_connector_index.set_watchdog(False)

File diff suppressed because it is too large Load Diff

View File

@@ -123,10 +123,7 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
self.last_parent_check = time.monotonic()
def should_stop(self) -> bool:
if self.redis_connector.stop.fenced:
return True
return False
return bool(self.redis_connector.stop.fenced)
def progress(self, tag: str, amount: int) -> None:
"""Amount isn't used yet."""
@@ -268,7 +265,7 @@ def validate_indexing_fence(
return
found = celery_find_task(
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
)
if found:
# the celery task exists in the redis queue
@@ -327,7 +324,7 @@ def validate_indexing_fences(
indexing tasks sent to celery are still in flight.
"""
reserved_indexing_tasks = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
)
# Use replica for this because the worst thing that happens
@@ -414,10 +411,12 @@ def should_index(
)
# uncomment for debugging
# task_logger.info(f"_should_index: "
# f"cc_pair={cc_pair.id} "
# f"connector={cc_pair.connector_id} "
# f"refresh_freq={connector.refresh_freq}")
task_logger.info(
f"_should_index: "
f"cc_pair={cc_pair.id} "
f"connector={cc_pair.connector_id} "
f"refresh_freq={connector.refresh_freq}"
)
# don't kick off indexing for `NOT_APPLICABLE` sources
if connector.source == DocumentSource.NOT_APPLICABLE:
@@ -517,7 +516,7 @@ def should_index(
return True
def try_creating_indexing_task(
def try_creating_docfetching_task(
celery_app: Celery,
cc_pair: ConnectorCredentialPair,
search_settings: SearchSettings,
@@ -592,16 +591,18 @@ def try_creating_indexing_task(
custom_task_id = redis_connector_index.generate_generator_task_id()
# Determine which queue to use based on whether this is a user file
# TODO: at the moment the indexing pipeline is
# shared between user files and connectors
queue = (
OnyxCeleryQueues.USER_FILES_INDEXING
if cc_pair.is_user_file
else OnyxCeleryQueues.CONNECTOR_INDEXING
else OnyxCeleryQueues.CONNECTOR_DOC_FETCHING
)
# when the task is sent, we have yet to finish setting up the fence
# therefore, the task must contain code that blocks until the fence is ready
result = celery_app.send_task(
OnyxCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
kwargs=dict(
index_attempt_id=index_attempt_id,
cc_pair_id=cc_pair.id,
@@ -613,7 +614,7 @@ def try_creating_indexing_task(
priority=OnyxCeleryPriority.MEDIUM,
)
if not result:
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
raise RuntimeError("send_task for connector_doc_fetching_task failed.")
# now fill out the fence with the rest of the data
redis_connector_index.set_active()

View File

@@ -0,0 +1,110 @@
from enum import Enum
from pydantic import BaseModel
class ConnectorIndexingContext(BaseModel):
tenant_id: str
cc_pair_id: int
search_settings_id: int
index_attempt_id: int
class IndexingWatchdogTerminalStatus(str, Enum):
"""The different statuses the watchdog can finish with.
TODO: create broader success/failure/abort categories
"""
UNDEFINED = "undefined"
SUCCEEDED = "succeeded"
SPAWN_FAILED = "spawn_failed" # connector spawn failed
SPAWN_NOT_ALIVE = (
"spawn_not_alive" # spawn succeeded but process did not come alive
)
BLOCKED_BY_DELETION = "blocked_by_deletion"
BLOCKED_BY_STOP_SIGNAL = "blocked_by_stop_signal"
FENCE_NOT_FOUND = "fence_not_found" # fence does not exist
FENCE_READINESS_TIMEOUT = (
"fence_readiness_timeout" # fence exists but wasn't ready within the timeout
)
FENCE_MISMATCH = "fence_mismatch" # task and fence metadata mismatch
TASK_ALREADY_RUNNING = "task_already_running" # task appears to be running already
INDEX_ATTEMPT_MISMATCH = (
"index_attempt_mismatch" # expected index attempt metadata not found in db
)
CONNECTOR_VALIDATION_ERROR = (
"connector_validation_error" # the connector validation failed
)
CONNECTOR_EXCEPTIONED = "connector_exceptioned" # the connector itself exceptioned
WATCHDOG_EXCEPTIONED = "watchdog_exceptioned" # the watchdog exceptioned
# the watchdog received a termination signal
TERMINATED_BY_SIGNAL = "terminated_by_signal"
# the watchdog terminated the task due to no activity
TERMINATED_BY_ACTIVITY_TIMEOUT = "terminated_by_activity_timeout"
# NOTE: this may actually be the same as SIGKILL, but parsed differently by python
# consolidate once we know more
OUT_OF_MEMORY = "out_of_memory"
PROCESS_SIGNAL_SIGKILL = "process_signal_sigkill"
@property
def code(self) -> int:
_ENUM_TO_CODE: dict[IndexingWatchdogTerminalStatus, int] = {
IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL: -9,
IndexingWatchdogTerminalStatus.OUT_OF_MEMORY: 137,
IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR: 247,
IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION: 248,
IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL: 249,
IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND: 250,
IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT: 251,
IndexingWatchdogTerminalStatus.FENCE_MISMATCH: 252,
IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING: 253,
IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH: 254,
IndexingWatchdogTerminalStatus.CONNECTOR_EXCEPTIONED: 255,
}
return _ENUM_TO_CODE[self]
@classmethod
def from_code(cls, code: int) -> "IndexingWatchdogTerminalStatus":
_CODE_TO_ENUM: dict[int, IndexingWatchdogTerminalStatus] = {
-9: IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL,
137: IndexingWatchdogTerminalStatus.OUT_OF_MEMORY,
247: IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR,
248: IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION,
249: IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL,
250: IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND,
251: IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT,
252: IndexingWatchdogTerminalStatus.FENCE_MISMATCH,
253: IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING,
254: IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH,
255: IndexingWatchdogTerminalStatus.CONNECTOR_EXCEPTIONED,
}
if code in _CODE_TO_ENUM:
return _CODE_TO_ENUM[code]
return IndexingWatchdogTerminalStatus.UNDEFINED
class SimpleJobResult:
"""The data we want to have when the watchdog finishes"""
def __init__(self) -> None:
self.status = IndexingWatchdogTerminalStatus.UNDEFINED
self.connector_source = None
self.exit_code = None
self.exception_str = None
status: IndexingWatchdogTerminalStatus
connector_source: str | None
exit_code: int | None
exception_str: str | None

View File

@@ -882,7 +882,13 @@ def monitor_celery_queues_helper(
r_celery = task.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r_celery)
n_indexing = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery)
n_docfetching = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
)
n_docprocessing = celery_get_queue_length(OnyxCeleryQueues.DOCPROCESSING, r_celery)
n_user_files_indexing = celery_get_queue_length(
OnyxCeleryQueues.USER_FILES_INDEXING, r_celery
)
n_sync = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
n_deletion = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
n_pruning = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery)
@@ -896,14 +902,20 @@ def monitor_celery_queues_helper(
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
)
n_indexing_prefetched = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
n_docfetching_prefetched = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
)
n_docprocessing_prefetched = celery_get_unacked_task_ids(
OnyxCeleryQueues.DOCPROCESSING, r_celery
)
task_logger.info(
f"Queue lengths: celery={n_celery} "
f"indexing={n_indexing} "
f"indexing_prefetched={len(n_indexing_prefetched)} "
f"docfetching={n_docfetching} "
f"docfetching_prefetched={len(n_docfetching_prefetched)} "
f"docprocessing={n_docprocessing} "
f"docprocessing_prefetched={len(n_docprocessing_prefetched)} "
f"user_files_indexing={n_user_files_indexing} "
f"sync={n_sync} "
f"deletion={n_deletion} "
f"pruning={n_pruning} "

View File

@@ -441,7 +441,7 @@ def connector_pruning_generator_task(
# set thread_local=False since we don't control what thread the indexing/pruning
# might run our callback with
lock: RedisLock = r.lock(
OnyxRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.id}",
OnyxRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.cc_pair_id}",
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
thread_local=False,
)

View File

@@ -0,0 +1,18 @@
"""Factory stub for running celery worker / celery beat.
This code is different from the primary/beat stubs because there is no EE version to
fetch. Port over the code in those files if we add an EE version of this worker."""
from celery import Celery
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
def get_app() -> Celery:
from onyx.background.celery.apps.docfetching import celery_app
return celery_app
app = get_app()

View File

@@ -33,7 +33,7 @@ def save_checkpoint(
"""Save a checkpoint for a given index attempt to the file store"""
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
file_store.save_file(
content=BytesIO(checkpoint.model_dump_json().encode()),
display_name=checkpoint_pointer,
@@ -52,11 +52,11 @@ def save_checkpoint(
def load_checkpoint(
db_session: Session, index_attempt_id: int, connector: BaseConnector
index_attempt_id: int, connector: BaseConnector
) -> ConnectorCheckpoint:
"""Load a checkpoint for a given index attempt from the file store"""
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
checkpoint_io = file_store.read_file(checkpoint_pointer, mode="rb")
checkpoint_data = checkpoint_io.read().decode("utf-8")
if isinstance(connector, CheckpointedConnector):
@@ -144,7 +144,6 @@ def get_latest_valid_checkpoint(
try:
previous_checkpoint = load_checkpoint(
db_session=db_session,
index_attempt_id=latest_valid_checkpoint_candidate.id,
connector=connector,
)
@@ -201,7 +200,7 @@ def cleanup_checkpoint(db_session: Session, index_attempt_id: int) -> None:
if not index_attempt.checkpoint_pointer:
return None
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
file_store.delete_file(index_attempt.checkpoint_pointer)
index_attempt.checkpoint_pointer = None

View File

@@ -5,7 +5,7 @@ from datetime import datetime
from datetime import timedelta
from datetime import timezone
from pydantic import BaseModel
from celery import Celery
from sqlalchemy.orm import Session
from onyx.access.access import source_should_fetch_permissions_during_indexing
@@ -19,17 +19,22 @@ from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE
from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import DocExtractionContext
from onyx.connectors.models import DocIndexingContext
from onyx.connectors.models import Document
from onyx.connectors.models import IndexAttemptMetadata
from onyx.connectors.models import TextSection
from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_last_successful_attempt_poll_range_end
from onyx.db.connector_credential_pair import update_connector_credential_pair
@@ -52,6 +57,7 @@ from onyx.db.index_attempt import update_docs_indexed
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
from onyx.document_index.factory import get_default_document_index
from onyx.file_store.document_batch_storage import get_document_batch_storage
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
@@ -59,6 +65,7 @@ from onyx.indexing.indexing_pipeline import build_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
from onyx.redis.redis_connector import RedisConnector
from onyx.utils.logger import setup_logger
from onyx.utils.logger import TaskAttemptSingleton
from onyx.utils.middleware import make_randomized_onyx_request_id
@@ -68,6 +75,7 @@ from onyx.utils.telemetry import RecordType
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
@@ -184,21 +192,8 @@ class ConnectorStopSignal(Exception):
"""A custom exception used to signal a stop in processing."""
class RunIndexingContext(BaseModel):
index_name: str
cc_pair_id: int
connector_id: int
credential_id: int
source: DocumentSource
earliest_index_time: float
from_beginning: bool
is_primary: bool
should_fetch_permissions_during_indexing: bool
search_settings_status: IndexModelStatus
def _check_connector_and_attempt_status(
db_session_temp: Session, ctx: RunIndexingContext, index_attempt_id: int
db_session_temp: Session, ctx: DocExtractionContext, index_attempt_id: int
) -> None:
"""
Checks the status of the connector credential pair and index attempt.
@@ -227,6 +222,7 @@ def _check_connector_and_attempt_status(
)
# TODO: delete from here if ends up unused
def _check_failure_threshold(
total_failures: int,
document_count: int,
@@ -271,7 +267,12 @@ def _run_indexing(
start_time = time.monotonic() # jsut used for logging
with get_session_with_current_tenant() as db_session_temp:
index_attempt_start = get_index_attempt(db_session_temp, index_attempt_id)
index_attempt_start = get_index_attempt(
db_session_temp,
index_attempt_id,
eager_load_cc_pair=True,
eager_load_search_settings=True,
)
if not index_attempt_start:
raise ValueError(
f"Index attempt {index_attempt_id} does not exist in DB. This should not be possible."
@@ -292,7 +293,7 @@ def _run_indexing(
index_attempt_start.connector_credential_pair.last_successful_index_time
is not None
)
ctx = RunIndexingContext(
ctx = DocExtractionContext(
index_name=index_attempt_start.search_settings.index_name,
cc_pair_id=index_attempt_start.connector_credential_pair.id,
connector_id=db_connector.id,
@@ -317,6 +318,7 @@ def _run_indexing(
and (from_beginning or not has_successful_attempt)
),
search_settings_status=index_attempt_start.search_settings.status,
doc_extraction_complete_batch_num=None,
)
last_successful_index_poll_range_end = (
@@ -416,7 +418,9 @@ def _run_indexing(
index_attempt: IndexAttempt | None = None
try:
with get_session_with_current_tenant() as db_session_temp:
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
index_attempt = get_index_attempt(
db_session_temp, index_attempt_id, eager_load_cc_pair=True
)
if not index_attempt:
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
@@ -456,6 +460,7 @@ def _run_indexing(
checkpoint=checkpoint,
)
# TODO: ensure this logic is moved correctly
unresolved_errors = get_index_attempt_errors_for_cc_pair(
cc_pair_id=ctx.cc_pair_id,
unresolved_only=True,
@@ -815,6 +820,7 @@ def _run_indexing(
def run_indexing_entrypoint(
app: Celery,
index_attempt_id: int,
tenant_id: str,
connector_credential_pair_id: int,
@@ -852,8 +858,14 @@ def run_indexing_entrypoint(
f"credentials='{credential_id}'"
)
with get_session_with_current_tenant() as db_session:
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
connector_document_extraction(
app,
index_attempt_id,
attempt.connector_credential_pair_id,
attempt.search_settings_id,
tenant_id,
callback,
)
logger.info(
f"Indexing finished{tenant_str}: "
@@ -861,3 +873,557 @@ def run_indexing_entrypoint(
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
def connector_document_extraction(
app: Celery,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
"""Extract documents from connector and queue them for indexing pipeline processing.
This is the first part of the split indexing process that runs the connector
and extracts documents, storing them in local files or S3 for later processing.
"""
start_time = time.monotonic()
logger.info(
f"Document extraction starting: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"tenant={tenant_id}"
)
# Get batch storage (transition to IN_PROGRESS is handled by run_indexing_entrypoint)
batch_storage = get_document_batch_storage(tenant_id, index_attempt_id)
# Initialize memory tracer. NOTE: won't actually do anything if
# `INDEXING_TRACER_INTERVAL` is 0.
memory_tracer = MemoryTracer(interval=INDEXING_TRACER_INTERVAL)
memory_tracer.start()
index_attempt = None
# comes from _run_indexing
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(
db_session,
index_attempt_id,
eager_load_cc_pair=True,
eager_load_search_settings=True,
)
if not index_attempt:
raise RuntimeError(f"Index attempt {index_attempt_id} not found")
if index_attempt.search_settings is None:
raise ValueError("Search settings must be set for indexing")
# Clear the indexing trigger if it was set, to prevent duplicate indexing attempts
if index_attempt.connector_credential_pair.indexing_trigger is not None:
logger.info(
"Clearing indexing trigger: "
f"cc_pair={index_attempt.connector_credential_pair.id} "
f"trigger={index_attempt.connector_credential_pair.indexing_trigger}"
)
mark_ccpair_with_indexing_trigger(
index_attempt.connector_credential_pair.id, None, db_session
)
db_connector = index_attempt.connector_credential_pair.connector
db_credential = index_attempt.connector_credential_pair.credential
is_primary = index_attempt.search_settings.status == IndexModelStatus.PRESENT
from_beginning = index_attempt.from_beginning
has_successful_attempt = (
index_attempt.connector_credential_pair.last_successful_index_time
is not None
)
ctx = DocExtractionContext(
index_name=index_attempt.search_settings.index_name,
cc_pair_id=cc_pair_id,
connector_id=db_connector.id,
credential_id=db_credential.id,
source=db_connector.source,
earliest_index_time=(
db_connector.indexing_start.timestamp()
if db_connector.indexing_start
else 0
),
from_beginning=index_attempt.from_beginning,
# Only update cc-pair status for primary index jobs
# Secondary index syncs at the end when swapping
is_primary=is_primary,
search_settings_status=index_attempt.search_settings.status,
doc_extraction_complete_batch_num=None, # None means not completed yet
should_fetch_permissions_during_indexing=(
index_attempt.connector_credential_pair.access_type == AccessType.SYNC
and source_should_fetch_permissions_during_indexing(db_connector.source)
and is_primary
# if we've already successfully indexed, let the doc_sync job
# take care of doc-level permissions
and (from_beginning or not has_successful_attempt)
),
)
# Set up time windows for polling
last_successful_index_poll_range_end = (
ctx.earliest_index_time
if ctx.from_beginning
else get_last_successful_attempt_poll_range_end(
cc_pair_id=ctx.cc_pair_id,
earliest_index=ctx.earliest_index_time,
search_settings=index_attempt.search_settings,
db_session=db_session,
)
)
if last_successful_index_poll_range_end > POLL_CONNECTOR_OFFSET:
window_start = datetime.fromtimestamp(
last_successful_index_poll_range_end, tz=timezone.utc
) - timedelta(minutes=POLL_CONNECTOR_OFFSET)
else:
# don't go into "negative" time if we've never indexed before
window_start = datetime.fromtimestamp(0, tz=timezone.utc)
most_recent_attempt = next(
iter(
get_recent_completed_attempts_for_cc_pair(
cc_pair_id=ctx.cc_pair_id,
search_settings_id=index_attempt.search_settings_id,
db_session=db_session,
limit=1,
)
),
None,
)
# if the last attempt failed, try and use the same window. This is necessary
# to ensure correctness with checkpointing. If we don't do this, things like
# new slack channels could be missed (since existing slack channels are
# cached as part of the checkpoint).
if (
most_recent_attempt
and most_recent_attempt.poll_range_end
and (
most_recent_attempt.status == IndexingStatus.FAILED
or most_recent_attempt.status == IndexingStatus.CANCELED
)
):
window_end = most_recent_attempt.poll_range_end
else:
window_end = datetime.now(tz=timezone.utc)
batch_storage.store_extraction_state(ctx)
# initialize indexing state
current_state = DocIndexingContext(
batches_done=0,
total_failures=0,
net_doc_change=0,
total_chunks=0,
)
batch_storage.store_indexing_state(current_state)
# TODO: maybe memory tracer here
# Set up connector runner
connector_runner = _get_connector_runner(
db_session=db_session,
attempt=index_attempt,
batch_size=INDEX_BATCH_SIZE,
start_time=window_start,
end_time=window_end,
include_permissions=ctx.should_fetch_permissions_during_indexing,
)
# don't use a checkpoint if we're explicitly indexing from
# the beginning in order to avoid weird interactions between
# checkpointing / failure handling
# OR
# if the last attempt was successful
if index_attempt.from_beginning or (
most_recent_attempt and most_recent_attempt.status.is_successful()
):
checkpoint = connector_runner.connector.build_dummy_checkpoint()
else:
checkpoint = get_latest_valid_checkpoint(
db_session=db_session,
cc_pair_id=ctx.cc_pair_id,
search_settings_id=index_attempt.search_settings_id,
window_start=window_start,
window_end=window_end,
connector=connector_runner.connector,
)
# Save initial checkpoint
save_checkpoint(
db_session=db_session,
index_attempt_id=index_attempt_id,
checkpoint=checkpoint,
)
try:
batch_num = 0
total_doc_batches_queued = 0
total_failures = 0
document_count = 0
# Main extraction loop
while checkpoint.has_more:
logger.info(
f"Running '{ctx.source.value}' connector with checkpoint: {checkpoint}"
)
for document_batch, failure, next_checkpoint in connector_runner.run(
checkpoint
):
# Check if connector is disabled mid run and stop if so unless it's the secondary
# index being built. We want to populate it even for paused connectors
# Often paused connectors are sources that aren't updated frequently but the
# contents still need to be initially pulled.
if callback:
if callback.should_stop():
raise ConnectorStopSignal("Connector stop signal detected")
# NOTE: this progress callback runs on every loop. We've seen cases
# where we loop many times with no new documents and eventually time
# out, so only doing the callback after indexing isn't sufficient.
# TODO: change to doc extraction if it doesnt break things
callback.progress("_run_indexing", 0)
# will exception if the connector/index attempt is marked as paused/failed
with get_session_with_current_tenant() as db_session:
_check_connector_and_attempt_status(
db_session, ctx, index_attempt_id
)
# save record of any failures at the connector level
if failure is not None:
total_failures += 1
with get_session_with_current_tenant() as db_session:
create_index_attempt_error(
index_attempt_id,
ctx.cc_pair_id,
failure,
db_session,
)
_check_failure_threshold(
total_failures, document_count, batch_num, failure
)
# Save checkpoint if provided
if next_checkpoint:
checkpoint = next_checkpoint
# below is all document processing task, so if no batch we can just continue
if document_batch is None:
continue
# Clean documents and create batch
doc_batch_cleaned = strip_null_characters(document_batch)
batch_description = []
for doc in doc_batch_cleaned:
batch_description.append(doc.to_short_descriptor())
doc_size = 0
for section in doc.sections:
if (
isinstance(section, TextSection)
and section.text is not None
):
doc_size += len(section.text)
if doc_size > INDEXING_SIZE_WARNING_THRESHOLD:
logger.warning(
f"Document size: doc='{doc.to_short_descriptor()}' "
f"size={doc_size} "
f"threshold={INDEXING_SIZE_WARNING_THRESHOLD}"
)
logger.debug(f"Indexing batch of documents: {batch_description}")
memory_tracer.increment_and_maybe_trace()
# TODO: replicate index_attempt_md from _run_indexing, store in blob storage
# instead of that big extraction state. index_attempt_md will be the
# persisted communication between the connector extraction task and the
# document processing task
batch_id = f"batch_{batch_num}"
# Store documents in storage
batch_storage.store_batch(batch_id, doc_batch_cleaned)
batch_storage.store_extraction_state(ctx)
# Create processing task data
processing_batch_data = {
"batch_id": batch_id,
"index_attempt_id": index_attempt_id,
"cc_pair_id": cc_pair_id,
"tenant_id": tenant_id,
"batch_num": batch_num, # 0-indexed
}
# Queue document processing task
app.send_task(
OnyxCeleryTask.DOCPROCESSING_TASK,
kwargs=processing_batch_data,
queue=OnyxCeleryQueues.DOCPROCESSING,
priority=OnyxCeleryPriority.MEDIUM,
)
batch_num += 1
total_doc_batches_queued += 1
logger.info(
f"Queued document processing batch: "
f"batch_id={batch_id} "
f"docs={len(doc_batch_cleaned)} "
f"attempt={index_attempt_id}"
)
# Check checkpoint size periodically
CHECKPOINT_SIZE_CHECK_INTERVAL = 100
if batch_num % CHECKPOINT_SIZE_CHECK_INTERVAL == 0:
check_checkpoint_size(checkpoint)
# Save latest checkpoint
# NOTE: checkpointing is used to track which batches have
# been stored, NOT which batches have been fully indexed
# as it used to be.
with get_session_with_current_tenant() as db_session:
save_checkpoint(
db_session=db_session,
index_attempt_id=index_attempt_id,
checkpoint=checkpoint,
)
elapsed_time = time.monotonic() - start_time
logger.info(
f"Document extraction completed: "
f"attempt={index_attempt_id} "
f"batches_queued={total_doc_batches_queued} "
f"elapsed={elapsed_time:.2f}s"
)
final_state = batch_storage.get_extraction_state()
if final_state is None:
raise RuntimeError("Extraction state should not be None")
# counts how many batches were queued
final_state.doc_extraction_complete_batch_num = batch_num
batch_storage.store_extraction_state(final_state)
check_indexing_completion(index_attempt_id, tenant_id)
except Exception as e:
logger.exception(
f"Document extraction failed: "
f"attempt={index_attempt_id} "
f"error={str(e)}"
)
# Clean up on failure
try:
batch_storage.cleanup_all_batches()
except Exception:
logger.exception(
"Failed to clean up document batches after extraction failure"
)
if isinstance(e, ConnectorValidationError):
# On validation errors during indexing, we want to cancel the indexing attempt
# and mark the CCPair as invalid. This prevents the connector from being
# used in the future until the credentials are updated.
with get_session_with_current_tenant() as db_session_temp:
logger.exception(
f"Marking attempt {index_attempt_id} as canceled due to validation error."
)
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
reason=f"{CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX}{str(e)}",
)
if is_primary:
if not index_attempt:
# should always be set by now
raise RuntimeError("Should never happen.")
VALIDATION_ERROR_THRESHOLD = 5
recent_index_attempts = get_recent_completed_attempts_for_cc_pair(
cc_pair_id=cc_pair_id,
search_settings_id=index_attempt.search_settings_id,
limit=VALIDATION_ERROR_THRESHOLD,
db_session=db_session_temp,
)
num_validation_errors = len(
[
index_attempt
for index_attempt in recent_index_attempts
if index_attempt.error_msg
and index_attempt.error_msg.startswith(
CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX
)
]
)
if num_validation_errors >= VALIDATION_ERROR_THRESHOLD:
logger.warning(
f"Connector {ctx.connector_id} has {num_validation_errors} consecutive validation"
f" errors. Marking the CC Pair as invalid."
)
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
status=ConnectorCredentialPairStatus.INVALID,
)
memory_tracer.stop()
raise e
elif isinstance(e, ConnectorStopSignal):
with get_session_with_current_tenant() as db_session_temp:
logger.exception(
f"Marking attempt {index_attempt_id} as canceled due to stop signal."
)
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
reason=str(e),
)
if is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
)
memory_tracer.stop()
else:
with get_session_with_current_tenant() as db_session_temp:
mark_attempt_failed(
index_attempt_id,
db_session_temp,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
if is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
)
memory_tracer.stop()
raise e
def check_indexing_completion(
index_attempt_id: int,
tenant_id: str,
) -> None:
logger.info(
f"Checking for document processing completion: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id}"
)
storage = get_document_batch_storage(tenant_id, index_attempt_id)
try:
last_progress_time = time.monotonic()
last_batches_completed = 0
# TODO: if there are no indexing workers running, fail the attempt
while True:
# Get current state
indexing_state = storage.ensure_indexing_state()
extraction_state = storage.ensure_extraction_state()
# Check if extraction is complete and all batches are processed
batches_total = extraction_state.doc_extraction_complete_batch_num
batches_processed = indexing_state.batches_done
extraction_completed = batches_total is not None and (
batches_processed >= batches_total
)
logger.info(
f"Processing status: "
f"extraction_completed={extraction_completed} "
f"batches_processed={batches_processed}/{batches_total or '?'}" # probably off by 1
)
if extraction_completed:
break
if batches_processed > last_batches_completed:
last_batches_completed = batches_processed
last_progress_time = time.monotonic()
elif time.monotonic() - last_progress_time > 3600 * 6:
raise RuntimeError(
f"Indexing attempt {index_attempt_id} has been indexing for 6 hours without progress. "
f"Marking it as failed."
)
logger.info(
f"All batches for index attempt {index_attempt_id} have been processed."
)
# All processing is complete
total_failures = indexing_state.total_failures
with get_session_with_current_tenant() as db_session:
if total_failures == 0:
mark_attempt_succeeded(index_attempt_id, db_session)
logger.info(f"Index attempt {index_attempt_id} completed successfully")
else:
mark_attempt_partially_succeeded(index_attempt_id, db_session)
logger.info(
f"Index attempt {index_attempt_id} completed with {total_failures} failures"
)
# Clean up all remaining storage
storage.cleanup_all_batches()
# Clean up the fence to allow the next run to start
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(db_session, index_attempt_id)
if index_attempt and index_attempt.search_settings:
redis_connector = RedisConnector(
tenant_id, index_attempt.connector_credential_pair_id
)
redis_connector_index = redis_connector.new_index(
index_attempt.search_settings.id
)
redis_connector_index.reset()
logger.info(f"Resetting indexing fence for attempt {index_attempt_id}")
except Exception as e:
logger.exception(
f"Failed to monitor document processing completion: "
f"attempt={index_attempt_id} "
f"error={str(e)}"
)
# Mark the attempt as failed if monitoring fails
try:
with get_session_with_current_tenant() as db_session:
mark_attempt_failed(
index_attempt_id,
db_session,
failure_reason=f"Processing monitoring failed: {str(e)}",
full_exception_trace=traceback.format_exc(),
)
except Exception:
logger.exception("Failed to mark attempt as failed")
# Try to clean up storage
try:
storage.cleanup_all_batches()
except Exception:
logger.exception("Failed to cleanup storage after monitoring failure")

View File

@@ -725,9 +725,7 @@ def stream_chat_message_objects(
)
# load all files needed for this chat chain in memory
files = load_all_chat_files(
history_msgs, new_msg_req.file_descriptors, db_session
)
files = load_all_chat_files(history_msgs, new_msg_req.file_descriptors)
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
latest_query_files = [file for file in files if file.file_id in req_file_ids]
user_file_ids = new_msg_req.user_file_ids or []

View File

@@ -324,6 +324,19 @@ except ValueError:
CELERY_WORKER_KG_PROCESSING_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_KG_PROCESSING_CONCURRENCY") or 4
)
CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT = 1
try:
env_value = os.environ.get("CELERY_WORKER_DOCFETCHING_CONCURRENCY")
if not env_value:
env_value = os.environ.get("NUM_DOCFETCHING_WORKERS")
if not env_value:
env_value = str(CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT)
CELERY_WORKER_DOCFETCHING_CONCURRENCY = int(env_value)
except ValueError:
CELERY_WORKER_DOCFETCHING_CONCURRENCY = (
CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT
)
# The maximum number of tasks that can be queued up to sync to Vespa in a single pass
VESPA_SYNC_MAX_TASKS = 1024

View File

@@ -66,6 +66,7 @@ POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
POSTGRES_CELERY_WORKER_INDEXING_APP_NAME = "celery_worker_indexing"
POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME = "celery_worker_docfetching"
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME = "celery_worker_kg_processing"
@@ -320,9 +321,12 @@ class OnyxCeleryQueues:
CSV_GENERATION = "csv_generation"
# Indexing queue
CONNECTOR_INDEXING = "connector_indexing"
USER_FILES_INDEXING = "user_files_indexing"
# Document processing pipeline queue
DOCPROCESSING = "docprocessing"
CONNECTOR_DOC_FETCHING = "connector_doc_fetching"
# Monitoring queue
MONITORING = "monitoring"
@@ -453,7 +457,11 @@ class OnyxCeleryTask:
CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK = (
"connector_external_group_sync_generator_task"
)
CONNECTOR_INDEXING_PROXY_TASK = "connector_indexing_proxy_task"
# New split indexing tasks
CONNECTOR_DOC_FETCHING_TASK = "connector_doc_fetching_task"
DOCPROCESSING_TASK = "docprocessing_task"
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"

View File

@@ -34,7 +34,6 @@ from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext
@@ -281,30 +280,28 @@ class BlobStorageConnector(LoadConnector, PollConnector):
# TODO: Refactor to avoid direct DB access in connector
# This will require broader refactoring across the codebase
with get_session_with_current_tenant() as db_session:
image_section, _ = store_image_and_create_section(
db_session=db_session,
image_data=downloaded_file,
file_id=f"{self.bucket_type}_{self.bucket_name}_{key.replace('/', '_')}",
display_name=file_name,
link=link,
file_origin=FileOrigin.CONNECTOR,
)
image_section, _ = store_image_and_create_section(
image_data=downloaded_file,
file_id=f"{self.bucket_type}_{self.bucket_name}_{key.replace('/', '_')}",
display_name=file_name,
link=link,
file_origin=FileOrigin.CONNECTOR,
)
batch.append(
Document(
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
sections=[image_section],
source=DocumentSource(self.bucket_type.value),
semantic_identifier=file_name,
doc_updated_at=last_modified,
metadata={},
)
batch.append(
Document(
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
sections=[image_section],
source=DocumentSource(self.bucket_type.value),
semantic_identifier=file_name,
doc_updated_at=last_modified,
metadata={},
)
)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) == self.batch_size:
yield batch
batch = []
except Exception:
logger.exception(f"Error processing image {key}")
continue

View File

@@ -23,7 +23,6 @@ from onyx.configs.app_configs import (
)
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.configs.constants import FileOrigin
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import is_accepted_file_ext
from onyx.file_processing.extract_file_text import OnyxExtensionType
@@ -224,19 +223,17 @@ def _process_image_attachment(
"""Process an image attachment by saving it without generating a summary."""
try:
# Use the standardized image storage and section creation
with get_session_with_current_tenant() as db_session:
section, file_name = store_image_and_create_section(
db_session=db_session,
image_data=raw_bytes,
file_id=Path(attachment["id"]).name,
display_name=attachment["title"],
media_type=media_type,
file_origin=FileOrigin.CONNECTOR,
)
logger.info(f"Stored image attachment with file name: {file_name}")
section, file_name = store_image_and_create_section(
image_data=raw_bytes,
file_id=Path(attachment["id"]).name,
display_name=attachment["title"],
media_type=media_type,
file_origin=FileOrigin.CONNECTOR,
)
logger.info(f"Stored image attachment with file name: {file_name}")
# Return empty text but include the file_name for later processing
return AttachmentProcessingResult(text="", file_name=file_name, error=None)
# Return empty text but include the file_name for later processing
return AttachmentProcessingResult(text="", file_name=file_name, error=None)
except Exception as e:
msg = f"Image storage failed for {attachment['title']}: {e}"
logger.error(msg, exc_info=e)

View File

@@ -5,8 +5,6 @@ from pathlib import Path
from typing import Any
from typing import IO
from sqlalchemy.orm import Session
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
@@ -18,7 +16,6 @@ from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext
@@ -32,7 +29,6 @@ logger = setup_logger()
def _create_image_section(
image_data: bytes,
db_session: Session,
parent_file_name: str,
display_name: str,
link: str | None = None,
@@ -58,7 +54,6 @@ def _create_image_section(
# Store the image and create a section
try:
section, stored_file_name = store_image_and_create_section(
db_session=db_session,
image_data=image_data,
file_id=file_id,
display_name=display_name,
@@ -77,7 +72,6 @@ def _process_file(
file: IO[Any],
metadata: dict[str, Any] | None,
pdf_pass: str | None,
db_session: Session,
) -> list[Document]:
"""
Process a file and return a list of Documents.
@@ -125,7 +119,6 @@ def _process_file(
try:
section, _ = _create_image_section(
image_data=image_data,
db_session=db_session,
parent_file_name=file_id,
display_name=title,
)
@@ -196,7 +189,6 @@ def _process_file(
try:
image_section, stored_file_name = _create_image_section(
image_data=img_data,
db_session=db_session,
parent_file_name=file_id,
display_name=f"{title} - image {idx}",
idx=idx,
@@ -260,37 +252,33 @@ class LocalFileConnector(LoadConnector):
"""
documents: list[Document] = []
with get_session_with_current_tenant() as db_session:
for file_id in self.file_locations:
file_store = get_default_file_store(db_session)
file_record = file_store.read_file_record(file_id=file_id)
if not file_record:
# typically an unsupported extension
logger.warning(
f"No file record found for '{file_id}' in PG; skipping."
)
continue
for file_id in self.file_locations:
file_store = get_default_file_store()
file_record = file_store.read_file_record(file_id=file_id)
if not file_record:
# typically an unsupported extension
logger.warning(f"No file record found for '{file_id}' in PG; skipping.")
continue
metadata = self._get_file_metadata(file_id)
file_io = file_store.read_file(file_id=file_id, mode="b")
new_docs = _process_file(
file_id=file_id,
file_name=file_record.display_name,
file=file_io,
metadata=metadata,
pdf_pass=self.pdf_pass,
db_session=db_session,
)
documents.extend(new_docs)
metadata = self._get_file_metadata(file_id)
file_io = file_store.read_file(file_id=file_id, mode="b")
new_docs = _process_file(
file_id=file_id,
file_name=file_record.display_name,
file=file_io,
metadata=metadata,
pdf_pass=self.pdf_pass,
)
documents.extend(new_docs)
if len(documents) >= self.batch_size:
yield documents
documents = []
if documents:
if len(documents) >= self.batch_size:
yield documents
documents = []
if documents:
yield documents
if __name__ == "__main__":
connector = LocalFileConnector(

View File

@@ -1,6 +1,10 @@
import copy
import json
import os
import sys
import threading
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from datetime import datetime
from enum import Enum
@@ -1374,3 +1378,139 @@ class GoogleDriveConnector(
@override
def validate_checkpoint_json(self, checkpoint_json: str) -> GoogleDriveCheckpoint:
return GoogleDriveCheckpoint.model_validate_json(checkpoint_json)
def get_credentials_from_env(email: str, oauth: bool) -> dict:
if oauth:
raw_credential_string = os.environ["GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR"]
else:
raw_credential_string = os.environ["GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR"]
refried_credential_string = json.dumps(json.loads(raw_credential_string))
# This is the Oauth token
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
# This is the service account key
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
# The email saved for both auth types
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method"
cred_key = (
DB_CREDENTIALS_DICT_TOKEN_KEY
if oauth
else DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
)
return {
cred_key: refried_credential_string,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email,
DB_CREDENTIALS_AUTHENTICATION_METHOD: "uploaded",
}
class CheckpointOutputWrapper:
"""
Wraps a CheckpointOutput generator to give things back in a more digestible format.
The connector format is easier for the connector implementor (e.g. it enforces exactly
one new checkpoint is returned AND that the checkpoint is at the end), thus the different
formats.
"""
def __init__(self) -> None:
self.next_checkpoint: GoogleDriveCheckpoint | None = None
def __call__(
self,
checkpoint_connector_generator: CheckpointOutput[GoogleDriveCheckpoint],
) -> Generator[
tuple[Document | None, ConnectorFailure | None, GoogleDriveCheckpoint | None],
None,
None,
]:
# grabs the final return value and stores it in the `next_checkpoint` variable
def _inner_wrapper(
checkpoint_connector_generator: CheckpointOutput[GoogleDriveCheckpoint],
) -> CheckpointOutput[GoogleDriveCheckpoint]:
self.next_checkpoint = yield from checkpoint_connector_generator
return self.next_checkpoint # not used
for document_or_failure in _inner_wrapper(checkpoint_connector_generator):
if isinstance(document_or_failure, Document):
yield document_or_failure, None, None
elif isinstance(document_or_failure, ConnectorFailure):
yield None, document_or_failure, None
else:
raise ValueError(
f"Invalid document_or_failure type: {type(document_or_failure)}"
)
if self.next_checkpoint is None:
raise RuntimeError(
"Checkpoint is None. This should never happen - the connector should always return a checkpoint."
)
yield None, None, self.next_checkpoint
def yield_all_docs_from_checkpoint_connector(
connector: GoogleDriveConnector,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
) -> Iterator[Document | ConnectorFailure]:
num_iterations = 0
checkpoint = connector.build_dummy_checkpoint()
while checkpoint.has_more:
doc_batch_generator = CheckpointOutputWrapper()(
connector.load_from_checkpoint(start, end, checkpoint)
)
for document, failure, next_checkpoint in doc_batch_generator:
if failure is not None:
yield failure
if document is not None:
yield document
if next_checkpoint is not None:
checkpoint = next_checkpoint
num_iterations += 1
if num_iterations > 100_000:
raise RuntimeError("Too many iterations. Infinite loop?")
if __name__ == "__main__":
import time
creds = get_credentials_from_env(
os.environ["GOOGLE_DRIVE_PRIMARY_ADMIN_EMAIL"], False
)
connector = GoogleDriveConnector(
include_shared_drives=True,
shared_drive_urls=None,
include_my_drives=True,
my_drive_emails=None,
shared_folder_urls=None,
include_files_shared_with_me=True,
specific_user_emails=None,
)
connector.load_credentials(creds)
max_fsize = 0
biggest_fsize = 0
num_errors = 0
start_time = time.time()
with open("stats.txt", "w") as f:
for num, doc_or_failure in enumerate(
yield_all_docs_from_checkpoint_connector(connector, 0, time.time())
):
if num % 200 == 0:
f.write(f"Processed {num} files\n")
f.write(f"Max file size: {max_fsize/1000_000:.2f} MB\n")
f.write(f"Time so far: {time.time() - start_time:.2f} seconds\n")
f.write(f"Docs per minute: {num/(time.time() - start_time)*60:.2f}\n")
biggest_fsize = max(biggest_fsize, max_fsize)
max_fsize = 0
if isinstance(doc_or_failure, Document):
max_fsize = max(max_fsize, sys.getsizeof(doc_or_failure))
elif isinstance(doc_or_failure, ConnectorFailure):
num_errors += 1
print(f"Num errors: {num_errors}")
print(f"Biggest file size: {biggest_fsize/1000_000:.2f} MB")
print(f"Time taken: {time.time() - start_time:.2f} seconds")

View File

@@ -29,7 +29,6 @@ from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import ImageSection
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import ALL_ACCEPTED_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import docx_to_text_and_images
from onyx.file_processing.extract_file_text import extract_file_text
@@ -193,17 +192,15 @@ def _download_and_extract_sections_basic(
# For images, store them for later processing
sections: list[TextSection | ImageSection] = []
try:
with get_session_with_current_tenant() as db_session:
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=response_call(),
file_id=file_id,
display_name=file_name,
media_type=mime_type,
file_origin=FileOrigin.CONNECTOR,
link=link,
)
sections.append(section)
section, embedded_id = store_image_and_create_section(
image_data=response_call(),
file_id=file_id,
display_name=file_name,
media_type=mime_type,
file_origin=FileOrigin.CONNECTOR,
link=link,
)
sections.append(section)
except Exception as e:
logger.error(f"Failed to process image {file_name}: {e}")
return sections
@@ -216,16 +213,14 @@ def _download_and_extract_sections_basic(
# Process embedded images in the PDF
try:
with get_session_with_current_tenant() as db_session:
for idx, (img_data, img_name) in enumerate(images):
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=img_data,
file_id=f"{file_id}_img_{idx}",
display_name=img_name or f"{file_name} - image {idx}",
file_origin=FileOrigin.CONNECTOR,
)
pdf_sections.append(section)
for idx, (img_data, img_name) in enumerate(images):
section, embedded_id = store_image_and_create_section(
image_data=img_data,
file_id=f"{file_id}_img_{idx}",
display_name=img_name or f"{file_name} - image {idx}",
file_origin=FileOrigin.CONNECTOR,
)
pdf_sections.append(section)
except Exception as e:
logger.error(f"Failed to process PDF images in {file_name}: {e}")
return pdf_sections

View File

@@ -12,7 +12,6 @@ from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import load_files_from_zip
from onyx.file_processing.extract_file_text import read_text_file
from onyx.file_processing.html_utils import web_html_cleanup
@@ -68,10 +67,7 @@ class GoogleSitesConnector(LoadConnector):
def load_from_state(self) -> GenerateDocumentsOutput:
documents: list[Document] = []
with get_session_with_current_tenant() as db_session:
file_content_io = get_default_file_store(db_session).read_file(
self.zip_path, mode="b"
)
file_content_io = get_default_file_store().read_file(self.zip_path, mode="b")
# load the HTML files
files = load_files_from_zip(file_content_io)

View File

@@ -11,6 +11,7 @@ from onyx.access.models import ExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import INDEX_SEPARATOR
from onyx.configs.constants import RETURN_SEPARATOR
from onyx.db.enums import IndexModelStatus
from onyx.utils.text_processing import make_url_compatible
@@ -373,3 +374,24 @@ class OnyxMetadata(BaseModel):
secondary_owners: list[BasicExpertInfo] | None = None
doc_updated_at: datetime | None = None
title: str | None = None
class DocExtractionContext(BaseModel):
index_name: str
cc_pair_id: int
connector_id: int
credential_id: int
source: DocumentSource
earliest_index_time: float
from_beginning: bool
is_primary: bool
should_fetch_permissions_during_indexing: bool
search_settings_status: IndexModelStatus
doc_extraction_complete_batch_num: int | None
class DocIndexingContext(BaseModel):
batches_done: int
total_failures: int
net_doc_change: int
total_chunks: int

View File

@@ -25,7 +25,6 @@ from onyx.context.search.models import MAX_METRICS_CONTENT
from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import RerankMetricsContainer
from onyx.context.search.models import SearchQuery
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.document_index.document_index_utils import (
translate_boost_count_to_multiplier,
)
@@ -70,18 +69,16 @@ def update_image_sections_with_query(
logger.debug(
f"Processing image chunk with ID: {chunk.unique_id}, image: {chunk.image_file_id}"
)
with get_session_with_current_tenant() as db_session:
file_record = get_default_file_store(db_session).read_file(
cast(str, chunk.image_file_id), mode="b"
)
if not file_record:
logger.error(f"Image file not found: {chunk.image_file_id}")
raise Exception("File not found")
file_content = file_record.read()
image_base64 = base64.b64encode(file_content).decode()
logger.debug(
f"Successfully loaded image data for {chunk.image_file_id}"
)
file_record = get_default_file_store().read_file(
cast(str, chunk.image_file_id), mode="b"
)
if not file_record:
logger.error(f"Image file not found: {chunk.image_file_id}")
raise Exception("File not found")
file_content = file_record.read()
image_base64 = base64.b64encode(file_content).decode()
logger.debug(f"Successfully loaded image data for {chunk.image_file_id}")
messages: list[BaseMessage] = [
SystemMessage(content=IMAGE_ANALYSIS_SYSTEM_PROMPT),

View File

@@ -229,7 +229,7 @@ def delete_messages_and_files_from_chat_session(
delete_tool_call_for_message_id(message_id=id, db_session=db_session)
delete_search_doc_message_relationship(message_id=id, db_session=db_session)
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
for file_info in files or []:
file_store.delete_file(file_id=file_info.get("id"))

View File

@@ -258,6 +258,7 @@ def get_connector_credential_pair_from_id_for_user(
def get_connector_credential_pair_from_id(
db_session: Session,
cc_pair_id: int,
eager_load_connector: bool = False,
eager_load_credential: bool = False,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair).distinct()
@@ -265,6 +266,8 @@ def get_connector_credential_pair_from_id(
if eager_load_credential:
stmt = stmt.options(joinedload(ConnectorCredentialPair.credential))
if eager_load_connector:
stmt = stmt.options(joinedload(ConnectorCredentialPair.connector))
result = db_session.execute(stmt)
return result.scalar_one_or_none()

View File

@@ -304,6 +304,18 @@ def get_session_with_current_tenant() -> Generator[Session, None, None]:
yield session
@contextmanager
def get_session_with_current_tenant_if_none(
session: Session | None,
) -> Generator[Session, None, None]:
if session is None:
tenant_id = get_current_tenant_id()
with get_session_with_tenant(tenant_id=tenant_id) as session:
yield session
else:
yield session
# Used in multi tenant mode when need to refer to the shared `public` schema
@contextmanager
def get_session_with_shared_schema() -> Generator[Session, None, None]:

View File

@@ -43,6 +43,15 @@ def get_filerecord_by_file_id(
return filestore
def get_filerecord_by_prefix(
prefix: str,
db_session: Session,
) -> list[FileRecord]:
return (
db_session.query(FileRecord).filter(FileRecord.file_id.like(f"{prefix}%")).all()
)
def delete_filerecord_by_file_id(
file_id: str,
db_session: Session,

View File

@@ -28,6 +28,8 @@ from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
# from sqlalchemy.sql.selectable import Select
# Comment out unused imports that cause mypy errors
# from onyx.auth.models import UserRole
# from onyx.configs.constants import MAX_LAST_VALID_CHECKPOINT_AGE_SECONDS
@@ -95,9 +97,25 @@ def get_recent_attempts_for_cc_pair(
def get_index_attempt(
db_session: Session, index_attempt_id: int
db_session: Session,
index_attempt_id: int,
eager_load_cc_pair: bool = False,
eager_load_search_settings: bool = False,
) -> IndexAttempt | None:
stmt = select(IndexAttempt).where(IndexAttempt.id == index_attempt_id)
if eager_load_cc_pair:
stmt = stmt.options(
joinedload(IndexAttempt.connector_credential_pair).joinedload(
ConnectorCredentialPair.connector
)
)
stmt = stmt.options(
joinedload(IndexAttempt.connector_credential_pair).joinedload(
ConnectorCredentialPair.credential
)
)
if eager_load_search_settings:
stmt = stmt.options(joinedload(IndexAttempt.search_settings))
return db_session.scalars(stmt).first()
@@ -373,16 +391,22 @@ def update_docs_indexed(
new_docs_indexed: int,
docs_removed_from_index: int,
) -> None:
"""Updates the docs_indexed and new_docs_indexed fields of an index attempt.
Adds the given values to the current values in the db"""
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt_id)
.with_for_update()
.with_for_update() # Locks the row when we try to update
).scalar_one()
attempt.total_docs_indexed = total_docs_indexed
attempt.new_docs_indexed = new_docs_indexed
attempt.docs_removed_from_index = docs_removed_from_index
attempt.total_docs_indexed = (
attempt.total_docs_indexed or 0
) + total_docs_indexed
attempt.new_docs_indexed = (attempt.new_docs_indexed or 0) + new_docs_indexed
attempt.docs_removed_from_index = (
attempt.docs_removed_from_index or 0
) + docs_removed_from_index
db_session.commit()
except Exception:
db_session.rollback()

View File

@@ -46,7 +46,7 @@ def create_user_files(
# NOTE: At the moment, zip metadata is not used for user files.
# Should revisit to decide whether this should be a feature.
upload_response = upload_files(files, db_session)
upload_response = upload_files(files)
user_files = []
for file_path, file in zip(upload_response.file_paths, files):

View File

@@ -1,8 +1,6 @@
from io import BytesIO
from typing import Tuple
from sqlalchemy.orm import Session
from onyx.configs.constants import FileOrigin
from onyx.connectors.models import ImageSection
from onyx.file_store.file_store import get_default_file_store
@@ -12,7 +10,6 @@ logger = setup_logger()
def store_image_and_create_section(
db_session: Session,
image_data: bytes,
file_id: str,
display_name: str,
@@ -24,7 +21,6 @@ def store_image_and_create_section(
Stores an image in FileStore and creates an ImageSection object without summarization.
Args:
db_session: Database session
image_data: Raw image bytes
file_id: Base identifier for the file
display_name: Human-readable name for the image
@@ -38,7 +34,7 @@ def store_image_and_create_section(
"""
# Storage logic
try:
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
file_id = file_store.save_file(
content=BytesIO(image_data),
display_name=display_name,

View File

@@ -0,0 +1,294 @@
import json
from abc import ABC
from abc import abstractmethod
from enum import Enum
from io import StringIO
from typing import cast
from typing import List
from typing import Optional
from typing import TypeAlias
from onyx.configs.constants import FileOrigin
from onyx.connectors.models import DocExtractionContext
from onyx.connectors.models import DocIndexingContext
from onyx.connectors.models import Document
from onyx.file_store.file_store import FileStore
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
logger = setup_logger()
class DocumentBatchStorageStateType(str, Enum):
EXTRACTION = "extraction"
INDEXING = "indexing"
DocumentStorageState: TypeAlias = DocExtractionContext | DocIndexingContext
STATE_TYPE_TO_MODEL: dict[str, type[DocumentStorageState]] = {
DocumentBatchStorageStateType.EXTRACTION.value: DocExtractionContext,
DocumentBatchStorageStateType.INDEXING.value: DocIndexingContext,
}
class DocumentBatchStorage(ABC):
"""Abstract base class for document batch storage implementations."""
def __init__(self, tenant_id: str, index_attempt_id: int):
self.tenant_id = tenant_id
self.index_attempt_id = index_attempt_id
self.base_path = f"{tenant_id}/{index_attempt_id}"
@abstractmethod
def store_batch(self, batch_id: str, documents: List[Document]) -> None:
"""Store a batch of documents."""
@abstractmethod
def get_batch(self, batch_id: str) -> Optional[List[Document]]:
"""Retrieve a batch of documents."""
@abstractmethod
def delete_batch(self, batch_id: str) -> None:
"""Delete a specific batch."""
@abstractmethod
def store_extraction_state(self, state: DocExtractionContext) -> None:
"""Store extraction state metadata."""
@abstractmethod
def get_extraction_state(self) -> DocExtractionContext | None:
"""Get extraction state metadata."""
@abstractmethod
def store_indexing_state(self, state: DocIndexingContext) -> None:
"""Store indexing state metadata."""
@abstractmethod
def get_indexing_state(self) -> DocIndexingContext | None:
"""Get indexing state metadata."""
@abstractmethod
def cleanup_all_batches(self) -> None:
"""Clean up all batches and state for this index attempt."""
def ensure_extraction_state(self) -> DocExtractionContext:
"""Ensure extraction state exists."""
state = self.get_extraction_state()
if not state:
raise RuntimeError("Expected extraction state not found")
return state
def ensure_indexing_state(self) -> DocIndexingContext:
"""Ensure indexing state exists."""
state = self.get_indexing_state()
if not state:
raise RuntimeError("Expected indexing state not found")
return state
def _serialize_documents(self, documents: list[Document]) -> str:
"""Serialize documents to JSON string."""
# Use mode='json' to properly serialize datetime and other complex types
return json.dumps([doc.model_dump(mode="json") for doc in documents], indent=2)
def _deserialize_documents(self, data: str) -> list[Document]:
"""Deserialize documents from JSON string."""
doc_dicts = json.loads(data)
return [Document.model_validate(doc_dict) for doc_dict in doc_dicts]
class FileStoreDocumentBatchStorage(DocumentBatchStorage):
"""FileStore-based implementation of document batch storage."""
def __init__(self, tenant_id: str, index_attempt_id: int, file_store: FileStore):
super().__init__(tenant_id, index_attempt_id)
self.file_store = file_store
# Track stored batch files for cleanup
self._batch_files: set[str] = set()
def _get_batch_file_name(self, batch_id: str) -> str:
"""Generate file name for a document batch."""
return f"document_batch_{self.base_path.replace('/', '_')}_{batch_id}.json"
def _get_state_file_name(self, state_type: str) -> str:
"""Generate file name for extraction state."""
return f"{state_type}_state_{self.base_path.replace('/', '_')}.json"
def store_batch(self, batch_id: str, documents: list[Document]) -> None:
"""Store a batch of documents using FileStore."""
file_name = self._get_batch_file_name(batch_id)
try:
data = self._serialize_documents(documents)
content = StringIO(data)
self.file_store.save_file(
file_id=file_name,
content=content,
display_name=f"Document Batch {batch_id}",
file_origin=FileOrigin.OTHER,
file_type="application/json",
file_metadata={
"tenant_id": self.tenant_id,
"index_attempt_id": str(self.index_attempt_id),
"batch_id": batch_id,
"document_count": str(len(documents)),
},
)
# Track this batch file for cleanup
self._batch_files.add(file_name)
logger.debug(
f"Stored batch {batch_id} with {len(documents)} documents to FileStore as {file_name}"
)
except Exception as e:
logger.error(f"Failed to store batch {batch_id}: {e}")
raise
def get_batch(self, batch_id: str) -> list[Document] | None:
"""Retrieve a batch of documents from FileStore."""
file_name = self._get_batch_file_name(batch_id)
try:
# Check if file exists
if not self.file_store.has_file(
file_id=file_name,
file_origin=FileOrigin.OTHER,
file_type="application/json",
):
logger.warning(f"Batch {batch_id} not found in FileStore")
return None
content_io = self.file_store.read_file(file_name)
data = content_io.read().decode("utf-8")
documents = self._deserialize_documents(data)
logger.debug(
f"Retrieved batch {batch_id} with {len(documents)} documents from FileStore"
)
return documents
except Exception as e:
logger.error(f"Failed to retrieve batch {batch_id}: {e}")
raise
def delete_batch(self, batch_id: str) -> None:
"""Delete a specific batch from FileStore."""
file_name = self._get_batch_file_name(batch_id)
try:
self.file_store.delete_file(file_name)
# Remove from tracked files
self._batch_files.discard(file_name)
logger.debug(f"Deleted batch {batch_id} from FileStore")
except Exception as e:
logger.warning(f"Failed to delete batch {batch_id}: {e}")
# Don't raise - batch might not exist, which is acceptable
def _store_state(self, state: DocumentStorageState, state_type: str) -> None:
"""Store state using FileStore."""
file_name = self._get_state_file_name(state_type)
try:
data = json.dumps(state.model_dump(mode="json"), indent=2)
content = StringIO(data)
self.file_store.save_file(
file_id=file_name,
content=content,
display_name=f"{state_type.capitalize()} State {self.base_path}",
file_origin=FileOrigin.OTHER,
file_type="application/json",
file_metadata={
"tenant_id": self.tenant_id,
"index_attempt_id": str(self.index_attempt_id),
"type": f"{state_type}_state",
},
)
logger.debug(f"Stored {state_type} state to FileStore as {file_name}")
except Exception as e:
logger.error(f"Failed to store {state_type} state: {e}")
raise
def _get_state(self, state_type: str) -> DocumentStorageState | None:
"""Get state from FileStore."""
file_name = self._get_state_file_name(state_type)
try:
# Check if file exists
if not self.file_store.has_file(
file_id=file_name,
file_origin=FileOrigin.OTHER,
file_type="application/json",
):
return None
content_io = self.file_store.read_file(file_name)
data = content_io.read().decode("utf-8")
state = STATE_TYPE_TO_MODEL[state_type].model_validate(json.loads(data))
logger.debug(f"Retrieved {state_type} state from FileStore")
return state
except Exception as e:
logger.error(f"Failed to retrieve {state_type} state: {e}")
return None
def store_extraction_state(self, state: DocExtractionContext) -> None:
"""Store extraction state using FileStore."""
self._store_state(state, DocumentBatchStorageStateType.EXTRACTION.value)
def get_extraction_state(self) -> DocExtractionContext | None:
"""Get extraction state from FileStore."""
return cast(
DocExtractionContext | None,
self._get_state(DocumentBatchStorageStateType.EXTRACTION.value),
)
def store_indexing_state(self, state: DocIndexingContext) -> None:
"""Store indexing state using FileStore."""
self._store_state(state, DocumentBatchStorageStateType.INDEXING.value)
def get_indexing_state(self) -> DocIndexingContext | None:
"""Get indexing state from FileStore."""
return cast(
DocIndexingContext | None,
self._get_state(DocumentBatchStorageStateType.INDEXING.value),
)
def cleanup_all_batches(self) -> None:
"""Clean up all batches and state for this index attempt."""
# Since we don't have direct access to S3 listing logic here,
# we'll rely on deleting tracked files.
# A more robust cleanup might involve a separate task that can list/delete
# from S3 if needed, but for now this is simpler.
deleted_count = 0
# Create a copy of the set to avoid issues with modification during iteration
for file_name in list(self._batch_files):
try:
self.file_store.delete_file(file_name)
deleted_count += 1
except Exception as e:
logger.warning(f"Failed to delete batch file {file_name}: {e}")
# Delete state
for state_type in DocumentBatchStorageStateType:
try:
state_file_name = self._get_state_file_name(state_type.value)
self.file_store.delete_file(state_file_name)
deleted_count += 1
except Exception as e:
logger.warning(f"Failed to delete extraction state: {e}")
# Clear tracked files
self._batch_files.clear()
logger.info(
f"Cleaned up {deleted_count} files for index attempt {self.index_attempt_id}"
)
def get_document_batch_storage(
tenant_id: str, index_attempt_id: int
) -> DocumentBatchStorage:
"""Factory function to get the configured document batch storage implementation."""
# The get_default_file_store will now correctly use S3BackedFileStore
# or other configured stores based on environment variables
file_store = get_default_file_store()
return FileStoreDocumentBatchStorage(tenant_id, index_attempt_id, file_store)

View File

@@ -22,10 +22,14 @@ from onyx.configs.app_configs import S3_FILE_STORE_BUCKET_NAME
from onyx.configs.app_configs import S3_FILE_STORE_PREFIX
from onyx.configs.app_configs import S3_VERIFY_SSL
from onyx.configs.constants import FileOrigin
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_current_tenant_if_none
from onyx.db.file_record import delete_filerecord_by_file_id
from onyx.db.file_record import get_filerecord_by_file_id
from onyx.db.file_record import get_filerecord_by_file_id_optional
from onyx.db.file_record import get_filerecord_by_prefix
from onyx.db.file_record import upsert_filerecord
from onyx.db.models import FileRecord
from onyx.db.models import FileRecord as FileStoreModel
from onyx.file_store.s3_key_utils import generate_s3_key
from onyx.utils.file import FileWithMimeType
@@ -129,13 +133,18 @@ class FileStore(ABC):
Get the file + parse out the mime type.
"""
@abstractmethod
def list_files_by_prefix(self, prefix: str) -> list[FileRecord]:
"""
List all file IDs that start with the given prefix.
"""
class S3BackedFileStore(FileStore):
"""Isn't necessarily S3, but is any S3-compatible storage (e.g. MinIO)"""
def __init__(
self,
db_session: Session,
bucket_name: str,
aws_access_key_id: str | None = None,
aws_secret_access_key: str | None = None,
@@ -144,7 +153,6 @@ class S3BackedFileStore(FileStore):
s3_prefix: str | None = None,
s3_verify_ssl: bool = True,
) -> None:
self.db_session = db_session
self._s3_client: S3Client | None = None
self._bucket_name = bucket_name
self._aws_access_key_id = aws_access_key_id
@@ -272,10 +280,12 @@ class S3BackedFileStore(FileStore):
file_id: str,
file_origin: FileOrigin,
file_type: str,
db_session: Session | None = None,
) -> bool:
file_record = get_filerecord_by_file_id_optional(
file_id=file_id, db_session=self.db_session
)
with get_session_with_current_tenant_if_none(db_session) as db_session:
file_record = get_filerecord_by_file_id_optional(
file_id=file_id, db_session=db_session
)
return (
file_record is not None
and file_record.file_origin == file_origin
@@ -290,6 +300,7 @@ class S3BackedFileStore(FileStore):
file_type: str,
file_metadata: dict[str, Any] | None = None,
file_id: str | None = None,
db_session: Session | None = None,
) -> str:
if file_id is None:
file_id = str(uuid.uuid4())
@@ -314,27 +325,33 @@ class S3BackedFileStore(FileStore):
ContentType=file_type,
)
# Save metadata to database
upsert_filerecord(
file_id=file_id,
display_name=display_name or file_id,
file_origin=file_origin,
file_type=file_type,
bucket_name=bucket_name,
object_key=s3_key,
db_session=self.db_session,
file_metadata=file_metadata,
)
self.db_session.commit()
with get_session_with_current_tenant_if_none(db_session) as db_session:
# Save metadata to database
upsert_filerecord(
file_id=file_id,
display_name=display_name or file_id,
file_origin=file_origin,
file_type=file_type,
bucket_name=bucket_name,
object_key=s3_key,
db_session=db_session,
file_metadata=file_metadata,
)
db_session.commit()
return file_id
def read_file(
self, file_id: str, mode: str | None = None, use_tempfile: bool = False
self,
file_id: str,
mode: str | None = None,
use_tempfile: bool = False,
db_session: Session | None = None,
) -> IO[bytes]:
file_record = get_filerecord_by_file_id(
file_id=file_id, db_session=self.db_session
)
with get_session_with_current_tenant_if_none(db_session) as db_session:
file_record = get_filerecord_by_file_id(
file_id=file_id, db_session=db_session
)
s3_client = self._get_s3_client()
try:
@@ -356,32 +373,37 @@ class S3BackedFileStore(FileStore):
else:
return BytesIO(file_content)
def read_file_record(self, file_id: str) -> FileStoreModel:
file_record = get_filerecord_by_file_id(
file_id=file_id, db_session=self.db_session
)
def read_file_record(
self, file_id: str, db_session: Session | None = None
) -> FileStoreModel:
with get_session_with_current_tenant_if_none(db_session) as db_session:
file_record = get_filerecord_by_file_id(
file_id=file_id, db_session=db_session
)
return file_record
def delete_file(self, file_id: str) -> None:
try:
file_record = get_filerecord_by_file_id(
file_id=file_id, db_session=self.db_session
)
def delete_file(self, file_id: str, db_session: Session | None = None) -> None:
with get_session_with_current_tenant_if_none(db_session) as db_session:
try:
# Delete from external storage
s3_client = self._get_s3_client()
s3_client.delete_object(
Bucket=file_record.bucket_name, Key=file_record.object_key
)
file_record = get_filerecord_by_file_id(
file_id=file_id, db_session=db_session
)
# Delete metadata from database
delete_filerecord_by_file_id(file_id=file_id, db_session=self.db_session)
# Delete from external storage
s3_client = self._get_s3_client()
s3_client.delete_object(
Bucket=file_record.bucket_name, Key=file_record.object_key
)
self.db_session.commit()
# Delete metadata from database
delete_filerecord_by_file_id(file_id=file_id, db_session=db_session)
except Exception:
self.db_session.rollback()
raise
db_session.commit()
except Exception:
db_session.rollback()
raise
def get_file_with_mime_type(self, filename: str) -> FileWithMimeType | None:
mime_type: str = "application/octet-stream"
@@ -395,8 +417,18 @@ class S3BackedFileStore(FileStore):
except Exception:
return None
def list_files_by_prefix(self, prefix: str) -> list[FileRecord]:
"""
List all file IDs that start with the given prefix.
"""
with get_session_with_current_tenant() as db_session:
file_records = get_filerecord_by_prefix(
prefix=prefix, db_session=db_session
)
return file_records
def get_s3_file_store(db_session: Session) -> S3BackedFileStore:
def get_s3_file_store() -> S3BackedFileStore:
"""
Returns the S3 file store implementation.
"""
@@ -409,7 +441,6 @@ def get_s3_file_store(db_session: Session) -> S3BackedFileStore:
)
return S3BackedFileStore(
db_session=db_session,
bucket_name=bucket_name,
aws_access_key_id=S3_AWS_ACCESS_KEY_ID,
aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY,
@@ -420,7 +451,7 @@ def get_s3_file_store(db_session: Session) -> S3BackedFileStore:
)
def get_default_file_store(db_session: Session) -> FileStore:
def get_default_file_store() -> FileStore:
"""
Returns the configured file store implementation.
@@ -445,4 +476,4 @@ def get_default_file_store(db_session: Session) -> FileStore:
Other S3-compatible storage (Digital Ocean, Linode, etc.):
- Same as MinIO, but set appropriate S3_ENDPOINT_URL
"""
return get_s3_file_store(db_session)
return get_s3_file_store()

View File

@@ -8,7 +8,6 @@ import requests
from sqlalchemy.orm import Session
from onyx.configs.constants import FileOrigin
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import ChatMessage
from onyx.db.models import UserFile
from onyx.db.models import UserFolder
@@ -49,28 +48,23 @@ def store_user_file_plaintext(user_file_id: int, plaintext_content: str) -> bool
# Use a separate session to avoid committing the caller's transaction
try:
with get_session_with_current_tenant() as file_store_session:
file_store = get_default_file_store(file_store_session)
file_content = BytesIO(plaintext_content.encode("utf-8"))
file_store.save_file(
content=file_content,
display_name=f"Plaintext for user file {user_file_id}",
file_origin=FileOrigin.PLAINTEXT_CACHE,
file_type="text/plain",
file_id=plaintext_file_name,
)
return True
file_store = get_default_file_store()
file_content = BytesIO(plaintext_content.encode("utf-8"))
file_store.save_file(
content=file_content,
display_name=f"Plaintext for user file {user_file_id}",
file_origin=FileOrigin.PLAINTEXT_CACHE,
file_type="text/plain",
file_id=plaintext_file_name,
)
return True
except Exception as e:
logger.warning(f"Failed to store plaintext for user file {user_file_id}: {e}")
return False
def load_chat_file(
file_descriptor: FileDescriptor, db_session: Session
) -> InMemoryChatFile:
file_io = get_default_file_store(db_session).read_file(
file_descriptor["id"], mode="b"
)
def load_chat_file(file_descriptor: FileDescriptor) -> InMemoryChatFile:
file_io = get_default_file_store().read_file(file_descriptor["id"], mode="b")
return InMemoryChatFile(
file_id=file_descriptor["id"],
content=file_io.read(),
@@ -82,7 +76,6 @@ def load_chat_file(
def load_all_chat_files(
chat_messages: list[ChatMessage],
file_descriptors: list[FileDescriptor],
db_session: Session,
) -> list[InMemoryChatFile]:
file_descriptors_for_history: list[FileDescriptor] = []
for chat_message in chat_messages:
@@ -93,7 +86,7 @@ def load_all_chat_files(
list[InMemoryChatFile],
run_functions_tuples_in_parallel(
[
(load_chat_file, (file, db_session))
(load_chat_file, (file,))
for file in file_descriptors + file_descriptors_for_history
]
),
@@ -117,7 +110,7 @@ def load_user_file(file_id: int, db_session: Session) -> InMemoryChatFile:
raise ValueError(f"User file with id {file_id} not found")
# Get the file record to determine the appropriate chat file type
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
file_record = file_store.read_file_record(user_file.file_id)
# Determine appropriate chat file type based on the original file's MIME type
@@ -263,34 +256,29 @@ def get_user_files_as_user(
def save_file_from_url(url: str) -> str:
"""NOTE: using multiple sessions here, since this is often called
using multithreading. In practice, sharing a session has resulted in
weird errors."""
with get_session_with_current_tenant() as db_session:
response = requests.get(url)
response.raise_for_status()
response = requests.get(url)
response.raise_for_status()
file_io = BytesIO(response.content)
file_store = get_default_file_store(db_session)
file_id = file_store.save_file(
content=file_io,
display_name="GeneratedImage",
file_origin=FileOrigin.CHAT_IMAGE_GEN,
file_type="image/png;base64",
)
return file_id
file_io = BytesIO(response.content)
file_store = get_default_file_store()
file_id = file_store.save_file(
content=file_io,
display_name="GeneratedImage",
file_origin=FileOrigin.CHAT_IMAGE_GEN,
file_type="image/png;base64",
)
return file_id
def save_file_from_base64(base64_string: str) -> str:
with get_session_with_current_tenant() as db_session:
file_store = get_default_file_store(db_session)
file_id = file_store.save_file(
content=BytesIO(base64.b64decode(base64_string)),
display_name="GeneratedImage",
file_origin=FileOrigin.CHAT_IMAGE_GEN,
file_type=get_image_type(base64_string),
)
return file_id
file_store = get_default_file_store()
file_id = file_store.save_file(
content=BytesIO(base64.b64decode(base64_string)),
display_name="GeneratedImage",
file_origin=FileOrigin.CHAT_IMAGE_GEN,
file_type=get_image_type(base64_string),
)
return file_id
def save_file(

View File

@@ -4,6 +4,13 @@ from typing import Any
import httpx
def make_default_kwargs() -> dict[str, Any]:
return {
"http2": True,
"limits": httpx.Limits(),
}
class HttpxPool:
"""Class to manage a global httpx Client instance"""
@@ -11,10 +18,6 @@ class HttpxPool:
_lock: threading.Lock = threading.Lock()
# Default parameters for creation
DEFAULT_KWARGS = {
"http2": True,
"limits": lambda: httpx.Limits(),
}
def __init__(self) -> None:
pass
@@ -22,7 +25,7 @@ class HttpxPool:
@classmethod
def _init_client(cls, **kwargs: Any) -> httpx.Client:
"""Private helper method to create and return an httpx.Client."""
merged_kwargs = {**cls.DEFAULT_KWARGS, **kwargs}
merged_kwargs = {**(make_default_kwargs()), **kwargs}
return httpx.Client(**merged_kwargs)
@classmethod

View File

@@ -41,7 +41,6 @@ from onyx.db.document import update_docs_updated_at__no_commit
from onyx.db.document import upsert_document_by_connector_credential_pair
from onyx.db.document import upsert_documents
from onyx.db.document_set import fetch_document_sets_for_documents
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import Document as DBDocument
from onyx.db.models import IndexModelStatus
from onyx.db.search_settings import get_active_search_settings
@@ -496,36 +495,33 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
# Try to get image summary
try:
with get_session_with_current_tenant() as db_session:
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
file_record = file_store.read_file_record(
file_record = file_store.read_file_record(
file_id=section.image_file_id
)
if not file_record:
logger.warning(
f"Image file {section.image_file_id} not found in FileStore"
)
processed_section.text = "[Image could not be processed]"
else:
# Get the image data
image_data_io = file_store.read_file(
file_id=section.image_file_id
)
if not file_record:
logger.warning(
f"Image file {section.image_file_id} not found in FileStore"
)
image_data = image_data_io.read()
summary = summarize_image_with_error_handling(
llm=llm,
image_data=image_data,
context_name=file_record.display_name or "Image",
)
processed_section.text = "[Image could not be processed]"
if summary:
processed_section.text = summary
else:
# Get the image data
image_data_io = file_store.read_file(
file_id=section.image_file_id
)
image_data = image_data_io.read()
summary = summarize_image_with_error_handling(
llm=llm,
image_data=image_data,
context_name=file_record.display_name or "Image",
)
if summary:
processed_section.text = summary
else:
processed_section.text = (
"[Image could not be summarized]"
)
processed_section.text = "[Image could not be summarized]"
except Exception as e:
logger.error(f"Error processing image section: {e}")
processed_section.text = "[Error processing image]"

View File

@@ -258,7 +258,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
setup_onyx(db_session, POSTGRES_DEFAULT_SCHEMA)
# set up the file store (e.g. create bucket if needed). On multi-tenant,
# this is done via IaC
get_default_file_store(db_session).initialize()
get_default_file_store().initialize()
else:
setup_multitenant_onyx()

View File

@@ -16,24 +16,26 @@ class RedisConnector:
"""Composes several classes to simplify interacting with a connector and its
associated background tasks / associated redis interactions."""
def __init__(self, tenant_id: str, id: int) -> None:
def __init__(self, tenant_id: str, cc_pair_id: int) -> None:
"""id: a connector credential pair id"""
self.tenant_id: str = tenant_id
self.id: int = id
self.cc_pair_id: int = cc_pair_id
self.redis: redis.Redis = get_redis_client(tenant_id=tenant_id)
self.stop = RedisConnectorStop(tenant_id, id, self.redis)
self.prune = RedisConnectorPrune(tenant_id, id, self.redis)
self.delete = RedisConnectorDelete(tenant_id, id, self.redis)
self.permissions = RedisConnectorPermissionSync(tenant_id, id, self.redis)
self.stop = RedisConnectorStop(tenant_id, cc_pair_id, self.redis)
self.prune = RedisConnectorPrune(tenant_id, cc_pair_id, self.redis)
self.delete = RedisConnectorDelete(tenant_id, cc_pair_id, self.redis)
self.permissions = RedisConnectorPermissionSync(
tenant_id, cc_pair_id, self.redis
)
self.external_group_sync = RedisConnectorExternalGroupSync(
tenant_id, id, self.redis
tenant_id, cc_pair_id, self.redis
)
def new_index(self, search_settings_id: int) -> RedisConnectorIndex:
return RedisConnectorIndex(
self.tenant_id, self.id, search_settings_id, self.redis
self.tenant_id, self.cc_pair_id, search_settings_id, self.redis
)
def wait_for_indexing_termination(

View File

@@ -58,10 +58,7 @@ class RedisConnectorDelete:
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
return bool(self.redis.exists(self.fence_key))
@property
def payload(self) -> RedisConnectorDeletePayload | None:
@@ -93,10 +90,7 @@ class RedisConnectorDelete:
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
def active(self) -> bool:
if self.redis.exists(self.active_key):
return True
return False
return bool(self.redis.exists(self.active_key))
def _generate_task_id(self) -> str:
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"

View File

@@ -88,10 +88,7 @@ class RedisConnectorPermissionSync:
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
return bool(self.redis.exists(self.fence_key))
@property
def payload(self) -> RedisConnectorPermissionSyncPayload | None:
@@ -128,10 +125,7 @@ class RedisConnectorPermissionSync:
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
def active(self) -> bool:
if self.redis.exists(self.active_key):
return True
return False
return bool(self.redis.exists(self.active_key))
@property
def generator_complete(self) -> int | None:

View File

@@ -84,10 +84,7 @@ class RedisConnectorExternalGroupSync:
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
return bool(self.redis.exists(self.fence_key))
@property
def payload(self) -> RedisConnectorExternalGroupSyncPayload | None:
@@ -125,10 +122,7 @@ class RedisConnectorExternalGroupSync:
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
def active(self) -> bool:
if self.redis.exists(self.active_key):
return True
return False
return bool(self.redis.exists(self.active_key))
@property
def generator_complete(self) -> int | None:

View File

@@ -31,7 +31,10 @@ class RedisConnectorIndex:
PREFIX + "_generator_complete"
) # connectorindexing_generator_complete
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
GENERATOR_LOCK_PREFIX = "da_lock:indexing:docfetching"
FILESTORE_LOCK_PREFIX = "da_lock:indexing:filestore"
DB_LOCK_PREFIX = "da_lock:indexing:db"
PER_WORKER_LOCK_PREFIX = "da_lock:indexing:per_worker"
TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate
TERMINATE_TTL = 600
@@ -53,33 +56,45 @@ class RedisConnectorIndex:
def __init__(
self,
tenant_id: str,
id: int,
cc_pair_id: int,
search_settings_id: int,
redis: redis.Redis,
) -> None:
self.tenant_id: str = tenant_id
self.id = id
self.cc_pair_id = cc_pair_id
self.search_settings_id = search_settings_id
self.redis = redis
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}/{search_settings_id}"
self.fence_key: str = f"{self.FENCE_PREFIX}_{cc_pair_id}/{search_settings_id}"
self.generator_progress_key = (
f"{self.GENERATOR_PROGRESS_PREFIX}_{id}/{search_settings_id}"
f"{self.GENERATOR_PROGRESS_PREFIX}_{cc_pair_id}/{search_settings_id}"
)
self.generator_complete_key = (
f"{self.GENERATOR_COMPLETE_PREFIX}_{id}/{search_settings_id}"
f"{self.GENERATOR_COMPLETE_PREFIX}_{cc_pair_id}/{search_settings_id}"
)
self.filestore_lock_key = (
f"{self.FILESTORE_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
)
self.generator_lock_key = (
f"{self.GENERATOR_LOCK_PREFIX}_{id}/{search_settings_id}"
f"{self.GENERATOR_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
)
self.terminate_key = f"{self.TERMINATE_PREFIX}_{id}/{search_settings_id}"
self.watchdog_key = f"{self.WATCHDOG_PREFIX}_{id}/{search_settings_id}"
self.per_worker_lock_key = (
f"{self.PER_WORKER_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
)
self.db_lock_key = f"{self.DB_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
self.terminate_key = (
f"{self.TERMINATE_PREFIX}_{cc_pair_id}/{search_settings_id}"
)
self.watchdog_key = f"{self.WATCHDOG_PREFIX}_{cc_pair_id}/{search_settings_id}"
self.active_key = f"{self.ACTIVE_PREFIX}_{id}/{search_settings_id}"
self.active_key = f"{self.ACTIVE_PREFIX}_{cc_pair_id}/{search_settings_id}"
self.connector_active_key = (
f"{self.CONNECTOR_ACTIVE_PREFIX}_{id}/{search_settings_id}"
f"{self.CONNECTOR_ACTIVE_PREFIX}_{cc_pair_id}/{search_settings_id}"
)
def lock_key_by_batch(self, batch_n: int) -> str:
return f"{self.per_worker_lock_key}/{batch_n}"
@classmethod
def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str:
return f"{cls.FENCE_PREFIX}_{cc_pair_id}/{search_settings_id}"
@@ -89,7 +104,7 @@ class RedisConnectorIndex:
# we prefix the task id so it's easier to keep track of who created the task
# aka "connectorindexing+generator_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
return f"{self.GENERATOR_TASK_PREFIX}_{self.id}/{self.search_settings_id}_{uuid4()}"
return f"{self.GENERATOR_TASK_PREFIX}_{self.cc_pair_id}/{self.search_settings_id}_{uuid4()}"
@property
def fenced(self) -> bool:
@@ -160,10 +175,7 @@ class RedisConnectorIndex:
self.redis.set(self.connector_active_key, 0, ex=self.CONNECTOR_ACTIVE_TTL)
def connector_active(self) -> bool:
if self.redis.exists(self.connector_active_key):
return True
return False
return bool(self.redis.exists(self.connector_active_key))
def connector_active_ttl(self) -> int:
"""Refer to https://redis.io/docs/latest/commands/ttl/
@@ -175,9 +187,6 @@ class RedisConnectorIndex:
ttl = cast(int, self.redis.ttl(self.connector_active_key))
return ttl
def generator_locked(self) -> bool:
return bool(self.redis.exists(self.generator_lock_key))
def set_generator_complete(self, payload: int | None) -> None:
if not payload:
self.redis.delete(self.generator_complete_key)
@@ -212,6 +221,8 @@ class RedisConnectorIndex:
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
self.redis.delete(self.connector_active_key)
self.redis.delete(self.active_key)
self.redis.delete(self.filestore_lock_key)
self.redis.delete(self.db_lock_key)
self.redis.delete(self.generator_lock_key)
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
@@ -226,7 +237,7 @@ class RedisConnectorIndex:
for key in r.scan_iter(RedisConnectorIndex.ACTIVE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_LOCK_PREFIX + "*"):
for key in r.scan_iter(RedisConnectorIndex.FILESTORE_LOCK_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_COMPLETE_PREFIX + "*"):

View File

@@ -92,10 +92,7 @@ class RedisConnectorPrune:
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
return bool(self.redis.exists(self.fence_key))
@property
def payload(self) -> RedisConnectorPrunePayload | None:
@@ -130,10 +127,7 @@ class RedisConnectorPrune:
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
def active(self) -> bool:
if self.redis.exists(self.active_key):
return True
return False
return bool(self.redis.exists(self.active_key))
@property
def generator_complete(self) -> int | None:

View File

@@ -23,10 +23,7 @@ class RedisConnectorStop:
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
return bool(self.redis.exists(self.fence_key))
def set_fence(self, value: bool) -> None:
if not value:
@@ -37,10 +34,7 @@ class RedisConnectorStop:
@property
def timed_out(self) -> bool:
if self.redis.exists(self.timeout_key):
return False
return True
return not bool(self.redis.exists(self.timeout_key))
def set_timeout(self) -> None:
"""After calling this, call timed_out to determine if the timeout has been

View File

@@ -418,7 +418,7 @@ def extract_zip_metadata(zf: zipfile.ZipFile) -> dict[str, Any]:
return zip_metadata
def upload_files(files: list[UploadFile], db_session: Session) -> FileUploadResponse:
def upload_files(files: list[UploadFile]) -> FileUploadResponse:
for file in files:
if not file.filename:
raise HTTPException(status_code=400, detail="File name cannot be empty")
@@ -431,7 +431,7 @@ def upload_files(files: list[UploadFile], db_session: Session) -> FileUploadResp
deduped_file_paths = []
zip_metadata = {}
try:
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
seen_zip = False
for file in files:
if file.content_type and file.content_type.startswith("application/zip"):
@@ -488,9 +488,8 @@ def upload_files(files: list[UploadFile], db_session: Session) -> FileUploadResp
def upload_files_api(
files: list[UploadFile],
_: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> FileUploadResponse:
return upload_files(files, db_session)
return upload_files(files)
@router.get("/admin/connector")

View File

@@ -176,10 +176,9 @@ def undelete_persona(
@admin_router.post("/upload-image")
def upload_file(
file: UploadFile,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_user),
) -> dict[str, str]:
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
file_type = ChatFileType.IMAGE
file_id = file_store.save_file(
content=file.file,

View File

@@ -205,6 +205,6 @@ def create_deletion_attempt_for_connector_id(
if cc_pair.connector.source == DocumentSource.FILE:
connector = cc_pair.connector
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
for file_name in connector.connector_specific_config.get("file_locations", []):
file_store.delete_file(file_name)

View File

@@ -720,7 +720,7 @@ def upload_files_for_chat(
detail="File size must be less than 20MB",
)
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
file_info: list[tuple[str, str | None, ChatFileType]] = []
for file in files:
@@ -823,10 +823,9 @@ def upload_files_for_chat(
@router.get("/file/{file_id:path}")
def fetch_chat_file(
file_id: str,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_user),
) -> Response:
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
file_record = file_store.read_file_record(file_id)
if not file_record:
raise HTTPException(status_code=404, detail="File not found")

View File

@@ -11,7 +11,6 @@ from onyx.configs.constants import CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAU
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import ONYX_EMAILABLE_LOGO_MAX_DIM
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.file_store.file_store import get_default_file_store
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.utils.file import FileWithMimeType
@@ -40,9 +39,8 @@ class OnyxRuntime:
onyx_file: FileWithMimeType | None = None
if db_filename:
with get_session_with_shared_schema() as db_session:
file_store = get_default_file_store(db_session)
onyx_file = file_store.get_file_with_mime_type(db_filename)
file_store = get_default_file_store()
onyx_file = file_store.get_file_with_mime_type(db_filename)
if not onyx_file:
onyx_file = OnyxStaticFileManager.get_static(static_filename)

View File

@@ -17,7 +17,6 @@ from requests import JSONDecodeError
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.constants import FileOrigin
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
@@ -205,27 +204,26 @@ class CustomTool(BaseTool):
def _save_and_get_file_references(
self, file_content: bytes | str, content_type: str
) -> List[str]:
with get_session_with_current_tenant() as db_session:
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
file_id = str(uuid.uuid4())
file_id = str(uuid.uuid4())
# Handle both binary and text content
if isinstance(file_content, str):
content = BytesIO(file_content.encode())
else:
content = BytesIO(file_content)
# Handle both binary and text content
if isinstance(file_content, str):
content = BytesIO(file_content.encode())
else:
content = BytesIO(file_content)
file_store.save_file(
file_id=file_id,
content=content,
display_name=file_id,
file_origin=FileOrigin.CHAT_UPLOAD,
file_type=content_type,
file_metadata={
"content_type": content_type,
},
)
file_store.save_file(
file_id=file_id,
content=content,
display_name=file_id,
file_origin=FileOrigin.CHAT_UPLOAD,
file_type=content_type,
file_metadata={
"content_type": content_type,
},
)
return [file_id]
@@ -328,22 +326,21 @@ class CustomTool(BaseTool):
# Load files from storage
files = []
with get_session_with_current_tenant() as db_session:
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
for file_id in response.tool_result.file_ids:
try:
file_io = file_store.read_file(file_id, mode="b")
files.append(
InMemoryChatFile(
file_id=file_id,
filename=file_id,
content=file_io.read(),
file_type=file_type,
)
for file_id in response.tool_result.file_ids:
try:
file_io = file_store.read_file(file_id, mode="b")
files.append(
InMemoryChatFile(
file_id=file_id,
filename=file_id,
content=file_io.read(),
file_type=file_type,
)
except Exception:
logger.exception(f"Failed to read file {file_id}")
)
except Exception:
logger.exception(f"Failed to read file {file_id}")
# Update prompt with file content
prompt_builder.update_user_prompt(

View File

@@ -384,3 +384,24 @@ def parallel_yield(gens: list[Iterator[R]], max_workers: int = 10) -> Iterator[R
)
next_ind += 1
del future_to_index[future]
def parallel_yield_from_funcs(
funcs: list[Callable[..., R]],
max_workers: int = 10,
) -> Iterator[R]:
"""
Runs the list of functions with thread-level parallelism, yielding
results as available. The asynchronous nature of this yielding means
that stopping the returned iterator early DOES NOT GUARANTEE THAT NO
FURTHER ITEMS WERE PRODUCED by the input funcs. Only use this function
if you are consuming all elements from the functions OR it is acceptable
for some extra function code to run and not have the result(s) yielded.
"""
def func_wrapper(func: Callable[[], R]) -> Iterator[R]:
yield func()
yield from parallel_yield(
[func_wrapper(func) for func in funcs], max_workers=max_workers
)

View File

@@ -69,7 +69,7 @@ def run_jobs() -> None:
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=indexing@%n",
"--queues=connector_indexing",
"--queues=docprocessing",
]
cmd_worker_user_files_indexing = [
@@ -111,6 +111,19 @@ def run_jobs() -> None:
"--queues=kg_processing",
]
cmd_worker_docfetching = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.docfetching",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=docfetching@%n",
"--queues=connector_doc_fetching,user_files_indexing",
]
cmd_beat = [
"celery",
"-A",
@@ -157,6 +170,13 @@ def run_jobs() -> None:
text=True,
)
worker_docfetching_process = subprocess.Popen(
cmd_worker_docfetching,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
beat_process = subprocess.Popen(
cmd_beat, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
@@ -184,6 +204,9 @@ def run_jobs() -> None:
worker_kg_processing_thread = threading.Thread(
target=monitor_process, args=("KG_PROCESSING", worker_kg_processing_process)
)
worker_docfetching_thread = threading.Thread(
target=monitor_process, args=("DOCFETCHING", worker_docfetching_process)
)
beat_thread = threading.Thread(target=monitor_process, args=("BEAT", beat_process))
worker_primary_thread.start()
@@ -193,6 +216,7 @@ def run_jobs() -> None:
worker_user_files_indexing_thread.start()
worker_monitoring_thread.start()
worker_kg_processing_thread.start()
worker_docfetching_thread.start()
beat_thread.start()
worker_primary_thread.join()
@@ -202,6 +226,7 @@ def run_jobs() -> None:
worker_user_files_indexing_thread.join()
worker_monitoring_thread.join()
worker_kg_processing_thread.join()
worker_docfetching_thread.join()
beat_thread.join()

View File

@@ -213,7 +213,7 @@ def _delete_connector(cc_pair_id: int, db_session: Session) -> None:
if file_names:
logger.notice("Deleting stored files!")
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
for file_name in file_names:
logger.notice(f"Deleting file {file_name}")
file_store.delete_file(file_name)

View File

@@ -69,7 +69,7 @@ stopasgroup=true
command=celery -A onyx.background.celery.versioned_apps.indexing worker
--loglevel=INFO
--hostname=indexing@%%n
-Q connector_indexing
-Q docprocessing
stdout_logfile=/var/log/celery_worker_indexing.log
stdout_logfile_maxbytes=16MB
redirect_stderr=true
@@ -89,6 +89,18 @@ autorestart=true
startsecs=10
stopasgroup=true
[program:celery_worker_docfetching]
command=celery -A onyx.background.celery.versioned_apps.docfetching worker
--loglevel=INFO
--hostname=docfetching@%%n
-Q connector_doc_fetching
stdout_logfile=/var/log/celery_worker_docfetching.log
stdout_logfile_maxbytes=16MB
redirect_stderr=true
autorestart=true
startsecs=10
stopasgroup=true
[program:celery_worker_monitoring]
command=celery -A onyx.background.celery.versioned_apps.monitoring worker
--loglevel=INFO
@@ -161,6 +173,7 @@ command=tail -qF
/var/log/celery_worker_heavy.log
/var/log/celery_worker_indexing.log
/var/log/celery_worker_user_files_indexing.log
/var/log/celery_worker_docfetching.log
/var/log/celery_worker_monitoring.log
/var/log/slack_bot.log
/var/log/supervisord_watchdog_celery_beat.log

View File

@@ -70,20 +70,19 @@ def _get_all_backend_configs() -> List[BackendConfig]:
if S3_ENDPOINT_URL:
minio_access_key = "minioadmin"
minio_secret_key = "minioadmin"
if minio_access_key and minio_secret_key:
configs.append(
{
"endpoint_url": S3_ENDPOINT_URL,
"access_key": minio_access_key,
"secret_key": minio_secret_key,
"region": "us-east-1",
"verify_ssl": False,
"backend_name": "MinIO",
}
)
configs.append(
{
"endpoint_url": S3_ENDPOINT_URL,
"access_key": minio_access_key,
"secret_key": minio_secret_key,
"region": "us-east-1",
"verify_ssl": False,
"backend_name": "MinIO",
}
)
# AWS S3 configuration (if credentials are available)
if S3_AWS_ACCESS_KEY_ID and S3_AWS_SECRET_ACCESS_KEY:
elif S3_AWS_ACCESS_KEY_ID and S3_AWS_SECRET_ACCESS_KEY:
configs.append(
{
"endpoint_url": None,
@@ -116,7 +115,6 @@ def file_store(
# Create S3BackedFileStore with backend-specific configuration
store = S3BackedFileStore(
db_session=db_session,
bucket_name=TEST_BUCKET_NAME,
aws_access_key_id=backend_config["access_key"],
aws_secret_access_key=backend_config["secret_key"],
@@ -827,7 +825,6 @@ class TestS3BackedFileStore:
# Create a new database session for each worker to avoid conflicts
with get_session_with_current_tenant() as worker_session:
worker_file_store = S3BackedFileStore(
db_session=worker_session,
bucket_name=current_bucket_name,
aws_access_key_id=current_access_key,
aws_secret_access_key=current_secret_key,
@@ -849,6 +846,7 @@ class TestS3BackedFileStore:
display_name=f"Worker {worker_id} File",
file_origin=file_origin,
file_type=file_type,
db_session=worker_session,
)
results.append((file_name, content))
return True
@@ -885,3 +883,94 @@ class TestS3BackedFileStore:
read_content_io = file_store.read_file(file_id)
actual_content: str = read_content_io.read().decode("utf-8")
assert actual_content == expected_content
def test_list_files_by_prefix(self, file_store: S3BackedFileStore) -> None:
"""Test listing files by prefix returns only correctly prefixed files"""
test_prefix = "documents-batch-"
# Files that should be returned (start with the prefix)
prefixed_files: List[str] = [
f"{test_prefix}001.txt",
f"{test_prefix}002.json",
f"{test_prefix}abc.pdf",
f"{test_prefix}xyz-final.docx",
]
# Files that should NOT be returned (don't start with prefix, even if they contain it)
non_prefixed_files: List[str] = [
f"other-{test_prefix}001.txt", # Contains prefix but doesn't start with it
f"backup-{test_prefix}data.txt", # Contains prefix but doesn't start with it
f"{uuid.uuid4()}.txt", # Random file without prefix
"reports-001.pdf", # Different prefix
f"my-{test_prefix[:-1]}.txt", # Similar but not exact prefix
]
all_files = prefixed_files + non_prefixed_files
saved_file_ids: List[str] = []
# Save all test files
for file_name in all_files:
content = f"Content for {file_name}"
content_io = BytesIO(content.encode("utf-8"))
returned_file_id = file_store.save_file(
content=content_io,
display_name=f"Display: {file_name}",
file_origin=FileOrigin.OTHER,
file_type="text/plain",
file_id=file_name,
)
saved_file_ids.append(returned_file_id)
# Verify file was saved
assert returned_file_id == file_name
# Test the list_files_by_prefix functionality
prefix_results = file_store.list_files_by_prefix(test_prefix)
# Extract file IDs from results
returned_file_ids = [record.file_id for record in prefix_results]
# Verify correct number of files returned
assert len(returned_file_ids) == len(prefixed_files), (
f"Expected {len(prefixed_files)} files with prefix '{test_prefix}', "
f"but got {len(returned_file_ids)}: {returned_file_ids}"
)
# Verify all prefixed files are returned
for expected_file_id in prefixed_files:
assert expected_file_id in returned_file_ids, (
f"File '{expected_file_id}' should be in results but was not found. "
f"Returned files: {returned_file_ids}"
)
# Verify no non-prefixed files are returned
for unexpected_file_id in non_prefixed_files:
assert unexpected_file_id not in returned_file_ids, (
f"File '{unexpected_file_id}' should NOT be in results but was found. "
f"Returned files: {returned_file_ids}"
)
# Verify the returned records have correct properties
for record in prefix_results:
assert record.file_id.startswith(test_prefix)
assert record.display_name == f"Display: {record.file_id}"
assert record.file_origin == FileOrigin.OTHER
assert record.file_type == "text/plain"
assert record.bucket_name == file_store._get_bucket_name()
# Test with empty prefix (should return all files we created)
all_results = file_store.list_files_by_prefix("")
all_returned_ids = [record.file_id for record in all_results]
# Should include all our test files
for file_id in saved_file_ids:
assert (
file_id in all_returned_ids
), f"File '{file_id}' should be in results for empty prefix"
# Test with non-existent prefix
nonexistent_results = file_store.list_files_by_prefix("nonexistent-prefix-")
assert (
len(nonexistent_results) == 0
), "Should return empty list for non-existent prefix"

View File

@@ -77,18 +77,15 @@ def sample_file_io(sample_content: bytes) -> BytesIO:
class TestExternalStorageFileStore:
"""Test external storage file store functionality (S3-compatible)"""
def test_get_default_file_store_s3(self, db_session: Session) -> None:
def test_get_default_file_store_s3(self) -> None:
"""Test that external storage file store is returned"""
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
assert isinstance(file_store, S3BackedFileStore)
def test_s3_client_initialization_with_credentials(
self, db_session: Session
) -> None:
def test_s3_client_initialization_with_credentials(self) -> None:
"""Test S3 client initialization with explicit credentials"""
with patch("boto3.client") as mock_boto3:
file_store = S3BackedFileStore(
db_session,
bucket_name="test-bucket",
aws_access_key_id="test-key",
aws_secret_access_key="test-secret",
@@ -110,7 +107,6 @@ class TestExternalStorageFileStore:
"""Test S3 client initialization with IAM role (no explicit credentials)"""
with patch("boto3.client") as mock_boto3:
file_store = S3BackedFileStore(
db_session,
bucket_name="test-bucket",
aws_access_key_id=None,
aws_secret_access_key=None,
@@ -129,16 +125,16 @@ class TestExternalStorageFileStore:
assert "aws_access_key_id" not in call_kwargs
assert "aws_secret_access_key" not in call_kwargs
def test_s3_bucket_name_configuration(self, db_session: Session) -> None:
def test_s3_bucket_name_configuration(self) -> None:
"""Test S3 bucket name configuration"""
with patch(
"onyx.file_store.file_store.S3_FILE_STORE_BUCKET_NAME", "my-test-bucket"
):
file_store = S3BackedFileStore(db_session, bucket_name="my-test-bucket")
file_store = S3BackedFileStore(bucket_name="my-test-bucket")
bucket_name: str = file_store._get_bucket_name()
assert bucket_name == "my-test-bucket"
def test_s3_key_generation_default_prefix(self, db_session: Session) -> None:
def test_s3_key_generation_default_prefix(self) -> None:
"""Test S3 key generation with default prefix"""
with (
patch("onyx.file_store.file_store.S3_FILE_STORE_PREFIX", "onyx-files"),
@@ -147,11 +143,11 @@ class TestExternalStorageFileStore:
return_value="test-tenant",
),
):
file_store = S3BackedFileStore(db_session, bucket_name="test-bucket")
file_store = S3BackedFileStore(bucket_name="test-bucket")
s3_key: str = file_store._get_s3_key("test-file.txt")
assert s3_key == "onyx-files/test-tenant/test-file.txt"
def test_s3_key_generation_custom_prefix(self, db_session: Session) -> None:
def test_s3_key_generation_custom_prefix(self) -> None:
"""Test S3 key generation with custom prefix"""
with (
patch("onyx.file_store.file_store.S3_FILE_STORE_PREFIX", "custom-prefix"),
@@ -161,17 +157,15 @@ class TestExternalStorageFileStore:
),
):
file_store = S3BackedFileStore(
db_session, bucket_name="test-bucket", s3_prefix="custom-prefix"
bucket_name="test-bucket", s3_prefix="custom-prefix"
)
s3_key: str = file_store._get_s3_key("test-file.txt")
assert s3_key == "custom-prefix/test-tenant/test-file.txt"
def test_s3_key_generation_with_different_tenant_ids(
self, db_session: Session
) -> None:
def test_s3_key_generation_with_different_tenant_ids(self) -> None:
"""Test S3 key generation with different tenant IDs"""
with patch("onyx.file_store.file_store.S3_FILE_STORE_PREFIX", "onyx-files"):
file_store = S3BackedFileStore(db_session, bucket_name="test-bucket")
file_store = S3BackedFileStore(bucket_name="test-bucket")
# Test with tenant ID "tenant-1"
with patch(
@@ -224,9 +218,7 @@ class TestExternalStorageFileStore:
with patch("onyx.db.file_record.upsert_filerecord") as mock_upsert:
mock_upsert.return_value = Mock()
file_store = S3BackedFileStore(
mock_db_session, bucket_name="test-bucket"
)
file_store = S3BackedFileStore(bucket_name="test-bucket")
# This should not raise an exception
file_store.save_file(
@@ -235,6 +227,7 @@ class TestExternalStorageFileStore:
display_name="Test File",
file_origin=FileOrigin.OTHER,
file_type="text/plain",
db_session=mock_db_session,
)
# Verify S3 client was called correctly
@@ -244,14 +237,13 @@ class TestExternalStorageFileStore:
assert call_args[1]["Key"] == "onyx-files/public/test-file.txt"
assert call_args[1]["ContentType"] == "text/plain"
def test_minio_client_initialization(self, db_session: Session) -> None:
def test_minio_client_initialization(self) -> None:
"""Test S3 client initialization with MinIO endpoint"""
with (
patch("boto3.client") as mock_boto3,
patch("urllib3.disable_warnings"),
):
file_store = S3BackedFileStore(
db_session,
bucket_name="test-bucket",
aws_access_key_id="minioadmin",
aws_secret_access_key="minioadmin",
@@ -277,11 +269,10 @@ class TestExternalStorageFileStore:
assert config.signature_version == "s3v4"
assert config.s3["addressing_style"] == "path"
def test_minio_ssl_verification_enabled(self, db_session: Session) -> None:
def test_minio_ssl_verification_enabled(self) -> None:
"""Test MinIO with SSL verification enabled"""
with patch("boto3.client") as mock_boto3:
file_store = S3BackedFileStore(
db_session,
bucket_name="test-bucket",
aws_access_key_id="test-key",
aws_secret_access_key="test-secret",
@@ -295,11 +286,10 @@ class TestExternalStorageFileStore:
assert "verify" not in call_kwargs or call_kwargs.get("verify") is not False
assert call_kwargs["endpoint_url"] == "https://minio.example.com"
def test_aws_s3_without_endpoint_url(self, db_session: Session) -> None:
def test_aws_s3_without_endpoint_url(self) -> None:
"""Test that regular AWS S3 doesn't include endpoint URL or custom config"""
with patch("boto3.client") as mock_boto3:
file_store = S3BackedFileStore(
db_session,
bucket_name="test-bucket",
aws_access_key_id="test-key",
aws_secret_access_key="test-secret",
@@ -321,8 +311,8 @@ class TestExternalStorageFileStore:
class TestFileStoreInterface:
"""Test the general file store interface"""
def test_file_store_always_external_storage(self, db_session: Session) -> None:
def test_file_store_always_external_storage(self) -> None:
"""Test that external storage file store is always returned"""
# File store should always be S3BackedFileStore regardless of environment
file_store = get_default_file_store(db_session)
file_store = get_default_file_store()
assert isinstance(file_store, S3BackedFileStore)

View File

@@ -472,6 +472,7 @@ services:
environment:
MINIO_ROOT_USER: ${MINIO_ROOT_USER:-minioadmin}
MINIO_ROOT_PASSWORD: ${MINIO_ROOT_PASSWORD:-minioadmin}
# Note: we've seen the default bucket creation logic not work in some cases
MINIO_DEFAULT_BUCKETS: ${S3_FILE_STORE_BUCKET_NAME:-onyx-file-store-bucket}
volumes:
- minio_data:/data

View File

@@ -0,0 +1,81 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: {{ include "onyx-stack.fullname" . }}-celery-worker-docfetching
labels:
{{- include "onyx-stack.labels" . | nindent 4 }}
spec:
{{- if not .Values.celery_worker_docfetching.autoscaling.enabled }}
replicas: {{ .Values.celery_worker_docfetching.replicaCount }}
{{- end }}
selector:
matchLabels:
{{- include "onyx-stack.selectorLabels" . | nindent 6 }}
{{- if .Values.celery_worker_docfetching.deploymentLabels }}
{{- toYaml .Values.celery_worker_docfetching.deploymentLabels | nindent 6 }}
{{- end }}
template:
metadata:
{{- with .Values.celery_worker_docfetching.podAnnotations }}
annotations:
{{- toYaml . | nindent 8 }}
{{- end }}
labels:
{{- include "onyx-stack.labels" . | nindent 8 }}
{{- with .Values.celery_worker_docfetching.podLabels }}
{{- toYaml . | nindent 8 }}
{{- end }}
spec:
{{- with .Values.imagePullSecrets }}
imagePullSecrets:
{{- toYaml . | nindent 8 }}
{{- end }}
serviceAccountName: {{ include "onyx-stack.serviceAccountName" . }}
securityContext:
{{- toYaml .Values.celery_worker_docfetching.podSecurityContext | nindent 8 }}
containers:
- name: celery-worker-docfetching
securityContext:
{{- toYaml .Values.celery_worker_docfetching.securityContext | nindent 12 }}
image: "{{ .Values.celery_shared.image.repository }}:{{ .Values.celery_shared.image.tag | default .Chart.AppVersion }}"
imagePullPolicy: {{ .Values.celery_shared.image.pullPolicy }}
command:
[
"celery",
"-A",
"onyx.background.celery.versioned_apps.docfetching",
"worker",
"--loglevel=INFO",
"--hostname=docfetching@%n",
"-Q",
"connector_doc_fetching,user_files_indexing",
]
resources:
{{- toYaml .Values.celery_worker_docfetching.resources | nindent 12 }}
envFrom:
- configMapRef:
name: {{ .Values.config.envConfigMapName }}
env:
{{- include "onyx-stack.envSecrets" . | nindent 12}}
startupProbe:
{{ .Values.celery_shared.startupProbe | toYaml | nindent 12}}
readinessProbe:
{{ .Values.celery_shared.readinessProbe | toYaml | nindent 12}}
exec:
command:
- /bin/bash
- -c
- >
python onyx/background/celery/celery_k8s_probe.py
--probe readiness
--filename /tmp/onyx_k8s_docfetching_readiness.txt
livenessProbe:
{{ .Values.celery_shared.livenessProbe | toYaml | nindent 12}}
exec:
command:
- /bin/bash
- -c
- >
python onyx/background/celery/celery_k8s_probe.py
--probe liveness
--filename /tmp/onyx_k8s_docfetching_liveness.txt

View File

@@ -48,7 +48,7 @@ spec:
"--loglevel=INFO",
"--hostname=indexing@%n",
"-Q",
"connector_indexing",
"docprocessing",
]
resources:
{{- toYaml .Values.celery_worker_indexing.resources | nindent 12 }}

View File

@@ -538,6 +538,27 @@ slackbot:
limits:
cpu: "1000m"
memory: "2000Mi"
celery_worker_docfetching:
replicaCount: 1
autoscaling:
enabled: false
podAnnotations: {}
podLabels:
scope: onyx-backend-celery
app: celery-worker-docfetching
deploymentLabels:
app: celery-worker-docfetching
podSecurityContext:
{}
securityContext:
privileged: true
runAsUser: 0
resources: {}
volumes: [] # Additional volumes on the output Deployment definition.
volumeMounts: [] # Additional volumeMounts on the output Deployment definition.
nodeSelector: {}
tolerations: []
affinity: {}
redis:
enabled: true

View File

@@ -217,7 +217,7 @@ const collections = (
<div className="ml-1">Document Processing</div>
</div>
),
link: "/admin/configuration/document-processing",
link: "/admin/configuration/docfetching",
},
...(kgExposed
? [