Compare commits

...

9 Commits

Author SHA1 Message Date
pablonyx
ab1498d60e clarify comment 2025-02-23 15:11:07 -08:00
pablonyx
a7b392a979 k 2025-02-23 12:50:48 -08:00
pablonyx
215d0d9001 quick update 2025-02-23 12:42:38 -08:00
pablonyx
9153339029 validated 2025-02-23 12:40:28 -08:00
pablonyx
34619aa1b7 k 2025-02-23 12:36:23 -08:00
pablonyx
89b3ccc02d k 2025-02-23 12:36:06 -08:00
pablonyx
b00794b79b fix typing 2025-02-23 11:45:45 -08:00
pablonyx
f50d26cc7d dedupe 2025-02-23 11:30:10 -08:00
pablonyx
ac300559e2 perm sync + pruning 2025-02-23 10:11:43 -08:00
7 changed files with 312 additions and 15 deletions

View File

@@ -271,6 +271,7 @@ def monitor_connector_deletion_taskset(
if fence_data.num_tasks is None:
# the fence is setting up but isn't ready yet
return
redis_connector.delete.detect_stuck_subtasks(cc_pair_id, r)
remaining = redis_connector.delete.get_remaining()
task_logger.info(
@@ -306,6 +307,9 @@ def monitor_connector_deletion_taskset(
# NOTE(rkuo): if this happens, documents somehow got added while
# deletion was in progress. Likely a bug gating off pruning and indexing
# work before deletion starts.
# This may also happen if a subtask `DOCUMENT_BY_CC_PAIR_CLEANUP_TASK` does not successfully complete
# but we clear the fence after a timeout.
# In this case, we will re-attempt deletion of the remaining documents.
task_logger.warning(
"Connector deletion - documents still found after taskset completion. "
"Clearing the current deletion attempt and allowing deletion to restart: "

View File

@@ -812,6 +812,9 @@ def monitor_ccpair_permissions_taskset(
if not payload:
return
# Check for stuck subtasks
redis_connector.permissions.detect_stuck_subtasks(cc_pair_id, r)
remaining = redis_connector.permissions.get_remaining()
task_logger.info(
f"Permissions sync progress: "

View File

@@ -532,6 +532,9 @@ def monitor_ccpair_pruning_taskset(
if initial is None:
return
# Check for stuck subtasks
redis_connector.prune.detect_stuck_subtasks(cc_pair_id, r)
remaining = redis_connector.prune.get_remaining()
task_logger.info(
f"Connector pruning progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"

View File

@@ -18,6 +18,7 @@ from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.document import delete_document_by_connector_credential_pair__no_commit
from onyx.db.document import delete_documents_complete__no_commit
from onyx.db.document import fetch_chunk_count_for_document
@@ -32,6 +33,8 @@ from onyx.db.search_settings import get_active_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaDocumentFields
from onyx.httpx.httpx_pool import HttpxPool
from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
@@ -44,6 +47,8 @@ DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES = 3
LIGHT_SOFT_TIME_LIMIT = 105
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES = 3
@shared_task(
name=OnyxCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
@@ -57,6 +62,7 @@ def document_by_cc_pair_cleanup_task(
document_id: str,
connector_id: int,
credential_id: int,
flow_type: str, # "prune" or "delete"
tenant_id: str | None,
) -> bool:
"""A lightweight subtask used to clean up document to cc pair relationships.
@@ -78,8 +84,32 @@ def document_by_cc_pair_cleanup_task(
start = time.monotonic()
cc_pair = None
request_id = self.request.id if self.request.id is not None else "missing-task-id"
try:
with get_session_with_current_tenant() as db_session:
cc_pair_local = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if not cc_pair_local:
task_logger.warning(
f"cc_pair not found for {connector_id} {credential_id}"
)
return False
cc_pair = cc_pair_local
if flow_type == "prune":
RedisConnectorPrune.update_subtask_heartbeat(
cc_pair.id, request_id, get_redis_client()
)
elif flow_type == "delete":
RedisConnectorDelete.update_subtask_heartbeat(
cc_pair.id, request_id, get_redis_client()
)
action = "skip"
chunks_affected = 0
@@ -166,6 +196,8 @@ def document_by_cc_pair_cleanup_task(
f"chunks={chunks_affected} "
f"elapsed={elapsed:.2f}"
)
return True
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
return False

View File

@@ -32,6 +32,8 @@ class RedisConnectorDelete:
PREFIX = "connectordeletion"
FENCE_PREFIX = f"{PREFIX}_fence" # "connectordeletion_fence"
TASKSET_PREFIX = f"{PREFIX}_taskset" # "connectordeletion_taskset"
SUBTASK_CREATION_TIMES_PREFIX = f"{PREFIX}_subtask_creation_times"
SUBTASK_HEARTBEAT_PREFIX = f"{PREFIX}_subtask_heartbeat"
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
@@ -41,6 +43,9 @@ class RedisConnectorDelete:
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}"
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
self.subtask_creation_times_key = f"{self.SUBTASK_CREATION_TIMES_PREFIX}_{id}"
self.subtask_heartbeat_prefix = f"{self.SUBTASK_HEARTBEAT_PREFIX}_{id}"
def taskset_clear(self) -> None:
self.redis.delete(self.taskset_key)
@@ -120,6 +125,11 @@ class RedisConnectorDelete:
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
self.redis.sadd(self.taskset_key, custom_task_id)
# Record creation time in a dedicated hash
self.redis.hset(
self.subtask_creation_times_key, custom_task_id, str(time.time())
)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
OnyxCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
@@ -127,6 +137,7 @@ class RedisConnectorDelete:
document_id=doc.id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
flow_type="delete",
tenant_id=self.tenant_id,
),
queue=OnyxCeleryQueues.CONNECTOR_DELETION,
@@ -147,8 +158,86 @@ class RedisConnectorDelete:
@staticmethod
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
taskset_key = f"{RedisConnectorDelete.TASKSET_PREFIX}_{id}"
creation_times_key = (
f"{RedisConnectorDelete.SUBTASK_CREATION_TIMES_PREFIX}_{id}"
)
r.srem(taskset_key, task_id)
return
r.hdel(creation_times_key, task_id)
@staticmethod
def update_subtask_heartbeat(id: int, task_id: str, r: redis.Redis) -> None:
"""
Subtask calls this to mark 'I am alive'.
"""
heartbeat_key = (
f"{RedisConnectorDelete.SUBTASK_HEARTBEAT_PREFIX}_{id}:{task_id}"
)
r.set(heartbeat_key, time.time(), ex=300) # e.g. 5-min TTL
@staticmethod
def _parse_float(val: bytes | str) -> float:
"""
Safely parse the raw Redis value (bytes/str) into a float or raise ValueError.
"""
if isinstance(val, bytes):
val_str = val.decode("utf-8")
else:
val_str = str(val)
return float(val_str)
@staticmethod
def detect_stuck_subtasks(
id: int,
r: redis.Redis,
threshold_s: float = 600,
) -> None:
"""
Removes stale or never-started subtasks from the deletion taskset
if their heartbeat or creation time exceeds threshold_s seconds.
"""
taskset_key = f"{RedisConnectorDelete.TASKSET_PREFIX}_{id}"
creation_times_key = (
f"{RedisConnectorDelete.SUBTASK_CREATION_TIMES_PREFIX}_{id}"
)
heartbeat_prefix = f"{RedisConnectorDelete.SUBTASK_HEARTBEAT_PREFIX}_{id}"
now = time.time()
for subtask_id_bytes in r.sscan_iter(taskset_key):
subtask_id = subtask_id_bytes.decode("utf-8")
hb_key = f"{heartbeat_prefix}:{subtask_id}"
last_beat_raw = cast(bytes, r.get(hb_key))
if last_beat_raw is not None:
# Subtask heartbeated; check if stale
try:
last_beat_val = RedisConnectorDelete._parse_float(last_beat_raw)
except ValueError:
raise ValueError(
f"Failed to parse heartbeat value for subtask {subtask_id}"
)
if now - last_beat_val > threshold_s:
r.srem(taskset_key, subtask_id)
r.hdel(creation_times_key, subtask_id)
else:
# No heartbeat; check creation time
creation_time_raw = cast(bytes, r.hget(creation_times_key, subtask_id))
if creation_time_raw is not None:
try:
creation_time_val = RedisConnectorDelete._parse_float(
creation_time_raw
)
except ValueError:
raise ValueError(
f"Failed to parse creation time value for subtask {subtask_id}"
)
if now - creation_time_val > threshold_s:
r.srem(taskset_key, subtask_id)
r.hdel(creation_times_key, subtask_id)
@staticmethod
def reset_all(r: redis.Redis) -> None:

View File

@@ -52,6 +52,9 @@ class RedisConnectorPermissionSync:
ACTIVE_PREFIX = PREFIX + "_active"
ACTIVE_TTL = CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT * 2
SUBTASK_CREATION_TIMES_PREFIX = f"{PREFIX}_subtask_creation_times"
SUBTASK_HEARTBEAT_PREFIX = f"{PREFIX}_subtask_heartbeat"
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
self.id = id
@@ -67,6 +70,9 @@ class RedisConnectorPermissionSync:
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
self.active_key = f"{self.ACTIVE_PREFIX}_{id}"
self.subtask_creation_times_key = f"{self.SUBTASK_CREATION_TIMES_PREFIX}_{id}"
self.subtask_heartbeat_prefix = f"{self.SUBTASK_HEARTBEAT_PREFIX}_{id}"
def taskset_clear(self) -> None:
self.redis.delete(self.taskset_key)
@@ -163,41 +169,45 @@ class RedisConnectorPermissionSync:
def generate_tasks(
self,
celery_app: Celery,
lock: RedisLock | None,
lock: RedisLock,
new_permissions: list[DocExternalAccess],
source_string: str,
connector_id: int,
credential_id: int,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
# Create a task for each document permission sync
for doc_perm in new_permissions:
async_results = []
for doc_external_access in new_permissions:
current_time = time.monotonic()
if lock and current_time - last_lock_time >= (
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
):
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
lock.reacquire()
last_lock_time = current_time
# Add task for document permissions sync
custom_task_id = f"{self.subtask_prefix}_{uuid4()}"
# Add to the tracking taskset in redis
self.redis.sadd(self.taskset_key, custom_task_id)
# Record creation time as a string
self.redis.hset(
self.subtask_creation_times_key, custom_task_id, str(time.time())
)
result = celery_app.send_task(
OnyxCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
kwargs=dict(
tenant_id=self.tenant_id,
serialized_doc_external_access=doc_perm.to_dict(),
serialized_doc_external_access=doc_external_access.to_dict(),
source_string=source_string,
connector_id=connector_id,
credential_id=credential_id,
),
queue=OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT,
task_id=custom_task_id,
priority=OnyxCeleryPriority.HIGH,
ignore_result=True,
priority=OnyxCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
@@ -213,9 +223,83 @@ class RedisConnectorPermissionSync:
@staticmethod
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
taskset_key = f"{RedisConnectorPermissionSync.TASKSET_PREFIX}_{id}"
creation_times_key = (
f"{RedisConnectorPermissionSync.SUBTASK_CREATION_TIMES_PREFIX}_{id}"
)
r.srem(taskset_key, task_id)
r.hdel(creation_times_key, task_id)
return
@staticmethod
def update_subtask_heartbeat(id: int, subtask_id: str, r: redis.Redis) -> None:
heartbeat_key = (
f"{RedisConnectorPermissionSync.SUBTASK_HEARTBEAT_PREFIX}_{id}:{subtask_id}"
)
r.set(heartbeat_key, str(time.time()), ex=300) # TTL set to 5 minutes
@staticmethod
def _parse_float(val: bytes | str) -> float:
if isinstance(val, bytes):
val_str = val.decode("utf-8")
else:
val_str = str(val)
return float(val_str)
@staticmethod
def detect_stuck_subtasks(
id: int,
r: redis.Redis,
threshold_s: float = 600,
) -> None:
"""
Removes stale or never-started subtasks from the permission sync taskset
if their heartbeat or creation time exceeds threshold_s seconds.
"""
taskset_key = f"{RedisConnectorPermissionSync.TASKSET_PREFIX}_{id}"
creation_times_key = (
f"{RedisConnectorPermissionSync.SUBTASK_CREATION_TIMES_PREFIX}_{id}"
)
heartbeat_prefix = (
f"{RedisConnectorPermissionSync.SUBTASK_HEARTBEAT_PREFIX}_{id}"
)
now = time.time()
for subtask_id_bytes in r.sscan_iter(taskset_key):
subtask_id = subtask_id_bytes.decode("utf-8")
hb_key = f"{heartbeat_prefix}:{subtask_id}"
last_beat_raw = cast(bytes, r.get(hb_key))
if last_beat_raw is not None:
try:
last_beat_val = RedisConnectorPermissionSync._parse_float(
last_beat_raw
)
except ValueError:
raise ValueError(
f"Failed to parse heartbeat value for subtask {subtask_id}"
)
if now - last_beat_val > threshold_s:
r.srem(taskset_key, subtask_id)
r.hdel(creation_times_key, subtask_id)
else:
creation_time_raw = cast(bytes, r.hget(creation_times_key, subtask_id))
if creation_time_raw is not None:
try:
creation_time_val = RedisConnectorPermissionSync._parse_float(
creation_time_raw
)
except ValueError:
raise ValueError(
f"Failed to parse creation time value for subtask {subtask_id}"
)
if now - creation_time_val > threshold_s:
r.srem(taskset_key, subtask_id)
r.hdel(creation_times_key, subtask_id)
@staticmethod
def reset_all(r: redis.Redis) -> None:
"""Deletes all redis values for all connectors"""

View File

@@ -52,6 +52,9 @@ class RedisConnectorPrune:
ACTIVE_PREFIX = PREFIX + "_active"
ACTIVE_TTL = CELERY_PRUNING_LOCK_TIMEOUT * 2
SUBTASK_CREATION_TIMES_PREFIX = f"{PREFIX}_subtask_creation_times"
SUBTASK_HEARTBEAT_PREFIX = f"{PREFIX}_subtask_heartbeat"
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
self.id = id
@@ -62,11 +65,16 @@ class RedisConnectorPrune:
self.generator_progress_key = f"{self.GENERATOR_PROGRESS_PREFIX}_{id}"
self.generator_complete_key = f"{self.GENERATOR_COMPLETE_PREFIX}_{id}"
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
self.taskset_key = (
f"{self.TASKSET_PREFIX}_{id}" # connectorpruning_taskset_{id}
)
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
self.active_key = f"{self.ACTIVE_PREFIX}_{id}"
self.subtask_creation_times_key = f"{self.SUBTASK_CREATION_TIMES_PREFIX}_{id}"
self.subtask_heartbeat_prefix = f"{self.SUBTASK_HEARTBEAT_PREFIX}_{id}"
def taskset_clear(self) -> None:
self.redis.delete(self.taskset_key)
@@ -187,16 +195,21 @@ class RedisConnectorPrune:
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.subtask_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# Add to the tracking taskset in redis
self.redis.sadd(self.taskset_key, custom_task_id)
# Priority on sync's triggered by new indexing should be medium
# Record creation time in a dedicated hash
self.redis.hset(
self.subtask_creation_times_key, custom_task_id, str(time.time())
)
result = celery_app.send_task(
OnyxCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
kwargs=dict(
document_id=doc_id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
flow_type="prune",
tenant_id=self.tenant_id,
),
queue=OnyxCeleryQueues.CONNECTOR_DELETION,
@@ -220,9 +233,78 @@ class RedisConnectorPrune:
@staticmethod
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
taskset_key = f"{RedisConnectorPrune.TASKSET_PREFIX}_{id}"
creation_times_key = f"{RedisConnectorPrune.SUBTASK_CREATION_TIMES_PREFIX}_{id}"
r.srem(taskset_key, task_id)
r.hdel(creation_times_key, task_id)
return
@staticmethod
def update_subtask_heartbeat(id: int, subtask_id: str, r: redis.Redis) -> None:
heartbeat_key = (
f"{RedisConnectorPrune.SUBTASK_HEARTBEAT_PREFIX}_{id}:{subtask_id}"
)
r.set(heartbeat_key, time.time(), ex=300) # TTL set to 5 minutes
@staticmethod
def _parse_float(val: bytes | str) -> float:
"""
Safely parse the raw Redis value (bytes/str) into a float or raise ValueError.
"""
if isinstance(val, bytes):
val_str = val.decode("utf-8")
else:
val_str = str(val)
return float(val_str)
@staticmethod
def detect_stuck_subtasks(
id: int,
r: redis.Redis,
threshold_s: float = 600,
) -> None:
"""
Removes stale or never-started subtasks from the pruning taskset
if their heartbeat or creation time exceeds threshold_s seconds.
"""
taskset_key = f"{RedisConnectorPrune.TASKSET_PREFIX}_{id}"
creation_times_key = f"{RedisConnectorPrune.SUBTASK_CREATION_TIMES_PREFIX}_{id}"
heartbeat_prefix = f"{RedisConnectorPrune.SUBTASK_HEARTBEAT_PREFIX}_{id}"
now = time.time()
for subtask_id_bytes in r.sscan_iter(taskset_key):
subtask_id = subtask_id_bytes.decode("utf-8")
hb_key = f"{heartbeat_prefix}:{subtask_id}"
last_beat_raw = cast(bytes, r.get(hb_key))
if last_beat_raw is not None:
try:
last_beat_val = RedisConnectorPrune._parse_float(last_beat_raw)
except ValueError:
raise ValueError(
f"Failed to parse heartbeat value for subtask {subtask_id}"
)
if now - last_beat_val > threshold_s:
r.srem(taskset_key, subtask_id)
r.hdel(creation_times_key, subtask_id)
else:
creation_time_raw = cast(bytes, r.hget(creation_times_key, subtask_id))
if creation_time_raw is not None:
try:
creation_time_val = RedisConnectorPrune._parse_float(
creation_time_raw
)
except ValueError:
raise ValueError(
f"Failed to parse creation time value for subtask {subtask_id}"
)
if now - creation_time_val > threshold_s:
r.srem(taskset_key, subtask_id)
r.hdel(creation_times_key, subtask_id)
@staticmethod
def reset_all(r: redis.Redis) -> None:
"""Deletes all redis values for all connectors"""