Compare commits

...

5 Commits

Author SHA1 Message Date
Richard Kuo (Danswer)
8a915d3d3e Merge branch 'main' of https://github.com/onyx-dot-app/onyx into feature/propagate-exceptions 2025-02-13 14:10:38 -08:00
Richard Kuo (Danswer)
2e1b13f400 refactor the watchdog more to emit data consistently at the end of the function 2025-02-13 14:08:14 -08:00
Richard Kuo (Danswer)
18d4354f58 Merge branch 'main' of https://github.com/onyx-dot-app/onyx into feature/propagate-exceptions
# Conflicts:
#	backend/onyx/background/celery/tasks/indexing/tasks.py
2025-02-12 19:38:41 -08:00
Richard Kuo (Danswer)
7641f225c2 remove debug testing 2025-02-12 13:50:26 -08:00
Richard Kuo (Danswer)
454c20e4d8 better propagation of exceptions up the stack 2025-02-12 13:49:59 -08:00
5 changed files with 319 additions and 208 deletions

View File

@@ -1,9 +1,10 @@
import multiprocessing
import os
import sys
import time
import traceback
from datetime import datetime
from datetime import timezone
from enum import Enum
from http import HTTPStatus
from time import sleep
from typing import Any
@@ -15,6 +16,7 @@ from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from celery.result import AsyncResult
from celery.states import READY_STATES
from pydantic import BaseModel
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
@@ -26,6 +28,7 @@ from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attem
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
from onyx.background.celery.tasks.indexing.utils import try_creating_indexing_task
from onyx.background.celery.tasks.indexing.utils import validate_indexing_fences
from onyx.background.indexing.job_client import SimpleJob
from onyx.background.indexing.job_client import SimpleJobClient
from onyx.background.indexing.run_indexing import run_indexing_entrypoint
from onyx.configs.app_configs import MANAGED_VESPA
@@ -70,6 +73,65 @@ from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
class IndexingWatchdogTerminalStatus(str, Enum):
"""The different statuses the watchdog can finish with"""
UNDEFINED = "undefined"
SUCCEEDED = "succeeded"
SPAWN_FAILED = "spawn_failed" # connector spawn failed
CONNECTOR_EXCEPTIONED = "connector_exceptioned" # the connector itself exceptioned
WATCHDOG_EXCEPTIONED = "watchdog_exceptioned" # the watchdog exceptioned
# the watchdog received a termination signal
TERMINATED_BY_SIGNAL = "terminated_by_signal"
# the watchdog terminated the task due to no activity
TERMINATED_BY_ACTIVITY_TIMEOUT = "terminated_by_activity_timeout"
class SimpleJobResult:
"""The data we want to have when the watchdog finishes"""
def __init__(self) -> None:
self.status = IndexingWatchdogTerminalStatus.UNDEFINED
self.connector_source = None
self.exit_code = None
self.exception_str = None
status: IndexingWatchdogTerminalStatus
connector_source: str | None
exit_code: int | None
exception_str: str | None
class ConnectorIndexingContext(BaseModel):
tenant_id: str | None
cc_pair_id: int
search_settings_id: int
index_attempt_id: int
class ConnectorIndexingLogBuilder:
def __init__(self, ctx: ConnectorIndexingContext):
self.ctx = ctx
def build(self, msg: str, **kwargs: Any) -> str:
msg_final = (
f"{msg}: "
f"tenant_id={self.ctx.tenant_id} "
f"attempt={self.ctx.index_attempt_id} "
f"cc_pair={self.ctx.cc_pair_id} "
f"search_settings={self.ctx.search_settings_id}"
)
# Append extra keyword arguments in logfmt style
if kwargs:
extra_logfmt = " ".join(f"{key}={value}" for key, value in kwargs.items())
msg_final = f"{msg_final} {extra_logfmt}"
return msg_final
def monitor_ccpair_indexing_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
@@ -496,7 +558,6 @@ def connector_indexing_task(
f"search_settings={search_settings_id}"
)
attempt_found = False
n_final_progress: int | None = None
# 20 is the documented default for httpx max_keepalive_connections
@@ -595,7 +656,7 @@ def connector_indexing_task(
raise ValueError(
f"Index attempt not found: index_attempt={index_attempt_id}"
)
attempt_found = True
# attempt_found = True
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
@@ -650,20 +711,6 @@ def connector_indexing_task(
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
if attempt_found:
try:
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_failed(
index_attempt_id, db_session, failure_reason=str(e)
)
except Exception:
logger.exception(
"Indexing watchdog - transient exception looking up index attempt: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
raise e
finally:
@@ -678,41 +725,46 @@ def connector_indexing_task(
return n_final_progress
def connector_indexing_task_wrapper(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
is_ee: bool,
) -> int | None:
"""Just wraps connector_indexing_task so we can log any exceptions before
re-raising it."""
result: int | None = None
def process_job_result(
job: SimpleJob,
connector_source: str | None,
redis_connector_index: RedisConnectorIndex,
log_builder: ConnectorIndexingLogBuilder,
) -> SimpleJobResult:
result = SimpleJobResult()
result.connector_source = connector_source
try:
result = connector_indexing_task(
index_attempt_id,
cc_pair_id,
search_settings_id,
tenant_id,
is_ee,
)
except Exception:
logger.exception(
f"connector_indexing_task exceptioned: "
f"tenant={tenant_id} "
f"index_attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
if job.process:
result.exit_code = job.process.exitcode
# There is a cloud related bug outside of our code
# where spawned tasks return with an exit code of 1.
# Unfortunately, exceptions also return with an exit code of 1,
# so just raising an exception isn't informative
# Exiting with 255 makes it possible to distinguish between normal exits
# and exceptions.
sys.exit(255)
if job.status != "error":
result.status = IndexingWatchdogTerminalStatus.SUCCEEDED
return result
ignore_exitcode = False
# In EKS, there is an edge case where successful tasks return exit
# code 1 in the cloud due to the set_spawn_method not sticking.
# We've since worked around this, but the following is a safe way to
# work around this issue. Basically, we ignore the job error state
# if the completion signal is OK.
status_int = redis_connector_index.get_completion()
if status_int:
status_enum = HTTPStatus(status_int)
if status_enum == HTTPStatus.OK:
ignore_exitcode = True
if not ignore_exitcode:
result.status = IndexingWatchdogTerminalStatus.CONNECTOR_EXCEPTIONED
result.exception_str = job.exception()
else:
task_logger.warning(
log_builder.build(
"Indexing watchdog - spawned task has non-zero exit code "
"but completion signal is OK. Continuing...",
exit_code=str(result.exit_code),
)
)
return result
@@ -730,12 +782,32 @@ def connector_indexing_proxy_task(
search_settings_id: int,
tenant_id: str | None,
) -> None:
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
"""celery out of process task execution strategy is pool=prefork, but it uses fork,
and forking is inherently unstable.
To work around this, we use pool=threads and proxy our work to a spawned task.
TODO(rkuo): refactor this so that there is a single return path where we canonically
log the result of running this function.
"""
start = time.monotonic()
result = SimpleJobResult()
ctx = ConnectorIndexingContext(
tenant_id=tenant_id,
cc_pair_id=cc_pair_id,
search_settings_id=search_settings_id,
index_attempt_id=index_attempt_id,
)
log_builder = ConnectorIndexingLogBuilder(ctx)
task_logger.info(
f"Indexing watchdog - starting: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"mp_start_method={multiprocessing.get_start_method()}"
log_builder.build(
"Indexing watchdog - starting",
mp_start_method=str(multiprocessing.get_start_method()),
)
)
if not self.request.id:
@@ -744,7 +816,7 @@ def connector_indexing_proxy_task(
client = SimpleJobClient()
job = client.submit(
connector_indexing_task_wrapper,
connector_indexing_task,
index_attempt_id,
cc_pair_id,
search_settings_id,
@@ -754,89 +826,97 @@ def connector_indexing_proxy_task(
)
if not job:
result.status = IndexingWatchdogTerminalStatus.SPAWN_FAILED
task_logger.info(
f"Indexing watchdog - spawn failed: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
log_builder.build(
"Indexing watchdog - finished",
status=str(result.status.value),
exit_code=str(result.exit_code),
)
)
return
task_logger.info(
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
task_logger.info(log_builder.build("Indexing watchdog - spawn succeeded"))
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
while True:
sleep(5)
try:
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
if not index_attempt:
raise RuntimeError("Index attempt not found")
# renew watchdog signal (this has a shorter timeout than set_active)
redis_connector_index.set_watchdog(True)
# renew active signal
redis_connector_index.set_active()
# if the job is done, clean up and break
if job.done():
exit_code: int | None
try:
if job.status == "error":
ignore_exitcode = False
exit_code = None
if job.process:
exit_code = job.process.exitcode
# seeing odd behavior where spawned tasks usually return exit code 1 in the cloud,
# even though logging clearly indicates successful completion
# to work around this, we ignore the job error state if the completion signal is OK
status_int = redis_connector_index.get_completion()
if status_int:
status_enum = HTTPStatus(status_int)
if status_enum == HTTPStatus.OK:
ignore_exitcode = True
if not ignore_exitcode:
raise RuntimeError("Spawned task exceptioned.")
task_logger.warning(
"Indexing watchdog - spawned task has non-zero exit code "
"but completion signal is OK. Continuing...: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"exit_code={exit_code}"
)
except Exception:
task_logger.error(
"Indexing watchdog - spawned task exceptioned: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"exit_code={exit_code} "
f"error={job.exception()}"
)
raise
finally:
job.release()
break
# if a termination signal is detected, clean up and break
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
"Indexing watchdog - termination signal detected: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
result.connector_source = (
index_attempt.connector_credential_pair.connector.source.value
)
while True:
sleep(5)
# renew watchdog signal (this has a shorter timeout than set_active)
redis_connector_index.set_watchdog(True)
# renew active signal
redis_connector_index.set_active()
# if the job is done, clean up and break
if job.done():
try:
result = process_job_result(
job, result.connector_source, redis_connector_index, log_builder
)
except Exception:
task_logger.exception(
log_builder.build(
"Indexing watchdog - spawned task exceptioned"
)
)
finally:
job.release()
break
# if a termination signal is detected, clean up and break
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
log_builder.build("Indexing watchdog - termination signal detected")
)
result.status = IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL
break
# if the spawned task is still running, restart the check once again
# if the index attempt is not in a finished status
try:
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
except Exception:
# if the DB exceptioned, just restart the check.
# polling the index attempt status doesn't need to be strongly consistent
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception looking up index attempt"
)
)
continue
except Exception:
result.status = IndexingWatchdogTerminalStatus.WATCHDOG_EXCEPTIONED
result.exception_str = traceback.format_exc()
elapsed = time.monotonic() - start
if result.exception_str is None:
# print without exception
if result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL:
try:
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_canceled(
@@ -847,46 +927,53 @@ def connector_indexing_proxy_task(
except Exception:
# if the DB exceptions, we'll just get an unfriendly failure message
# in the UI instead of the cancellation message
logger.exception(
"Indexing watchdog - transient exception marking index attempt as canceled: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as canceled"
)
)
job.cancel()
break
# if the spawned task is still running, restart the check once again
# if the index attempt is not in a finished status
task_logger.info(
log_builder.build(
"Indexing watchdog - finished",
source=result.connector_source,
status=str(result.status.value),
exit_code=str(result.exit_code),
elapsed=f"{elapsed:.2f}s",
)
)
else:
# print with exception
try:
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
failure_reason = (
f"Spawned task exceptioned: exit_code={result.exit_code}"
)
mark_attempt_failed(
ctx.index_attempt_id,
db_session,
failure_reason=failure_reason,
full_exception_trace=result.exception_str,
)
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
except Exception:
# if the DB exceptioned, just restart the check.
# polling the index attempt status doesn't need to be strongly consistent
logger.exception(
"Indexing watchdog - transient exception looking up index attempt: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as failed"
)
)
continue
task_logger.warning(
log_builder.build(
"Indexing watchdog - finished",
source=result.connector_source,
status=result.status.value,
exit_code=str(result.exit_code),
exception=str(result.exception_str),
elapsed=f"{elapsed:.2f}s",
)
)
redis_connector_index.set_watchdog(False)
task_logger.info(
f"Indexing watchdog - finished: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
return

View File

@@ -240,7 +240,8 @@ def validate_indexing_fence(
# it would be odd to get here as there isn't that much that can go wrong during
# initial fence setup, but it's still worth making sure we can recover
logger.info(
f"validate_indexing_fence - Resetting fence in basic state without any activity: fence={fence_key}"
f"validate_indexing_fence - "
f"Resetting fence in basic state without any activity: fence={fence_key}"
)
redis_connector_index.reset()
return

View File

@@ -5,6 +5,8 @@ not follow the expected behavior, etc.
NOTE: cannot use Celery directly due to
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
import multiprocessing as mp
import sys
import traceback
from collections.abc import Callable
from dataclasses import dataclass
from multiprocessing.context import SpawnProcess
@@ -28,7 +30,10 @@ JobStatusType = (
def _initializer(
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
func: Callable,
queue: mp.Queue,
args: list | tuple,
kwargs: dict[str, Any] | None = None,
) -> Any:
"""Initialize the child process with a fresh SQLAlchemy Engine.
@@ -52,13 +57,23 @@ def _initializer(
)
# Proceed with executing the target function
return func(*args, **kwargs)
try:
return func(*args, **kwargs)
except Exception:
logger.exception("SimpleJob exceptioned")
error_msg = traceback.format_exc()
queue.put(error_msg) # Send the exception to the parent process
sys.exit(255) # use a specific exit code when we handle a top level exception
def _run_in_process(
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
func: Callable,
queue: mp.Queue,
args: list | tuple,
kwargs: dict[str, Any] | None = None,
) -> None:
_initializer(func, args, kwargs)
_initializer(func, queue, args, kwargs)
@dataclass
@@ -67,6 +82,8 @@ class SimpleJob:
id: int
process: Optional["SpawnProcess"] = None
queue: Optional[mp.Queue] = None
_exception: Optional[str] = None
def cancel(self) -> bool:
return self.release()
@@ -100,9 +117,15 @@ class SimpleJob:
def exception(self) -> str:
"""Needed to match the Dask API, but not implemented since we don't currently
have a way to get back the exception information from the child process."""
return (
f"Job with ID '{self.id}' was killed or encountered an unhandled exception."
)
"""Retrieve exception from the multiprocessing queue if available."""
if self._exception is None and self.queue and not self.queue.empty():
self._exception = self.queue.get() # Get exception from queue
if self._exception:
return self._exception
return f"Job with ID '{self.id}' did not report an exception."
class SimpleJobClient:
@@ -137,8 +160,11 @@ class SimpleJobClient:
# this approach allows us to always "spawn" a new process regardless of
# get_start_method's current setting
ctx = mp.get_context("spawn")
process = ctx.Process(target=_run_in_process, args=(func, args), daemon=True)
job = SimpleJob(id=job_id, process=process)
queue = ctx.Queue()
process = ctx.Process(
target=_run_in_process, args=(func, queue, args), daemon=True
)
job = SimpleJob(id=job_id, process=process, queue=queue)
process.start()
self.jobs[job_id] = job

View File

@@ -558,46 +558,43 @@ def run_indexing_entrypoint(
is_ee: bool = False,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
try:
if is_ee:
global_version.set_ee()
"""Don't swallow exceptions here ... propagate them up."""
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
TaskAttemptSingleton.set_cc_and_index_id(
index_attempt_id, connector_credential_pair_id
if is_ee:
global_version.set_ee()
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
TaskAttemptSingleton.set_cc_and_index_id(
index_attempt_id, connector_credential_pair_id
)
with get_session_with_tenant(tenant_id) as db_session:
# TODO: remove long running session entirely
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
tenant_str = ""
if tenant_id is not None:
tenant_str = f" for tenant {tenant_id}"
connector_name = attempt.connector_credential_pair.connector.name
connector_config = (
attempt.connector_credential_pair.connector.connector_specific_config
)
with get_session_with_tenant(tenant_id) as db_session:
# TODO: remove long running session entirely
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
credential_id = attempt.connector_credential_pair.credential_id
tenant_str = ""
if tenant_id is not None:
tenant_str = f" for tenant {tenant_id}"
logger.info(
f"Indexing starting{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
connector_name = attempt.connector_credential_pair.connector.name
connector_config = (
attempt.connector_credential_pair.connector.connector_specific_config
)
credential_id = attempt.connector_credential_pair.credential_id
with get_session_with_tenant(tenant_id) as db_session:
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
logger.info(
f"Indexing starting{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
with get_session_with_tenant(tenant_id) as db_session:
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
logger.info(
f"Indexing finished{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
except Exception as e:
logger.exception(
f"Indexing job with ID '{index_attempt_id}' for tenant {tenant_id} failed due to {e}"
)
logger.info(
f"Indexing finished{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)

View File

@@ -76,6 +76,6 @@ class ConnectorRunner:
)
logger.error(
f"Error in connector. type: {exc_type};\n"
f"local_vars below -> \n{local_vars_str}"
f"local_vars below -> \n{local_vars_str[:1024]}"
)
raise