fix: CheckpointedConnector pruning only processes first checkpoint step (mirror of #8464) (#8468)

Co-authored-by: Yves Grolet <yves@grolet.com>
This commit is contained in:
Evan Lohn
2026-02-15 16:45:43 -08:00
committed by GitHub
parent 6aea36b573
commit 63b9a869af

View File

@@ -5,17 +5,19 @@ from datetime import timezone
from pathlib import Path
from typing import Any
from typing import cast
from typing import TypeVar
import httpx
from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
from onyx.connectors.connector_runner import batched_doc_ids
from onyx.connectors.connector_runner import CheckpointOutputWrapper
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import ConnectorCheckpoint
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SlimConnector
@@ -31,6 +33,54 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
PRUNING_CHECKPOINTED_BATCH_SIZE = 32
CT = TypeVar("CT", bound=ConnectorCheckpoint)
def _checkpointed_batched_doc_ids(
connector: CheckpointedConnector[CT],
start: float,
end: float,
batch_size: int,
) -> Generator[set[str], None, None]:
"""Loop through all checkpoint steps and yield batched document IDs.
Some checkpointed connectors (e.g. IMAP) are multi-step: the first
checkpoint call may only initialize internal state without yielding
any documents. This function loops until checkpoint.has_more is False
to ensure all document IDs are collected across every step.
"""
checkpoint = connector.build_dummy_checkpoint()
while True:
checkpoint_output = connector.load_from_checkpoint(
start=start, end=end, checkpoint=checkpoint
)
wrapper: CheckpointOutputWrapper[CT] = CheckpointOutputWrapper()
batch: set[str] = set()
for document, _hierarchy_node, failure, next_checkpoint in wrapper(
checkpoint_output
):
if document is not None:
batch.add(document.id)
elif (
failure
and failure.failed_document
and failure.failed_document.document_id
):
batch.add(failure.failed_document.document_id)
if next_checkpoint is not None:
checkpoint = next_checkpoint
if len(batch) >= batch_size:
yield batch
batch = set()
if batch:
yield batch
if not checkpoint.has_more:
break
def document_batch_to_ids(
doc_batch: (
@@ -80,12 +130,8 @@ def extract_ids_from_runnable_connector(
elif isinstance(runnable_connector, CheckpointedConnector):
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
end = datetime.now(timezone.utc).timestamp()
checkpoint = runnable_connector.build_dummy_checkpoint()
checkpoint_generator = runnable_connector.load_from_checkpoint(
start=start, end=end, checkpoint=checkpoint
)
doc_batch_id_generator = batched_doc_ids(
checkpoint_generator, batch_size=PRUNING_CHECKPOINTED_BATCH_SIZE
doc_batch_id_generator = _checkpointed_batched_doc_ids(
runnable_connector, start, end, PRUNING_CHECKPOINTED_BATCH_SIZE
)
else:
raise RuntimeError("Pruning job could not find a valid runnable_connector.")