Compare commits

...

6 Commits

Author SHA1 Message Date
Richard Kuo (Danswer)
6883e57cb8 Merge branch 'main' of https://github.com/danswer-ai/danswer into bugfix/locking_5 2025-01-06 19:14:41 -08:00
Richard Kuo (Danswer)
8f01fe1dae use a larger scan_iter value for performance 2025-01-06 16:43:46 -08:00
Richard Kuo (Danswer)
1f82a3dcdf move lock_beat test outside the try catch so that we don't worry about testing locks we never took 2025-01-06 16:03:44 -08:00
Richard Kuo (Danswer)
59934e6cfe more logging 2025-01-06 13:51:10 -08:00
Richard Kuo (Danswer)
760f946e60 test reacquire outside of loop 2025-01-06 10:56:30 -08:00
Richard Kuo (Danswer)
701d701894 more debugging 2025-01-05 22:05:59 -08:00
10 changed files with 111 additions and 58 deletions

View File

@@ -44,11 +44,11 @@ def check_for_connector_deletion_task(
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
try:
# collect cc_pair_ids
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:

View File

@@ -102,11 +102,11 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
try:
# get all cc pairs that need to be synced
cc_pair_ids_to_sync: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:

View File

@@ -102,11 +102,11 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
try:
cc_pair_ids_to_sync: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)

View File

@@ -63,6 +63,7 @@ from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_index import RedisConnectorIndexPayload
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
@@ -219,11 +220,11 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
try:
locked = True
# check for search settings swap
@@ -257,6 +258,18 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
for cc_pair_id in cc_pair_ids:
lock_beat.reacquire()
# debugging logic - remove after we're done
if (
tenant_id == "tenant_i-043470d740845ec56"
or tenant_id == "tenant_82b497ce-88aa-4fbd-841a-92cae43529c8"
):
task_logger.info(
f"check_for_indexing lock: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"ttl={redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)}"
)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
with get_session_with_tenant(tenant_id) as db_session:
search_settings_list: list[SearchSettings] = get_active_search_settings(
@@ -405,7 +418,9 @@ def validate_indexing_fences(
)
# validate all existing indexing jobs
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
for key_bytes in r.scan_iter(
RedisConnectorIndex.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
validate_indexing_fence(

View File

@@ -89,11 +89,11 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
try:
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)

View File

@@ -4,6 +4,7 @@ import traceback
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from typing import Any
from typing import cast
import httpx
@@ -70,6 +71,7 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation
@@ -103,11 +105,11 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
try:
with get_session_with_tenant(tenant_id) as db_session:
try_generate_stale_document_sync_tasks(
self.app, db_session, r, lock_beat, tenant_id
@@ -217,6 +219,8 @@ def try_generate_stale_document_sync_tasks(
total_tasks_generated = 0
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
lock_beat.reacquire()
rc = RedisConnectorCredentialPair(tenant_id, cc_pair.id)
rc.set_skip_docs(docs_to_skip)
result = rc.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
@@ -752,7 +756,7 @@ def monitor_ccpair_indexing_taskset(
@shared_task(name=OnyxCeleryTask.MONITOR_VESPA_SYNC, soft_time_limit=300, bind=True)
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
It scans for fence values and then gets the counts of any associated tasksets.
If the count is 0, that means all tasks finished and we should clean up.
@@ -766,7 +770,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
time_start = time.monotonic()
timings: dict[str, float] = {}
timings: dict[str, Any] = {}
timings["start"] = time_start
r = get_redis_client(tenant_id=tenant_id)
@@ -776,16 +780,15 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# prevent overlapping tasks
if not lock_beat.acquire(blocking=False):
task_logger.info("monitor_vespa_sync exiting due to overlap")
return False
# prevent overlapping tasks
if not lock_beat.acquire(blocking=False):
return None
try:
# print current queue lengths
phase_start = time.monotonic()
# we don't need every tenant polling redis for this info.
if not MULTI_TENANT or random.randint(1, 100) == 100:
if not MULTI_TENANT or random.randint(1, 10) == 10:
r_celery = self.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r_celery)
n_indexing = celery_get_queue_length(
@@ -826,6 +829,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
f"permissions_upsert={n_permissions_upsert} "
)
timings["queues"] = time.monotonic() - phase_start
timings["queues_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
# scan and monitor activity to completion
phase_start = time.monotonic()
@@ -833,24 +837,37 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)
timings["connector"] = time.monotonic() - phase_start
timings["connector_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
for key_bytes in r.scan_iter(RedisConnectorDelete.FENCE_PREFIX + "*"):
lock_beat.reacquire()
lock_beat.reacquire()
for key_bytes in r.scan_iter(
RedisConnectorDelete.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
lock_beat.reacquire()
timings["connector_deletion"] = time.monotonic() - phase_start
timings["connector_deletion_ttl"] = r.ttl(
OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK
)
phase_start = time.monotonic()
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
lock_beat.reacquire()
lock_beat.reacquire()
for key_bytes in r.scan_iter(
RedisDocumentSet.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
with get_session_with_tenant(tenant_id) as db_session:
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
timings["document_set"] = time.monotonic() - phase_start
lock_beat.reacquire()
timings["documentset"] = time.monotonic() - phase_start
timings["documentset_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
lock_beat.reacquire()
lock_beat.reacquire()
for key_bytes in r.scan_iter(
RedisUserGroup.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback(
"onyx.background.celery.tasks.vespa.tasks",
"monitor_usergroup_taskset",
@@ -858,29 +875,45 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
)
with get_session_with_tenant(tenant_id) as db_session:
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
timings["usergroup"] = time.monotonic() - phase_start
timings["usergroup_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
for key_bytes in r.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
lock_beat.reacquire()
lock_beat.reacquire()
for key_bytes in r.scan_iter(
RedisConnectorPrune.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
timings["pruning"] = time.monotonic() - phase_start
timings["pruning_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
lock_beat.reacquire()
lock_beat.reacquire()
for key_bytes in r.scan_iter(
RedisConnectorIndex.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
timings["indexing"] = time.monotonic() - phase_start
timings["indexing_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
for key_bytes in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
lock_beat.reacquire()
lock_beat.reacquire()
for key_bytes in r.scan_iter(
RedisConnectorPermissionSync.FENCE_PREFIX + "*",
count=SCAN_ITER_COUNT_DEFAULT,
):
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
timings["permissions"] = time.monotonic() - phase_start
timings["permissions_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -889,18 +922,10 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
if lock_beat.owned():
lock_beat.release()
else:
t = timings
task_logger.error(
"monitor_vespa_sync - Lock not owned on completion: "
f"tenant={tenant_id} "
f"queues={t.get('queues')} "
f"connector={t.get('connector')} "
f"connector_deletion={t.get('connector_deletion')} "
f"document_set={t.get('document_set')} "
f"usergroup={t.get('usergroup')} "
f"pruning={t.get('pruning')} "
f"indexing={t.get('indexing')} "
f"permissions={t.get('permissions')}"
f"timings={timings}"
)
redis_lock_dump(lock_beat, r)

View File

@@ -13,6 +13,7 @@ from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
class RedisConnectorPermissionSyncPayload(BaseModel):
@@ -68,7 +69,10 @@ class RedisConnectorPermissionSync:
def get_active_task_count(self) -> int:
"""Count of active permission sync tasks"""
count = 0
for _ in self.redis.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
for _ in self.redis.scan_iter(
RedisConnectorPermissionSync.FENCE_PREFIX + "*",
count=SCAN_ITER_COUNT_DEFAULT,
):
count += 1
return count

View File

@@ -7,6 +7,8 @@ from pydantic import BaseModel
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
class RedisConnectorExternalGroupSyncPayload(BaseModel):
started: datetime | None
@@ -63,7 +65,8 @@ class RedisConnectorExternalGroupSync:
"""Count of active external group syncing tasks"""
count = 0
for _ in self.redis.scan_iter(
RedisConnectorExternalGroupSync.FENCE_PREFIX + "*"
RedisConnectorExternalGroupSync.FENCE_PREFIX + "*",
count=SCAN_ITER_COUNT_DEFAULT,
):
count += 1
return count

View File

@@ -12,6 +12,7 @@ from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
class RedisConnectorPrune:
@@ -63,7 +64,9 @@ class RedisConnectorPrune:
def get_active_task_count(self) -> int:
"""Count of active pruning tasks"""
count = 0
for key in self.redis.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
for key in self.redis.scan_iter(
RedisConnectorPrune.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
count += 1
return count

View File

@@ -28,6 +28,8 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
SCAN_ITER_COUNT_DEFAULT = 4096
class TenantRedis(redis.Redis):
def __init__(self, tenant_id: str, *args: Any, **kwargs: Any) -> None:
@@ -115,6 +117,7 @@ class TenantRedis(redis.Redis):
"hexists",
"hset",
"hdel",
"ttl",
] # Regular methods that need simple prefixing
if item == "scan_iter":