mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-01 13:45:44 +00:00
Compare commits
13 Commits
refactor-m
...
v1.0.0-clo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
81ffcc00db | ||
|
|
cc3ee12bff | ||
|
|
70a01680ff | ||
|
|
8da82dbfaf | ||
|
|
56e749dcee | ||
|
|
f2ca7a8769 | ||
|
|
c3de6a8e49 | ||
|
|
c8cd85b284 | ||
|
|
a899254766 | ||
|
|
d33adfa91a | ||
|
|
b9df82c5a1 | ||
|
|
b8f652109f | ||
|
|
3aeaff7cda |
@@ -9,7 +9,7 @@ Create Date: 2025-06-20 14:44:54.241159
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
from httpx import HTTPStatusError
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.db.search_settings import SearchSettings
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
@@ -175,6 +175,7 @@ def update_document_id_in_database(old_doc_id: str, new_doc_id: str) -> None:
|
||||
)
|
||||
|
||||
# Update search_doc table (stores search results for chat replay)
|
||||
# This is critical for agent functionality
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"UPDATE search_doc SET document_id = :new_id WHERE document_id = :old_id"
|
||||
@@ -396,6 +397,39 @@ def delete_document_from_db(current_doc_id: str, index_name: str) -> None:
|
||||
try:
|
||||
bind = op.get_bind()
|
||||
|
||||
# Delete from agent-related tables first (order matters due to foreign keys)
|
||||
# Delete from agent__sub_query__search_doc first since it references search_doc
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM agent__sub_query__search_doc
|
||||
WHERE search_doc_id IN (
|
||||
SELECT id FROM search_doc WHERE document_id = :doc_id
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
# Delete from chat_message__search_doc
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM chat_message__search_doc
|
||||
WHERE search_doc_id IN (
|
||||
SELECT id FROM search_doc WHERE document_id = :doc_id
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
# Now we can safely delete from search_doc
|
||||
bind.execute(
|
||||
sa.text("DELETE FROM search_doc WHERE document_id = :doc_id"),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
# Delete from document_by_connector_credential_pair
|
||||
bind.execute(
|
||||
sa.text(
|
||||
@@ -405,11 +439,6 @@ def delete_document_from_db(current_doc_id: str, index_name: str) -> None:
|
||||
)
|
||||
|
||||
# Delete from other tables that reference this document
|
||||
bind.execute(
|
||||
sa.text("DELETE FROM search_doc WHERE document_id = :doc_id"),
|
||||
{"doc_id": current_doc_id},
|
||||
)
|
||||
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"DELETE FROM document_retrieval_feedback WHERE document_id = :doc_id"
|
||||
@@ -531,7 +560,6 @@ def upgrade() -> None:
|
||||
updated_count += 1
|
||||
except Exception as e:
|
||||
print(f"Failed to update document {current_doc_id}: {e}")
|
||||
from httpx import HTTPStatusError
|
||||
|
||||
if isinstance(e, HTTPStatusError):
|
||||
print(f"HTTPStatusError: {e}")
|
||||
|
||||
@@ -159,7 +159,7 @@ def _migrate_files_to_postgres() -> None:
|
||||
|
||||
# only create external store if we have files to migrate. This line
|
||||
# makes it so we need to have S3/MinIO configured to run this migration.
|
||||
external_store = get_s3_file_store(db_session=session)
|
||||
external_store = get_s3_file_store()
|
||||
|
||||
for i, file_id in enumerate(files_to_migrate, 1):
|
||||
print(f"Migrating file {i}/{total_files}: {file_id}")
|
||||
@@ -219,7 +219,7 @@ def _migrate_files_to_external_storage() -> None:
|
||||
# Get database session
|
||||
bind = op.get_bind()
|
||||
session = Session(bind=bind)
|
||||
external_store = get_s3_file_store(db_session=session)
|
||||
external_store = get_s3_file_store()
|
||||
|
||||
# Find all files currently stored in PostgreSQL (lobj_oid is not null)
|
||||
result = session.execute(
|
||||
|
||||
@@ -91,7 +91,7 @@ def export_query_history_task(
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
stream.seek(0)
|
||||
get_default_file_store(db_session).save_file(
|
||||
get_default_file_store().save_file(
|
||||
content=stream,
|
||||
display_name=report_name,
|
||||
file_origin=FileOrigin.QUERY_HISTORY_CSV,
|
||||
|
||||
@@ -422,7 +422,7 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
+ f"_{redis_connector.cc_pair_id}",
|
||||
timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
@@ -383,7 +383,7 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
+ f"_{redis_connector.cc_pair_id}",
|
||||
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
|
||||
@@ -114,7 +114,6 @@ def get_all_usage_reports(db_session: Session) -> list[UsageReportMetadata]:
|
||||
|
||||
|
||||
def get_usage_report_data(
|
||||
db_session: Session,
|
||||
report_display_name: str,
|
||||
) -> IO:
|
||||
"""
|
||||
@@ -128,7 +127,7 @@ def get_usage_report_data(
|
||||
Returns:
|
||||
The usage report data.
|
||||
"""
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
# usage report may be very large, so don't load it all into memory
|
||||
return file_store.read_file(
|
||||
file_id=report_display_name, mode="b", use_tempfile=True
|
||||
|
||||
@@ -134,15 +134,14 @@ def ee_fetch_settings() -> EnterpriseSettings:
|
||||
def put_logo(
|
||||
file: UploadFile,
|
||||
is_logotype: bool = False,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
upload_logo(file=file, db_session=db_session, is_logotype=is_logotype)
|
||||
upload_logo(file=file, is_logotype=is_logotype)
|
||||
|
||||
|
||||
def fetch_logo_helper(db_session: Session) -> Response:
|
||||
try:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
onyx_file = file_store.get_file_with_mime_type(get_logo_filename())
|
||||
if not onyx_file:
|
||||
raise ValueError("get_onyx_file returned None!")
|
||||
@@ -158,7 +157,7 @@ def fetch_logo_helper(db_session: Session) -> Response:
|
||||
|
||||
def fetch_logotype_helper(db_session: Session) -> Response:
|
||||
try:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
onyx_file = file_store.get_file_with_mime_type(get_logotype_filename())
|
||||
if not onyx_file:
|
||||
raise ValueError("get_onyx_file returned None!")
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import IO
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload
|
||||
from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
|
||||
@@ -99,9 +98,7 @@ def guess_file_type(filename: str) -> str:
|
||||
return "application/octet-stream"
|
||||
|
||||
|
||||
def upload_logo(
|
||||
db_session: Session, file: UploadFile | str, is_logotype: bool = False
|
||||
) -> bool:
|
||||
def upload_logo(file: UploadFile | str, is_logotype: bool = False) -> bool:
|
||||
content: IO[Any]
|
||||
|
||||
if isinstance(file, str):
|
||||
@@ -129,7 +126,7 @@ def upload_logo(
|
||||
display_name = file.filename
|
||||
file_type = file.content_type or "image/jpeg"
|
||||
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
file_store.save_file(
|
||||
content=content,
|
||||
display_name=display_name,
|
||||
|
||||
@@ -358,7 +358,7 @@ def get_query_history_export_status(
|
||||
# If task is None, then it's possible that the task has already finished processing.
|
||||
# Therefore, we should then check if the export file has already been stored inside of the file-store.
|
||||
# If that *also* doesn't exist, then we can return a 404.
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
|
||||
report_name = construct_query_history_report_name(request_id)
|
||||
has_file = file_store.has_file(
|
||||
@@ -385,7 +385,7 @@ def download_query_history_csv(
|
||||
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
|
||||
|
||||
report_name = construct_query_history_report_name(request_id)
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
has_file = file_store.has_file(
|
||||
file_id=report_name,
|
||||
file_origin=FileOrigin.QUERY_HISTORY_CSV,
|
||||
|
||||
@@ -53,7 +53,7 @@ def read_usage_report(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
try:
|
||||
file = get_usage_report_data(db_session, report_name)
|
||||
file = get_usage_report_data(report_name)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@@ -112,7 +112,7 @@ def create_new_usage_report(
|
||||
period: tuple[datetime, datetime] | None,
|
||||
) -> UsageReportMetadata:
|
||||
report_id = str(uuid.uuid4())
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
|
||||
messages_file_id = generate_chat_messages_report(
|
||||
db_session, file_store, report_id, period
|
||||
|
||||
@@ -200,10 +200,10 @@ def _seed_enterprise_settings(seed_config: SeedConfiguration) -> None:
|
||||
store_ee_settings(final_enterprise_settings)
|
||||
|
||||
|
||||
def _seed_logo(db_session: Session, logo_path: str | None) -> None:
|
||||
def _seed_logo(logo_path: str | None) -> None:
|
||||
if logo_path:
|
||||
logger.notice("Uploading logo")
|
||||
upload_logo(db_session=db_session, file=logo_path)
|
||||
upload_logo(file=logo_path)
|
||||
|
||||
|
||||
def _seed_analytics_script(seed_config: SeedConfiguration) -> None:
|
||||
@@ -245,7 +245,7 @@ def seed_db() -> None:
|
||||
if seed_config.custom_tools is not None:
|
||||
_seed_custom_tools(db_session, seed_config.custom_tools)
|
||||
|
||||
_seed_logo(db_session, seed_config.seeded_logo_path)
|
||||
_seed_logo(seed_config.seeded_logo_path)
|
||||
_seed_enterprise_settings(seed_config)
|
||||
_seed_analytics_script(seed_config)
|
||||
|
||||
|
||||
102
backend/onyx/background/celery/apps/docfetching.py
Normal file
102
backend/onyx/background/celery/apps/docfetching.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.docfetching")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME)
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
base_bootsteps = app_base.get_bootsteps()
|
||||
for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
]
|
||||
)
|
||||
23
backend/onyx/background/celery/configs/docfetching.py
Normal file
23
backend/onyx/background/celery/configs/docfetching.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
from onyx.configs.app_configs import CELERY_WORKER_DOCFETCHING_CONCURRENCY
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
# Document processing worker configuration
|
||||
# Similar to indexing worker since it does similar CPU/memory intensive work
|
||||
worker_concurrency = CELERY_WORKER_DOCFETCHING_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
713
backend/onyx/background/celery/tasks/docfetching/tasks.py
Normal file
713
backend/onyx/background/celery/tasks/docfetching/tasks.py
Normal file
@@ -0,0 +1,713 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from time import sleep
|
||||
|
||||
import sentry_sdk
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
from onyx.background.celery.tasks.indexing.tasks import ConnectorIndexingLogBuilder
|
||||
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
|
||||
from onyx.background.celery.tasks.models import ConnectorIndexingContext
|
||||
from onyx.background.celery.tasks.models import IndexingWatchdogTerminalStatus
|
||||
from onyx.background.celery.tasks.models import SimpleJobResult
|
||||
from onyx.background.indexing.job_client import SimpleJob
|
||||
from onyx.background.indexing.job_client import SimpleJobClient
|
||||
from onyx.background.indexing.job_client import SimpleJobException
|
||||
from onyx.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import mark_attempt_canceled
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndexPayload
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _wait_for_fence_payload(
|
||||
redis_connector: RedisConnector,
|
||||
redis_connector_index: RedisConnectorIndex,
|
||||
index_attempt_id: int,
|
||||
) -> RedisConnectorIndexPayload:
|
||||
# this wait is needed to avoid a race condition where
|
||||
# the primary worker sends the task and it is immediately executed
|
||||
# before the primary worker can finalize the fence
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
|
||||
raise SimpleJobException(
|
||||
f"docfetching_task - timed out waiting for fence to be ready: "
|
||||
f"fence={redis_connector.permissions.fence_key}",
|
||||
code=IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT.code,
|
||||
)
|
||||
|
||||
if not redis_connector_index.fenced: # The fence must exist
|
||||
raise SimpleJobException(
|
||||
f"docfetching_task - fence not found: fence={redis_connector_index.fence_key}",
|
||||
code=IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND.code,
|
||||
)
|
||||
|
||||
payload = redis_connector_index.payload # The payload must exist
|
||||
if not payload:
|
||||
raise SimpleJobException(
|
||||
"docfetching_task: payload invalid or not found",
|
||||
code=IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND.code,
|
||||
)
|
||||
|
||||
if payload.index_attempt_id is None or payload.celery_task_id is None:
|
||||
logger.info(
|
||||
f"docfetching_task - Waiting for fence: fence={redis_connector_index.fence_key}"
|
||||
)
|
||||
sleep(1)
|
||||
continue
|
||||
|
||||
if payload.index_attempt_id != index_attempt_id:
|
||||
raise SimpleJobException(
|
||||
f"docfetching_task - id mismatch. Task may be left over from previous run.: "
|
||||
f"task_index_attempt={index_attempt_id} "
|
||||
f"payload_index_attempt={payload.index_attempt_id}",
|
||||
code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"docfetching_task - Fence found, continuing...: fence={redis_connector_index.fence_key}"
|
||||
)
|
||||
return payload
|
||||
|
||||
|
||||
def docfetching_task(
|
||||
app: Celery,
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
is_ee: bool,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
This function is run in a SimpleJob as a new process. It is responsible for validating
|
||||
some stuff, but basically it just calls run_indexing_entrypoint.
|
||||
|
||||
NOTE: if an exception is raised out of this task, the primary worker will detect
|
||||
that the task transitioned to a "READY" state but the generator_complete_key doesn't exist.
|
||||
This will cause the primary worker to abort the indexing attempt and clean up.
|
||||
"""
|
||||
|
||||
# Since connector_indexing_proxy_task spawns a new process using this function as
|
||||
# the entrypoint, we init Sentry here.
|
||||
if SENTRY_DSN:
|
||||
sentry_sdk.init(
|
||||
dsn=SENTRY_DSN,
|
||||
traces_sample_rate=0.1,
|
||||
)
|
||||
logger.info("Sentry initialized")
|
||||
else:
|
||||
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
|
||||
|
||||
logger.info(
|
||||
f"Indexing spawned task starting: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
# 20 is the documented default for httpx max_keepalive_connections
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
if redis_connector.delete.fenced:
|
||||
raise SimpleJobException(
|
||||
f"Indexing will not start because connector deletion is in progress: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={redis_connector.delete.fence_key}",
|
||||
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION.code,
|
||||
)
|
||||
|
||||
if redis_connector.stop.fenced:
|
||||
raise SimpleJobException(
|
||||
f"Indexing will not start because a connector stop signal was detected: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={redis_connector.stop.fence_key}",
|
||||
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL.code,
|
||||
)
|
||||
|
||||
payload = _wait_for_fence_payload(
|
||||
redis_connector, redis_connector_index, index_attempt_id
|
||||
)
|
||||
r = get_redis_client()
|
||||
# set thread_local=False since we don't control what thread the indexing/pruning
|
||||
# might run our callback with
|
||||
lock: RedisLock = r.lock(
|
||||
redis_connector_index.generator_lock_key,
|
||||
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
logger.warning(
|
||||
f"Indexing task already running, exiting...: "
|
||||
f"index_attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
raise SimpleJobException(
|
||||
f"Indexing task already running, exiting...: "
|
||||
f"index_attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}",
|
||||
code=IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING.code,
|
||||
)
|
||||
|
||||
payload.started = datetime.now(timezone.utc)
|
||||
redis_connector_index.set_fence(payload)
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if not attempt:
|
||||
raise SimpleJobException(
|
||||
f"Index attempt not found: index_attempt={index_attempt_id}",
|
||||
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
|
||||
)
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
if not cc_pair:
|
||||
raise SimpleJobException(
|
||||
f"cc_pair not found: cc_pair={cc_pair_id}",
|
||||
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
|
||||
)
|
||||
|
||||
# TODO: pretty sure it's impossible for these to happen, should probably delete
|
||||
if not cc_pair.connector:
|
||||
raise SimpleJobException(
|
||||
f"Connector not found: cc_pair={cc_pair_id} connector={cc_pair.connector_id}",
|
||||
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
|
||||
)
|
||||
|
||||
if not cc_pair.credential:
|
||||
raise SimpleJobException(
|
||||
f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}",
|
||||
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
|
||||
)
|
||||
|
||||
# define a callback class
|
||||
callback = IndexingCallback(
|
||||
os.getppid(),
|
||||
redis_connector,
|
||||
lock,
|
||||
r,
|
||||
redis_connector_index,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Indexing spawned task running entrypoint: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
# This is where the heavy/real work happens
|
||||
run_indexing_entrypoint(
|
||||
app,
|
||||
index_attempt_id,
|
||||
tenant_id,
|
||||
cc_pair_id,
|
||||
is_ee,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
# get back the total number of indexed docs and return it
|
||||
redis_connector_index.get_progress()
|
||||
redis_connector_index.set_generator_complete(HTTPStatus.OK.value)
|
||||
except ConnectorValidationError:
|
||||
raise SimpleJobException(
|
||||
f"Indexing task failed: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}",
|
||||
code=IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR.code,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Indexing spawned task failed: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
# special bulletproofing ... truncate long exception messages
|
||||
# for exception types that require more args, this will fail
|
||||
# thus the try/except
|
||||
try:
|
||||
sanitized_e = type(e)(str(e)[:1024])
|
||||
sanitized_e.__traceback__ = e.__traceback__
|
||||
raise sanitized_e
|
||||
except Exception:
|
||||
raise e
|
||||
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
logger.info(
|
||||
f"Indexing spawned task finished: attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
os._exit(0) # ensure process exits cleanly
|
||||
|
||||
|
||||
def process_job_result(
|
||||
job: SimpleJob,
|
||||
connector_source: str | None,
|
||||
redis_connector_index: RedisConnectorIndex,
|
||||
log_builder: ConnectorIndexingLogBuilder,
|
||||
) -> SimpleJobResult:
|
||||
result = SimpleJobResult()
|
||||
result.connector_source = connector_source
|
||||
|
||||
if job.process:
|
||||
result.exit_code = job.process.exitcode
|
||||
|
||||
if job.status != "error":
|
||||
result.status = IndexingWatchdogTerminalStatus.SUCCEEDED
|
||||
return result
|
||||
|
||||
ignore_exitcode = False
|
||||
|
||||
# In EKS, there is an edge case where successful tasks return exit
|
||||
# code 1 in the cloud due to the set_spawn_method not sticking.
|
||||
# We've since worked around this, but the following is a safe way to
|
||||
# work around this issue. Basically, we ignore the job error state
|
||||
# if the completion signal is OK.
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int:
|
||||
status_enum = HTTPStatus(status_int)
|
||||
if status_enum == HTTPStatus.OK:
|
||||
ignore_exitcode = True
|
||||
|
||||
if ignore_exitcode:
|
||||
result.status = IndexingWatchdogTerminalStatus.SUCCEEDED
|
||||
task_logger.warning(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - spawned task has non-zero exit code "
|
||||
"but completion signal is OK. Continuing...",
|
||||
exit_code=str(result.exit_code),
|
||||
)
|
||||
)
|
||||
else:
|
||||
if result.exit_code is not None:
|
||||
result.status = IndexingWatchdogTerminalStatus.from_code(result.exit_code)
|
||||
|
||||
result.exception_str = job.exception()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
|
||||
bind=True,
|
||||
acks_late=False,
|
||||
track_started=True,
|
||||
)
|
||||
def docfetching_proxy_task(
|
||||
self: Task,
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
This task is the entrypoint for the full indexing pipeline, which is composed of two tasks:
|
||||
docfetching and docprocessing.
|
||||
This task is spawned by "try_creating_indexing_task" which is called in the "check_for_indexing" task.
|
||||
|
||||
This task spawns a new process for a new scheduled index attempt. That
|
||||
new process (which runs the docfetching_task function) does the following:
|
||||
|
||||
TODO: clen up
|
||||
determines parameters of the indexing attempt (which connector indexing function to run,
|
||||
start and end time, from prev checkpoint or not), then run that connector. Specifically,
|
||||
connectors are responsible for reading data from an outside source and converting it to Onyx documents.
|
||||
At the moment these two steps (reading external data and converting to an Onyx document)
|
||||
are not parallelized in most connectors; that's a subject for future work.
|
||||
upserts documents to postgres (index_doc_batch_prepare)
|
||||
chunks each document (optionally adds context for contextual rag)
|
||||
embeds chunks (embed_chunks_with_failure_handling) via a call to the model server
|
||||
write chunks to vespa (write_chunks_to_vector_db_with_backoff)
|
||||
update document and indexing metadata in postgres
|
||||
|
||||
pulls all document IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
from the most recently pulled document ID list
|
||||
|
||||
TODO(rkuo): refactor this so that there is a single return path where we canonically
|
||||
log the result of running this function.
|
||||
|
||||
Some more Richard notes:
|
||||
celery out of process task execution strategy is pool=prefork, but it uses fork,
|
||||
and forking is inherently unstable.
|
||||
|
||||
To work around this, we use pool=threads and proxy our work to a spawned task.
|
||||
|
||||
acks_late must be set to False. Otherwise, celery's visibility timeout will
|
||||
cause any task that runs longer than the timeout to be redispatched by the broker.
|
||||
There appears to be no good workaround for this, so we need to handle redispatching
|
||||
manually.
|
||||
NOTE: we try/except all db access in this function because as a watchdog, this function
|
||||
needs to be extremely stable.
|
||||
"""
|
||||
start = time.monotonic()
|
||||
|
||||
result = SimpleJobResult()
|
||||
|
||||
ctx = ConnectorIndexingContext(
|
||||
tenant_id=tenant_id,
|
||||
cc_pair_id=cc_pair_id,
|
||||
search_settings_id=search_settings_id,
|
||||
index_attempt_id=index_attempt_id,
|
||||
)
|
||||
|
||||
log_builder = ConnectorIndexingLogBuilder(ctx)
|
||||
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - starting",
|
||||
mp_start_method=str(multiprocessing.get_start_method()),
|
||||
)
|
||||
)
|
||||
|
||||
if not self.request.id:
|
||||
task_logger.error("self.request.id is None!")
|
||||
|
||||
client = SimpleJobClient()
|
||||
task_logger.info(f"submitting docfetching_task with tenant_id={tenant_id}")
|
||||
|
||||
job = client.submit(
|
||||
docfetching_task,
|
||||
self.app,
|
||||
index_attempt_id,
|
||||
cc_pair_id,
|
||||
search_settings_id,
|
||||
global_version.is_ee_version(),
|
||||
tenant_id,
|
||||
)
|
||||
|
||||
if not job or not job.process:
|
||||
result.status = IndexingWatchdogTerminalStatus.SPAWN_FAILED
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - finished",
|
||||
status=str(result.status.value),
|
||||
exit_code=str(result.exit_code),
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Ensure the process has moved out of the starting state
|
||||
num_waits = 0
|
||||
while True:
|
||||
if num_waits > 15:
|
||||
result.status = IndexingWatchdogTerminalStatus.SPAWN_NOT_ALIVE
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - finished",
|
||||
status=str(result.status.value),
|
||||
exit_code=str(result.exit_code),
|
||||
)
|
||||
)
|
||||
job.release()
|
||||
return
|
||||
|
||||
if job.process.is_alive() or job.process.exitcode is not None:
|
||||
break
|
||||
|
||||
sleep(1)
|
||||
num_waits += 1
|
||||
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - spawn succeeded",
|
||||
pid=str(job.process.pid),
|
||||
)
|
||||
)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
# Track the last time memory info was emitted
|
||||
last_memory_emit_time = 0.0
|
||||
|
||||
# track the last ttl and the time it was observed
|
||||
last_activity_ttl_observed: float = time.monotonic()
|
||||
last_activity_ttl: int = 0
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session,
|
||||
index_attempt_id=index_attempt_id,
|
||||
eager_load_cc_pair=True,
|
||||
)
|
||||
if not index_attempt:
|
||||
raise RuntimeError("Index attempt not found")
|
||||
|
||||
result.connector_source = (
|
||||
index_attempt.connector_credential_pair.connector.source.value
|
||||
)
|
||||
|
||||
redis_connector_index.set_active() # renew active signal
|
||||
|
||||
# prime the connector active signal (renewed inside the connector)
|
||||
redis_connector_index.set_connector_active()
|
||||
|
||||
while True:
|
||||
sleep(5)
|
||||
|
||||
now = time.monotonic()
|
||||
|
||||
# renew watchdog signal (this has a shorter timeout than set_active)
|
||||
redis_connector_index.set_watchdog(True)
|
||||
|
||||
# renew active signal
|
||||
redis_connector_index.set_active()
|
||||
|
||||
# if the job is done, clean up and break
|
||||
if job.done():
|
||||
try:
|
||||
result = process_job_result(
|
||||
job, result.connector_source, redis_connector_index, log_builder
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - spawned task exceptioned"
|
||||
)
|
||||
)
|
||||
finally:
|
||||
job.release()
|
||||
break
|
||||
|
||||
# log the memory usage for tracking down memory leaks / connector-specific memory issues
|
||||
pid = job.process.pid
|
||||
if pid is not None:
|
||||
# Only emit memory info once per minute (60 seconds)
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_memory_emit_time >= 60.0:
|
||||
emit_process_memory(
|
||||
pid,
|
||||
"indexing_worker",
|
||||
{
|
||||
"cc_pair_id": cc_pair_id,
|
||||
"search_settings_id": search_settings_id,
|
||||
"index_attempt_id": index_attempt_id,
|
||||
},
|
||||
)
|
||||
last_memory_emit_time = current_time
|
||||
|
||||
# if a termination signal is detected, break (exit point will clean up)
|
||||
if self.request.id and redis_connector_index.terminating(self.request.id):
|
||||
task_logger.warning(
|
||||
log_builder.build("Indexing watchdog - termination signal detected")
|
||||
)
|
||||
|
||||
result.status = IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL
|
||||
break
|
||||
|
||||
# if activity timeout is detected, break (exit point will clean up)
|
||||
ttl = redis_connector_index.connector_active_ttl()
|
||||
if ttl < 0:
|
||||
# verify expectations around ttl
|
||||
last_observed = last_activity_ttl_observed - now
|
||||
if now > last_activity_ttl_observed + last_activity_ttl:
|
||||
task_logger.warning(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - activity timeout exceeded",
|
||||
last_observed=f"{last_observed:.2f}s",
|
||||
last_ttl=f"{last_activity_ttl}",
|
||||
timeout=f"{CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
|
||||
)
|
||||
)
|
||||
|
||||
result.status = (
|
||||
IndexingWatchdogTerminalStatus.TERMINATED_BY_ACTIVITY_TIMEOUT
|
||||
)
|
||||
break
|
||||
else:
|
||||
task_logger.warning(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - activity timeout expired unexpectedly, "
|
||||
"waiting for last observed TTL before exiting",
|
||||
last_observed=f"{last_observed:.2f}s",
|
||||
last_ttl=f"{last_activity_ttl}",
|
||||
timeout=f"{CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
|
||||
)
|
||||
)
|
||||
else:
|
||||
last_activity_ttl_observed = now
|
||||
last_activity_ttl = ttl
|
||||
|
||||
# if the spawned task is still running, restart the check once again
|
||||
# if the index attempt is not in a finished status
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - transient exception looking up index attempt"
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
result.status = IndexingWatchdogTerminalStatus.WATCHDOG_EXCEPTIONED
|
||||
if isinstance(e, ConnectorValidationError):
|
||||
# No need to expose full stack trace for validation errors
|
||||
result.exception_str = str(e)
|
||||
else:
|
||||
result.exception_str = traceback.format_exc()
|
||||
|
||||
# handle exit and reporting
|
||||
elapsed = time.monotonic() - start
|
||||
if result.exception_str is not None:
|
||||
# print with exception
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
failure_reason = (
|
||||
f"Spawned task exceptioned: exit_code={result.exit_code}"
|
||||
)
|
||||
mark_attempt_failed(
|
||||
ctx.index_attempt_id,
|
||||
db_session,
|
||||
failure_reason=failure_reason,
|
||||
full_exception_trace=result.exception_str,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - transient exception marking index attempt as failed"
|
||||
)
|
||||
)
|
||||
|
||||
normalized_exception_str = "None"
|
||||
if result.exception_str:
|
||||
normalized_exception_str = result.exception_str.replace(
|
||||
"\n", "\\n"
|
||||
).replace('"', '\\"')
|
||||
|
||||
task_logger.warning(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - finished",
|
||||
source=result.connector_source,
|
||||
status=result.status.value,
|
||||
exit_code=str(result.exit_code),
|
||||
exception=f'"{normalized_exception_str}"',
|
||||
elapsed=f"{elapsed:.2f}s",
|
||||
)
|
||||
)
|
||||
|
||||
redis_connector_index.set_watchdog(False)
|
||||
raise RuntimeError(f"Exception encountered: traceback={result.exception_str}")
|
||||
|
||||
# print without exception
|
||||
if result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
logger.exception(
|
||||
f"Marking attempt {index_attempt_id} as canceled due to termination signal"
|
||||
)
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
"Connector termination signal detected",
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - transient exception marking index attempt as canceled"
|
||||
)
|
||||
)
|
||||
|
||||
job.cancel()
|
||||
elif result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_ACTIVITY_TIMEOUT:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
"Indexing watchdog - activity timeout exceeded: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"timeout={CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - transient exception marking index attempt as failed"
|
||||
)
|
||||
)
|
||||
job.cancel()
|
||||
else:
|
||||
pass
|
||||
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - finished",
|
||||
source=result.connector_source,
|
||||
status=str(result.status.value),
|
||||
exit_code=str(result.exit_code),
|
||||
elapsed=f"{elapsed:.2f}s",
|
||||
)
|
||||
)
|
||||
|
||||
redis_connector_index.set_watchdog(False)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -123,10 +123,7 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
|
||||
self.last_parent_check = time.monotonic()
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_connector.stop.fenced:
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(self.redis_connector.stop.fenced)
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
"""Amount isn't used yet."""
|
||||
@@ -268,7 +265,7 @@ def validate_indexing_fence(
|
||||
return
|
||||
|
||||
found = celery_find_task(
|
||||
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
|
||||
)
|
||||
if found:
|
||||
# the celery task exists in the redis queue
|
||||
@@ -327,7 +324,7 @@ def validate_indexing_fences(
|
||||
indexing tasks sent to celery are still in flight.
|
||||
"""
|
||||
reserved_indexing_tasks = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
|
||||
)
|
||||
|
||||
# Use replica for this because the worst thing that happens
|
||||
@@ -414,10 +411,12 @@ def should_index(
|
||||
)
|
||||
|
||||
# uncomment for debugging
|
||||
# task_logger.info(f"_should_index: "
|
||||
# f"cc_pair={cc_pair.id} "
|
||||
# f"connector={cc_pair.connector_id} "
|
||||
# f"refresh_freq={connector.refresh_freq}")
|
||||
task_logger.info(
|
||||
f"_should_index: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"connector={cc_pair.connector_id} "
|
||||
f"refresh_freq={connector.refresh_freq}"
|
||||
)
|
||||
|
||||
# don't kick off indexing for `NOT_APPLICABLE` sources
|
||||
if connector.source == DocumentSource.NOT_APPLICABLE:
|
||||
@@ -517,7 +516,7 @@ def should_index(
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_indexing_task(
|
||||
def try_creating_docfetching_task(
|
||||
celery_app: Celery,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
search_settings: SearchSettings,
|
||||
@@ -592,16 +591,18 @@ def try_creating_indexing_task(
|
||||
custom_task_id = redis_connector_index.generate_generator_task_id()
|
||||
|
||||
# Determine which queue to use based on whether this is a user file
|
||||
# TODO: at the moment the indexing pipeline is
|
||||
# shared between user files and connectors
|
||||
queue = (
|
||||
OnyxCeleryQueues.USER_FILES_INDEXING
|
||||
if cc_pair.is_user_file
|
||||
else OnyxCeleryQueues.CONNECTOR_INDEXING
|
||||
else OnyxCeleryQueues.CONNECTOR_DOC_FETCHING
|
||||
)
|
||||
|
||||
# when the task is sent, we have yet to finish setting up the fence
|
||||
# therefore, the task must contain code that blocks until the fence is ready
|
||||
result = celery_app.send_task(
|
||||
OnyxCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
|
||||
OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
|
||||
kwargs=dict(
|
||||
index_attempt_id=index_attempt_id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
@@ -613,7 +614,7 @@ def try_creating_indexing_task(
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
if not result:
|
||||
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
|
||||
raise RuntimeError("send_task for connector_doc_fetching_task failed.")
|
||||
|
||||
# now fill out the fence with the rest of the data
|
||||
redis_connector_index.set_active()
|
||||
|
||||
110
backend/onyx/background/celery/tasks/models.py
Normal file
110
backend/onyx/background/celery/tasks/models.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ConnectorIndexingContext(BaseModel):
|
||||
tenant_id: str
|
||||
cc_pair_id: int
|
||||
search_settings_id: int
|
||||
index_attempt_id: int
|
||||
|
||||
|
||||
class IndexingWatchdogTerminalStatus(str, Enum):
|
||||
"""The different statuses the watchdog can finish with.
|
||||
|
||||
TODO: create broader success/failure/abort categories
|
||||
"""
|
||||
|
||||
UNDEFINED = "undefined"
|
||||
|
||||
SUCCEEDED = "succeeded"
|
||||
|
||||
SPAWN_FAILED = "spawn_failed" # connector spawn failed
|
||||
SPAWN_NOT_ALIVE = (
|
||||
"spawn_not_alive" # spawn succeeded but process did not come alive
|
||||
)
|
||||
|
||||
BLOCKED_BY_DELETION = "blocked_by_deletion"
|
||||
BLOCKED_BY_STOP_SIGNAL = "blocked_by_stop_signal"
|
||||
FENCE_NOT_FOUND = "fence_not_found" # fence does not exist
|
||||
FENCE_READINESS_TIMEOUT = (
|
||||
"fence_readiness_timeout" # fence exists but wasn't ready within the timeout
|
||||
)
|
||||
FENCE_MISMATCH = "fence_mismatch" # task and fence metadata mismatch
|
||||
TASK_ALREADY_RUNNING = "task_already_running" # task appears to be running already
|
||||
INDEX_ATTEMPT_MISMATCH = (
|
||||
"index_attempt_mismatch" # expected index attempt metadata not found in db
|
||||
)
|
||||
|
||||
CONNECTOR_VALIDATION_ERROR = (
|
||||
"connector_validation_error" # the connector validation failed
|
||||
)
|
||||
CONNECTOR_EXCEPTIONED = "connector_exceptioned" # the connector itself exceptioned
|
||||
WATCHDOG_EXCEPTIONED = "watchdog_exceptioned" # the watchdog exceptioned
|
||||
|
||||
# the watchdog received a termination signal
|
||||
TERMINATED_BY_SIGNAL = "terminated_by_signal"
|
||||
|
||||
# the watchdog terminated the task due to no activity
|
||||
TERMINATED_BY_ACTIVITY_TIMEOUT = "terminated_by_activity_timeout"
|
||||
|
||||
# NOTE: this may actually be the same as SIGKILL, but parsed differently by python
|
||||
# consolidate once we know more
|
||||
OUT_OF_MEMORY = "out_of_memory"
|
||||
|
||||
PROCESS_SIGNAL_SIGKILL = "process_signal_sigkill"
|
||||
|
||||
@property
|
||||
def code(self) -> int:
|
||||
_ENUM_TO_CODE: dict[IndexingWatchdogTerminalStatus, int] = {
|
||||
IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL: -9,
|
||||
IndexingWatchdogTerminalStatus.OUT_OF_MEMORY: 137,
|
||||
IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR: 247,
|
||||
IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION: 248,
|
||||
IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL: 249,
|
||||
IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND: 250,
|
||||
IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT: 251,
|
||||
IndexingWatchdogTerminalStatus.FENCE_MISMATCH: 252,
|
||||
IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING: 253,
|
||||
IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH: 254,
|
||||
IndexingWatchdogTerminalStatus.CONNECTOR_EXCEPTIONED: 255,
|
||||
}
|
||||
|
||||
return _ENUM_TO_CODE[self]
|
||||
|
||||
@classmethod
|
||||
def from_code(cls, code: int) -> "IndexingWatchdogTerminalStatus":
|
||||
_CODE_TO_ENUM: dict[int, IndexingWatchdogTerminalStatus] = {
|
||||
-9: IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL,
|
||||
137: IndexingWatchdogTerminalStatus.OUT_OF_MEMORY,
|
||||
247: IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR,
|
||||
248: IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION,
|
||||
249: IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL,
|
||||
250: IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND,
|
||||
251: IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT,
|
||||
252: IndexingWatchdogTerminalStatus.FENCE_MISMATCH,
|
||||
253: IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING,
|
||||
254: IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH,
|
||||
255: IndexingWatchdogTerminalStatus.CONNECTOR_EXCEPTIONED,
|
||||
}
|
||||
|
||||
if code in _CODE_TO_ENUM:
|
||||
return _CODE_TO_ENUM[code]
|
||||
|
||||
return IndexingWatchdogTerminalStatus.UNDEFINED
|
||||
|
||||
|
||||
class SimpleJobResult:
|
||||
"""The data we want to have when the watchdog finishes"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.status = IndexingWatchdogTerminalStatus.UNDEFINED
|
||||
self.connector_source = None
|
||||
self.exit_code = None
|
||||
self.exception_str = None
|
||||
|
||||
status: IndexingWatchdogTerminalStatus
|
||||
connector_source: str | None
|
||||
exit_code: int | None
|
||||
exception_str: str | None
|
||||
@@ -882,7 +882,13 @@ def monitor_celery_queues_helper(
|
||||
|
||||
r_celery = task.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length("celery", r_celery)
|
||||
n_indexing = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery)
|
||||
n_docfetching = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
|
||||
)
|
||||
n_docprocessing = celery_get_queue_length(OnyxCeleryQueues.DOCPROCESSING, r_celery)
|
||||
n_user_files_indexing = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILES_INDEXING, r_celery
|
||||
)
|
||||
n_sync = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
|
||||
n_deletion = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
|
||||
n_pruning = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery)
|
||||
@@ -896,14 +902,20 @@ def monitor_celery_queues_helper(
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
|
||||
)
|
||||
|
||||
n_indexing_prefetched = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
n_docfetching_prefetched = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
|
||||
)
|
||||
n_docprocessing_prefetched = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.DOCPROCESSING, r_celery
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Queue lengths: celery={n_celery} "
|
||||
f"indexing={n_indexing} "
|
||||
f"indexing_prefetched={len(n_indexing_prefetched)} "
|
||||
f"docfetching={n_docfetching} "
|
||||
f"docfetching_prefetched={len(n_docfetching_prefetched)} "
|
||||
f"docprocessing={n_docprocessing} "
|
||||
f"docprocessing_prefetched={len(n_docprocessing_prefetched)} "
|
||||
f"user_files_indexing={n_user_files_indexing} "
|
||||
f"sync={n_sync} "
|
||||
f"deletion={n_deletion} "
|
||||
f"pruning={n_pruning} "
|
||||
|
||||
@@ -441,7 +441,7 @@ def connector_pruning_generator_task(
|
||||
# set thread_local=False since we don't control what thread the indexing/pruning
|
||||
# might run our callback with
|
||||
lock: RedisLock = r.lock(
|
||||
OnyxRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.id}",
|
||||
OnyxRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.cc_pair_id}",
|
||||
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
18
backend/onyx/background/celery/versioned_apps/docfetching.py
Normal file
18
backend/onyx/background/celery/versioned_apps/docfetching.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Factory stub for running celery worker / celery beat.
|
||||
This code is different from the primary/beat stubs because there is no EE version to
|
||||
fetch. Port over the code in those files if we add an EE version of this worker."""
|
||||
|
||||
from celery import Celery
|
||||
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
|
||||
def get_app() -> Celery:
|
||||
from onyx.background.celery.apps.docfetching import celery_app
|
||||
|
||||
return celery_app
|
||||
|
||||
|
||||
app = get_app()
|
||||
@@ -33,7 +33,7 @@ def save_checkpoint(
|
||||
"""Save a checkpoint for a given index attempt to the file store"""
|
||||
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
|
||||
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
file_store.save_file(
|
||||
content=BytesIO(checkpoint.model_dump_json().encode()),
|
||||
display_name=checkpoint_pointer,
|
||||
@@ -52,11 +52,11 @@ def save_checkpoint(
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
db_session: Session, index_attempt_id: int, connector: BaseConnector
|
||||
index_attempt_id: int, connector: BaseConnector
|
||||
) -> ConnectorCheckpoint:
|
||||
"""Load a checkpoint for a given index attempt from the file store"""
|
||||
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
checkpoint_io = file_store.read_file(checkpoint_pointer, mode="rb")
|
||||
checkpoint_data = checkpoint_io.read().decode("utf-8")
|
||||
if isinstance(connector, CheckpointedConnector):
|
||||
@@ -144,7 +144,6 @@ def get_latest_valid_checkpoint(
|
||||
|
||||
try:
|
||||
previous_checkpoint = load_checkpoint(
|
||||
db_session=db_session,
|
||||
index_attempt_id=latest_valid_checkpoint_candidate.id,
|
||||
connector=connector,
|
||||
)
|
||||
@@ -201,7 +200,7 @@ def cleanup_checkpoint(db_session: Session, index_attempt_id: int) -> None:
|
||||
if not index_attempt.checkpoint_pointer:
|
||||
return None
|
||||
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
file_store.delete_file(index_attempt.checkpoint_pointer)
|
||||
|
||||
index_attempt.checkpoint_pointer = None
|
||||
|
||||
@@ -5,7 +5,7 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
from pydantic import BaseModel
|
||||
from celery import Celery
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.access import source_should_fetch_permissions_during_indexing
|
||||
@@ -19,17 +19,22 @@ from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE
|
||||
from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.connectors.connector_runner import ConnectorRunner
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import DocExtractionContext
|
||||
from onyx.connectors.models import DocIndexingContext
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.connector import mark_cc_pair_as_permissions_synced
|
||||
from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_last_successful_attempt_poll_range_end
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair
|
||||
@@ -52,6 +57,7 @@ from onyx.db.index_attempt import update_docs_indexed
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexAttemptError
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.file_store.document_batch_storage import get_document_batch_storage
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
@@ -59,6 +65,7 @@ from onyx.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
)
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.logger import TaskAttemptSingleton
|
||||
from onyx.utils.middleware import make_randomized_onyx_request_id
|
||||
@@ -68,6 +75,7 @@ from onyx.utils.telemetry import RecordType
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
|
||||
@@ -184,21 +192,8 @@ class ConnectorStopSignal(Exception):
|
||||
"""A custom exception used to signal a stop in processing."""
|
||||
|
||||
|
||||
class RunIndexingContext(BaseModel):
|
||||
index_name: str
|
||||
cc_pair_id: int
|
||||
connector_id: int
|
||||
credential_id: int
|
||||
source: DocumentSource
|
||||
earliest_index_time: float
|
||||
from_beginning: bool
|
||||
is_primary: bool
|
||||
should_fetch_permissions_during_indexing: bool
|
||||
search_settings_status: IndexModelStatus
|
||||
|
||||
|
||||
def _check_connector_and_attempt_status(
|
||||
db_session_temp: Session, ctx: RunIndexingContext, index_attempt_id: int
|
||||
db_session_temp: Session, ctx: DocExtractionContext, index_attempt_id: int
|
||||
) -> None:
|
||||
"""
|
||||
Checks the status of the connector credential pair and index attempt.
|
||||
@@ -227,6 +222,7 @@ def _check_connector_and_attempt_status(
|
||||
)
|
||||
|
||||
|
||||
# TODO: delete from here if ends up unused
|
||||
def _check_failure_threshold(
|
||||
total_failures: int,
|
||||
document_count: int,
|
||||
@@ -271,7 +267,12 @@ def _run_indexing(
|
||||
start_time = time.monotonic() # jsut used for logging
|
||||
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
index_attempt_start = get_index_attempt(db_session_temp, index_attempt_id)
|
||||
index_attempt_start = get_index_attempt(
|
||||
db_session_temp,
|
||||
index_attempt_id,
|
||||
eager_load_cc_pair=True,
|
||||
eager_load_search_settings=True,
|
||||
)
|
||||
if not index_attempt_start:
|
||||
raise ValueError(
|
||||
f"Index attempt {index_attempt_id} does not exist in DB. This should not be possible."
|
||||
@@ -292,7 +293,7 @@ def _run_indexing(
|
||||
index_attempt_start.connector_credential_pair.last_successful_index_time
|
||||
is not None
|
||||
)
|
||||
ctx = RunIndexingContext(
|
||||
ctx = DocExtractionContext(
|
||||
index_name=index_attempt_start.search_settings.index_name,
|
||||
cc_pair_id=index_attempt_start.connector_credential_pair.id,
|
||||
connector_id=db_connector.id,
|
||||
@@ -317,6 +318,7 @@ def _run_indexing(
|
||||
and (from_beginning or not has_successful_attempt)
|
||||
),
|
||||
search_settings_status=index_attempt_start.search_settings.status,
|
||||
doc_extraction_complete_batch_num=None,
|
||||
)
|
||||
|
||||
last_successful_index_poll_range_end = (
|
||||
@@ -416,7 +418,9 @@ def _run_indexing(
|
||||
index_attempt: IndexAttempt | None = None
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
|
||||
index_attempt = get_index_attempt(
|
||||
db_session_temp, index_attempt_id, eager_load_cc_pair=True
|
||||
)
|
||||
if not index_attempt:
|
||||
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
|
||||
|
||||
@@ -456,6 +460,7 @@ def _run_indexing(
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
# TODO: ensure this logic is moved correctly
|
||||
unresolved_errors = get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
unresolved_only=True,
|
||||
@@ -815,6 +820,7 @@ def _run_indexing(
|
||||
|
||||
|
||||
def run_indexing_entrypoint(
|
||||
app: Celery,
|
||||
index_attempt_id: int,
|
||||
tenant_id: str,
|
||||
connector_credential_pair_id: int,
|
||||
@@ -852,8 +858,14 @@ def run_indexing_entrypoint(
|
||||
f"credentials='{credential_id}'"
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
|
||||
connector_document_extraction(
|
||||
app,
|
||||
index_attempt_id,
|
||||
attempt.connector_credential_pair_id,
|
||||
attempt.search_settings_id,
|
||||
tenant_id,
|
||||
callback,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Indexing finished{tenant_str}: "
|
||||
@@ -861,3 +873,557 @@ def run_indexing_entrypoint(
|
||||
f"config='{connector_config}' "
|
||||
f"credentials='{credential_id}'"
|
||||
)
|
||||
|
||||
|
||||
def connector_document_extraction(
|
||||
app: Celery,
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
tenant_id: str,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> None:
|
||||
"""Extract documents from connector and queue them for indexing pipeline processing.
|
||||
|
||||
This is the first part of the split indexing process that runs the connector
|
||||
and extracts documents, storing them in local files or S3 for later processing.
|
||||
"""
|
||||
|
||||
start_time = time.monotonic()
|
||||
|
||||
logger.info(
|
||||
f"Document extraction starting: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
|
||||
# Get batch storage (transition to IN_PROGRESS is handled by run_indexing_entrypoint)
|
||||
batch_storage = get_document_batch_storage(tenant_id, index_attempt_id)
|
||||
|
||||
# Initialize memory tracer. NOTE: won't actually do anything if
|
||||
# `INDEXING_TRACER_INTERVAL` is 0.
|
||||
memory_tracer = MemoryTracer(interval=INDEXING_TRACER_INTERVAL)
|
||||
memory_tracer.start()
|
||||
|
||||
index_attempt = None
|
||||
# comes from _run_indexing
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session,
|
||||
index_attempt_id,
|
||||
eager_load_cc_pair=True,
|
||||
eager_load_search_settings=True,
|
||||
)
|
||||
if not index_attempt:
|
||||
raise RuntimeError(f"Index attempt {index_attempt_id} not found")
|
||||
|
||||
if index_attempt.search_settings is None:
|
||||
raise ValueError("Search settings must be set for indexing")
|
||||
|
||||
# Clear the indexing trigger if it was set, to prevent duplicate indexing attempts
|
||||
if index_attempt.connector_credential_pair.indexing_trigger is not None:
|
||||
logger.info(
|
||||
"Clearing indexing trigger: "
|
||||
f"cc_pair={index_attempt.connector_credential_pair.id} "
|
||||
f"trigger={index_attempt.connector_credential_pair.indexing_trigger}"
|
||||
)
|
||||
mark_ccpair_with_indexing_trigger(
|
||||
index_attempt.connector_credential_pair.id, None, db_session
|
||||
)
|
||||
|
||||
db_connector = index_attempt.connector_credential_pair.connector
|
||||
db_credential = index_attempt.connector_credential_pair.credential
|
||||
is_primary = index_attempt.search_settings.status == IndexModelStatus.PRESENT
|
||||
from_beginning = index_attempt.from_beginning
|
||||
has_successful_attempt = (
|
||||
index_attempt.connector_credential_pair.last_successful_index_time
|
||||
is not None
|
||||
)
|
||||
ctx = DocExtractionContext(
|
||||
index_name=index_attempt.search_settings.index_name,
|
||||
cc_pair_id=cc_pair_id,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
source=db_connector.source,
|
||||
earliest_index_time=(
|
||||
db_connector.indexing_start.timestamp()
|
||||
if db_connector.indexing_start
|
||||
else 0
|
||||
),
|
||||
from_beginning=index_attempt.from_beginning,
|
||||
# Only update cc-pair status for primary index jobs
|
||||
# Secondary index syncs at the end when swapping
|
||||
is_primary=is_primary,
|
||||
search_settings_status=index_attempt.search_settings.status,
|
||||
doc_extraction_complete_batch_num=None, # None means not completed yet
|
||||
should_fetch_permissions_during_indexing=(
|
||||
index_attempt.connector_credential_pair.access_type == AccessType.SYNC
|
||||
and source_should_fetch_permissions_during_indexing(db_connector.source)
|
||||
and is_primary
|
||||
# if we've already successfully indexed, let the doc_sync job
|
||||
# take care of doc-level permissions
|
||||
and (from_beginning or not has_successful_attempt)
|
||||
),
|
||||
)
|
||||
|
||||
# Set up time windows for polling
|
||||
last_successful_index_poll_range_end = (
|
||||
ctx.earliest_index_time
|
||||
if ctx.from_beginning
|
||||
else get_last_successful_attempt_poll_range_end(
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
earliest_index=ctx.earliest_index_time,
|
||||
search_settings=index_attempt.search_settings,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
|
||||
if last_successful_index_poll_range_end > POLL_CONNECTOR_OFFSET:
|
||||
window_start = datetime.fromtimestamp(
|
||||
last_successful_index_poll_range_end, tz=timezone.utc
|
||||
) - timedelta(minutes=POLL_CONNECTOR_OFFSET)
|
||||
else:
|
||||
# don't go into "negative" time if we've never indexed before
|
||||
window_start = datetime.fromtimestamp(0, tz=timezone.utc)
|
||||
|
||||
most_recent_attempt = next(
|
||||
iter(
|
||||
get_recent_completed_attempts_for_cc_pair(
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
search_settings_id=index_attempt.search_settings_id,
|
||||
db_session=db_session,
|
||||
limit=1,
|
||||
)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
# if the last attempt failed, try and use the same window. This is necessary
|
||||
# to ensure correctness with checkpointing. If we don't do this, things like
|
||||
# new slack channels could be missed (since existing slack channels are
|
||||
# cached as part of the checkpoint).
|
||||
if (
|
||||
most_recent_attempt
|
||||
and most_recent_attempt.poll_range_end
|
||||
and (
|
||||
most_recent_attempt.status == IndexingStatus.FAILED
|
||||
or most_recent_attempt.status == IndexingStatus.CANCELED
|
||||
)
|
||||
):
|
||||
window_end = most_recent_attempt.poll_range_end
|
||||
else:
|
||||
window_end = datetime.now(tz=timezone.utc)
|
||||
|
||||
batch_storage.store_extraction_state(ctx)
|
||||
# initialize indexing state
|
||||
current_state = DocIndexingContext(
|
||||
batches_done=0,
|
||||
total_failures=0,
|
||||
net_doc_change=0,
|
||||
total_chunks=0,
|
||||
)
|
||||
batch_storage.store_indexing_state(current_state)
|
||||
|
||||
# TODO: maybe memory tracer here
|
||||
|
||||
# Set up connector runner
|
||||
connector_runner = _get_connector_runner(
|
||||
db_session=db_session,
|
||||
attempt=index_attempt,
|
||||
batch_size=INDEX_BATCH_SIZE,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
include_permissions=ctx.should_fetch_permissions_during_indexing,
|
||||
)
|
||||
|
||||
# don't use a checkpoint if we're explicitly indexing from
|
||||
# the beginning in order to avoid weird interactions between
|
||||
# checkpointing / failure handling
|
||||
# OR
|
||||
# if the last attempt was successful
|
||||
if index_attempt.from_beginning or (
|
||||
most_recent_attempt and most_recent_attempt.status.is_successful()
|
||||
):
|
||||
checkpoint = connector_runner.connector.build_dummy_checkpoint()
|
||||
else:
|
||||
checkpoint = get_latest_valid_checkpoint(
|
||||
db_session=db_session,
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
search_settings_id=index_attempt.search_settings_id,
|
||||
window_start=window_start,
|
||||
window_end=window_end,
|
||||
connector=connector_runner.connector,
|
||||
)
|
||||
|
||||
# Save initial checkpoint
|
||||
save_checkpoint(
|
||||
db_session=db_session,
|
||||
index_attempt_id=index_attempt_id,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
try:
|
||||
batch_num = 0
|
||||
total_doc_batches_queued = 0
|
||||
total_failures = 0
|
||||
document_count = 0
|
||||
|
||||
# Main extraction loop
|
||||
while checkpoint.has_more:
|
||||
logger.info(
|
||||
f"Running '{ctx.source.value}' connector with checkpoint: {checkpoint}"
|
||||
)
|
||||
for document_batch, failure, next_checkpoint in connector_runner.run(
|
||||
checkpoint
|
||||
):
|
||||
# Check if connector is disabled mid run and stop if so unless it's the secondary
|
||||
# index being built. We want to populate it even for paused connectors
|
||||
# Often paused connectors are sources that aren't updated frequently but the
|
||||
# contents still need to be initially pulled.
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise ConnectorStopSignal("Connector stop signal detected")
|
||||
|
||||
# NOTE: this progress callback runs on every loop. We've seen cases
|
||||
# where we loop many times with no new documents and eventually time
|
||||
# out, so only doing the callback after indexing isn't sufficient.
|
||||
# TODO: change to doc extraction if it doesnt break things
|
||||
callback.progress("_run_indexing", 0)
|
||||
|
||||
# will exception if the connector/index attempt is marked as paused/failed
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
_check_connector_and_attempt_status(
|
||||
db_session, ctx, index_attempt_id
|
||||
)
|
||||
|
||||
# save record of any failures at the connector level
|
||||
if failure is not None:
|
||||
total_failures += 1
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
create_index_attempt_error(
|
||||
index_attempt_id,
|
||||
ctx.cc_pair_id,
|
||||
failure,
|
||||
db_session,
|
||||
)
|
||||
|
||||
_check_failure_threshold(
|
||||
total_failures, document_count, batch_num, failure
|
||||
)
|
||||
|
||||
# Save checkpoint if provided
|
||||
if next_checkpoint:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
# below is all document processing task, so if no batch we can just continue
|
||||
if document_batch is None:
|
||||
continue
|
||||
|
||||
# Clean documents and create batch
|
||||
doc_batch_cleaned = strip_null_characters(document_batch)
|
||||
batch_description = []
|
||||
|
||||
for doc in doc_batch_cleaned:
|
||||
batch_description.append(doc.to_short_descriptor())
|
||||
|
||||
doc_size = 0
|
||||
for section in doc.sections:
|
||||
if (
|
||||
isinstance(section, TextSection)
|
||||
and section.text is not None
|
||||
):
|
||||
doc_size += len(section.text)
|
||||
|
||||
if doc_size > INDEXING_SIZE_WARNING_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Document size: doc='{doc.to_short_descriptor()}' "
|
||||
f"size={doc_size} "
|
||||
f"threshold={INDEXING_SIZE_WARNING_THRESHOLD}"
|
||||
)
|
||||
|
||||
logger.debug(f"Indexing batch of documents: {batch_description}")
|
||||
memory_tracer.increment_and_maybe_trace()
|
||||
# TODO: replicate index_attempt_md from _run_indexing, store in blob storage
|
||||
# instead of that big extraction state. index_attempt_md will be the
|
||||
# persisted communication between the connector extraction task and the
|
||||
# document processing task
|
||||
batch_id = f"batch_{batch_num}"
|
||||
|
||||
# Store documents in storage
|
||||
batch_storage.store_batch(batch_id, doc_batch_cleaned)
|
||||
|
||||
batch_storage.store_extraction_state(ctx)
|
||||
|
||||
# Create processing task data
|
||||
processing_batch_data = {
|
||||
"batch_id": batch_id,
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"cc_pair_id": cc_pair_id,
|
||||
"tenant_id": tenant_id,
|
||||
"batch_num": batch_num, # 0-indexed
|
||||
}
|
||||
|
||||
# Queue document processing task
|
||||
app.send_task(
|
||||
OnyxCeleryTask.DOCPROCESSING_TASK,
|
||||
kwargs=processing_batch_data,
|
||||
queue=OnyxCeleryQueues.DOCPROCESSING,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
total_doc_batches_queued += 1
|
||||
|
||||
logger.info(
|
||||
f"Queued document processing batch: "
|
||||
f"batch_id={batch_id} "
|
||||
f"docs={len(doc_batch_cleaned)} "
|
||||
f"attempt={index_attempt_id}"
|
||||
)
|
||||
|
||||
# Check checkpoint size periodically
|
||||
CHECKPOINT_SIZE_CHECK_INTERVAL = 100
|
||||
if batch_num % CHECKPOINT_SIZE_CHECK_INTERVAL == 0:
|
||||
check_checkpoint_size(checkpoint)
|
||||
|
||||
# Save latest checkpoint
|
||||
# NOTE: checkpointing is used to track which batches have
|
||||
# been stored, NOT which batches have been fully indexed
|
||||
# as it used to be.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
save_checkpoint(
|
||||
db_session=db_session,
|
||||
index_attempt_id=index_attempt_id,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
elapsed_time = time.monotonic() - start_time
|
||||
|
||||
logger.info(
|
||||
f"Document extraction completed: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"batches_queued={total_doc_batches_queued} "
|
||||
f"elapsed={elapsed_time:.2f}s"
|
||||
)
|
||||
|
||||
final_state = batch_storage.get_extraction_state()
|
||||
if final_state is None:
|
||||
raise RuntimeError("Extraction state should not be None")
|
||||
|
||||
# counts how many batches were queued
|
||||
final_state.doc_extraction_complete_batch_num = batch_num
|
||||
batch_storage.store_extraction_state(final_state)
|
||||
|
||||
check_indexing_completion(index_attempt_id, tenant_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Document extraction failed: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"error={str(e)}"
|
||||
)
|
||||
|
||||
# Clean up on failure
|
||||
try:
|
||||
batch_storage.cleanup_all_batches()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to clean up document batches after extraction failure"
|
||||
)
|
||||
|
||||
if isinstance(e, ConnectorValidationError):
|
||||
# On validation errors during indexing, we want to cancel the indexing attempt
|
||||
# and mark the CCPair as invalid. This prevents the connector from being
|
||||
# used in the future until the credentials are updated.
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
logger.exception(
|
||||
f"Marking attempt {index_attempt_id} as canceled due to validation error."
|
||||
)
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
reason=f"{CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX}{str(e)}",
|
||||
)
|
||||
|
||||
if is_primary:
|
||||
if not index_attempt:
|
||||
# should always be set by now
|
||||
raise RuntimeError("Should never happen.")
|
||||
|
||||
VALIDATION_ERROR_THRESHOLD = 5
|
||||
|
||||
recent_index_attempts = get_recent_completed_attempts_for_cc_pair(
|
||||
cc_pair_id=cc_pair_id,
|
||||
search_settings_id=index_attempt.search_settings_id,
|
||||
limit=VALIDATION_ERROR_THRESHOLD,
|
||||
db_session=db_session_temp,
|
||||
)
|
||||
num_validation_errors = len(
|
||||
[
|
||||
index_attempt
|
||||
for index_attempt in recent_index_attempts
|
||||
if index_attempt.error_msg
|
||||
and index_attempt.error_msg.startswith(
|
||||
CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if num_validation_errors >= VALIDATION_ERROR_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Connector {ctx.connector_id} has {num_validation_errors} consecutive validation"
|
||||
f" errors. Marking the CC Pair as invalid."
|
||||
)
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
status=ConnectorCredentialPairStatus.INVALID,
|
||||
)
|
||||
memory_tracer.stop()
|
||||
raise e
|
||||
elif isinstance(e, ConnectorStopSignal):
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
logger.exception(
|
||||
f"Marking attempt {index_attempt_id} as canceled due to stop signal."
|
||||
)
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
reason=str(e),
|
||||
)
|
||||
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
)
|
||||
|
||||
memory_tracer.stop()
|
||||
else:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
failure_reason=str(e),
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
)
|
||||
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
)
|
||||
|
||||
memory_tracer.stop()
|
||||
raise e
|
||||
|
||||
|
||||
def check_indexing_completion(
|
||||
index_attempt_id: int,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
|
||||
logger.info(
|
||||
f"Checking for document processing completion: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
|
||||
storage = get_document_batch_storage(tenant_id, index_attempt_id)
|
||||
|
||||
try:
|
||||
last_progress_time = time.monotonic()
|
||||
last_batches_completed = 0
|
||||
# TODO: if there are no indexing workers running, fail the attempt
|
||||
while True:
|
||||
# Get current state
|
||||
indexing_state = storage.ensure_indexing_state()
|
||||
extraction_state = storage.ensure_extraction_state()
|
||||
|
||||
# Check if extraction is complete and all batches are processed
|
||||
batches_total = extraction_state.doc_extraction_complete_batch_num
|
||||
batches_processed = indexing_state.batches_done
|
||||
extraction_completed = batches_total is not None and (
|
||||
batches_processed >= batches_total
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Processing status: "
|
||||
f"extraction_completed={extraction_completed} "
|
||||
f"batches_processed={batches_processed}/{batches_total or '?'}" # probably off by 1
|
||||
)
|
||||
|
||||
if extraction_completed:
|
||||
break
|
||||
|
||||
if batches_processed > last_batches_completed:
|
||||
last_batches_completed = batches_processed
|
||||
last_progress_time = time.monotonic()
|
||||
elif time.monotonic() - last_progress_time > 3600 * 6:
|
||||
raise RuntimeError(
|
||||
f"Indexing attempt {index_attempt_id} has been indexing for 6 hours without progress. "
|
||||
f"Marking it as failed."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"All batches for index attempt {index_attempt_id} have been processed."
|
||||
)
|
||||
|
||||
# All processing is complete
|
||||
total_failures = indexing_state.total_failures
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
if total_failures == 0:
|
||||
mark_attempt_succeeded(index_attempt_id, db_session)
|
||||
logger.info(f"Index attempt {index_attempt_id} completed successfully")
|
||||
else:
|
||||
mark_attempt_partially_succeeded(index_attempt_id, db_session)
|
||||
logger.info(
|
||||
f"Index attempt {index_attempt_id} completed with {total_failures} failures"
|
||||
)
|
||||
|
||||
# Clean up all remaining storage
|
||||
storage.cleanup_all_batches()
|
||||
|
||||
# Clean up the fence to allow the next run to start
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
index_attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if index_attempt and index_attempt.search_settings:
|
||||
redis_connector = RedisConnector(
|
||||
tenant_id, index_attempt.connector_credential_pair_id
|
||||
)
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
index_attempt.search_settings.id
|
||||
)
|
||||
redis_connector_index.reset()
|
||||
logger.info(f"Resetting indexing fence for attempt {index_attempt_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to monitor document processing completion: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"error={str(e)}"
|
||||
)
|
||||
|
||||
# Mark the attempt as failed if monitoring fails
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
failure_reason=f"Processing monitoring failed: {str(e)}",
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to mark attempt as failed")
|
||||
|
||||
# Try to clean up storage
|
||||
try:
|
||||
storage.cleanup_all_batches()
|
||||
except Exception:
|
||||
logger.exception("Failed to cleanup storage after monitoring failure")
|
||||
|
||||
@@ -725,9 +725,7 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
# load all files needed for this chat chain in memory
|
||||
files = load_all_chat_files(
|
||||
history_msgs, new_msg_req.file_descriptors, db_session
|
||||
)
|
||||
files = load_all_chat_files(history_msgs, new_msg_req.file_descriptors)
|
||||
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
|
||||
latest_query_files = [file for file in files if file.file_id in req_file_ids]
|
||||
user_file_ids = new_msg_req.user_file_ids or []
|
||||
|
||||
@@ -324,6 +324,19 @@ except ValueError:
|
||||
CELERY_WORKER_KG_PROCESSING_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_KG_PROCESSING_CONCURRENCY") or 4
|
||||
)
|
||||
CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT = 1
|
||||
try:
|
||||
env_value = os.environ.get("CELERY_WORKER_DOCFETCHING_CONCURRENCY")
|
||||
if not env_value:
|
||||
env_value = os.environ.get("NUM_DOCFETCHING_WORKERS")
|
||||
|
||||
if not env_value:
|
||||
env_value = str(CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT)
|
||||
CELERY_WORKER_DOCFETCHING_CONCURRENCY = int(env_value)
|
||||
except ValueError:
|
||||
CELERY_WORKER_DOCFETCHING_CONCURRENCY = (
|
||||
CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT
|
||||
)
|
||||
|
||||
# The maximum number of tasks that can be queued up to sync to Vespa in a single pass
|
||||
VESPA_SYNC_MAX_TASKS = 1024
|
||||
|
||||
@@ -66,6 +66,7 @@ POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
|
||||
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
|
||||
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_APP_NAME = "celery_worker_indexing"
|
||||
POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME = "celery_worker_docfetching"
|
||||
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
|
||||
POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME = "celery_worker_kg_processing"
|
||||
@@ -320,9 +321,12 @@ class OnyxCeleryQueues:
|
||||
CSV_GENERATION = "csv_generation"
|
||||
|
||||
# Indexing queue
|
||||
CONNECTOR_INDEXING = "connector_indexing"
|
||||
USER_FILES_INDEXING = "user_files_indexing"
|
||||
|
||||
# Document processing pipeline queue
|
||||
DOCPROCESSING = "docprocessing"
|
||||
CONNECTOR_DOC_FETCHING = "connector_doc_fetching"
|
||||
|
||||
# Monitoring queue
|
||||
MONITORING = "monitoring"
|
||||
|
||||
@@ -453,7 +457,11 @@ class OnyxCeleryTask:
|
||||
CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK = (
|
||||
"connector_external_group_sync_generator_task"
|
||||
)
|
||||
CONNECTOR_INDEXING_PROXY_TASK = "connector_indexing_proxy_task"
|
||||
|
||||
# New split indexing tasks
|
||||
CONNECTOR_DOC_FETCHING_TASK = "connector_doc_fetching_task"
|
||||
DOCPROCESSING_TASK = "docprocessing_task"
|
||||
|
||||
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
|
||||
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
|
||||
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"
|
||||
|
||||
@@ -34,7 +34,6 @@ from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.file_processing.extract_file_text import extract_text_and_images
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.extract_file_text import is_accepted_file_ext
|
||||
@@ -281,30 +280,28 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
|
||||
# TODO: Refactor to avoid direct DB access in connector
|
||||
# This will require broader refactoring across the codebase
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
image_section, _ = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=downloaded_file,
|
||||
file_id=f"{self.bucket_type}_{self.bucket_name}_{key.replace('/', '_')}",
|
||||
display_name=file_name,
|
||||
link=link,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
image_section, _ = store_image_and_create_section(
|
||||
image_data=downloaded_file,
|
||||
file_id=f"{self.bucket_type}_{self.bucket_name}_{key.replace('/', '_')}",
|
||||
display_name=file_name,
|
||||
link=link,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
|
||||
sections=[image_section],
|
||||
source=DocumentSource(self.bucket_type.value),
|
||||
semantic_identifier=file_name,
|
||||
doc_updated_at=last_modified,
|
||||
metadata={},
|
||||
)
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
|
||||
sections=[image_section],
|
||||
source=DocumentSource(self.bucket_type.value),
|
||||
semantic_identifier=file_name,
|
||||
doc_updated_at=last_modified,
|
||||
metadata={},
|
||||
)
|
||||
)
|
||||
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
except Exception:
|
||||
logger.exception(f"Error processing image {key}")
|
||||
continue
|
||||
|
||||
@@ -23,7 +23,6 @@ from onyx.configs.app_configs import (
|
||||
)
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import is_accepted_file_ext
|
||||
from onyx.file_processing.extract_file_text import OnyxExtensionType
|
||||
@@ -224,19 +223,17 @@ def _process_image_attachment(
|
||||
"""Process an image attachment by saving it without generating a summary."""
|
||||
try:
|
||||
# Use the standardized image storage and section creation
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
section, file_name = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=raw_bytes,
|
||||
file_id=Path(attachment["id"]).name,
|
||||
display_name=attachment["title"],
|
||||
media_type=media_type,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
logger.info(f"Stored image attachment with file name: {file_name}")
|
||||
section, file_name = store_image_and_create_section(
|
||||
image_data=raw_bytes,
|
||||
file_id=Path(attachment["id"]).name,
|
||||
display_name=attachment["title"],
|
||||
media_type=media_type,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
logger.info(f"Stored image attachment with file name: {file_name}")
|
||||
|
||||
# Return empty text but include the file_name for later processing
|
||||
return AttachmentProcessingResult(text="", file_name=file_name, error=None)
|
||||
# Return empty text but include the file_name for later processing
|
||||
return AttachmentProcessingResult(text="", file_name=file_name, error=None)
|
||||
except Exception as e:
|
||||
msg = f"Image storage failed for {attachment['title']}: {e}"
|
||||
logger.error(msg, exc_info=e)
|
||||
|
||||
@@ -5,8 +5,6 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import IO
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FileOrigin
|
||||
@@ -18,7 +16,6 @@ from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.file_processing.extract_file_text import extract_text_and_images
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.extract_file_text import is_accepted_file_ext
|
||||
@@ -32,7 +29,6 @@ logger = setup_logger()
|
||||
|
||||
def _create_image_section(
|
||||
image_data: bytes,
|
||||
db_session: Session,
|
||||
parent_file_name: str,
|
||||
display_name: str,
|
||||
link: str | None = None,
|
||||
@@ -58,7 +54,6 @@ def _create_image_section(
|
||||
# Store the image and create a section
|
||||
try:
|
||||
section, stored_file_name = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=image_data,
|
||||
file_id=file_id,
|
||||
display_name=display_name,
|
||||
@@ -77,7 +72,6 @@ def _process_file(
|
||||
file: IO[Any],
|
||||
metadata: dict[str, Any] | None,
|
||||
pdf_pass: str | None,
|
||||
db_session: Session,
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Process a file and return a list of Documents.
|
||||
@@ -125,7 +119,6 @@ def _process_file(
|
||||
try:
|
||||
section, _ = _create_image_section(
|
||||
image_data=image_data,
|
||||
db_session=db_session,
|
||||
parent_file_name=file_id,
|
||||
display_name=title,
|
||||
)
|
||||
@@ -196,7 +189,6 @@ def _process_file(
|
||||
try:
|
||||
image_section, stored_file_name = _create_image_section(
|
||||
image_data=img_data,
|
||||
db_session=db_session,
|
||||
parent_file_name=file_id,
|
||||
display_name=f"{title} - image {idx}",
|
||||
idx=idx,
|
||||
@@ -260,37 +252,33 @@ class LocalFileConnector(LoadConnector):
|
||||
"""
|
||||
documents: list[Document] = []
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for file_id in self.file_locations:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_record = file_store.read_file_record(file_id=file_id)
|
||||
if not file_record:
|
||||
# typically an unsupported extension
|
||||
logger.warning(
|
||||
f"No file record found for '{file_id}' in PG; skipping."
|
||||
)
|
||||
continue
|
||||
for file_id in self.file_locations:
|
||||
file_store = get_default_file_store()
|
||||
file_record = file_store.read_file_record(file_id=file_id)
|
||||
if not file_record:
|
||||
# typically an unsupported extension
|
||||
logger.warning(f"No file record found for '{file_id}' in PG; skipping.")
|
||||
continue
|
||||
|
||||
metadata = self._get_file_metadata(file_id)
|
||||
file_io = file_store.read_file(file_id=file_id, mode="b")
|
||||
new_docs = _process_file(
|
||||
file_id=file_id,
|
||||
file_name=file_record.display_name,
|
||||
file=file_io,
|
||||
metadata=metadata,
|
||||
pdf_pass=self.pdf_pass,
|
||||
db_session=db_session,
|
||||
)
|
||||
documents.extend(new_docs)
|
||||
metadata = self._get_file_metadata(file_id)
|
||||
file_io = file_store.read_file(file_id=file_id, mode="b")
|
||||
new_docs = _process_file(
|
||||
file_id=file_id,
|
||||
file_name=file_record.display_name,
|
||||
file=file_io,
|
||||
metadata=metadata,
|
||||
pdf_pass=self.pdf_pass,
|
||||
)
|
||||
documents.extend(new_docs)
|
||||
|
||||
if len(documents) >= self.batch_size:
|
||||
yield documents
|
||||
|
||||
documents = []
|
||||
|
||||
if documents:
|
||||
if len(documents) >= self.batch_size:
|
||||
yield documents
|
||||
|
||||
documents = []
|
||||
|
||||
if documents:
|
||||
yield documents
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = LocalFileConnector(
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
@@ -1374,3 +1378,139 @@ class GoogleDriveConnector(
|
||||
@override
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> GoogleDriveCheckpoint:
|
||||
return GoogleDriveCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
|
||||
def get_credentials_from_env(email: str, oauth: bool) -> dict:
|
||||
if oauth:
|
||||
raw_credential_string = os.environ["GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR"]
|
||||
else:
|
||||
raw_credential_string = os.environ["GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR"]
|
||||
|
||||
refried_credential_string = json.dumps(json.loads(raw_credential_string))
|
||||
|
||||
# This is the Oauth token
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
|
||||
# This is the service account key
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
|
||||
# The email saved for both auth types
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method"
|
||||
cred_key = (
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY
|
||||
if oauth
|
||||
else DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
|
||||
)
|
||||
return {
|
||||
cred_key: refried_credential_string,
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email,
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD: "uploaded",
|
||||
}
|
||||
|
||||
|
||||
class CheckpointOutputWrapper:
|
||||
"""
|
||||
Wraps a CheckpointOutput generator to give things back in a more digestible format.
|
||||
The connector format is easier for the connector implementor (e.g. it enforces exactly
|
||||
one new checkpoint is returned AND that the checkpoint is at the end), thus the different
|
||||
formats.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.next_checkpoint: GoogleDriveCheckpoint | None = None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
checkpoint_connector_generator: CheckpointOutput[GoogleDriveCheckpoint],
|
||||
) -> Generator[
|
||||
tuple[Document | None, ConnectorFailure | None, GoogleDriveCheckpoint | None],
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
# grabs the final return value and stores it in the `next_checkpoint` variable
|
||||
def _inner_wrapper(
|
||||
checkpoint_connector_generator: CheckpointOutput[GoogleDriveCheckpoint],
|
||||
) -> CheckpointOutput[GoogleDriveCheckpoint]:
|
||||
self.next_checkpoint = yield from checkpoint_connector_generator
|
||||
return self.next_checkpoint # not used
|
||||
|
||||
for document_or_failure in _inner_wrapper(checkpoint_connector_generator):
|
||||
if isinstance(document_or_failure, Document):
|
||||
yield document_or_failure, None, None
|
||||
elif isinstance(document_or_failure, ConnectorFailure):
|
||||
yield None, document_or_failure, None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid document_or_failure type: {type(document_or_failure)}"
|
||||
)
|
||||
|
||||
if self.next_checkpoint is None:
|
||||
raise RuntimeError(
|
||||
"Checkpoint is None. This should never happen - the connector should always return a checkpoint."
|
||||
)
|
||||
|
||||
yield None, None, self.next_checkpoint
|
||||
|
||||
|
||||
def yield_all_docs_from_checkpoint_connector(
|
||||
connector: GoogleDriveConnector,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
) -> Iterator[Document | ConnectorFailure]:
|
||||
num_iterations = 0
|
||||
|
||||
checkpoint = connector.build_dummy_checkpoint()
|
||||
while checkpoint.has_more:
|
||||
doc_batch_generator = CheckpointOutputWrapper()(
|
||||
connector.load_from_checkpoint(start, end, checkpoint)
|
||||
)
|
||||
for document, failure, next_checkpoint in doc_batch_generator:
|
||||
if failure is not None:
|
||||
yield failure
|
||||
if document is not None:
|
||||
yield document
|
||||
if next_checkpoint is not None:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
num_iterations += 1
|
||||
if num_iterations > 100_000:
|
||||
raise RuntimeError("Too many iterations. Infinite loop?")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
creds = get_credentials_from_env(
|
||||
os.environ["GOOGLE_DRIVE_PRIMARY_ADMIN_EMAIL"], False
|
||||
)
|
||||
connector = GoogleDriveConnector(
|
||||
include_shared_drives=True,
|
||||
shared_drive_urls=None,
|
||||
include_my_drives=True,
|
||||
my_drive_emails=None,
|
||||
shared_folder_urls=None,
|
||||
include_files_shared_with_me=True,
|
||||
specific_user_emails=None,
|
||||
)
|
||||
connector.load_credentials(creds)
|
||||
max_fsize = 0
|
||||
biggest_fsize = 0
|
||||
num_errors = 0
|
||||
start_time = time.time()
|
||||
with open("stats.txt", "w") as f:
|
||||
for num, doc_or_failure in enumerate(
|
||||
yield_all_docs_from_checkpoint_connector(connector, 0, time.time())
|
||||
):
|
||||
if num % 200 == 0:
|
||||
f.write(f"Processed {num} files\n")
|
||||
f.write(f"Max file size: {max_fsize/1000_000:.2f} MB\n")
|
||||
f.write(f"Time so far: {time.time() - start_time:.2f} seconds\n")
|
||||
f.write(f"Docs per minute: {num/(time.time() - start_time)*60:.2f}\n")
|
||||
biggest_fsize = max(biggest_fsize, max_fsize)
|
||||
max_fsize = 0
|
||||
if isinstance(doc_or_failure, Document):
|
||||
max_fsize = max(max_fsize, sys.getsizeof(doc_or_failure))
|
||||
elif isinstance(doc_or_failure, ConnectorFailure):
|
||||
num_errors += 1
|
||||
print(f"Num errors: {num_errors}")
|
||||
print(f"Biggest file size: {biggest_fsize/1000_000:.2f} MB")
|
||||
print(f"Time taken: {time.time() - start_time:.2f} seconds")
|
||||
|
||||
@@ -29,7 +29,6 @@ from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.file_processing.extract_file_text import ALL_ACCEPTED_FILE_EXTENSIONS
|
||||
from onyx.file_processing.extract_file_text import docx_to_text_and_images
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
@@ -193,17 +192,15 @@ def _download_and_extract_sections_basic(
|
||||
# For images, store them for later processing
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=response_call(),
|
||||
file_id=file_id,
|
||||
display_name=file_name,
|
||||
media_type=mime_type,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
link=link,
|
||||
)
|
||||
sections.append(section)
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
image_data=response_call(),
|
||||
file_id=file_id,
|
||||
display_name=file_name,
|
||||
media_type=mime_type,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
link=link,
|
||||
)
|
||||
sections.append(section)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process image {file_name}: {e}")
|
||||
return sections
|
||||
@@ -216,16 +213,14 @@ def _download_and_extract_sections_basic(
|
||||
|
||||
# Process embedded images in the PDF
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for idx, (img_data, img_name) in enumerate(images):
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=img_data,
|
||||
file_id=f"{file_id}_img_{idx}",
|
||||
display_name=img_name or f"{file_name} - image {idx}",
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
pdf_sections.append(section)
|
||||
for idx, (img_data, img_name) in enumerate(images):
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
image_data=img_data,
|
||||
file_id=f"{file_id}_img_{idx}",
|
||||
display_name=img_name or f"{file_name} - image {idx}",
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
pdf_sections.append(section)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process PDF images in {file_name}: {e}")
|
||||
return pdf_sections
|
||||
|
||||
@@ -12,7 +12,6 @@ from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.file_processing.extract_file_text import load_files_from_zip
|
||||
from onyx.file_processing.extract_file_text import read_text_file
|
||||
from onyx.file_processing.html_utils import web_html_cleanup
|
||||
@@ -68,10 +67,7 @@ class GoogleSitesConnector(LoadConnector):
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
documents: list[Document] = []
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
file_content_io = get_default_file_store(db_session).read_file(
|
||||
self.zip_path, mode="b"
|
||||
)
|
||||
file_content_io = get_default_file_store().read_file(self.zip_path, mode="b")
|
||||
|
||||
# load the HTML files
|
||||
files = load_files_from_zip(file_content_io)
|
||||
|
||||
@@ -11,6 +11,7 @@ from onyx.access.models import ExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import INDEX_SEPARATOR
|
||||
from onyx.configs.constants import RETURN_SEPARATOR
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.utils.text_processing import make_url_compatible
|
||||
|
||||
|
||||
@@ -373,3 +374,24 @@ class OnyxMetadata(BaseModel):
|
||||
secondary_owners: list[BasicExpertInfo] | None = None
|
||||
doc_updated_at: datetime | None = None
|
||||
title: str | None = None
|
||||
|
||||
|
||||
class DocExtractionContext(BaseModel):
|
||||
index_name: str
|
||||
cc_pair_id: int
|
||||
connector_id: int
|
||||
credential_id: int
|
||||
source: DocumentSource
|
||||
earliest_index_time: float
|
||||
from_beginning: bool
|
||||
is_primary: bool
|
||||
should_fetch_permissions_during_indexing: bool
|
||||
search_settings_status: IndexModelStatus
|
||||
doc_extraction_complete_batch_num: int | None
|
||||
|
||||
|
||||
class DocIndexingContext(BaseModel):
|
||||
batches_done: int
|
||||
total_failures: int
|
||||
net_doc_change: int
|
||||
total_chunks: int
|
||||
|
||||
@@ -25,7 +25,6 @@ from onyx.context.search.models import MAX_METRICS_CONTENT
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.models import RerankMetricsContainer
|
||||
from onyx.context.search.models import SearchQuery
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.document_index.document_index_utils import (
|
||||
translate_boost_count_to_multiplier,
|
||||
)
|
||||
@@ -70,18 +69,16 @@ def update_image_sections_with_query(
|
||||
logger.debug(
|
||||
f"Processing image chunk with ID: {chunk.unique_id}, image: {chunk.image_file_id}"
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
file_record = get_default_file_store(db_session).read_file(
|
||||
cast(str, chunk.image_file_id), mode="b"
|
||||
)
|
||||
if not file_record:
|
||||
logger.error(f"Image file not found: {chunk.image_file_id}")
|
||||
raise Exception("File not found")
|
||||
file_content = file_record.read()
|
||||
image_base64 = base64.b64encode(file_content).decode()
|
||||
logger.debug(
|
||||
f"Successfully loaded image data for {chunk.image_file_id}"
|
||||
)
|
||||
|
||||
file_record = get_default_file_store().read_file(
|
||||
cast(str, chunk.image_file_id), mode="b"
|
||||
)
|
||||
if not file_record:
|
||||
logger.error(f"Image file not found: {chunk.image_file_id}")
|
||||
raise Exception("File not found")
|
||||
file_content = file_record.read()
|
||||
image_base64 = base64.b64encode(file_content).decode()
|
||||
logger.debug(f"Successfully loaded image data for {chunk.image_file_id}")
|
||||
|
||||
messages: list[BaseMessage] = [
|
||||
SystemMessage(content=IMAGE_ANALYSIS_SYSTEM_PROMPT),
|
||||
|
||||
@@ -229,7 +229,7 @@ def delete_messages_and_files_from_chat_session(
|
||||
delete_tool_call_for_message_id(message_id=id, db_session=db_session)
|
||||
delete_search_doc_message_relationship(message_id=id, db_session=db_session)
|
||||
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
for file_info in files or []:
|
||||
file_store.delete_file(file_id=file_info.get("id"))
|
||||
|
||||
|
||||
@@ -258,6 +258,7 @@ def get_connector_credential_pair_from_id_for_user(
|
||||
def get_connector_credential_pair_from_id(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
eager_load_connector: bool = False,
|
||||
eager_load_credential: bool = False,
|
||||
) -> ConnectorCredentialPair | None:
|
||||
stmt = select(ConnectorCredentialPair).distinct()
|
||||
@@ -265,6 +266,8 @@ def get_connector_credential_pair_from_id(
|
||||
|
||||
if eager_load_credential:
|
||||
stmt = stmt.options(joinedload(ConnectorCredentialPair.credential))
|
||||
if eager_load_connector:
|
||||
stmt = stmt.options(joinedload(ConnectorCredentialPair.connector))
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@@ -304,6 +304,18 @@ def get_session_with_current_tenant() -> Generator[Session, None, None]:
|
||||
yield session
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session_with_current_tenant_if_none(
|
||||
session: Session | None,
|
||||
) -> Generator[Session, None, None]:
|
||||
if session is None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as session:
|
||||
yield session
|
||||
else:
|
||||
yield session
|
||||
|
||||
|
||||
# Used in multi tenant mode when need to refer to the shared `public` schema
|
||||
@contextmanager
|
||||
def get_session_with_shared_schema() -> Generator[Session, None, None]:
|
||||
|
||||
@@ -43,6 +43,15 @@ def get_filerecord_by_file_id(
|
||||
return filestore
|
||||
|
||||
|
||||
def get_filerecord_by_prefix(
|
||||
prefix: str,
|
||||
db_session: Session,
|
||||
) -> list[FileRecord]:
|
||||
return (
|
||||
db_session.query(FileRecord).filter(FileRecord.file_id.like(f"{prefix}%")).all()
|
||||
)
|
||||
|
||||
|
||||
def delete_filerecord_by_file_id(
|
||||
file_id: str,
|
||||
db_session: Session,
|
||||
|
||||
@@ -28,6 +28,8 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
|
||||
# from sqlalchemy.sql.selectable import Select
|
||||
|
||||
# Comment out unused imports that cause mypy errors
|
||||
# from onyx.auth.models import UserRole
|
||||
# from onyx.configs.constants import MAX_LAST_VALID_CHECKPOINT_AGE_SECONDS
|
||||
@@ -95,9 +97,25 @@ def get_recent_attempts_for_cc_pair(
|
||||
|
||||
|
||||
def get_index_attempt(
|
||||
db_session: Session, index_attempt_id: int
|
||||
db_session: Session,
|
||||
index_attempt_id: int,
|
||||
eager_load_cc_pair: bool = False,
|
||||
eager_load_search_settings: bool = False,
|
||||
) -> IndexAttempt | None:
|
||||
stmt = select(IndexAttempt).where(IndexAttempt.id == index_attempt_id)
|
||||
if eager_load_cc_pair:
|
||||
stmt = stmt.options(
|
||||
joinedload(IndexAttempt.connector_credential_pair).joinedload(
|
||||
ConnectorCredentialPair.connector
|
||||
)
|
||||
)
|
||||
stmt = stmt.options(
|
||||
joinedload(IndexAttempt.connector_credential_pair).joinedload(
|
||||
ConnectorCredentialPair.credential
|
||||
)
|
||||
)
|
||||
if eager_load_search_settings:
|
||||
stmt = stmt.options(joinedload(IndexAttempt.search_settings))
|
||||
return db_session.scalars(stmt).first()
|
||||
|
||||
|
||||
@@ -373,16 +391,22 @@ def update_docs_indexed(
|
||||
new_docs_indexed: int,
|
||||
docs_removed_from_index: int,
|
||||
) -> None:
|
||||
"""Updates the docs_indexed and new_docs_indexed fields of an index attempt.
|
||||
Adds the given values to the current values in the db"""
|
||||
try:
|
||||
attempt = db_session.execute(
|
||||
select(IndexAttempt)
|
||||
.where(IndexAttempt.id == index_attempt_id)
|
||||
.with_for_update()
|
||||
.with_for_update() # Locks the row when we try to update
|
||||
).scalar_one()
|
||||
|
||||
attempt.total_docs_indexed = total_docs_indexed
|
||||
attempt.new_docs_indexed = new_docs_indexed
|
||||
attempt.docs_removed_from_index = docs_removed_from_index
|
||||
attempt.total_docs_indexed = (
|
||||
attempt.total_docs_indexed or 0
|
||||
) + total_docs_indexed
|
||||
attempt.new_docs_indexed = (attempt.new_docs_indexed or 0) + new_docs_indexed
|
||||
attempt.docs_removed_from_index = (
|
||||
attempt.docs_removed_from_index or 0
|
||||
) + docs_removed_from_index
|
||||
db_session.commit()
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
|
||||
@@ -46,7 +46,7 @@ def create_user_files(
|
||||
|
||||
# NOTE: At the moment, zip metadata is not used for user files.
|
||||
# Should revisit to decide whether this should be a feature.
|
||||
upload_response = upload_files(files, db_session)
|
||||
upload_response = upload_files(files)
|
||||
user_files = []
|
||||
|
||||
for file_path, file in zip(upload_response.file_paths, files):
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from io import BytesIO
|
||||
from typing import Tuple
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
@@ -12,7 +10,6 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def store_image_and_create_section(
|
||||
db_session: Session,
|
||||
image_data: bytes,
|
||||
file_id: str,
|
||||
display_name: str,
|
||||
@@ -24,7 +21,6 @@ def store_image_and_create_section(
|
||||
Stores an image in FileStore and creates an ImageSection object without summarization.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
image_data: Raw image bytes
|
||||
file_id: Base identifier for the file
|
||||
display_name: Human-readable name for the image
|
||||
@@ -38,7 +34,7 @@ def store_image_and_create_section(
|
||||
"""
|
||||
# Storage logic
|
||||
try:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
file_id = file_store.save_file(
|
||||
content=BytesIO(image_data),
|
||||
display_name=display_name,
|
||||
|
||||
294
backend/onyx/file_store/document_batch_storage.py
Normal file
294
backend/onyx/file_store/document_batch_storage.py
Normal file
@@ -0,0 +1,294 @@
|
||||
import json
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from io import StringIO
|
||||
from typing import cast
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import TypeAlias
|
||||
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.connectors.models import DocExtractionContext
|
||||
from onyx.connectors.models import DocIndexingContext
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.file_store.file_store import FileStore
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class DocumentBatchStorageStateType(str, Enum):
|
||||
EXTRACTION = "extraction"
|
||||
INDEXING = "indexing"
|
||||
|
||||
|
||||
DocumentStorageState: TypeAlias = DocExtractionContext | DocIndexingContext
|
||||
|
||||
STATE_TYPE_TO_MODEL: dict[str, type[DocumentStorageState]] = {
|
||||
DocumentBatchStorageStateType.EXTRACTION.value: DocExtractionContext,
|
||||
DocumentBatchStorageStateType.INDEXING.value: DocIndexingContext,
|
||||
}
|
||||
|
||||
|
||||
class DocumentBatchStorage(ABC):
|
||||
"""Abstract base class for document batch storage implementations."""
|
||||
|
||||
def __init__(self, tenant_id: str, index_attempt_id: int):
|
||||
self.tenant_id = tenant_id
|
||||
self.index_attempt_id = index_attempt_id
|
||||
self.base_path = f"{tenant_id}/{index_attempt_id}"
|
||||
|
||||
@abstractmethod
|
||||
def store_batch(self, batch_id: str, documents: List[Document]) -> None:
|
||||
"""Store a batch of documents."""
|
||||
|
||||
@abstractmethod
|
||||
def get_batch(self, batch_id: str) -> Optional[List[Document]]:
|
||||
"""Retrieve a batch of documents."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_batch(self, batch_id: str) -> None:
|
||||
"""Delete a specific batch."""
|
||||
|
||||
@abstractmethod
|
||||
def store_extraction_state(self, state: DocExtractionContext) -> None:
|
||||
"""Store extraction state metadata."""
|
||||
|
||||
@abstractmethod
|
||||
def get_extraction_state(self) -> DocExtractionContext | None:
|
||||
"""Get extraction state metadata."""
|
||||
|
||||
@abstractmethod
|
||||
def store_indexing_state(self, state: DocIndexingContext) -> None:
|
||||
"""Store indexing state metadata."""
|
||||
|
||||
@abstractmethod
|
||||
def get_indexing_state(self) -> DocIndexingContext | None:
|
||||
"""Get indexing state metadata."""
|
||||
|
||||
@abstractmethod
|
||||
def cleanup_all_batches(self) -> None:
|
||||
"""Clean up all batches and state for this index attempt."""
|
||||
|
||||
def ensure_extraction_state(self) -> DocExtractionContext:
|
||||
"""Ensure extraction state exists."""
|
||||
state = self.get_extraction_state()
|
||||
if not state:
|
||||
raise RuntimeError("Expected extraction state not found")
|
||||
return state
|
||||
|
||||
def ensure_indexing_state(self) -> DocIndexingContext:
|
||||
"""Ensure indexing state exists."""
|
||||
state = self.get_indexing_state()
|
||||
if not state:
|
||||
raise RuntimeError("Expected indexing state not found")
|
||||
return state
|
||||
|
||||
def _serialize_documents(self, documents: list[Document]) -> str:
|
||||
"""Serialize documents to JSON string."""
|
||||
# Use mode='json' to properly serialize datetime and other complex types
|
||||
return json.dumps([doc.model_dump(mode="json") for doc in documents], indent=2)
|
||||
|
||||
def _deserialize_documents(self, data: str) -> list[Document]:
|
||||
"""Deserialize documents from JSON string."""
|
||||
doc_dicts = json.loads(data)
|
||||
return [Document.model_validate(doc_dict) for doc_dict in doc_dicts]
|
||||
|
||||
|
||||
class FileStoreDocumentBatchStorage(DocumentBatchStorage):
|
||||
"""FileStore-based implementation of document batch storage."""
|
||||
|
||||
def __init__(self, tenant_id: str, index_attempt_id: int, file_store: FileStore):
|
||||
super().__init__(tenant_id, index_attempt_id)
|
||||
self.file_store = file_store
|
||||
# Track stored batch files for cleanup
|
||||
self._batch_files: set[str] = set()
|
||||
|
||||
def _get_batch_file_name(self, batch_id: str) -> str:
|
||||
"""Generate file name for a document batch."""
|
||||
return f"document_batch_{self.base_path.replace('/', '_')}_{batch_id}.json"
|
||||
|
||||
def _get_state_file_name(self, state_type: str) -> str:
|
||||
"""Generate file name for extraction state."""
|
||||
return f"{state_type}_state_{self.base_path.replace('/', '_')}.json"
|
||||
|
||||
def store_batch(self, batch_id: str, documents: list[Document]) -> None:
|
||||
"""Store a batch of documents using FileStore."""
|
||||
file_name = self._get_batch_file_name(batch_id)
|
||||
try:
|
||||
data = self._serialize_documents(documents)
|
||||
content = StringIO(data)
|
||||
|
||||
self.file_store.save_file(
|
||||
file_id=file_name,
|
||||
content=content,
|
||||
display_name=f"Document Batch {batch_id}",
|
||||
file_origin=FileOrigin.OTHER,
|
||||
file_type="application/json",
|
||||
file_metadata={
|
||||
"tenant_id": self.tenant_id,
|
||||
"index_attempt_id": str(self.index_attempt_id),
|
||||
"batch_id": batch_id,
|
||||
"document_count": str(len(documents)),
|
||||
},
|
||||
)
|
||||
|
||||
# Track this batch file for cleanup
|
||||
self._batch_files.add(file_name)
|
||||
|
||||
logger.debug(
|
||||
f"Stored batch {batch_id} with {len(documents)} documents to FileStore as {file_name}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store batch {batch_id}: {e}")
|
||||
raise
|
||||
|
||||
def get_batch(self, batch_id: str) -> list[Document] | None:
|
||||
"""Retrieve a batch of documents from FileStore."""
|
||||
file_name = self._get_batch_file_name(batch_id)
|
||||
try:
|
||||
# Check if file exists
|
||||
if not self.file_store.has_file(
|
||||
file_id=file_name,
|
||||
file_origin=FileOrigin.OTHER,
|
||||
file_type="application/json",
|
||||
):
|
||||
logger.warning(f"Batch {batch_id} not found in FileStore")
|
||||
return None
|
||||
|
||||
content_io = self.file_store.read_file(file_name)
|
||||
data = content_io.read().decode("utf-8")
|
||||
|
||||
documents = self._deserialize_documents(data)
|
||||
logger.debug(
|
||||
f"Retrieved batch {batch_id} with {len(documents)} documents from FileStore"
|
||||
)
|
||||
return documents
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve batch {batch_id}: {e}")
|
||||
raise
|
||||
|
||||
def delete_batch(self, batch_id: str) -> None:
|
||||
"""Delete a specific batch from FileStore."""
|
||||
file_name = self._get_batch_file_name(batch_id)
|
||||
try:
|
||||
self.file_store.delete_file(file_name)
|
||||
# Remove from tracked files
|
||||
self._batch_files.discard(file_name)
|
||||
logger.debug(f"Deleted batch {batch_id} from FileStore")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete batch {batch_id}: {e}")
|
||||
# Don't raise - batch might not exist, which is acceptable
|
||||
|
||||
def _store_state(self, state: DocumentStorageState, state_type: str) -> None:
|
||||
"""Store state using FileStore."""
|
||||
file_name = self._get_state_file_name(state_type)
|
||||
try:
|
||||
data = json.dumps(state.model_dump(mode="json"), indent=2)
|
||||
content = StringIO(data)
|
||||
|
||||
self.file_store.save_file(
|
||||
file_id=file_name,
|
||||
content=content,
|
||||
display_name=f"{state_type.capitalize()} State {self.base_path}",
|
||||
file_origin=FileOrigin.OTHER,
|
||||
file_type="application/json",
|
||||
file_metadata={
|
||||
"tenant_id": self.tenant_id,
|
||||
"index_attempt_id": str(self.index_attempt_id),
|
||||
"type": f"{state_type}_state",
|
||||
},
|
||||
)
|
||||
|
||||
logger.debug(f"Stored {state_type} state to FileStore as {file_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store {state_type} state: {e}")
|
||||
raise
|
||||
|
||||
def _get_state(self, state_type: str) -> DocumentStorageState | None:
|
||||
"""Get state from FileStore."""
|
||||
file_name = self._get_state_file_name(state_type)
|
||||
try:
|
||||
# Check if file exists
|
||||
if not self.file_store.has_file(
|
||||
file_id=file_name,
|
||||
file_origin=FileOrigin.OTHER,
|
||||
file_type="application/json",
|
||||
):
|
||||
return None
|
||||
|
||||
content_io = self.file_store.read_file(file_name)
|
||||
data = content_io.read().decode("utf-8")
|
||||
|
||||
state = STATE_TYPE_TO_MODEL[state_type].model_validate(json.loads(data))
|
||||
logger.debug(f"Retrieved {state_type} state from FileStore")
|
||||
return state
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve {state_type} state: {e}")
|
||||
return None
|
||||
|
||||
def store_extraction_state(self, state: DocExtractionContext) -> None:
|
||||
"""Store extraction state using FileStore."""
|
||||
self._store_state(state, DocumentBatchStorageStateType.EXTRACTION.value)
|
||||
|
||||
def get_extraction_state(self) -> DocExtractionContext | None:
|
||||
"""Get extraction state from FileStore."""
|
||||
return cast(
|
||||
DocExtractionContext | None,
|
||||
self._get_state(DocumentBatchStorageStateType.EXTRACTION.value),
|
||||
)
|
||||
|
||||
def store_indexing_state(self, state: DocIndexingContext) -> None:
|
||||
"""Store indexing state using FileStore."""
|
||||
self._store_state(state, DocumentBatchStorageStateType.INDEXING.value)
|
||||
|
||||
def get_indexing_state(self) -> DocIndexingContext | None:
|
||||
"""Get indexing state from FileStore."""
|
||||
return cast(
|
||||
DocIndexingContext | None,
|
||||
self._get_state(DocumentBatchStorageStateType.INDEXING.value),
|
||||
)
|
||||
|
||||
def cleanup_all_batches(self) -> None:
|
||||
"""Clean up all batches and state for this index attempt."""
|
||||
# Since we don't have direct access to S3 listing logic here,
|
||||
# we'll rely on deleting tracked files.
|
||||
# A more robust cleanup might involve a separate task that can list/delete
|
||||
# from S3 if needed, but for now this is simpler.
|
||||
deleted_count = 0
|
||||
|
||||
# Create a copy of the set to avoid issues with modification during iteration
|
||||
for file_name in list(self._batch_files):
|
||||
try:
|
||||
self.file_store.delete_file(file_name)
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete batch file {file_name}: {e}")
|
||||
|
||||
# Delete state
|
||||
for state_type in DocumentBatchStorageStateType:
|
||||
try:
|
||||
state_file_name = self._get_state_file_name(state_type.value)
|
||||
self.file_store.delete_file(state_file_name)
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete extraction state: {e}")
|
||||
|
||||
# Clear tracked files
|
||||
self._batch_files.clear()
|
||||
|
||||
logger.info(
|
||||
f"Cleaned up {deleted_count} files for index attempt {self.index_attempt_id}"
|
||||
)
|
||||
|
||||
|
||||
def get_document_batch_storage(
|
||||
tenant_id: str, index_attempt_id: int
|
||||
) -> DocumentBatchStorage:
|
||||
"""Factory function to get the configured document batch storage implementation."""
|
||||
# The get_default_file_store will now correctly use S3BackedFileStore
|
||||
# or other configured stores based on environment variables
|
||||
file_store = get_default_file_store()
|
||||
return FileStoreDocumentBatchStorage(tenant_id, index_attempt_id, file_store)
|
||||
@@ -22,10 +22,14 @@ from onyx.configs.app_configs import S3_FILE_STORE_BUCKET_NAME
|
||||
from onyx.configs.app_configs import S3_FILE_STORE_PREFIX
|
||||
from onyx.configs.app_configs import S3_VERIFY_SSL
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant_if_none
|
||||
from onyx.db.file_record import delete_filerecord_by_file_id
|
||||
from onyx.db.file_record import get_filerecord_by_file_id
|
||||
from onyx.db.file_record import get_filerecord_by_file_id_optional
|
||||
from onyx.db.file_record import get_filerecord_by_prefix
|
||||
from onyx.db.file_record import upsert_filerecord
|
||||
from onyx.db.models import FileRecord
|
||||
from onyx.db.models import FileRecord as FileStoreModel
|
||||
from onyx.file_store.s3_key_utils import generate_s3_key
|
||||
from onyx.utils.file import FileWithMimeType
|
||||
@@ -129,13 +133,18 @@ class FileStore(ABC):
|
||||
Get the file + parse out the mime type.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def list_files_by_prefix(self, prefix: str) -> list[FileRecord]:
|
||||
"""
|
||||
List all file IDs that start with the given prefix.
|
||||
"""
|
||||
|
||||
|
||||
class S3BackedFileStore(FileStore):
|
||||
"""Isn't necessarily S3, but is any S3-compatible storage (e.g. MinIO)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_session: Session,
|
||||
bucket_name: str,
|
||||
aws_access_key_id: str | None = None,
|
||||
aws_secret_access_key: str | None = None,
|
||||
@@ -144,7 +153,6 @@ class S3BackedFileStore(FileStore):
|
||||
s3_prefix: str | None = None,
|
||||
s3_verify_ssl: bool = True,
|
||||
) -> None:
|
||||
self.db_session = db_session
|
||||
self._s3_client: S3Client | None = None
|
||||
self._bucket_name = bucket_name
|
||||
self._aws_access_key_id = aws_access_key_id
|
||||
@@ -272,10 +280,12 @@ class S3BackedFileStore(FileStore):
|
||||
file_id: str,
|
||||
file_origin: FileOrigin,
|
||||
file_type: str,
|
||||
db_session: Session | None = None,
|
||||
) -> bool:
|
||||
file_record = get_filerecord_by_file_id_optional(
|
||||
file_id=file_id, db_session=self.db_session
|
||||
)
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
file_record = get_filerecord_by_file_id_optional(
|
||||
file_id=file_id, db_session=db_session
|
||||
)
|
||||
return (
|
||||
file_record is not None
|
||||
and file_record.file_origin == file_origin
|
||||
@@ -290,6 +300,7 @@ class S3BackedFileStore(FileStore):
|
||||
file_type: str,
|
||||
file_metadata: dict[str, Any] | None = None,
|
||||
file_id: str | None = None,
|
||||
db_session: Session | None = None,
|
||||
) -> str:
|
||||
if file_id is None:
|
||||
file_id = str(uuid.uuid4())
|
||||
@@ -314,27 +325,33 @@ class S3BackedFileStore(FileStore):
|
||||
ContentType=file_type,
|
||||
)
|
||||
|
||||
# Save metadata to database
|
||||
upsert_filerecord(
|
||||
file_id=file_id,
|
||||
display_name=display_name or file_id,
|
||||
file_origin=file_origin,
|
||||
file_type=file_type,
|
||||
bucket_name=bucket_name,
|
||||
object_key=s3_key,
|
||||
db_session=self.db_session,
|
||||
file_metadata=file_metadata,
|
||||
)
|
||||
self.db_session.commit()
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
# Save metadata to database
|
||||
upsert_filerecord(
|
||||
file_id=file_id,
|
||||
display_name=display_name or file_id,
|
||||
file_origin=file_origin,
|
||||
file_type=file_type,
|
||||
bucket_name=bucket_name,
|
||||
object_key=s3_key,
|
||||
db_session=db_session,
|
||||
file_metadata=file_metadata,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
return file_id
|
||||
|
||||
def read_file(
|
||||
self, file_id: str, mode: str | None = None, use_tempfile: bool = False
|
||||
self,
|
||||
file_id: str,
|
||||
mode: str | None = None,
|
||||
use_tempfile: bool = False,
|
||||
db_session: Session | None = None,
|
||||
) -> IO[bytes]:
|
||||
file_record = get_filerecord_by_file_id(
|
||||
file_id=file_id, db_session=self.db_session
|
||||
)
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
file_record = get_filerecord_by_file_id(
|
||||
file_id=file_id, db_session=db_session
|
||||
)
|
||||
|
||||
s3_client = self._get_s3_client()
|
||||
try:
|
||||
@@ -356,32 +373,37 @@ class S3BackedFileStore(FileStore):
|
||||
else:
|
||||
return BytesIO(file_content)
|
||||
|
||||
def read_file_record(self, file_id: str) -> FileStoreModel:
|
||||
file_record = get_filerecord_by_file_id(
|
||||
file_id=file_id, db_session=self.db_session
|
||||
)
|
||||
def read_file_record(
|
||||
self, file_id: str, db_session: Session | None = None
|
||||
) -> FileStoreModel:
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
file_record = get_filerecord_by_file_id(
|
||||
file_id=file_id, db_session=db_session
|
||||
)
|
||||
return file_record
|
||||
|
||||
def delete_file(self, file_id: str) -> None:
|
||||
try:
|
||||
file_record = get_filerecord_by_file_id(
|
||||
file_id=file_id, db_session=self.db_session
|
||||
)
|
||||
def delete_file(self, file_id: str, db_session: Session | None = None) -> None:
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
try:
|
||||
|
||||
# Delete from external storage
|
||||
s3_client = self._get_s3_client()
|
||||
s3_client.delete_object(
|
||||
Bucket=file_record.bucket_name, Key=file_record.object_key
|
||||
)
|
||||
file_record = get_filerecord_by_file_id(
|
||||
file_id=file_id, db_session=db_session
|
||||
)
|
||||
|
||||
# Delete metadata from database
|
||||
delete_filerecord_by_file_id(file_id=file_id, db_session=self.db_session)
|
||||
# Delete from external storage
|
||||
s3_client = self._get_s3_client()
|
||||
s3_client.delete_object(
|
||||
Bucket=file_record.bucket_name, Key=file_record.object_key
|
||||
)
|
||||
|
||||
self.db_session.commit()
|
||||
# Delete metadata from database
|
||||
delete_filerecord_by_file_id(file_id=file_id, db_session=db_session)
|
||||
|
||||
except Exception:
|
||||
self.db_session.rollback()
|
||||
raise
|
||||
db_session.commit()
|
||||
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
raise
|
||||
|
||||
def get_file_with_mime_type(self, filename: str) -> FileWithMimeType | None:
|
||||
mime_type: str = "application/octet-stream"
|
||||
@@ -395,8 +417,18 @@ class S3BackedFileStore(FileStore):
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def list_files_by_prefix(self, prefix: str) -> list[FileRecord]:
|
||||
"""
|
||||
List all file IDs that start with the given prefix.
|
||||
"""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
file_records = get_filerecord_by_prefix(
|
||||
prefix=prefix, db_session=db_session
|
||||
)
|
||||
return file_records
|
||||
|
||||
def get_s3_file_store(db_session: Session) -> S3BackedFileStore:
|
||||
|
||||
def get_s3_file_store() -> S3BackedFileStore:
|
||||
"""
|
||||
Returns the S3 file store implementation.
|
||||
"""
|
||||
@@ -409,7 +441,6 @@ def get_s3_file_store(db_session: Session) -> S3BackedFileStore:
|
||||
)
|
||||
|
||||
return S3BackedFileStore(
|
||||
db_session=db_session,
|
||||
bucket_name=bucket_name,
|
||||
aws_access_key_id=S3_AWS_ACCESS_KEY_ID,
|
||||
aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY,
|
||||
@@ -420,7 +451,7 @@ def get_s3_file_store(db_session: Session) -> S3BackedFileStore:
|
||||
)
|
||||
|
||||
|
||||
def get_default_file_store(db_session: Session) -> FileStore:
|
||||
def get_default_file_store() -> FileStore:
|
||||
"""
|
||||
Returns the configured file store implementation.
|
||||
|
||||
@@ -445,4 +476,4 @@ def get_default_file_store(db_session: Session) -> FileStore:
|
||||
Other S3-compatible storage (Digital Ocean, Linode, etc.):
|
||||
- Same as MinIO, but set appropriate S3_ENDPOINT_URL
|
||||
"""
|
||||
return get_s3_file_store(db_session)
|
||||
return get_s3_file_store()
|
||||
|
||||
@@ -8,7 +8,6 @@ import requests
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.models import UserFolder
|
||||
@@ -49,28 +48,23 @@ def store_user_file_plaintext(user_file_id: int, plaintext_content: str) -> bool
|
||||
|
||||
# Use a separate session to avoid committing the caller's transaction
|
||||
try:
|
||||
with get_session_with_current_tenant() as file_store_session:
|
||||
file_store = get_default_file_store(file_store_session)
|
||||
file_content = BytesIO(plaintext_content.encode("utf-8"))
|
||||
file_store.save_file(
|
||||
content=file_content,
|
||||
display_name=f"Plaintext for user file {user_file_id}",
|
||||
file_origin=FileOrigin.PLAINTEXT_CACHE,
|
||||
file_type="text/plain",
|
||||
file_id=plaintext_file_name,
|
||||
)
|
||||
return True
|
||||
file_store = get_default_file_store()
|
||||
file_content = BytesIO(plaintext_content.encode("utf-8"))
|
||||
file_store.save_file(
|
||||
content=file_content,
|
||||
display_name=f"Plaintext for user file {user_file_id}",
|
||||
file_origin=FileOrigin.PLAINTEXT_CACHE,
|
||||
file_type="text/plain",
|
||||
file_id=plaintext_file_name,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store plaintext for user file {user_file_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def load_chat_file(
|
||||
file_descriptor: FileDescriptor, db_session: Session
|
||||
) -> InMemoryChatFile:
|
||||
file_io = get_default_file_store(db_session).read_file(
|
||||
file_descriptor["id"], mode="b"
|
||||
)
|
||||
def load_chat_file(file_descriptor: FileDescriptor) -> InMemoryChatFile:
|
||||
file_io = get_default_file_store().read_file(file_descriptor["id"], mode="b")
|
||||
return InMemoryChatFile(
|
||||
file_id=file_descriptor["id"],
|
||||
content=file_io.read(),
|
||||
@@ -82,7 +76,6 @@ def load_chat_file(
|
||||
def load_all_chat_files(
|
||||
chat_messages: list[ChatMessage],
|
||||
file_descriptors: list[FileDescriptor],
|
||||
db_session: Session,
|
||||
) -> list[InMemoryChatFile]:
|
||||
file_descriptors_for_history: list[FileDescriptor] = []
|
||||
for chat_message in chat_messages:
|
||||
@@ -93,7 +86,7 @@ def load_all_chat_files(
|
||||
list[InMemoryChatFile],
|
||||
run_functions_tuples_in_parallel(
|
||||
[
|
||||
(load_chat_file, (file, db_session))
|
||||
(load_chat_file, (file,))
|
||||
for file in file_descriptors + file_descriptors_for_history
|
||||
]
|
||||
),
|
||||
@@ -117,7 +110,7 @@ def load_user_file(file_id: int, db_session: Session) -> InMemoryChatFile:
|
||||
raise ValueError(f"User file with id {file_id} not found")
|
||||
|
||||
# Get the file record to determine the appropriate chat file type
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
file_record = file_store.read_file_record(user_file.file_id)
|
||||
|
||||
# Determine appropriate chat file type based on the original file's MIME type
|
||||
@@ -263,34 +256,29 @@ def get_user_files_as_user(
|
||||
|
||||
|
||||
def save_file_from_url(url: str) -> str:
|
||||
"""NOTE: using multiple sessions here, since this is often called
|
||||
using multithreading. In practice, sharing a session has resulted in
|
||||
weird errors."""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
file_io = BytesIO(response.content)
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_id = file_store.save_file(
|
||||
content=file_io,
|
||||
display_name="GeneratedImage",
|
||||
file_origin=FileOrigin.CHAT_IMAGE_GEN,
|
||||
file_type="image/png;base64",
|
||||
)
|
||||
return file_id
|
||||
file_io = BytesIO(response.content)
|
||||
file_store = get_default_file_store()
|
||||
file_id = file_store.save_file(
|
||||
content=file_io,
|
||||
display_name="GeneratedImage",
|
||||
file_origin=FileOrigin.CHAT_IMAGE_GEN,
|
||||
file_type="image/png;base64",
|
||||
)
|
||||
return file_id
|
||||
|
||||
|
||||
def save_file_from_base64(base64_string: str) -> str:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_id = file_store.save_file(
|
||||
content=BytesIO(base64.b64decode(base64_string)),
|
||||
display_name="GeneratedImage",
|
||||
file_origin=FileOrigin.CHAT_IMAGE_GEN,
|
||||
file_type=get_image_type(base64_string),
|
||||
)
|
||||
return file_id
|
||||
file_store = get_default_file_store()
|
||||
file_id = file_store.save_file(
|
||||
content=BytesIO(base64.b64decode(base64_string)),
|
||||
display_name="GeneratedImage",
|
||||
file_origin=FileOrigin.CHAT_IMAGE_GEN,
|
||||
file_type=get_image_type(base64_string),
|
||||
)
|
||||
return file_id
|
||||
|
||||
|
||||
def save_file(
|
||||
|
||||
@@ -4,6 +4,13 @@ from typing import Any
|
||||
import httpx
|
||||
|
||||
|
||||
def make_default_kwargs() -> dict[str, Any]:
|
||||
return {
|
||||
"http2": True,
|
||||
"limits": httpx.Limits(),
|
||||
}
|
||||
|
||||
|
||||
class HttpxPool:
|
||||
"""Class to manage a global httpx Client instance"""
|
||||
|
||||
@@ -11,10 +18,6 @@ class HttpxPool:
|
||||
_lock: threading.Lock = threading.Lock()
|
||||
|
||||
# Default parameters for creation
|
||||
DEFAULT_KWARGS = {
|
||||
"http2": True,
|
||||
"limits": lambda: httpx.Limits(),
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
@@ -22,7 +25,7 @@ class HttpxPool:
|
||||
@classmethod
|
||||
def _init_client(cls, **kwargs: Any) -> httpx.Client:
|
||||
"""Private helper method to create and return an httpx.Client."""
|
||||
merged_kwargs = {**cls.DEFAULT_KWARGS, **kwargs}
|
||||
merged_kwargs = {**(make_default_kwargs()), **kwargs}
|
||||
return httpx.Client(**merged_kwargs)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -41,7 +41,6 @@ from onyx.db.document import update_docs_updated_at__no_commit
|
||||
from onyx.db.document import upsert_document_by_connector_credential_pair
|
||||
from onyx.db.document import upsert_documents
|
||||
from onyx.db.document_set import fetch_document_sets_for_documents
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import Document as DBDocument
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
@@ -496,36 +495,33 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
|
||||
|
||||
# Try to get image summary
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
|
||||
file_record = file_store.read_file_record(
|
||||
file_record = file_store.read_file_record(
|
||||
file_id=section.image_file_id
|
||||
)
|
||||
if not file_record:
|
||||
logger.warning(
|
||||
f"Image file {section.image_file_id} not found in FileStore"
|
||||
)
|
||||
|
||||
processed_section.text = "[Image could not be processed]"
|
||||
else:
|
||||
# Get the image data
|
||||
image_data_io = file_store.read_file(
|
||||
file_id=section.image_file_id
|
||||
)
|
||||
if not file_record:
|
||||
logger.warning(
|
||||
f"Image file {section.image_file_id} not found in FileStore"
|
||||
)
|
||||
image_data = image_data_io.read()
|
||||
summary = summarize_image_with_error_handling(
|
||||
llm=llm,
|
||||
image_data=image_data,
|
||||
context_name=file_record.display_name or "Image",
|
||||
)
|
||||
|
||||
processed_section.text = "[Image could not be processed]"
|
||||
if summary:
|
||||
processed_section.text = summary
|
||||
else:
|
||||
# Get the image data
|
||||
image_data_io = file_store.read_file(
|
||||
file_id=section.image_file_id
|
||||
)
|
||||
image_data = image_data_io.read()
|
||||
summary = summarize_image_with_error_handling(
|
||||
llm=llm,
|
||||
image_data=image_data,
|
||||
context_name=file_record.display_name or "Image",
|
||||
)
|
||||
|
||||
if summary:
|
||||
processed_section.text = summary
|
||||
else:
|
||||
processed_section.text = (
|
||||
"[Image could not be summarized]"
|
||||
)
|
||||
processed_section.text = "[Image could not be summarized]"
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image section: {e}")
|
||||
processed_section.text = "[Error processing image]"
|
||||
|
||||
@@ -258,7 +258,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
setup_onyx(db_session, POSTGRES_DEFAULT_SCHEMA)
|
||||
# set up the file store (e.g. create bucket if needed). On multi-tenant,
|
||||
# this is done via IaC
|
||||
get_default_file_store(db_session).initialize()
|
||||
get_default_file_store().initialize()
|
||||
else:
|
||||
setup_multitenant_onyx()
|
||||
|
||||
|
||||
@@ -16,24 +16,26 @@ class RedisConnector:
|
||||
"""Composes several classes to simplify interacting with a connector and its
|
||||
associated background tasks / associated redis interactions."""
|
||||
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
def __init__(self, tenant_id: str, cc_pair_id: int) -> None:
|
||||
"""id: a connector credential pair id"""
|
||||
|
||||
self.tenant_id: str = tenant_id
|
||||
self.id: int = id
|
||||
self.cc_pair_id: int = cc_pair_id
|
||||
self.redis: redis.Redis = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
self.stop = RedisConnectorStop(tenant_id, id, self.redis)
|
||||
self.prune = RedisConnectorPrune(tenant_id, id, self.redis)
|
||||
self.delete = RedisConnectorDelete(tenant_id, id, self.redis)
|
||||
self.permissions = RedisConnectorPermissionSync(tenant_id, id, self.redis)
|
||||
self.stop = RedisConnectorStop(tenant_id, cc_pair_id, self.redis)
|
||||
self.prune = RedisConnectorPrune(tenant_id, cc_pair_id, self.redis)
|
||||
self.delete = RedisConnectorDelete(tenant_id, cc_pair_id, self.redis)
|
||||
self.permissions = RedisConnectorPermissionSync(
|
||||
tenant_id, cc_pair_id, self.redis
|
||||
)
|
||||
self.external_group_sync = RedisConnectorExternalGroupSync(
|
||||
tenant_id, id, self.redis
|
||||
tenant_id, cc_pair_id, self.redis
|
||||
)
|
||||
|
||||
def new_index(self, search_settings_id: int) -> RedisConnectorIndex:
|
||||
return RedisConnectorIndex(
|
||||
self.tenant_id, self.id, search_settings_id, self.redis
|
||||
self.tenant_id, self.cc_pair_id, search_settings_id, self.redis
|
||||
)
|
||||
|
||||
def wait_for_indexing_termination(
|
||||
|
||||
@@ -58,10 +58,7 @@ class RedisConnectorDelete:
|
||||
|
||||
@property
|
||||
def fenced(self) -> bool:
|
||||
if self.redis.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(self.redis.exists(self.fence_key))
|
||||
|
||||
@property
|
||||
def payload(self) -> RedisConnectorDeletePayload | None:
|
||||
@@ -93,10 +90,7 @@ class RedisConnectorDelete:
|
||||
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
|
||||
|
||||
def active(self) -> bool:
|
||||
if self.redis.exists(self.active_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(self.redis.exists(self.active_key))
|
||||
|
||||
def _generate_task_id(self) -> str:
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
|
||||
@@ -88,10 +88,7 @@ class RedisConnectorPermissionSync:
|
||||
|
||||
@property
|
||||
def fenced(self) -> bool:
|
||||
if self.redis.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(self.redis.exists(self.fence_key))
|
||||
|
||||
@property
|
||||
def payload(self) -> RedisConnectorPermissionSyncPayload | None:
|
||||
@@ -128,10 +125,7 @@ class RedisConnectorPermissionSync:
|
||||
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
|
||||
|
||||
def active(self) -> bool:
|
||||
if self.redis.exists(self.active_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(self.redis.exists(self.active_key))
|
||||
|
||||
@property
|
||||
def generator_complete(self) -> int | None:
|
||||
|
||||
@@ -84,10 +84,7 @@ class RedisConnectorExternalGroupSync:
|
||||
|
||||
@property
|
||||
def fenced(self) -> bool:
|
||||
if self.redis.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(self.redis.exists(self.fence_key))
|
||||
|
||||
@property
|
||||
def payload(self) -> RedisConnectorExternalGroupSyncPayload | None:
|
||||
@@ -125,10 +122,7 @@ class RedisConnectorExternalGroupSync:
|
||||
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
|
||||
|
||||
def active(self) -> bool:
|
||||
if self.redis.exists(self.active_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(self.redis.exists(self.active_key))
|
||||
|
||||
@property
|
||||
def generator_complete(self) -> int | None:
|
||||
|
||||
@@ -31,7 +31,10 @@ class RedisConnectorIndex:
|
||||
PREFIX + "_generator_complete"
|
||||
) # connectorindexing_generator_complete
|
||||
|
||||
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
|
||||
GENERATOR_LOCK_PREFIX = "da_lock:indexing:docfetching"
|
||||
FILESTORE_LOCK_PREFIX = "da_lock:indexing:filestore"
|
||||
DB_LOCK_PREFIX = "da_lock:indexing:db"
|
||||
PER_WORKER_LOCK_PREFIX = "da_lock:indexing:per_worker"
|
||||
|
||||
TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate
|
||||
TERMINATE_TTL = 600
|
||||
@@ -53,33 +56,45 @@ class RedisConnectorIndex:
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
redis: redis.Redis,
|
||||
) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
self.id = id
|
||||
self.cc_pair_id = cc_pair_id
|
||||
self.search_settings_id = search_settings_id
|
||||
self.redis = redis
|
||||
|
||||
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}/{search_settings_id}"
|
||||
self.fence_key: str = f"{self.FENCE_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
self.generator_progress_key = (
|
||||
f"{self.GENERATOR_PROGRESS_PREFIX}_{id}/{search_settings_id}"
|
||||
f"{self.GENERATOR_PROGRESS_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
)
|
||||
self.generator_complete_key = (
|
||||
f"{self.GENERATOR_COMPLETE_PREFIX}_{id}/{search_settings_id}"
|
||||
f"{self.GENERATOR_COMPLETE_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
)
|
||||
self.filestore_lock_key = (
|
||||
f"{self.FILESTORE_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
)
|
||||
self.generator_lock_key = (
|
||||
f"{self.GENERATOR_LOCK_PREFIX}_{id}/{search_settings_id}"
|
||||
f"{self.GENERATOR_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
)
|
||||
self.terminate_key = f"{self.TERMINATE_PREFIX}_{id}/{search_settings_id}"
|
||||
self.watchdog_key = f"{self.WATCHDOG_PREFIX}_{id}/{search_settings_id}"
|
||||
self.per_worker_lock_key = (
|
||||
f"{self.PER_WORKER_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
)
|
||||
self.db_lock_key = f"{self.DB_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
self.terminate_key = (
|
||||
f"{self.TERMINATE_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
)
|
||||
self.watchdog_key = f"{self.WATCHDOG_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
|
||||
self.active_key = f"{self.ACTIVE_PREFIX}_{id}/{search_settings_id}"
|
||||
self.active_key = f"{self.ACTIVE_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
self.connector_active_key = (
|
||||
f"{self.CONNECTOR_ACTIVE_PREFIX}_{id}/{search_settings_id}"
|
||||
f"{self.CONNECTOR_ACTIVE_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
)
|
||||
|
||||
def lock_key_by_batch(self, batch_n: int) -> str:
|
||||
return f"{self.per_worker_lock_key}/{batch_n}"
|
||||
|
||||
@classmethod
|
||||
def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str:
|
||||
return f"{cls.FENCE_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
@@ -89,7 +104,7 @@ class RedisConnectorIndex:
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "connectorindexing+generator_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
|
||||
return f"{self.GENERATOR_TASK_PREFIX}_{self.id}/{self.search_settings_id}_{uuid4()}"
|
||||
return f"{self.GENERATOR_TASK_PREFIX}_{self.cc_pair_id}/{self.search_settings_id}_{uuid4()}"
|
||||
|
||||
@property
|
||||
def fenced(self) -> bool:
|
||||
@@ -160,10 +175,7 @@ class RedisConnectorIndex:
|
||||
self.redis.set(self.connector_active_key, 0, ex=self.CONNECTOR_ACTIVE_TTL)
|
||||
|
||||
def connector_active(self) -> bool:
|
||||
if self.redis.exists(self.connector_active_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(self.redis.exists(self.connector_active_key))
|
||||
|
||||
def connector_active_ttl(self) -> int:
|
||||
"""Refer to https://redis.io/docs/latest/commands/ttl/
|
||||
@@ -175,9 +187,6 @@ class RedisConnectorIndex:
|
||||
ttl = cast(int, self.redis.ttl(self.connector_active_key))
|
||||
return ttl
|
||||
|
||||
def generator_locked(self) -> bool:
|
||||
return bool(self.redis.exists(self.generator_lock_key))
|
||||
|
||||
def set_generator_complete(self, payload: int | None) -> None:
|
||||
if not payload:
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
@@ -212,6 +221,8 @@ class RedisConnectorIndex:
|
||||
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
self.redis.delete(self.connector_active_key)
|
||||
self.redis.delete(self.active_key)
|
||||
self.redis.delete(self.filestore_lock_key)
|
||||
self.redis.delete(self.db_lock_key)
|
||||
self.redis.delete(self.generator_lock_key)
|
||||
self.redis.delete(self.generator_progress_key)
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
@@ -226,7 +237,7 @@ class RedisConnectorIndex:
|
||||
for key in r.scan_iter(RedisConnectorIndex.ACTIVE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_LOCK_PREFIX + "*"):
|
||||
for key in r.scan_iter(RedisConnectorIndex.FILESTORE_LOCK_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_COMPLETE_PREFIX + "*"):
|
||||
|
||||
@@ -92,10 +92,7 @@ class RedisConnectorPrune:
|
||||
|
||||
@property
|
||||
def fenced(self) -> bool:
|
||||
if self.redis.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(self.redis.exists(self.fence_key))
|
||||
|
||||
@property
|
||||
def payload(self) -> RedisConnectorPrunePayload | None:
|
||||
@@ -130,10 +127,7 @@ class RedisConnectorPrune:
|
||||
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
|
||||
|
||||
def active(self) -> bool:
|
||||
if self.redis.exists(self.active_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(self.redis.exists(self.active_key))
|
||||
|
||||
@property
|
||||
def generator_complete(self) -> int | None:
|
||||
|
||||
@@ -23,10 +23,7 @@ class RedisConnectorStop:
|
||||
|
||||
@property
|
||||
def fenced(self) -> bool:
|
||||
if self.redis.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(self.redis.exists(self.fence_key))
|
||||
|
||||
def set_fence(self, value: bool) -> None:
|
||||
if not value:
|
||||
@@ -37,10 +34,7 @@ class RedisConnectorStop:
|
||||
|
||||
@property
|
||||
def timed_out(self) -> bool:
|
||||
if self.redis.exists(self.timeout_key):
|
||||
return False
|
||||
|
||||
return True
|
||||
return not bool(self.redis.exists(self.timeout_key))
|
||||
|
||||
def set_timeout(self) -> None:
|
||||
"""After calling this, call timed_out to determine if the timeout has been
|
||||
|
||||
@@ -418,7 +418,7 @@ def extract_zip_metadata(zf: zipfile.ZipFile) -> dict[str, Any]:
|
||||
return zip_metadata
|
||||
|
||||
|
||||
def upload_files(files: list[UploadFile], db_session: Session) -> FileUploadResponse:
|
||||
def upload_files(files: list[UploadFile]) -> FileUploadResponse:
|
||||
for file in files:
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="File name cannot be empty")
|
||||
@@ -431,7 +431,7 @@ def upload_files(files: list[UploadFile], db_session: Session) -> FileUploadResp
|
||||
deduped_file_paths = []
|
||||
zip_metadata = {}
|
||||
try:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
seen_zip = False
|
||||
for file in files:
|
||||
if file.content_type and file.content_type.startswith("application/zip"):
|
||||
@@ -488,9 +488,8 @@ def upload_files(files: list[UploadFile], db_session: Session) -> FileUploadResp
|
||||
def upload_files_api(
|
||||
files: list[UploadFile],
|
||||
_: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FileUploadResponse:
|
||||
return upload_files(files, db_session)
|
||||
return upload_files(files)
|
||||
|
||||
|
||||
@router.get("/admin/connector")
|
||||
|
||||
@@ -176,10 +176,9 @@ def undelete_persona(
|
||||
@admin_router.post("/upload-image")
|
||||
def upload_file(
|
||||
file: UploadFile,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_user),
|
||||
) -> dict[str, str]:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
file_type = ChatFileType.IMAGE
|
||||
file_id = file_store.save_file(
|
||||
content=file.file,
|
||||
|
||||
@@ -205,6 +205,6 @@ def create_deletion_attempt_for_connector_id(
|
||||
|
||||
if cc_pair.connector.source == DocumentSource.FILE:
|
||||
connector = cc_pair.connector
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
for file_name in connector.connector_specific_config.get("file_locations", []):
|
||||
file_store.delete_file(file_name)
|
||||
|
||||
@@ -720,7 +720,7 @@ def upload_files_for_chat(
|
||||
detail="File size must be less than 20MB",
|
||||
)
|
||||
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
|
||||
file_info: list[tuple[str, str | None, ChatFileType]] = []
|
||||
for file in files:
|
||||
@@ -823,10 +823,9 @@ def upload_files_for_chat(
|
||||
@router.get("/file/{file_id:path}")
|
||||
def fetch_chat_file(
|
||||
file_id: str,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_user),
|
||||
) -> Response:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
file_record = file_store.read_file_record(file_id)
|
||||
if not file_record:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
@@ -11,7 +11,6 @@ from onyx.configs.constants import CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAU
|
||||
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import ONYX_EMAILABLE_LOGO_MAX_DIM
|
||||
from onyx.db.engine.sql_engine import get_session_with_shared_schema
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.utils.file import FileWithMimeType
|
||||
@@ -40,9 +39,8 @@ class OnyxRuntime:
|
||||
onyx_file: FileWithMimeType | None = None
|
||||
|
||||
if db_filename:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
file_store = get_default_file_store(db_session)
|
||||
onyx_file = file_store.get_file_with_mime_type(db_filename)
|
||||
file_store = get_default_file_store()
|
||||
onyx_file = file_store.get_file_with_mime_type(db_filename)
|
||||
|
||||
if not onyx_file:
|
||||
onyx_file = OnyxStaticFileManager.get_static(static_filename)
|
||||
|
||||
@@ -17,7 +17,6 @@ from requests import JSONDecodeError
|
||||
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
@@ -205,27 +204,26 @@ class CustomTool(BaseTool):
|
||||
def _save_and_get_file_references(
|
||||
self, file_content: bytes | str, content_type: str
|
||||
) -> List[str]:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
|
||||
file_id = str(uuid.uuid4())
|
||||
file_id = str(uuid.uuid4())
|
||||
|
||||
# Handle both binary and text content
|
||||
if isinstance(file_content, str):
|
||||
content = BytesIO(file_content.encode())
|
||||
else:
|
||||
content = BytesIO(file_content)
|
||||
# Handle both binary and text content
|
||||
if isinstance(file_content, str):
|
||||
content = BytesIO(file_content.encode())
|
||||
else:
|
||||
content = BytesIO(file_content)
|
||||
|
||||
file_store.save_file(
|
||||
file_id=file_id,
|
||||
content=content,
|
||||
display_name=file_id,
|
||||
file_origin=FileOrigin.CHAT_UPLOAD,
|
||||
file_type=content_type,
|
||||
file_metadata={
|
||||
"content_type": content_type,
|
||||
},
|
||||
)
|
||||
file_store.save_file(
|
||||
file_id=file_id,
|
||||
content=content,
|
||||
display_name=file_id,
|
||||
file_origin=FileOrigin.CHAT_UPLOAD,
|
||||
file_type=content_type,
|
||||
file_metadata={
|
||||
"content_type": content_type,
|
||||
},
|
||||
)
|
||||
|
||||
return [file_id]
|
||||
|
||||
@@ -328,22 +326,21 @@ class CustomTool(BaseTool):
|
||||
|
||||
# Load files from storage
|
||||
files = []
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
|
||||
for file_id in response.tool_result.file_ids:
|
||||
try:
|
||||
file_io = file_store.read_file(file_id, mode="b")
|
||||
files.append(
|
||||
InMemoryChatFile(
|
||||
file_id=file_id,
|
||||
filename=file_id,
|
||||
content=file_io.read(),
|
||||
file_type=file_type,
|
||||
)
|
||||
for file_id in response.tool_result.file_ids:
|
||||
try:
|
||||
file_io = file_store.read_file(file_id, mode="b")
|
||||
files.append(
|
||||
InMemoryChatFile(
|
||||
file_id=file_id,
|
||||
filename=file_id,
|
||||
content=file_io.read(),
|
||||
file_type=file_type,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to read file {file_id}")
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to read file {file_id}")
|
||||
|
||||
# Update prompt with file content
|
||||
prompt_builder.update_user_prompt(
|
||||
|
||||
@@ -384,3 +384,24 @@ def parallel_yield(gens: list[Iterator[R]], max_workers: int = 10) -> Iterator[R
|
||||
)
|
||||
next_ind += 1
|
||||
del future_to_index[future]
|
||||
|
||||
|
||||
def parallel_yield_from_funcs(
|
||||
funcs: list[Callable[..., R]],
|
||||
max_workers: int = 10,
|
||||
) -> Iterator[R]:
|
||||
"""
|
||||
Runs the list of functions with thread-level parallelism, yielding
|
||||
results as available. The asynchronous nature of this yielding means
|
||||
that stopping the returned iterator early DOES NOT GUARANTEE THAT NO
|
||||
FURTHER ITEMS WERE PRODUCED by the input funcs. Only use this function
|
||||
if you are consuming all elements from the functions OR it is acceptable
|
||||
for some extra function code to run and not have the result(s) yielded.
|
||||
"""
|
||||
|
||||
def func_wrapper(func: Callable[[], R]) -> Iterator[R]:
|
||||
yield func()
|
||||
|
||||
yield from parallel_yield(
|
||||
[func_wrapper(func) for func in funcs], max_workers=max_workers
|
||||
)
|
||||
|
||||
@@ -69,7 +69,7 @@ def run_jobs() -> None:
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=indexing@%n",
|
||||
"--queues=connector_indexing",
|
||||
"--queues=docprocessing",
|
||||
]
|
||||
|
||||
cmd_worker_user_files_indexing = [
|
||||
@@ -111,6 +111,19 @@ def run_jobs() -> None:
|
||||
"--queues=kg_processing",
|
||||
]
|
||||
|
||||
cmd_worker_docfetching = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.docfetching",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=docfetching@%n",
|
||||
"--queues=connector_doc_fetching,user_files_indexing",
|
||||
]
|
||||
|
||||
cmd_beat = [
|
||||
"celery",
|
||||
"-A",
|
||||
@@ -157,6 +170,13 @@ def run_jobs() -> None:
|
||||
text=True,
|
||||
)
|
||||
|
||||
worker_docfetching_process = subprocess.Popen(
|
||||
cmd_worker_docfetching,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
|
||||
beat_process = subprocess.Popen(
|
||||
cmd_beat, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
@@ -184,6 +204,9 @@ def run_jobs() -> None:
|
||||
worker_kg_processing_thread = threading.Thread(
|
||||
target=monitor_process, args=("KG_PROCESSING", worker_kg_processing_process)
|
||||
)
|
||||
worker_docfetching_thread = threading.Thread(
|
||||
target=monitor_process, args=("DOCFETCHING", worker_docfetching_process)
|
||||
)
|
||||
beat_thread = threading.Thread(target=monitor_process, args=("BEAT", beat_process))
|
||||
|
||||
worker_primary_thread.start()
|
||||
@@ -193,6 +216,7 @@ def run_jobs() -> None:
|
||||
worker_user_files_indexing_thread.start()
|
||||
worker_monitoring_thread.start()
|
||||
worker_kg_processing_thread.start()
|
||||
worker_docfetching_thread.start()
|
||||
beat_thread.start()
|
||||
|
||||
worker_primary_thread.join()
|
||||
@@ -202,6 +226,7 @@ def run_jobs() -> None:
|
||||
worker_user_files_indexing_thread.join()
|
||||
worker_monitoring_thread.join()
|
||||
worker_kg_processing_thread.join()
|
||||
worker_docfetching_thread.join()
|
||||
beat_thread.join()
|
||||
|
||||
|
||||
|
||||
@@ -213,7 +213,7 @@ def _delete_connector(cc_pair_id: int, db_session: Session) -> None:
|
||||
|
||||
if file_names:
|
||||
logger.notice("Deleting stored files!")
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
for file_name in file_names:
|
||||
logger.notice(f"Deleting file {file_name}")
|
||||
file_store.delete_file(file_name)
|
||||
|
||||
@@ -69,7 +69,7 @@ stopasgroup=true
|
||||
command=celery -A onyx.background.celery.versioned_apps.indexing worker
|
||||
--loglevel=INFO
|
||||
--hostname=indexing@%%n
|
||||
-Q connector_indexing
|
||||
-Q docprocessing
|
||||
stdout_logfile=/var/log/celery_worker_indexing.log
|
||||
stdout_logfile_maxbytes=16MB
|
||||
redirect_stderr=true
|
||||
@@ -89,6 +89,18 @@ autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
|
||||
[program:celery_worker_docfetching]
|
||||
command=celery -A onyx.background.celery.versioned_apps.docfetching worker
|
||||
--loglevel=INFO
|
||||
--hostname=docfetching@%%n
|
||||
-Q connector_doc_fetching
|
||||
stdout_logfile=/var/log/celery_worker_docfetching.log
|
||||
stdout_logfile_maxbytes=16MB
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
|
||||
[program:celery_worker_monitoring]
|
||||
command=celery -A onyx.background.celery.versioned_apps.monitoring worker
|
||||
--loglevel=INFO
|
||||
@@ -161,6 +173,7 @@ command=tail -qF
|
||||
/var/log/celery_worker_heavy.log
|
||||
/var/log/celery_worker_indexing.log
|
||||
/var/log/celery_worker_user_files_indexing.log
|
||||
/var/log/celery_worker_docfetching.log
|
||||
/var/log/celery_worker_monitoring.log
|
||||
/var/log/slack_bot.log
|
||||
/var/log/supervisord_watchdog_celery_beat.log
|
||||
|
||||
@@ -70,20 +70,19 @@ def _get_all_backend_configs() -> List[BackendConfig]:
|
||||
if S3_ENDPOINT_URL:
|
||||
minio_access_key = "minioadmin"
|
||||
minio_secret_key = "minioadmin"
|
||||
if minio_access_key and minio_secret_key:
|
||||
configs.append(
|
||||
{
|
||||
"endpoint_url": S3_ENDPOINT_URL,
|
||||
"access_key": minio_access_key,
|
||||
"secret_key": minio_secret_key,
|
||||
"region": "us-east-1",
|
||||
"verify_ssl": False,
|
||||
"backend_name": "MinIO",
|
||||
}
|
||||
)
|
||||
configs.append(
|
||||
{
|
||||
"endpoint_url": S3_ENDPOINT_URL,
|
||||
"access_key": minio_access_key,
|
||||
"secret_key": minio_secret_key,
|
||||
"region": "us-east-1",
|
||||
"verify_ssl": False,
|
||||
"backend_name": "MinIO",
|
||||
}
|
||||
)
|
||||
|
||||
# AWS S3 configuration (if credentials are available)
|
||||
if S3_AWS_ACCESS_KEY_ID and S3_AWS_SECRET_ACCESS_KEY:
|
||||
elif S3_AWS_ACCESS_KEY_ID and S3_AWS_SECRET_ACCESS_KEY:
|
||||
configs.append(
|
||||
{
|
||||
"endpoint_url": None,
|
||||
@@ -116,7 +115,6 @@ def file_store(
|
||||
|
||||
# Create S3BackedFileStore with backend-specific configuration
|
||||
store = S3BackedFileStore(
|
||||
db_session=db_session,
|
||||
bucket_name=TEST_BUCKET_NAME,
|
||||
aws_access_key_id=backend_config["access_key"],
|
||||
aws_secret_access_key=backend_config["secret_key"],
|
||||
@@ -827,7 +825,6 @@ class TestS3BackedFileStore:
|
||||
# Create a new database session for each worker to avoid conflicts
|
||||
with get_session_with_current_tenant() as worker_session:
|
||||
worker_file_store = S3BackedFileStore(
|
||||
db_session=worker_session,
|
||||
bucket_name=current_bucket_name,
|
||||
aws_access_key_id=current_access_key,
|
||||
aws_secret_access_key=current_secret_key,
|
||||
@@ -849,6 +846,7 @@ class TestS3BackedFileStore:
|
||||
display_name=f"Worker {worker_id} File",
|
||||
file_origin=file_origin,
|
||||
file_type=file_type,
|
||||
db_session=worker_session,
|
||||
)
|
||||
results.append((file_name, content))
|
||||
return True
|
||||
@@ -885,3 +883,94 @@ class TestS3BackedFileStore:
|
||||
read_content_io = file_store.read_file(file_id)
|
||||
actual_content: str = read_content_io.read().decode("utf-8")
|
||||
assert actual_content == expected_content
|
||||
|
||||
def test_list_files_by_prefix(self, file_store: S3BackedFileStore) -> None:
|
||||
"""Test listing files by prefix returns only correctly prefixed files"""
|
||||
test_prefix = "documents-batch-"
|
||||
|
||||
# Files that should be returned (start with the prefix)
|
||||
prefixed_files: List[str] = [
|
||||
f"{test_prefix}001.txt",
|
||||
f"{test_prefix}002.json",
|
||||
f"{test_prefix}abc.pdf",
|
||||
f"{test_prefix}xyz-final.docx",
|
||||
]
|
||||
|
||||
# Files that should NOT be returned (don't start with prefix, even if they contain it)
|
||||
non_prefixed_files: List[str] = [
|
||||
f"other-{test_prefix}001.txt", # Contains prefix but doesn't start with it
|
||||
f"backup-{test_prefix}data.txt", # Contains prefix but doesn't start with it
|
||||
f"{uuid.uuid4()}.txt", # Random file without prefix
|
||||
"reports-001.pdf", # Different prefix
|
||||
f"my-{test_prefix[:-1]}.txt", # Similar but not exact prefix
|
||||
]
|
||||
|
||||
all_files = prefixed_files + non_prefixed_files
|
||||
saved_file_ids: List[str] = []
|
||||
|
||||
# Save all test files
|
||||
for file_name in all_files:
|
||||
content = f"Content for {file_name}"
|
||||
content_io = BytesIO(content.encode("utf-8"))
|
||||
|
||||
returned_file_id = file_store.save_file(
|
||||
content=content_io,
|
||||
display_name=f"Display: {file_name}",
|
||||
file_origin=FileOrigin.OTHER,
|
||||
file_type="text/plain",
|
||||
file_id=file_name,
|
||||
)
|
||||
saved_file_ids.append(returned_file_id)
|
||||
|
||||
# Verify file was saved
|
||||
assert returned_file_id == file_name
|
||||
|
||||
# Test the list_files_by_prefix functionality
|
||||
prefix_results = file_store.list_files_by_prefix(test_prefix)
|
||||
|
||||
# Extract file IDs from results
|
||||
returned_file_ids = [record.file_id for record in prefix_results]
|
||||
|
||||
# Verify correct number of files returned
|
||||
assert len(returned_file_ids) == len(prefixed_files), (
|
||||
f"Expected {len(prefixed_files)} files with prefix '{test_prefix}', "
|
||||
f"but got {len(returned_file_ids)}: {returned_file_ids}"
|
||||
)
|
||||
|
||||
# Verify all prefixed files are returned
|
||||
for expected_file_id in prefixed_files:
|
||||
assert expected_file_id in returned_file_ids, (
|
||||
f"File '{expected_file_id}' should be in results but was not found. "
|
||||
f"Returned files: {returned_file_ids}"
|
||||
)
|
||||
|
||||
# Verify no non-prefixed files are returned
|
||||
for unexpected_file_id in non_prefixed_files:
|
||||
assert unexpected_file_id not in returned_file_ids, (
|
||||
f"File '{unexpected_file_id}' should NOT be in results but was found. "
|
||||
f"Returned files: {returned_file_ids}"
|
||||
)
|
||||
|
||||
# Verify the returned records have correct properties
|
||||
for record in prefix_results:
|
||||
assert record.file_id.startswith(test_prefix)
|
||||
assert record.display_name == f"Display: {record.file_id}"
|
||||
assert record.file_origin == FileOrigin.OTHER
|
||||
assert record.file_type == "text/plain"
|
||||
assert record.bucket_name == file_store._get_bucket_name()
|
||||
|
||||
# Test with empty prefix (should return all files we created)
|
||||
all_results = file_store.list_files_by_prefix("")
|
||||
all_returned_ids = [record.file_id for record in all_results]
|
||||
|
||||
# Should include all our test files
|
||||
for file_id in saved_file_ids:
|
||||
assert (
|
||||
file_id in all_returned_ids
|
||||
), f"File '{file_id}' should be in results for empty prefix"
|
||||
|
||||
# Test with non-existent prefix
|
||||
nonexistent_results = file_store.list_files_by_prefix("nonexistent-prefix-")
|
||||
assert (
|
||||
len(nonexistent_results) == 0
|
||||
), "Should return empty list for non-existent prefix"
|
||||
|
||||
@@ -77,18 +77,15 @@ def sample_file_io(sample_content: bytes) -> BytesIO:
|
||||
class TestExternalStorageFileStore:
|
||||
"""Test external storage file store functionality (S3-compatible)"""
|
||||
|
||||
def test_get_default_file_store_s3(self, db_session: Session) -> None:
|
||||
def test_get_default_file_store_s3(self) -> None:
|
||||
"""Test that external storage file store is returned"""
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
assert isinstance(file_store, S3BackedFileStore)
|
||||
|
||||
def test_s3_client_initialization_with_credentials(
|
||||
self, db_session: Session
|
||||
) -> None:
|
||||
def test_s3_client_initialization_with_credentials(self) -> None:
|
||||
"""Test S3 client initialization with explicit credentials"""
|
||||
with patch("boto3.client") as mock_boto3:
|
||||
file_store = S3BackedFileStore(
|
||||
db_session,
|
||||
bucket_name="test-bucket",
|
||||
aws_access_key_id="test-key",
|
||||
aws_secret_access_key="test-secret",
|
||||
@@ -110,7 +107,6 @@ class TestExternalStorageFileStore:
|
||||
"""Test S3 client initialization with IAM role (no explicit credentials)"""
|
||||
with patch("boto3.client") as mock_boto3:
|
||||
file_store = S3BackedFileStore(
|
||||
db_session,
|
||||
bucket_name="test-bucket",
|
||||
aws_access_key_id=None,
|
||||
aws_secret_access_key=None,
|
||||
@@ -129,16 +125,16 @@ class TestExternalStorageFileStore:
|
||||
assert "aws_access_key_id" not in call_kwargs
|
||||
assert "aws_secret_access_key" not in call_kwargs
|
||||
|
||||
def test_s3_bucket_name_configuration(self, db_session: Session) -> None:
|
||||
def test_s3_bucket_name_configuration(self) -> None:
|
||||
"""Test S3 bucket name configuration"""
|
||||
with patch(
|
||||
"onyx.file_store.file_store.S3_FILE_STORE_BUCKET_NAME", "my-test-bucket"
|
||||
):
|
||||
file_store = S3BackedFileStore(db_session, bucket_name="my-test-bucket")
|
||||
file_store = S3BackedFileStore(bucket_name="my-test-bucket")
|
||||
bucket_name: str = file_store._get_bucket_name()
|
||||
assert bucket_name == "my-test-bucket"
|
||||
|
||||
def test_s3_key_generation_default_prefix(self, db_session: Session) -> None:
|
||||
def test_s3_key_generation_default_prefix(self) -> None:
|
||||
"""Test S3 key generation with default prefix"""
|
||||
with (
|
||||
patch("onyx.file_store.file_store.S3_FILE_STORE_PREFIX", "onyx-files"),
|
||||
@@ -147,11 +143,11 @@ class TestExternalStorageFileStore:
|
||||
return_value="test-tenant",
|
||||
),
|
||||
):
|
||||
file_store = S3BackedFileStore(db_session, bucket_name="test-bucket")
|
||||
file_store = S3BackedFileStore(bucket_name="test-bucket")
|
||||
s3_key: str = file_store._get_s3_key("test-file.txt")
|
||||
assert s3_key == "onyx-files/test-tenant/test-file.txt"
|
||||
|
||||
def test_s3_key_generation_custom_prefix(self, db_session: Session) -> None:
|
||||
def test_s3_key_generation_custom_prefix(self) -> None:
|
||||
"""Test S3 key generation with custom prefix"""
|
||||
with (
|
||||
patch("onyx.file_store.file_store.S3_FILE_STORE_PREFIX", "custom-prefix"),
|
||||
@@ -161,17 +157,15 @@ class TestExternalStorageFileStore:
|
||||
),
|
||||
):
|
||||
file_store = S3BackedFileStore(
|
||||
db_session, bucket_name="test-bucket", s3_prefix="custom-prefix"
|
||||
bucket_name="test-bucket", s3_prefix="custom-prefix"
|
||||
)
|
||||
s3_key: str = file_store._get_s3_key("test-file.txt")
|
||||
assert s3_key == "custom-prefix/test-tenant/test-file.txt"
|
||||
|
||||
def test_s3_key_generation_with_different_tenant_ids(
|
||||
self, db_session: Session
|
||||
) -> None:
|
||||
def test_s3_key_generation_with_different_tenant_ids(self) -> None:
|
||||
"""Test S3 key generation with different tenant IDs"""
|
||||
with patch("onyx.file_store.file_store.S3_FILE_STORE_PREFIX", "onyx-files"):
|
||||
file_store = S3BackedFileStore(db_session, bucket_name="test-bucket")
|
||||
file_store = S3BackedFileStore(bucket_name="test-bucket")
|
||||
|
||||
# Test with tenant ID "tenant-1"
|
||||
with patch(
|
||||
@@ -224,9 +218,7 @@ class TestExternalStorageFileStore:
|
||||
with patch("onyx.db.file_record.upsert_filerecord") as mock_upsert:
|
||||
mock_upsert.return_value = Mock()
|
||||
|
||||
file_store = S3BackedFileStore(
|
||||
mock_db_session, bucket_name="test-bucket"
|
||||
)
|
||||
file_store = S3BackedFileStore(bucket_name="test-bucket")
|
||||
|
||||
# This should not raise an exception
|
||||
file_store.save_file(
|
||||
@@ -235,6 +227,7 @@ class TestExternalStorageFileStore:
|
||||
display_name="Test File",
|
||||
file_origin=FileOrigin.OTHER,
|
||||
file_type="text/plain",
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
# Verify S3 client was called correctly
|
||||
@@ -244,14 +237,13 @@ class TestExternalStorageFileStore:
|
||||
assert call_args[1]["Key"] == "onyx-files/public/test-file.txt"
|
||||
assert call_args[1]["ContentType"] == "text/plain"
|
||||
|
||||
def test_minio_client_initialization(self, db_session: Session) -> None:
|
||||
def test_minio_client_initialization(self) -> None:
|
||||
"""Test S3 client initialization with MinIO endpoint"""
|
||||
with (
|
||||
patch("boto3.client") as mock_boto3,
|
||||
patch("urllib3.disable_warnings"),
|
||||
):
|
||||
file_store = S3BackedFileStore(
|
||||
db_session,
|
||||
bucket_name="test-bucket",
|
||||
aws_access_key_id="minioadmin",
|
||||
aws_secret_access_key="minioadmin",
|
||||
@@ -277,11 +269,10 @@ class TestExternalStorageFileStore:
|
||||
assert config.signature_version == "s3v4"
|
||||
assert config.s3["addressing_style"] == "path"
|
||||
|
||||
def test_minio_ssl_verification_enabled(self, db_session: Session) -> None:
|
||||
def test_minio_ssl_verification_enabled(self) -> None:
|
||||
"""Test MinIO with SSL verification enabled"""
|
||||
with patch("boto3.client") as mock_boto3:
|
||||
file_store = S3BackedFileStore(
|
||||
db_session,
|
||||
bucket_name="test-bucket",
|
||||
aws_access_key_id="test-key",
|
||||
aws_secret_access_key="test-secret",
|
||||
@@ -295,11 +286,10 @@ class TestExternalStorageFileStore:
|
||||
assert "verify" not in call_kwargs or call_kwargs.get("verify") is not False
|
||||
assert call_kwargs["endpoint_url"] == "https://minio.example.com"
|
||||
|
||||
def test_aws_s3_without_endpoint_url(self, db_session: Session) -> None:
|
||||
def test_aws_s3_without_endpoint_url(self) -> None:
|
||||
"""Test that regular AWS S3 doesn't include endpoint URL or custom config"""
|
||||
with patch("boto3.client") as mock_boto3:
|
||||
file_store = S3BackedFileStore(
|
||||
db_session,
|
||||
bucket_name="test-bucket",
|
||||
aws_access_key_id="test-key",
|
||||
aws_secret_access_key="test-secret",
|
||||
@@ -321,8 +311,8 @@ class TestExternalStorageFileStore:
|
||||
class TestFileStoreInterface:
|
||||
"""Test the general file store interface"""
|
||||
|
||||
def test_file_store_always_external_storage(self, db_session: Session) -> None:
|
||||
def test_file_store_always_external_storage(self) -> None:
|
||||
"""Test that external storage file store is always returned"""
|
||||
# File store should always be S3BackedFileStore regardless of environment
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store = get_default_file_store()
|
||||
assert isinstance(file_store, S3BackedFileStore)
|
||||
|
||||
@@ -472,6 +472,7 @@ services:
|
||||
environment:
|
||||
MINIO_ROOT_USER: ${MINIO_ROOT_USER:-minioadmin}
|
||||
MINIO_ROOT_PASSWORD: ${MINIO_ROOT_PASSWORD:-minioadmin}
|
||||
# Note: we've seen the default bucket creation logic not work in some cases
|
||||
MINIO_DEFAULT_BUCKETS: ${S3_FILE_STORE_BUCKET_NAME:-onyx-file-store-bucket}
|
||||
volumes:
|
||||
- minio_data:/data
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: {{ include "onyx-stack.fullname" . }}-celery-worker-docfetching
|
||||
labels:
|
||||
{{- include "onyx-stack.labels" . | nindent 4 }}
|
||||
spec:
|
||||
{{- if not .Values.celery_worker_docfetching.autoscaling.enabled }}
|
||||
replicas: {{ .Values.celery_worker_docfetching.replicaCount }}
|
||||
{{- end }}
|
||||
selector:
|
||||
matchLabels:
|
||||
{{- include "onyx-stack.selectorLabels" . | nindent 6 }}
|
||||
{{- if .Values.celery_worker_docfetching.deploymentLabels }}
|
||||
{{- toYaml .Values.celery_worker_docfetching.deploymentLabels | nindent 6 }}
|
||||
{{- end }}
|
||||
template:
|
||||
metadata:
|
||||
{{- with .Values.celery_worker_docfetching.podAnnotations }}
|
||||
annotations:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
labels:
|
||||
{{- include "onyx-stack.labels" . | nindent 8 }}
|
||||
{{- with .Values.celery_worker_docfetching.podLabels }}
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
{{- with .Values.imagePullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
serviceAccountName: {{ include "onyx-stack.serviceAccountName" . }}
|
||||
securityContext:
|
||||
{{- toYaml .Values.celery_worker_docfetching.podSecurityContext | nindent 8 }}
|
||||
containers:
|
||||
- name: celery-worker-docfetching
|
||||
securityContext:
|
||||
{{- toYaml .Values.celery_worker_docfetching.securityContext | nindent 12 }}
|
||||
image: "{{ .Values.celery_shared.image.repository }}:{{ .Values.celery_shared.image.tag | default .Chart.AppVersion }}"
|
||||
imagePullPolicy: {{ .Values.celery_shared.image.pullPolicy }}
|
||||
command:
|
||||
[
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.docfetching",
|
||||
"worker",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=docfetching@%n",
|
||||
"-Q",
|
||||
"connector_doc_fetching,user_files_indexing",
|
||||
]
|
||||
resources:
|
||||
{{- toYaml .Values.celery_worker_docfetching.resources | nindent 12 }}
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: {{ .Values.config.envConfigMapName }}
|
||||
env:
|
||||
{{- include "onyx-stack.envSecrets" . | nindent 12}}
|
||||
startupProbe:
|
||||
{{ .Values.celery_shared.startupProbe | toYaml | nindent 12}}
|
||||
readinessProbe:
|
||||
{{ .Values.celery_shared.readinessProbe | toYaml | nindent 12}}
|
||||
exec:
|
||||
command:
|
||||
- /bin/bash
|
||||
- -c
|
||||
- >
|
||||
python onyx/background/celery/celery_k8s_probe.py
|
||||
--probe readiness
|
||||
--filename /tmp/onyx_k8s_docfetching_readiness.txt
|
||||
livenessProbe:
|
||||
{{ .Values.celery_shared.livenessProbe | toYaml | nindent 12}}
|
||||
exec:
|
||||
command:
|
||||
- /bin/bash
|
||||
- -c
|
||||
- >
|
||||
python onyx/background/celery/celery_k8s_probe.py
|
||||
--probe liveness
|
||||
--filename /tmp/onyx_k8s_docfetching_liveness.txt
|
||||
@@ -48,7 +48,7 @@ spec:
|
||||
"--loglevel=INFO",
|
||||
"--hostname=indexing@%n",
|
||||
"-Q",
|
||||
"connector_indexing",
|
||||
"docprocessing",
|
||||
]
|
||||
resources:
|
||||
{{- toYaml .Values.celery_worker_indexing.resources | nindent 12 }}
|
||||
|
||||
@@ -538,6 +538,27 @@ slackbot:
|
||||
limits:
|
||||
cpu: "1000m"
|
||||
memory: "2000Mi"
|
||||
celery_worker_docfetching:
|
||||
replicaCount: 1
|
||||
autoscaling:
|
||||
enabled: false
|
||||
podAnnotations: {}
|
||||
podLabels:
|
||||
scope: onyx-backend-celery
|
||||
app: celery-worker-docfetching
|
||||
deploymentLabels:
|
||||
app: celery-worker-docfetching
|
||||
podSecurityContext:
|
||||
{}
|
||||
securityContext:
|
||||
privileged: true
|
||||
runAsUser: 0
|
||||
resources: {}
|
||||
volumes: [] # Additional volumes on the output Deployment definition.
|
||||
volumeMounts: [] # Additional volumeMounts on the output Deployment definition.
|
||||
nodeSelector: {}
|
||||
tolerations: []
|
||||
affinity: {}
|
||||
|
||||
redis:
|
||||
enabled: true
|
||||
|
||||
@@ -217,7 +217,7 @@ const collections = (
|
||||
<div className="ml-1">Document Processing</div>
|
||||
</div>
|
||||
),
|
||||
link: "/admin/configuration/document-processing",
|
||||
link: "/admin/configuration/docfetching",
|
||||
},
|
||||
...(kgExposed
|
||||
? [
|
||||
|
||||
Reference in New Issue
Block a user