Compare commits

..

1 Commits

Author SHA1 Message Date
Justin Tahara
e19198f1f2 chore(mt): reduce cleanup-idle-sandboxes beat cadence (#9984) 2026-04-08 02:29:21 +00:00
8 changed files with 52 additions and 423 deletions

View File

@@ -1,4 +1,3 @@
import time
from collections.abc import Generator
from collections.abc import Iterator
from collections.abc import Sequence
@@ -31,8 +30,6 @@ from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import SlimDocument
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.server.metrics.pruning_metrics import inc_pruning_rate_limit_error
from onyx.server.metrics.pruning_metrics import observe_pruning_enumeration_duration
from onyx.utils.logger import setup_logger
@@ -133,7 +130,6 @@ def _extract_from_batch(
def extract_ids_from_runnable_connector(
runnable_connector: BaseConnector,
callback: IndexingHeartbeatInterface | None = None,
connector_type: str = "unknown",
) -> SlimConnectorExtractionResult:
"""
Extract document IDs and hierarchy nodes from a runnable connector.
@@ -183,38 +179,21 @@ def extract_ids_from_runnable_connector(
)
# process raw batches to extract both IDs and hierarchy nodes
enumeration_start = time.monotonic()
try:
for doc_list in raw_batch_generator:
if callback and callback.should_stop():
raise RuntimeError(
"extract_ids_from_runnable_connector: Stop signal detected"
)
for doc_list in raw_batch_generator:
if callback and callback.should_stop():
raise RuntimeError(
"extract_ids_from_runnable_connector: Stop signal detected"
)
batch_result = _extract_from_batch(doc_list)
batch_ids = batch_result.raw_id_to_parent
batch_nodes = batch_result.hierarchy_nodes
doc_batch_processing_func(batch_ids)
all_raw_id_to_parent.update(batch_ids)
all_hierarchy_nodes.extend(batch_nodes)
batch_result = _extract_from_batch(doc_list)
batch_ids = batch_result.raw_id_to_parent
batch_nodes = batch_result.hierarchy_nodes
doc_batch_processing_func(batch_ids)
all_raw_id_to_parent.update(batch_ids)
all_hierarchy_nodes.extend(batch_nodes)
if callback:
callback.progress("extract_ids_from_runnable_connector", len(batch_ids))
except Exception as e:
# Best-effort rate limit detection via string matching.
# Connectors surface rate limits inconsistently — some raise HTTP 429,
# some use SDK-specific exceptions (e.g. google.api_core.exceptions.ResourceExhausted)
# that may or may not include "rate limit" or "429" in the message.
# TODO(Bo): replace with a standard ConnectorRateLimitError exception that all
# connectors raise when rate limited, making this check precise.
error_str = str(e)
if "rate limit" in error_str.lower() or "429" in error_str:
inc_pruning_rate_limit_error(connector_type)
raise
finally:
observe_pruning_enumeration_duration(
time.monotonic() - enumeration_start, connector_type
)
if callback:
callback.progress("extract_ids_from_runnable_connector", len(batch_ids))
return SlimConnectorExtractionResult(
raw_id_to_parent=all_raw_id_to_parent,

View File

@@ -142,7 +142,14 @@ beat_task_templates: list[dict] = [
{
"name": "cleanup-idle-sandboxes",
"task": OnyxCeleryTask.CLEANUP_IDLE_SANDBOXES,
"schedule": timedelta(minutes=1),
# SANDBOX_IDLE_TIMEOUT_SECONDS defaults to 1 hour, so there is no
# functional reason to scan more often than every ~15 minutes. In the
# cloud this is multiplied by CLOUD_BEAT_MULTIPLIER_DEFAULT (=8) so
# the effective cadence becomes ~2 hours, which still meets the
# idle-detection SLA. The previous 1-minute base schedule produced
# an 8-minute per-tenant fan-out and was the dominant source of
# background DB load on the cloud cluster.
"schedule": timedelta(minutes=15),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,

View File

@@ -72,7 +72,6 @@ from onyx.redis.redis_hierarchy import get_source_node_id_from_cache
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.metrics.pruning_metrics import observe_pruning_diff_duration
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.server.utils import make_short_id
from onyx.utils.logger import format_error_for_logging
@@ -571,9 +570,8 @@ def connector_pruning_generator_task(
)
# Extract docs and hierarchy nodes from the source
connector_type = cc_pair.connector.source.value
extraction_result = extract_ids_from_runnable_connector(
runnable_connector, callback, connector_type=connector_type
runnable_connector, callback
)
all_connector_doc_ids = extraction_result.raw_id_to_parent
@@ -638,46 +636,40 @@ def connector_pruning_generator_task(
commit=True,
)
diff_start = time.monotonic()
try:
# a list of docs in our local index
all_indexed_document_ids = {
doc.id
for doc in get_documents_for_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
}
# a list of docs in our local index
all_indexed_document_ids = {
doc.id
for doc in get_documents_for_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
}
# generate list of docs to remove (no longer in the source)
doc_ids_to_remove = list(
all_indexed_document_ids - all_connector_doc_ids.keys()
)
# generate list of docs to remove (no longer in the source)
doc_ids_to_remove = list(
all_indexed_document_ids - all_connector_doc_ids.keys()
)
task_logger.info(
"Pruning set collected: "
f"cc_pair={cc_pair_id} "
f"connector_source={cc_pair.connector.source} "
f"docs_to_remove={len(doc_ids_to_remove)}"
)
task_logger.info(
"Pruning set collected: "
f"cc_pair={cc_pair_id} "
f"connector_source={cc_pair.connector.source} "
f"docs_to_remove={len(doc_ids_to_remove)}"
)
task_logger.info(
f"RedisConnector.prune.generate_tasks starting. cc_pair={cc_pair_id}"
)
tasks_generated = redis_connector.prune.generate_tasks(
set(doc_ids_to_remove), self.app, db_session, None
)
if tasks_generated is None:
return None
task_logger.info(
f"RedisConnector.prune.generate_tasks starting. cc_pair={cc_pair_id}"
)
tasks_generated = redis_connector.prune.generate_tasks(
set(doc_ids_to_remove), self.app, db_session, None
)
if tasks_generated is None:
return None
task_logger.info(
f"RedisConnector.prune.generate_tasks finished. cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
)
finally:
observe_pruning_diff_duration(
time.monotonic() - diff_start, connector_type
)
task_logger.info(
f"RedisConnector.prune.generate_tasks finished. cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
)
redis_connector.prune.generator_complete = tasks_generated

View File

@@ -1,72 +0,0 @@
"""Pruning-specific Prometheus metrics.
Tracks three pruning pipeline phases for connector_pruning_generator_task:
1. Document ID enumeration duration (extract_ids_from_runnable_connector)
2. Diff + dispatch duration (DB lookup, set diff, generate_tasks)
3. Rate limit errors during enumeration
All metrics are labeled by connector_type to identify which connector sources
are the most expensive to prune. cc_pair_id is intentionally excluded to avoid
unbounded cardinality.
Usage:
from onyx.server.metrics.pruning_metrics import (
observe_pruning_enumeration_duration,
observe_pruning_diff_duration,
inc_pruning_rate_limit_error,
)
"""
from prometheus_client import Counter
from prometheus_client import Histogram
from onyx.utils.logger import setup_logger
logger = setup_logger()
PRUNING_ENUMERATION_DURATION = Histogram(
"onyx_pruning_enumeration_duration_seconds",
"Duration of document ID enumeration from the source connector during pruning",
["connector_type"],
buckets=[1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600],
)
PRUNING_DIFF_DURATION = Histogram(
"onyx_pruning_diff_duration_seconds",
"Duration of diff computation and subtask dispatch during pruning",
["connector_type"],
buckets=[1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600],
)
PRUNING_RATE_LIMIT_ERRORS = Counter(
"onyx_pruning_rate_limit_errors_total",
"Total rate limit errors encountered during pruning document ID enumeration",
["connector_type"],
)
def observe_pruning_enumeration_duration(
duration_seconds: float, connector_type: str
) -> None:
try:
PRUNING_ENUMERATION_DURATION.labels(connector_type=connector_type).observe(
duration_seconds
)
except Exception:
logger.debug("Failed to record pruning enumeration duration", exc_info=True)
def observe_pruning_diff_duration(duration_seconds: float, connector_type: str) -> None:
try:
PRUNING_DIFF_DURATION.labels(connector_type=connector_type).observe(
duration_seconds
)
except Exception:
logger.debug("Failed to record pruning diff duration", exc_info=True)
def inc_pruning_rate_limit_error(connector_type: str) -> None:
try:
PRUNING_RATE_LIMIT_ERRORS.labels(connector_type=connector_type).inc()
except Exception:
logger.debug("Failed to record pruning rate limit error", exc_info=True)

View File

@@ -1,149 +0,0 @@
"""Unit tests for extract_ids_from_runnable_connector metrics instrumentation."""
from collections.abc import Iterator
from unittest.mock import MagicMock
import pytest
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import SlimDocument
from onyx.server.metrics.pruning_metrics import PRUNING_ENUMERATION_DURATION
from onyx.server.metrics.pruning_metrics import PRUNING_RATE_LIMIT_ERRORS
def _make_slim_connector(doc_ids: list[str]) -> SlimConnector:
"""Mock SlimConnector that yields the given doc IDs in one batch."""
connector = MagicMock(spec=SlimConnector)
docs = [
MagicMock(spec=SlimDocument, id=doc_id, parent_hierarchy_raw_node_id=None)
for doc_id in doc_ids
]
connector.retrieve_all_slim_docs.return_value = iter([docs])
return connector
def _raising_connector(message: str) -> SlimConnector:
"""Mock SlimConnector whose generator raises with the given message."""
connector = MagicMock(spec=SlimConnector)
def raising_iter() -> Iterator:
raise Exception(message)
yield
connector.retrieve_all_slim_docs.return_value = raising_iter()
return connector
class TestEnumerationDuration:
def test_recorded_on_success(self) -> None:
connector = _make_slim_connector(["doc1"])
before = PRUNING_ENUMERATION_DURATION.labels(
connector_type="google_drive"
)._sum.get()
extract_ids_from_runnable_connector(connector, connector_type="google_drive")
after = PRUNING_ENUMERATION_DURATION.labels(
connector_type="google_drive"
)._sum.get()
assert after >= before # duration observed (non-negative)
def test_recorded_on_exception(self) -> None:
connector = _raising_connector("unexpected error")
before = PRUNING_ENUMERATION_DURATION.labels(
connector_type="confluence"
)._sum.get()
with pytest.raises(Exception):
extract_ids_from_runnable_connector(connector, connector_type="confluence")
after = PRUNING_ENUMERATION_DURATION.labels(
connector_type="confluence"
)._sum.get()
assert after >= before # duration observed even on exception
class TestRateLimitDetection:
def test_increments_on_rate_limit_message(self) -> None:
connector = _raising_connector("rate limit exceeded")
before = PRUNING_RATE_LIMIT_ERRORS.labels(
connector_type="google_drive"
)._value.get()
with pytest.raises(Exception, match="rate limit exceeded"):
extract_ids_from_runnable_connector(
connector, connector_type="google_drive"
)
after = PRUNING_RATE_LIMIT_ERRORS.labels(
connector_type="google_drive"
)._value.get()
assert after == before + 1
def test_increments_on_429_in_message(self) -> None:
connector = _raising_connector("HTTP 429 Too Many Requests")
before = PRUNING_RATE_LIMIT_ERRORS.labels(
connector_type="confluence"
)._value.get()
with pytest.raises(Exception, match="429"):
extract_ids_from_runnable_connector(connector, connector_type="confluence")
after = PRUNING_RATE_LIMIT_ERRORS.labels(
connector_type="confluence"
)._value.get()
assert after == before + 1
def test_does_not_increment_on_non_rate_limit_exception(self) -> None:
connector = _raising_connector("connection timeout")
before = PRUNING_RATE_LIMIT_ERRORS.labels(connector_type="slack")._value.get()
with pytest.raises(Exception, match="connection timeout"):
extract_ids_from_runnable_connector(connector, connector_type="slack")
after = PRUNING_RATE_LIMIT_ERRORS.labels(connector_type="slack")._value.get()
assert after == before
def test_rate_limit_detection_is_case_insensitive(self) -> None:
connector = _raising_connector("RATE LIMIT exceeded")
before = PRUNING_RATE_LIMIT_ERRORS.labels(connector_type="jira")._value.get()
with pytest.raises(Exception):
extract_ids_from_runnable_connector(connector, connector_type="jira")
after = PRUNING_RATE_LIMIT_ERRORS.labels(connector_type="jira")._value.get()
assert after == before + 1
def test_connector_type_label_matches_input(self) -> None:
connector = _raising_connector("rate limit exceeded")
before_gd = PRUNING_RATE_LIMIT_ERRORS.labels(
connector_type="google_drive"
)._value.get()
before_jira = PRUNING_RATE_LIMIT_ERRORS.labels(
connector_type="jira"
)._value.get()
with pytest.raises(Exception):
extract_ids_from_runnable_connector(
connector, connector_type="google_drive"
)
assert (
PRUNING_RATE_LIMIT_ERRORS.labels(connector_type="google_drive")._value.get()
== before_gd + 1
)
assert (
PRUNING_RATE_LIMIT_ERRORS.labels(connector_type="jira")._value.get()
== before_jira
)
def test_defaults_to_unknown_connector_type(self) -> None:
connector = _raising_connector("rate limit exceeded")
before = PRUNING_RATE_LIMIT_ERRORS.labels(connector_type="unknown")._value.get()
with pytest.raises(Exception):
extract_ids_from_runnable_connector(connector)
after = PRUNING_RATE_LIMIT_ERRORS.labels(connector_type="unknown")._value.get()
assert after == before + 1

View File

@@ -1,128 +0,0 @@
"""Tests for pruning-specific Prometheus metrics."""
import pytest
from onyx.server.metrics.pruning_metrics import inc_pruning_rate_limit_error
from onyx.server.metrics.pruning_metrics import observe_pruning_diff_duration
from onyx.server.metrics.pruning_metrics import observe_pruning_enumeration_duration
from onyx.server.metrics.pruning_metrics import PRUNING_DIFF_DURATION
from onyx.server.metrics.pruning_metrics import PRUNING_ENUMERATION_DURATION
from onyx.server.metrics.pruning_metrics import PRUNING_RATE_LIMIT_ERRORS
class TestObservePruningEnumerationDuration:
def test_observes_duration(self) -> None:
before = PRUNING_ENUMERATION_DURATION.labels(
connector_type="google_drive"
)._sum.get()
observe_pruning_enumeration_duration(10.0, "google_drive")
after = PRUNING_ENUMERATION_DURATION.labels(
connector_type="google_drive"
)._sum.get()
assert after == pytest.approx(before + 10.0)
def test_labels_by_connector_type(self) -> None:
before_gd = PRUNING_ENUMERATION_DURATION.labels(
connector_type="google_drive"
)._sum.get()
before_conf = PRUNING_ENUMERATION_DURATION.labels(
connector_type="confluence"
)._sum.get()
observe_pruning_enumeration_duration(5.0, "google_drive")
after_gd = PRUNING_ENUMERATION_DURATION.labels(
connector_type="google_drive"
)._sum.get()
after_conf = PRUNING_ENUMERATION_DURATION.labels(
connector_type="confluence"
)._sum.get()
assert after_gd == pytest.approx(before_gd + 5.0)
assert after_conf == pytest.approx(before_conf)
def test_does_not_raise_on_exception(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
PRUNING_ENUMERATION_DURATION,
"labels",
lambda **_: (_ for _ in ()).throw(RuntimeError("boom")),
)
observe_pruning_enumeration_duration(1.0, "google_drive")
class TestObservePruningDiffDuration:
def test_observes_duration(self) -> None:
before = PRUNING_DIFF_DURATION.labels(connector_type="confluence")._sum.get()
observe_pruning_diff_duration(3.0, "confluence")
after = PRUNING_DIFF_DURATION.labels(connector_type="confluence")._sum.get()
assert after == pytest.approx(before + 3.0)
def test_labels_by_connector_type(self) -> None:
before_conf = PRUNING_DIFF_DURATION.labels(
connector_type="confluence"
)._sum.get()
before_slack = PRUNING_DIFF_DURATION.labels(connector_type="slack")._sum.get()
observe_pruning_diff_duration(2.0, "confluence")
after_conf = PRUNING_DIFF_DURATION.labels(
connector_type="confluence"
)._sum.get()
after_slack = PRUNING_DIFF_DURATION.labels(connector_type="slack")._sum.get()
assert after_conf == pytest.approx(before_conf + 2.0)
assert after_slack == pytest.approx(before_slack)
def test_does_not_raise_on_exception(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
PRUNING_DIFF_DURATION,
"labels",
lambda **_: (_ for _ in ()).throw(RuntimeError("boom")),
)
observe_pruning_diff_duration(1.0, "confluence")
class TestIncPruningRateLimitError:
def test_increments_counter(self) -> None:
before = PRUNING_RATE_LIMIT_ERRORS.labels(
connector_type="google_drive"
)._value.get()
inc_pruning_rate_limit_error("google_drive")
after = PRUNING_RATE_LIMIT_ERRORS.labels(
connector_type="google_drive"
)._value.get()
assert after == before + 1
def test_labels_by_connector_type(self) -> None:
before_gd = PRUNING_RATE_LIMIT_ERRORS.labels(
connector_type="google_drive"
)._value.get()
before_jira = PRUNING_RATE_LIMIT_ERRORS.labels(
connector_type="jira"
)._value.get()
inc_pruning_rate_limit_error("google_drive")
after_gd = PRUNING_RATE_LIMIT_ERRORS.labels(
connector_type="google_drive"
)._value.get()
after_jira = PRUNING_RATE_LIMIT_ERRORS.labels(
connector_type="jira"
)._value.get()
assert after_gd == before_gd + 1
assert after_jira == before_jira
def test_does_not_raise_on_exception(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
PRUNING_RATE_LIMIT_ERRORS,
"labels",
lambda **_: (_ for _ in ()).throw(RuntimeError("boom")),
)
inc_pruning_rate_limit_error("google_drive")