mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
18 Commits
cloud_debu
...
v0.13.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4e4c48291a | ||
|
|
871918738f | ||
|
|
b1780732b9 | ||
|
|
e92f48c17d | ||
|
|
8fad5f7e5e | ||
|
|
310732d10f | ||
|
|
322d7cdc90 | ||
|
|
67d943da11 | ||
|
|
9456fef307 | ||
|
|
cc3c0800f0 | ||
|
|
e860f15b64 | ||
|
|
574ef470a4 | ||
|
|
9e391495c2 | ||
|
|
e26d5430fa | ||
|
|
cce0ec2f22 | ||
|
|
a4f09a62a5 | ||
|
|
fd2428d97f | ||
|
|
cfc46812c8 |
@@ -59,7 +59,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=sender.concurrency)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -14,10 +14,16 @@ from celery.signals import worker_shutdown
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_utils import celery_is_worker_primary
|
||||
from danswer.background.celery.tasks.indexing.tasks import (
|
||||
get_unfenced_index_attempt_ids,
|
||||
)
|
||||
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
|
||||
from danswer.db.engine import get_session_with_default_tenant
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndex
|
||||
@@ -134,6 +140,23 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
|
||||
RedisConnectorStop.reset_all(r)
|
||||
|
||||
# mark orphaned index attempts as failed
|
||||
with get_session_with_default_tenant() as db_session:
|
||||
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
|
||||
for attempt_id in unfenced_attempt_ids:
|
||||
attempt = get_index_attempt(db_session, attempt_id)
|
||||
if not attempt:
|
||||
continue
|
||||
|
||||
failure_reason = (
|
||||
f"Orphaned index attempt found on startup: "
|
||||
f"index_attempt={attempt.id} "
|
||||
f"cc_pair={attempt.connector_credential_pair_id} "
|
||||
f"search_settings={attempt.search_settings_id}"
|
||||
)
|
||||
logger.warning(failure_reason)
|
||||
mark_attempt_failed(attempt.id, db_session, failure_reason)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
@@ -87,7 +87,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
cc_pair_id: int,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
|
||||
@@ -10,6 +10,8 @@ from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.exceptions import LockError
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
@@ -32,6 +34,8 @@ from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.db.enums import IndexModelStatus
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import delete_index_attempt
|
||||
from danswer.db.index_attempt import get_all_index_attempts_by_status
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
@@ -44,7 +48,8 @@ from danswer.db.swap_index import check_index_swap
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndexingFenceData
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndex
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndexPayload
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
@@ -61,14 +66,18 @@ class RunIndexingCallback(RunIndexingCallbackInterface):
|
||||
self,
|
||||
stop_key: str,
|
||||
generator_progress_key: str,
|
||||
redis_lock: redis.lock.Lock,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
):
|
||||
super().__init__()
|
||||
self.redis_lock: redis.lock.Lock = redis_lock
|
||||
self.redis_lock: RedisLock = redis_lock
|
||||
self.stop_key: str = stop_key
|
||||
self.generator_progress_key: str = generator_progress_key
|
||||
self.redis_client = redis_client
|
||||
self.started: datetime = datetime.now(timezone.utc)
|
||||
self.redis_lock.reacquire()
|
||||
|
||||
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_client.exists(self.stop_key):
|
||||
@@ -76,10 +85,70 @@ class RunIndexingCallback(RunIndexingCallbackInterface):
|
||||
return False
|
||||
|
||||
def progress(self, amount: int) -> None:
|
||||
self.redis_lock.reacquire()
|
||||
try:
|
||||
self.redis_lock.reacquire()
|
||||
self.last_lock_reacquire = datetime.now(timezone.utc)
|
||||
except LockError:
|
||||
logger.exception(
|
||||
f"RunIndexingCallback - lock.reacquire exceptioned. "
|
||||
f"lock_timeout={self.redis_lock.timeout} "
|
||||
f"start={self.started} "
|
||||
f"last_reacquired={self.last_lock_reacquire} "
|
||||
f"now={datetime.now(timezone.utc)}"
|
||||
)
|
||||
raise
|
||||
|
||||
self.redis_client.incrby(self.generator_progress_key, amount)
|
||||
|
||||
|
||||
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
|
||||
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
|
||||
want to clean them up.
|
||||
|
||||
Unfenced = attempt not in terminal state and fence does not exist.
|
||||
"""
|
||||
unfenced_attempts: list[int] = []
|
||||
|
||||
# inner/outer/inner double check pattern to avoid race conditions when checking for
|
||||
# bad state
|
||||
# inner = index_attempt in non terminal state
|
||||
# outer = r.fence_key down
|
||||
|
||||
# check the db for index attempts in a non terminal state
|
||||
attempts: list[IndexAttempt] = []
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
|
||||
)
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
|
||||
)
|
||||
|
||||
for attempt in attempts:
|
||||
fence_key = RedisConnectorIndex.fence_key_with_ids(
|
||||
attempt.connector_credential_pair_id, attempt.search_settings_id
|
||||
)
|
||||
|
||||
# if the fence is down / doesn't exist, possible error but not confirmed
|
||||
if r.exists(fence_key):
|
||||
continue
|
||||
|
||||
# Between the time the attempts are first looked up and the time we see the fence down,
|
||||
# the attempt may have completed and taken down the fence normally.
|
||||
|
||||
# We need to double check that the index attempt is still in a non terminal state
|
||||
# and matches the original state, which confirms we are really in a bad state.
|
||||
attempt_2 = get_index_attempt(db_session, attempt.id)
|
||||
if not attempt_2:
|
||||
continue
|
||||
|
||||
if attempt.status != attempt_2.status:
|
||||
continue
|
||||
|
||||
unfenced_attempts.append(attempt.id)
|
||||
|
||||
return unfenced_attempts
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_indexing",
|
||||
soft_time_limit=300,
|
||||
@@ -90,7 +159,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
lock_beat: RedisLock = r.lock(
|
||||
DanswerRedisLocks.CHECK_INDEXING_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -100,6 +169,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
# check for search settings swap
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
old_search_settings = check_index_swap(db_session=db_session)
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
@@ -118,13 +188,18 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
|
||||
# gather cc_pair_ids
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
lock_beat.reacquire()
|
||||
cc_pairs = fetch_connector_credential_pairs(db_session)
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
|
||||
# kick off index attempts
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
lock_beat.reacquire()
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Get the primary search settings
|
||||
@@ -180,6 +255,29 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
f"search_settings={search_settings_instance.id} "
|
||||
)
|
||||
tasks_created += 1
|
||||
|
||||
# Fail any index attempts in the DB that don't have fences
|
||||
# This shouldn't ever happen!
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
|
||||
for attempt_id in unfenced_attempt_ids:
|
||||
lock_beat.reacquire()
|
||||
|
||||
attempt = get_index_attempt(db_session, attempt_id)
|
||||
if not attempt:
|
||||
continue
|
||||
|
||||
failure_reason = (
|
||||
f"Unfenced index attempt found in DB: "
|
||||
f"index_attempt={attempt.id} "
|
||||
f"cc_pair={attempt.connector_credential_pair_id} "
|
||||
f"search_settings={attempt.search_settings_id}"
|
||||
)
|
||||
task_logger.error(failure_reason)
|
||||
mark_attempt_failed(
|
||||
attempt.id, db_session, failure_reason=failure_reason
|
||||
)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -189,6 +287,11 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"check_for_indexing - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
|
||||
return tasks_created
|
||||
|
||||
@@ -293,10 +396,11 @@ def try_creating_indexing_task(
|
||||
"""
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
index_attempt_id: int | None = None
|
||||
|
||||
# we need to serialize any attempt to trigger indexing since it can be triggered
|
||||
# either via celery beat or manually (API call)
|
||||
lock = r.lock(
|
||||
lock: RedisLock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -325,7 +429,7 @@ def try_creating_indexing_task(
|
||||
redis_connector_index.generator_clear()
|
||||
|
||||
# set a basic fence to start
|
||||
payload = RedisConnectorIndexingFenceData(
|
||||
payload = RedisConnectorIndexPayload(
|
||||
index_attempt_id=None,
|
||||
started=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
@@ -347,6 +451,8 @@ def try_creating_indexing_task(
|
||||
|
||||
custom_task_id = redis_connector_index.generate_generator_task_id()
|
||||
|
||||
# when the task is sent, we have yet to finish setting up the fence
|
||||
# therefore, the task must contain code that blocks until the fence is ready
|
||||
result = celery_app.send_task(
|
||||
"connector_indexing_proxy_task",
|
||||
kwargs=dict(
|
||||
@@ -368,13 +474,16 @@ def try_creating_indexing_task(
|
||||
redis_connector_index.set_fence(payload)
|
||||
|
||||
except Exception:
|
||||
redis_connector_index.set_fence(payload)
|
||||
task_logger.exception(
|
||||
f"Unexpected exception: "
|
||||
f"try_creating_indexing_task - Unexpected exception: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings.id}"
|
||||
)
|
||||
|
||||
if index_attempt_id is not None:
|
||||
delete_index_attempt(db_session, index_attempt_id)
|
||||
redis_connector_index.set_fence(None)
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
@@ -392,7 +501,7 @@ def connector_indexing_proxy_task(
|
||||
) -> None:
|
||||
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
|
||||
task_logger.info(
|
||||
f"Indexing proxy - starting: attempt={index_attempt_id} "
|
||||
f"Indexing watchdog - starting: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
@@ -400,7 +509,7 @@ def connector_indexing_proxy_task(
|
||||
client = SimpleJobClient()
|
||||
|
||||
job = client.submit(
|
||||
connector_indexing_task,
|
||||
connector_indexing_task_wrapper,
|
||||
index_attempt_id,
|
||||
cc_pair_id,
|
||||
search_settings_id,
|
||||
@@ -411,7 +520,7 @@ def connector_indexing_proxy_task(
|
||||
|
||||
if not job:
|
||||
task_logger.info(
|
||||
f"Indexing proxy - spawn failed: attempt={index_attempt_id} "
|
||||
f"Indexing watchdog - spawn failed: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
@@ -419,7 +528,7 @@ def connector_indexing_proxy_task(
|
||||
return
|
||||
|
||||
task_logger.info(
|
||||
f"Indexing proxy - spawn succeeded: attempt={index_attempt_id} "
|
||||
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
@@ -443,7 +552,7 @@ def connector_indexing_proxy_task(
|
||||
|
||||
if job.status == "error":
|
||||
task_logger.error(
|
||||
f"Indexing proxy - spawned task exceptioned: "
|
||||
f"Indexing watchdog - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
@@ -455,7 +564,7 @@ def connector_indexing_proxy_task(
|
||||
break
|
||||
|
||||
task_logger.info(
|
||||
f"Indexing proxy - finished: attempt={index_attempt_id} "
|
||||
f"Indexing watchdog - finished: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
@@ -463,6 +572,38 @@ def connector_indexing_proxy_task(
|
||||
return
|
||||
|
||||
|
||||
def connector_indexing_task_wrapper(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
tenant_id: str | None,
|
||||
is_ee: bool,
|
||||
) -> int | None:
|
||||
"""Just wraps connector_indexing_task so we can log any exceptions before
|
||||
re-raising it."""
|
||||
result: int | None = None
|
||||
|
||||
try:
|
||||
result = connector_indexing_task(
|
||||
index_attempt_id,
|
||||
cc_pair_id,
|
||||
search_settings_id,
|
||||
tenant_id,
|
||||
is_ee,
|
||||
)
|
||||
except:
|
||||
logger.exception(
|
||||
f"connector_indexing_task exceptioned: "
|
||||
f"tenant={tenant_id} "
|
||||
f"index_attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
raise
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def connector_indexing_task(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
@@ -516,6 +657,7 @@ def connector_indexing_task(
|
||||
if redis_connector.delete.fenced:
|
||||
raise RuntimeError(
|
||||
f"Indexing will not start because connector deletion is in progress: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={redis_connector.delete.fence_key}"
|
||||
)
|
||||
@@ -523,18 +665,18 @@ def connector_indexing_task(
|
||||
if redis_connector.stop.fenced:
|
||||
raise RuntimeError(
|
||||
f"Indexing will not start because a connector stop signal was detected: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={redis_connector.stop.fence_key}"
|
||||
)
|
||||
|
||||
while True:
|
||||
# wait for the fence to come up
|
||||
if not redis_connector_index.fenced:
|
||||
if not redis_connector_index.fenced: # The fence must exist
|
||||
raise ValueError(
|
||||
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
|
||||
)
|
||||
|
||||
payload = redis_connector_index.payload
|
||||
payload = redis_connector_index.payload # The payload must exist
|
||||
if not payload:
|
||||
raise ValueError("connector_indexing_task: payload invalid or not found")
|
||||
|
||||
@@ -557,7 +699,7 @@ def connector_indexing_task(
|
||||
)
|
||||
break
|
||||
|
||||
lock = r.lock(
|
||||
lock: RedisLock = r.lock(
|
||||
redis_connector_index.generator_lock_key,
|
||||
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -566,7 +708,7 @@ def connector_indexing_task(
|
||||
if not acquired:
|
||||
logger.warning(
|
||||
f"Indexing task already running, exiting...: "
|
||||
f"cc_pair={cc_pair_id} search_settings={search_settings_id}"
|
||||
f"index_attempt={index_attempt_id} cc_pair={cc_pair_id} search_settings={search_settings_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ from http import HTTPStatus
|
||||
from typing import cast
|
||||
|
||||
import httpx
|
||||
import redis
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
@@ -13,6 +12,7 @@ from celery.exceptions import SoftTimeLimitExceeded
|
||||
from celery.result import AsyncResult
|
||||
from celery.states import READY_STATES
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
from tenacity import RetryError
|
||||
|
||||
@@ -45,13 +45,10 @@ from danswer.db.document_set import fetch_document_sets_for_document
|
||||
from danswer.db.document_set import get_document_set_by_id
|
||||
from danswer.db.document_set import mark_document_set_as_synced
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.db.index_attempt import delete_index_attempts
|
||||
from danswer.db.index_attempt import get_all_index_attempts_by_status
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import VespaDocumentFields
|
||||
@@ -162,7 +159,7 @@ def try_generate_stale_document_sync_tasks(
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
# the fence is up, do nothing
|
||||
@@ -180,7 +177,12 @@ def try_generate_stale_document_sync_tasks(
|
||||
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair."
|
||||
)
|
||||
|
||||
task_logger.info("RedisConnector.generate_tasks starting by cc_pair.")
|
||||
task_logger.info(
|
||||
"RedisConnector.generate_tasks starting by cc_pair. "
|
||||
"Documents spanning multiple cc_pairs will only be synced once."
|
||||
)
|
||||
|
||||
docs_to_skip: set[str] = set()
|
||||
|
||||
# rkuo: we could technically sync all stale docs in one big pass.
|
||||
# but I feel it's more understandable to group the docs by cc_pair
|
||||
@@ -188,22 +190,21 @@ def try_generate_stale_document_sync_tasks(
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
rc = RedisConnectorCredentialPair(tenant_id, cc_pair.id)
|
||||
tasks_generated = rc.generate_tasks(
|
||||
celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
rc.set_skip_docs(docs_to_skip)
|
||||
result = rc.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
|
||||
|
||||
if tasks_generated is None:
|
||||
if result is None:
|
||||
continue
|
||||
|
||||
if tasks_generated == 0:
|
||||
if result[1] == 0:
|
||||
continue
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.generate_tasks finished for single cc_pair. "
|
||||
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
|
||||
f"cc_pair={cc_pair.id} tasks_generated={result[0]} tasks_possible={result[1]}"
|
||||
)
|
||||
|
||||
total_tasks_generated += tasks_generated
|
||||
total_tasks_generated += result[0]
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
|
||||
@@ -218,7 +219,7 @@ def try_generate_document_set_sync_tasks(
|
||||
document_set_id: int,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
@@ -246,12 +247,11 @@ def try_generate_document_set_sync_tasks(
|
||||
)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
tasks_generated = rds.generate_tasks(
|
||||
celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
if tasks_generated is None:
|
||||
result = rds.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
tasks_generated = result[0]
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
@@ -260,7 +260,7 @@ def try_generate_document_set_sync_tasks(
|
||||
|
||||
task_logger.info(
|
||||
f"RedisDocumentSet.generate_tasks finished. "
|
||||
f"document_set_id={document_set.id} tasks_generated={tasks_generated}"
|
||||
f"document_set={document_set.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
@@ -273,7 +273,7 @@ def try_generate_user_group_sync_tasks(
|
||||
usergroup_id: int,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
@@ -302,12 +302,11 @@ def try_generate_user_group_sync_tasks(
|
||||
task_logger.info(
|
||||
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
|
||||
)
|
||||
tasks_generated = rug.generate_tasks(
|
||||
celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
if tasks_generated is None:
|
||||
result = rug.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
tasks_generated = result[0]
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
@@ -316,7 +315,7 @@ def try_generate_user_group_sync_tasks(
|
||||
|
||||
task_logger.info(
|
||||
f"RedisUserGroup.generate_tasks finished. "
|
||||
f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}"
|
||||
f"usergroup={usergroup.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
@@ -580,8 +579,8 @@ def monitor_ccpair_indexing_taskset(
|
||||
progress = redis_connector_index.get_progress()
|
||||
if progress is not None:
|
||||
task_logger.info(
|
||||
f"Connector indexing progress: cc_pair_id={cc_pair_id} "
|
||||
f"search_settings_id={search_settings_id} "
|
||||
f"Connector indexing progress: cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
@@ -590,30 +589,46 @@ def monitor_ccpair_indexing_taskset(
|
||||
# the task is still setting up
|
||||
return
|
||||
|
||||
# Read result state BEFORE generator_complete_key to avoid a race condition
|
||||
# never use any blocking methods on the result from inside a task!
|
||||
result: AsyncResult = AsyncResult(payload.celery_task_id)
|
||||
result_state = result.state
|
||||
|
||||
# inner/outer/inner double check pattern to avoid race conditions when checking for
|
||||
# bad state
|
||||
|
||||
# inner = get_completion / generator_complete not signaled
|
||||
# outer = result.state in READY state
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int is None:
|
||||
if result_state in READY_STATES:
|
||||
# IF the task state is READY, THEN generator_complete should be set
|
||||
# if it isn't, then the worker crashed
|
||||
task_logger.info(
|
||||
f"Connector indexing aborted: "
|
||||
f"cc_pair_id={cc_pair_id} "
|
||||
f"search_settings_id={search_settings_id} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
if status_int is None: # inner signal not set ... possible error
|
||||
result_state = result.state
|
||||
if (
|
||||
result_state in READY_STATES
|
||||
): # outer signal in terminal state ... possible error
|
||||
# Now double check!
|
||||
if redis_connector_index.get_completion() is None:
|
||||
# inner signal still not set (and cannot change when outer result_state is READY)
|
||||
# Task is finished but generator complete isn't set.
|
||||
# We have a problem! Worker may have crashed.
|
||||
|
||||
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
|
||||
if index_attempt:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id=payload.index_attempt_id,
|
||||
db_session=db_session,
|
||||
failure_reason="Connector indexing aborted or exceptioned.",
|
||||
msg = (
|
||||
f"Connector indexing aborted or exceptioned: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"celery_task={payload.celery_task_id} "
|
||||
f"result_state={result_state} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
task_logger.warning(msg)
|
||||
|
||||
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
|
||||
if index_attempt:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id=payload.index_attempt_id,
|
||||
db_session=db_session,
|
||||
failure_reason=msg,
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
@@ -621,8 +636,8 @@ def monitor_ccpair_indexing_taskset(
|
||||
status_enum = HTTPStatus(status_int)
|
||||
|
||||
task_logger.info(
|
||||
f"Connector indexing finished: cc_pair_id={cc_pair_id} "
|
||||
f"search_settings_id={search_settings_id} "
|
||||
f"Connector indexing finished: cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"status={status_enum.name} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
@@ -643,7 +658,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
"""
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: redis.lock.Lock = r.lock(
|
||||
lock_beat: RedisLock = r.lock(
|
||||
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -677,32 +692,6 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
f"pruning={n_pruning}"
|
||||
)
|
||||
|
||||
# do some cleanup before clearing fences
|
||||
# check the db for any outstanding index attempts
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
attempts: list[IndexAttempt] = []
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
|
||||
)
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
|
||||
)
|
||||
|
||||
for a in attempts:
|
||||
# if attempts exist in the db but we don't detect them in redis, mark them as failed
|
||||
fence_key = RedisConnectorIndex.fence_key_with_ids(
|
||||
a.connector_credential_pair_id, a.search_settings_id
|
||||
)
|
||||
if not r.exists(fence_key):
|
||||
failure_reason = (
|
||||
f"Unknown index attempt. Might be left over from a process restart: "
|
||||
f"index_attempt={a.id} "
|
||||
f"cc_pair={a.connector_credential_pair_id} "
|
||||
f"search_settings={a.search_settings_id}"
|
||||
)
|
||||
task_logger.warning(failure_reason)
|
||||
mark_attempt_failed(a.id, db_session, failure_reason=failure_reason)
|
||||
|
||||
lock_beat.reacquire()
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
|
||||
@@ -433,11 +433,13 @@ def run_indexing_entrypoint(
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
|
||||
|
||||
tenant_str = ""
|
||||
if tenant_id is not None:
|
||||
tenant_str = f" for tenant {tenant_id}"
|
||||
|
||||
logger.info(
|
||||
f"Indexing starting for tenant {tenant_id}: "
|
||||
if tenant_id is not None
|
||||
else ""
|
||||
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"Indexing starting{tenant_str}: "
|
||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
f"credentials='{attempt.connector_credential_pair.connector_id}'"
|
||||
)
|
||||
@@ -445,10 +447,8 @@ def run_indexing_entrypoint(
|
||||
_run_indexing(db_session, attempt, tenant_id, callback)
|
||||
|
||||
logger.info(
|
||||
f"Indexing finished for tenant {tenant_id}: "
|
||||
if tenant_id is not None
|
||||
else ""
|
||||
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"Indexing finished{tenant_str}: "
|
||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
f"credentials='{attempt.connector_credential_pair.connector_id}'"
|
||||
)
|
||||
|
||||
@@ -74,7 +74,7 @@ CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
|
||||
|
||||
# needs to be long enough to cover the maximum time it takes to download an object
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
CELERY_INDEXING_LOCK_TIMEOUT = 60 * 60 # 60 min
|
||||
CELERY_INDEXING_LOCK_TIMEOUT = 3 * 60 * 60 # 60 min
|
||||
|
||||
# needs to be long enough to cover the maximum time it takes to download an object
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
|
||||
@@ -3,6 +3,8 @@ from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
from atlassian import Confluence # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
@@ -70,7 +72,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self.confluence_client: OnyxConfluence | None = None
|
||||
self._confluence_client: OnyxConfluence | None = None
|
||||
self.is_cloud = is_cloud
|
||||
|
||||
# Remove trailing slash from wiki_base if present
|
||||
@@ -97,39 +99,59 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self.cql_label_filter = ""
|
||||
if labels_to_skip:
|
||||
labels_to_skip = list(set(labels_to_skip))
|
||||
comma_separated_labels = ",".join(f"'{label}'" for label in labels_to_skip)
|
||||
comma_separated_labels = ",".join(
|
||||
f"'{quote(label)}'" for label in labels_to_skip
|
||||
)
|
||||
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
|
||||
|
||||
@property
|
||||
def confluence_client(self) -> OnyxConfluence:
|
||||
if self._confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
return self._confluence_client
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
# see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py
|
||||
# for a list of other hidden constructor args
|
||||
self.confluence_client = build_confluence_client(
|
||||
self._confluence_client = build_confluence_client(
|
||||
credentials_json=credentials,
|
||||
is_cloud=self.is_cloud,
|
||||
wiki_base=self.wiki_base,
|
||||
)
|
||||
|
||||
client_without_retries = Confluence(
|
||||
api_version="cloud" if self.is_cloud else "latest",
|
||||
url=self.wiki_base.rstrip("/"),
|
||||
username=credentials["confluence_username"] if self.is_cloud else None,
|
||||
password=credentials["confluence_access_token"] if self.is_cloud else None,
|
||||
token=credentials["confluence_access_token"] if not self.is_cloud else None,
|
||||
)
|
||||
spaces = client_without_retries.get_all_spaces(limit=1)
|
||||
if not spaces:
|
||||
raise RuntimeError(
|
||||
f"No spaces found at {self.wiki_base}! "
|
||||
"Check your credentials and wiki_base and make sure "
|
||||
"is_cloud is set correctly."
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_comment_string_for_page_id(self, page_id: str) -> str:
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
comment_string = ""
|
||||
|
||||
comment_cql = f"type=comment and container='{page_id}'"
|
||||
comment_cql += self.cql_label_filter
|
||||
|
||||
expand = ",".join(_COMMENT_EXPANSION_FIELDS)
|
||||
for comments in self.confluence_client.paginated_cql_page_retrieval(
|
||||
for comment in self.confluence_client.paginated_cql_retrieval(
|
||||
cql=comment_cql,
|
||||
expand=expand,
|
||||
):
|
||||
for comment in comments:
|
||||
comment_string += "\nComment:\n"
|
||||
comment_string += extract_text_from_confluence_html(
|
||||
confluence_client=self.confluence_client,
|
||||
confluence_object=comment,
|
||||
)
|
||||
comment_string += "\nComment:\n"
|
||||
comment_string += extract_text_from_confluence_html(
|
||||
confluence_client=self.confluence_client,
|
||||
confluence_object=comment,
|
||||
fetched_titles=set(),
|
||||
)
|
||||
|
||||
return comment_string
|
||||
|
||||
@@ -141,28 +163,30 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
If its a page, it extracts the text, adds the comments for the document text.
|
||||
If its an attachment, it just downloads the attachment and converts that into a document.
|
||||
"""
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
# The url and the id are the same
|
||||
object_url = build_confluence_document_id(
|
||||
self.wiki_base, confluence_object["_links"]["webui"]
|
||||
base_url=self.wiki_base,
|
||||
content_url=confluence_object["_links"]["webui"],
|
||||
is_cloud=self.is_cloud,
|
||||
)
|
||||
|
||||
object_text = None
|
||||
# Extract text from page
|
||||
if confluence_object["type"] == "page":
|
||||
object_text = extract_text_from_confluence_html(
|
||||
self.confluence_client, confluence_object
|
||||
confluence_client=self.confluence_client,
|
||||
confluence_object=confluence_object,
|
||||
fetched_titles={confluence_object.get("title", "")},
|
||||
)
|
||||
# Add comments to text
|
||||
object_text += self._get_comment_string_for_page_id(confluence_object["id"])
|
||||
elif confluence_object["type"] == "attachment":
|
||||
object_text = attachment_to_content(
|
||||
self.confluence_client, confluence_object
|
||||
confluence_client=self.confluence_client, attachment=confluence_object
|
||||
)
|
||||
|
||||
if object_text is None:
|
||||
# This only happens for attachments that are not parseable
|
||||
return None
|
||||
|
||||
# Get space name
|
||||
@@ -193,44 +217,39 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
)
|
||||
|
||||
def _fetch_document_batches(self) -> GenerateDocumentsOutput:
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
confluence_page_ids: list[str] = []
|
||||
|
||||
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
|
||||
# Fetch pages as Documents
|
||||
for page_batch in self.confluence_client.paginated_cql_page_retrieval(
|
||||
for page in self.confluence_client.paginated_cql_retrieval(
|
||||
cql=page_query,
|
||||
expand=",".join(_PAGE_EXPANSION_FIELDS),
|
||||
limit=self.batch_size,
|
||||
):
|
||||
for page in page_batch:
|
||||
confluence_page_ids.append(page["id"])
|
||||
doc = self._convert_object_to_document(page)
|
||||
if doc is not None:
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
confluence_page_ids.append(page["id"])
|
||||
doc = self._convert_object_to_document(page)
|
||||
if doc is not None:
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
# Fetch attachments as Documents
|
||||
for confluence_page_id in confluence_page_ids:
|
||||
attachment_cql = f"type=attachment and container='{confluence_page_id}'"
|
||||
attachment_cql += self.cql_label_filter
|
||||
# TODO: maybe should add time filter as well?
|
||||
for attachments in self.confluence_client.paginated_cql_page_retrieval(
|
||||
for attachment in self.confluence_client.paginated_cql_retrieval(
|
||||
cql=attachment_cql,
|
||||
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
|
||||
):
|
||||
for attachment in attachments:
|
||||
doc = self._convert_object_to_document(attachment)
|
||||
if doc is not None:
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
doc = self._convert_object_to_document(attachment)
|
||||
if doc is not None:
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
@@ -255,48 +274,47 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
doc_metadata_list: list[SlimDocument] = []
|
||||
|
||||
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
|
||||
|
||||
page_query = self.cql_page_query + self.cql_label_filter
|
||||
for pages in self.confluence_client.cql_paginate_all_expansions(
|
||||
for page in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=page_query,
|
||||
expand=restrictions_expand,
|
||||
):
|
||||
for page in pages:
|
||||
# If the page has restrictions, add them to the perm_sync_data
|
||||
# These will be used by doc_sync.py to sync permissions
|
||||
perm_sync_data = {
|
||||
"restrictions": page.get("restrictions", {}),
|
||||
"space_key": page.get("space", {}).get("key"),
|
||||
}
|
||||
# If the page has restrictions, add them to the perm_sync_data
|
||||
# These will be used by doc_sync.py to sync permissions
|
||||
perm_sync_data = {
|
||||
"restrictions": page.get("restrictions", {}),
|
||||
"space_key": page.get("space", {}).get("key"),
|
||||
}
|
||||
|
||||
doc_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=build_confluence_document_id(
|
||||
self.wiki_base,
|
||||
page["_links"]["webui"],
|
||||
self.is_cloud,
|
||||
),
|
||||
perm_sync_data=perm_sync_data,
|
||||
)
|
||||
)
|
||||
attachment_cql = f"type=attachment and container='{page['id']}'"
|
||||
attachment_cql += self.cql_label_filter
|
||||
for attachment in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=attachment_cql,
|
||||
expand=restrictions_expand,
|
||||
):
|
||||
doc_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=build_confluence_document_id(
|
||||
self.wiki_base, page["_links"]["webui"]
|
||||
self.wiki_base,
|
||||
attachment["_links"]["webui"],
|
||||
self.is_cloud,
|
||||
),
|
||||
perm_sync_data=perm_sync_data,
|
||||
)
|
||||
)
|
||||
attachment_cql = f"type=attachment and container='{page['id']}'"
|
||||
attachment_cql += self.cql_label_filter
|
||||
for attachments in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=attachment_cql,
|
||||
expand=restrictions_expand,
|
||||
):
|
||||
for attachment in attachments:
|
||||
doc_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=build_confluence_document_id(
|
||||
self.wiki_base, attachment["_links"]["webui"]
|
||||
),
|
||||
perm_sync_data=perm_sync_data,
|
||||
)
|
||||
)
|
||||
yield doc_metadata_list
|
||||
doc_metadata_list = []
|
||||
yield doc_metadata_list
|
||||
doc_metadata_list = []
|
||||
|
||||
@@ -20,6 +20,10 @@ F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
|
||||
|
||||
# https://jira.atlassian.com/browse/CONFCLOUD-76433
|
||||
_PROBLEMATIC_EXPANSIONS = "body.storage.value"
|
||||
_REPLACEMENT_EXPANSIONS = "body.view.value"
|
||||
|
||||
|
||||
class ConfluenceRateLimitError(Exception):
|
||||
pass
|
||||
@@ -80,7 +84,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
MAX_RETRIES = 5
|
||||
|
||||
TIMEOUT = 3600
|
||||
TIMEOUT = 600
|
||||
timeout_at = time.monotonic() + TIMEOUT
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
@@ -88,13 +92,16 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
raise TimeoutError(
|
||||
f"Confluence call attempts took longer than {TIMEOUT} seconds."
|
||||
)
|
||||
|
||||
try:
|
||||
# we're relying more on the client to rate limit itself
|
||||
# and applying our own retries in a more specific set of circumstances
|
||||
return confluence_call(*args, **kwargs)
|
||||
except HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
logger.warning(
|
||||
f"HTTPError in confluence call. "
|
||||
f"Retrying in {delay_until} seconds..."
|
||||
)
|
||||
while time.monotonic() < delay_until:
|
||||
# in the future, check a signal here to exit
|
||||
time.sleep(1)
|
||||
@@ -103,7 +110,6 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
# Users reported it to be intermittent, so just retry
|
||||
if attempt == MAX_RETRIES - 1:
|
||||
raise e
|
||||
|
||||
logger.exception(
|
||||
"Confluence Client raised an AttributeError. Retrying..."
|
||||
)
|
||||
@@ -141,7 +147,7 @@ class OnyxConfluence(Confluence):
|
||||
|
||||
def _paginate_url(
|
||||
self, url_suffix: str, limit: int | None = None
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
This will paginate through the top level query.
|
||||
"""
|
||||
@@ -153,46 +159,43 @@ class OnyxConfluence(Confluence):
|
||||
|
||||
while url_suffix:
|
||||
try:
|
||||
logger.debug(f"Making confluence call to {url_suffix}")
|
||||
next_response = self.get(url_suffix)
|
||||
except Exception as e:
|
||||
logger.exception("Error in danswer_cql: \n")
|
||||
raise e
|
||||
yield next_response.get("results", [])
|
||||
logger.warning(f"Error in confluence call to {url_suffix}")
|
||||
|
||||
# If the problematic expansion is in the url, replace it
|
||||
# with the replacement expansion and try again
|
||||
# If that fails, raise the error
|
||||
if _PROBLEMATIC_EXPANSIONS not in url_suffix:
|
||||
logger.exception(f"Error in confluence call to {url_suffix}")
|
||||
raise e
|
||||
logger.warning(
|
||||
f"Replacing {_PROBLEMATIC_EXPANSIONS} with {_REPLACEMENT_EXPANSIONS}"
|
||||
" and trying again."
|
||||
)
|
||||
url_suffix = url_suffix.replace(
|
||||
_PROBLEMATIC_EXPANSIONS,
|
||||
_REPLACEMENT_EXPANSIONS,
|
||||
)
|
||||
continue
|
||||
|
||||
# yield the results individually
|
||||
yield from next_response.get("results", [])
|
||||
|
||||
url_suffix = next_response.get("_links", {}).get("next")
|
||||
|
||||
def paginated_groups_retrieval(
|
||||
self,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
return self._paginate_url("rest/api/group", limit)
|
||||
|
||||
def paginated_group_members_retrieval(
|
||||
self,
|
||||
group_name: str,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
group_name = quote(group_name)
|
||||
return self._paginate_url(f"rest/api/group/{group_name}/member", limit)
|
||||
|
||||
def paginated_cql_user_retrieval(
|
||||
def paginated_cql_retrieval(
|
||||
self,
|
||||
cql: str,
|
||||
expand: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
The content/search endpoint can be used to fetch pages, attachments, and comments.
|
||||
"""
|
||||
expand_string = f"&expand={expand}" if expand else ""
|
||||
return self._paginate_url(
|
||||
f"rest/api/search/user?cql={cql}{expand_string}", limit
|
||||
)
|
||||
|
||||
def paginated_cql_page_retrieval(
|
||||
self,
|
||||
cql: str,
|
||||
expand: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
expand_string = f"&expand={expand}" if expand else ""
|
||||
return self._paginate_url(
|
||||
yield from self._paginate_url(
|
||||
f"rest/api/content/search?cql={cql}{expand_string}", limit
|
||||
)
|
||||
|
||||
@@ -201,7 +204,7 @@ class OnyxConfluence(Confluence):
|
||||
cql: str,
|
||||
expand: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
This function will paginate through the top level query first, then
|
||||
paginate through all of the expansions.
|
||||
@@ -221,6 +224,44 @@ class OnyxConfluence(Confluence):
|
||||
for item in data:
|
||||
_traverse_and_update(item)
|
||||
|
||||
for results in self.paginated_cql_page_retrieval(cql, expand, limit):
|
||||
_traverse_and_update(results)
|
||||
yield results
|
||||
for confluence_object in self.paginated_cql_retrieval(cql, expand, limit):
|
||||
_traverse_and_update(confluence_object)
|
||||
yield confluence_object
|
||||
|
||||
def paginated_cql_user_retrieval(
|
||||
self,
|
||||
cql: str,
|
||||
expand: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
The search/user endpoint can be used to fetch users.
|
||||
It's a seperate endpoint from the content/search endpoint used only for users.
|
||||
Otherwise it's very similar to the content/search endpoint.
|
||||
"""
|
||||
expand_string = f"&expand={expand}" if expand else ""
|
||||
yield from self._paginate_url(
|
||||
f"rest/api/search/user?cql={cql}{expand_string}", limit
|
||||
)
|
||||
|
||||
def paginated_groups_retrieval(
|
||||
self,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
This is not an SQL like query.
|
||||
It's a confluence specific endpoint that can be used to fetch groups.
|
||||
"""
|
||||
yield from self._paginate_url("rest/api/group", limit)
|
||||
|
||||
def paginated_group_members_retrieval(
|
||||
self,
|
||||
group_name: str,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
This is not an SQL like query.
|
||||
It's a confluence specific endpoint that can be used to fetch the members of a group.
|
||||
"""
|
||||
group_name = quote(group_name)
|
||||
yield from self._paginate_url(f"rest/api/group/{group_name}/member", limit)
|
||||
|
||||
@@ -2,6 +2,7 @@ import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
import bs4
|
||||
|
||||
@@ -71,7 +72,9 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
|
||||
|
||||
|
||||
def extract_text_from_confluence_html(
|
||||
confluence_client: OnyxConfluence, confluence_object: dict[str, Any]
|
||||
confluence_client: OnyxConfluence,
|
||||
confluence_object: dict[str, Any],
|
||||
fetched_titles: set[str],
|
||||
) -> str:
|
||||
"""Parse a Confluence html page and replace the 'user Id' by the real
|
||||
User Display Name
|
||||
@@ -79,7 +82,7 @@ def extract_text_from_confluence_html(
|
||||
Args:
|
||||
confluence_object (dict): The confluence object as a dict
|
||||
confluence_client (Confluence): Confluence client
|
||||
|
||||
fetched_titles (set[str]): The titles of the pages that have already been fetched
|
||||
Returns:
|
||||
str: loaded and formated Confluence page
|
||||
"""
|
||||
@@ -100,6 +103,73 @@ def extract_text_from_confluence_html(
|
||||
continue
|
||||
# Include @ sign for tagging, more clear for LLM
|
||||
user.replaceWith("@" + _get_user(confluence_client, user_id))
|
||||
|
||||
for html_page_reference in soup.findAll("ac:structured-macro"):
|
||||
# Here, we only want to process page within page macros
|
||||
if html_page_reference.attrs.get("ac:name") != "include":
|
||||
continue
|
||||
|
||||
page_data = html_page_reference.find("ri:page")
|
||||
if not page_data:
|
||||
logger.warning(
|
||||
f"Skipping retrieval of {html_page_reference} because because page data is missing"
|
||||
)
|
||||
continue
|
||||
|
||||
page_title = page_data.attrs.get("ri:content-title")
|
||||
if not page_title:
|
||||
# only fetch pages that have a title
|
||||
logger.warning(
|
||||
f"Skipping retrieval of {html_page_reference} because it has no title"
|
||||
)
|
||||
continue
|
||||
|
||||
if page_title in fetched_titles:
|
||||
# prevent recursive fetching of pages
|
||||
logger.debug(f"Skipping {page_title} because it has already been fetched")
|
||||
continue
|
||||
|
||||
fetched_titles.add(page_title)
|
||||
|
||||
# Wrap this in a try-except because there are some pages that might not exist
|
||||
try:
|
||||
page_query = f"type=page and title='{quote(page_title)}'"
|
||||
|
||||
page_contents: dict[str, Any] | None = None
|
||||
# Confluence enforces title uniqueness, so we should only get one result here
|
||||
for page in confluence_client.paginated_cql_retrieval(
|
||||
cql=page_query,
|
||||
expand="body.storage.value",
|
||||
limit=1,
|
||||
):
|
||||
page_contents = page
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error getting page contents for object {confluence_object}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not page_contents:
|
||||
continue
|
||||
|
||||
text_from_page = extract_text_from_confluence_html(
|
||||
confluence_client=confluence_client,
|
||||
confluence_object=page_contents,
|
||||
fetched_titles=fetched_titles,
|
||||
)
|
||||
|
||||
html_page_reference.replaceWith(text_from_page)
|
||||
|
||||
for html_link_body in soup.findAll("ac:link-body"):
|
||||
# This extracts the text from inline links in the page so they can be
|
||||
# represented in the document text as plain text
|
||||
try:
|
||||
text_from_link = html_link_body.text
|
||||
html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing ac:link-body: {e}")
|
||||
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
@@ -153,7 +223,9 @@ def attachment_to_content(
|
||||
return extracted_text
|
||||
|
||||
|
||||
def build_confluence_document_id(base_url: str, content_url: str) -> str:
|
||||
def build_confluence_document_id(
|
||||
base_url: str, content_url: str, is_cloud: bool
|
||||
) -> str:
|
||||
"""For confluence, the document id is the page url for a page based document
|
||||
or the attachment download url for an attachment based document
|
||||
|
||||
@@ -164,6 +236,8 @@ def build_confluence_document_id(base_url: str, content_url: str) -> str:
|
||||
Returns:
|
||||
str: The document id
|
||||
"""
|
||||
if is_cloud and not base_url.endswith("/wiki"):
|
||||
base_url += "/wiki"
|
||||
return f"{base_url}{content_url}"
|
||||
|
||||
|
||||
@@ -209,6 +283,6 @@ def build_confluence_client(
|
||||
password=credentials_json["confluence_access_token"] if is_cloud else None,
|
||||
token=credentials_json["confluence_access_token"] if not is_cloud else None,
|
||||
backoff_and_retry=True,
|
||||
max_backoff_retries=60,
|
||||
max_backoff_retries=10,
|
||||
max_backoff_seconds=60,
|
||||
)
|
||||
|
||||
@@ -192,23 +192,33 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def _update_traversed_parent_ids(self, folder_id: str) -> None:
|
||||
self._retrieved_ids.add(folder_id)
|
||||
|
||||
def _get_all_user_emails(self, admins_only: bool) -> list[str]:
|
||||
def _get_all_user_emails(self) -> list[str]:
|
||||
# Start with primary admin email
|
||||
user_emails = [self.primary_admin_email]
|
||||
|
||||
# Only fetch additional users if using service account
|
||||
if isinstance(self.creds, OAuthCredentials):
|
||||
return user_emails
|
||||
|
||||
admin_service = get_admin_service(
|
||||
creds=self.creds,
|
||||
user_email=self.primary_admin_email,
|
||||
)
|
||||
query = "isAdmin=true" if admins_only else "isAdmin=false"
|
||||
emails = []
|
||||
for user in execute_paginated_retrieval(
|
||||
retrieval_function=admin_service.users().list,
|
||||
list_key="users",
|
||||
fields=USER_FIELDS,
|
||||
domain=self.google_domain,
|
||||
query=query,
|
||||
):
|
||||
if email := user.get("primaryEmail"):
|
||||
emails.append(email)
|
||||
return emails
|
||||
|
||||
# Get admins first since they're more likely to have access to most files
|
||||
for is_admin in [True, False]:
|
||||
query = "isAdmin=true" if is_admin else "isAdmin=false"
|
||||
for user in execute_paginated_retrieval(
|
||||
retrieval_function=admin_service.users().list,
|
||||
list_key="users",
|
||||
fields=USER_FIELDS,
|
||||
domain=self.google_domain,
|
||||
query=query,
|
||||
):
|
||||
if email := user.get("primaryEmail"):
|
||||
if email not in user_emails:
|
||||
user_emails.append(email)
|
||||
return user_emails
|
||||
|
||||
def _get_all_drive_ids(self) -> set[str]:
|
||||
primary_drive_service = get_drive_service(
|
||||
@@ -216,55 +226,48 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
user_email=self.primary_admin_email,
|
||||
)
|
||||
all_drive_ids = set()
|
||||
# We don't want to fail if we're using OAuth because you can
|
||||
# access your my drive as a non admin user in an org still
|
||||
ignore_fetch_failure = isinstance(self.creds, OAuthCredentials)
|
||||
for drive in execute_paginated_retrieval(
|
||||
retrieval_function=primary_drive_service.drives().list,
|
||||
list_key="drives",
|
||||
continue_on_404_or_403=ignore_fetch_failure,
|
||||
useDomainAdminAccess=True,
|
||||
fields="drives(id)",
|
||||
):
|
||||
all_drive_ids.add(drive["id"])
|
||||
return all_drive_ids
|
||||
|
||||
def _initialize_all_class_variables(self) -> None:
|
||||
# Get all user emails
|
||||
# Get admins first becuase they are more likely to have access to the most files
|
||||
user_emails = [self.primary_admin_email]
|
||||
for admins_only in [True, False]:
|
||||
for email in self._get_all_user_emails(admins_only=admins_only):
|
||||
if email not in user_emails:
|
||||
user_emails.append(email)
|
||||
self._all_org_emails = user_emails
|
||||
|
||||
self._all_drive_ids: set[str] = self._get_all_drive_ids()
|
||||
|
||||
# remove drive ids from the folder ids because they are queried differently
|
||||
self._requested_folder_ids -= self._all_drive_ids
|
||||
|
||||
# Remove drive_ids that are not in the all_drive_ids and check them as folders instead
|
||||
invalid_drive_ids = self._requested_shared_drive_ids - self._all_drive_ids
|
||||
if invalid_drive_ids:
|
||||
if not all_drive_ids:
|
||||
logger.warning(
|
||||
f"Some shared drive IDs were not found. IDs: {invalid_drive_ids}"
|
||||
"No drives found. This is likely because oauth user "
|
||||
"is not an admin and cannot view all drive IDs. "
|
||||
"Continuing with only the shared drive IDs specified in the config."
|
||||
)
|
||||
logger.warning("Checking for folder access instead...")
|
||||
self._requested_folder_ids.update(invalid_drive_ids)
|
||||
all_drive_ids = set(self._requested_shared_drive_ids)
|
||||
|
||||
if not self.include_shared_drives:
|
||||
self._requested_shared_drive_ids = set()
|
||||
elif not self._requested_shared_drive_ids:
|
||||
self._requested_shared_drive_ids = self._all_drive_ids
|
||||
return all_drive_ids
|
||||
|
||||
def _impersonate_user_for_retrieval(
|
||||
self,
|
||||
user_email: str,
|
||||
is_slim: bool,
|
||||
filtered_drive_ids: set[str],
|
||||
filtered_folder_ids: set[str],
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
drive_service = get_drive_service(self.creds, user_email)
|
||||
|
||||
# if we are including my drives, try to get the current user's my
|
||||
# drive if any of the following are true:
|
||||
# - no specific emails were requested
|
||||
# - the current user's email is in the requested emails
|
||||
# - we are using OAuth (in which case we assume that is the only email we will try)
|
||||
if self.include_my_drives and (
|
||||
not self._requested_my_drive_emails
|
||||
or user_email in self._requested_my_drive_emails
|
||||
or isinstance(self.creds, OAuthCredentials)
|
||||
):
|
||||
yield from get_all_files_in_my_drive(
|
||||
service=drive_service,
|
||||
@@ -274,7 +277,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
end=end,
|
||||
)
|
||||
|
||||
remaining_drive_ids = self._requested_shared_drive_ids - self._retrieved_ids
|
||||
remaining_drive_ids = filtered_drive_ids - self._retrieved_ids
|
||||
for drive_id in remaining_drive_ids:
|
||||
yield from get_files_in_shared_drive(
|
||||
service=drive_service,
|
||||
@@ -285,7 +288,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
end=end,
|
||||
)
|
||||
|
||||
remaining_folders = self._requested_folder_ids - self._retrieved_ids
|
||||
remaining_folders = filtered_folder_ids - self._retrieved_ids
|
||||
for folder_id in remaining_folders:
|
||||
yield from crawl_folders_for_files(
|
||||
service=drive_service,
|
||||
@@ -302,22 +305,56 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
self._initialize_all_class_variables()
|
||||
all_org_emails: list[str] = self._get_all_user_emails()
|
||||
|
||||
all_drive_ids: set[str] = self._get_all_drive_ids()
|
||||
|
||||
# remove drive ids from the folder ids because they are queried differently
|
||||
filtered_folder_ids = self._requested_folder_ids - all_drive_ids
|
||||
|
||||
# Remove drive_ids that are not in the all_drive_ids and check them as folders instead
|
||||
invalid_drive_ids = self._requested_shared_drive_ids - all_drive_ids
|
||||
if invalid_drive_ids:
|
||||
logger.warning(
|
||||
f"Some shared drive IDs were not found. IDs: {invalid_drive_ids}"
|
||||
)
|
||||
logger.warning("Checking for folder access instead...")
|
||||
filtered_folder_ids.update(invalid_drive_ids)
|
||||
|
||||
# If including shared drives, use the requested IDs if provided,
|
||||
# otherwise use all drive IDs
|
||||
filtered_drive_ids = set()
|
||||
if self.include_shared_drives:
|
||||
if self._requested_shared_drive_ids:
|
||||
# Remove invalid drive IDs from requested IDs
|
||||
filtered_drive_ids = (
|
||||
self._requested_shared_drive_ids - invalid_drive_ids
|
||||
)
|
||||
else:
|
||||
filtered_drive_ids = all_drive_ids
|
||||
|
||||
# Process users in parallel using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
future_to_email = {
|
||||
executor.submit(
|
||||
self._impersonate_user_for_retrieval, email, is_slim, start, end
|
||||
self._impersonate_user_for_retrieval,
|
||||
email,
|
||||
is_slim,
|
||||
filtered_drive_ids,
|
||||
filtered_folder_ids,
|
||||
start,
|
||||
end,
|
||||
): email
|
||||
for email in self._all_org_emails
|
||||
for email in all_org_emails
|
||||
}
|
||||
|
||||
# Yield results as they complete
|
||||
for future in as_completed(future_to_email):
|
||||
yield from future.result()
|
||||
|
||||
remaining_folders = self._requested_folder_ids - self._retrieved_ids
|
||||
remaining_folders = (
|
||||
filtered_drive_ids | filtered_folder_ids
|
||||
) - self._retrieved_ids
|
||||
if remaining_folders:
|
||||
logger.warning(
|
||||
f"Some folders/drives were not retrieved. IDs: {remaining_folders}"
|
||||
|
||||
@@ -105,7 +105,7 @@ def execute_paginated_retrieval(
|
||||
)()
|
||||
elif e.resp.status == 404 or e.resp.status == 403:
|
||||
if continue_on_404_or_403:
|
||||
logger.warning(f"Error executing request: {e}")
|
||||
logger.debug(f"Error executing request: {e}")
|
||||
results = {}
|
||||
else:
|
||||
raise e
|
||||
|
||||
@@ -169,6 +169,7 @@ def get_document_connector_counts(
|
||||
def get_document_counts_for_cc_pairs(
|
||||
db_session: Session, cc_pair_identifiers: list[ConnectorCredentialPairIdentifier]
|
||||
) -> Sequence[tuple[int, int, int]]:
|
||||
"""Returns a sequence of tuples of (connector_id, credential_id, document count)"""
|
||||
stmt = (
|
||||
select(
|
||||
DocumentByConnectorCredentialPair.connector_id,
|
||||
@@ -323,23 +324,23 @@ def upsert_documents(
|
||||
|
||||
|
||||
def upsert_document_by_connector_credential_pair(
|
||||
db_session: Session, document_metadata_batch: list[DocumentMetadata]
|
||||
db_session: Session, connector_id: int, credential_id: int, document_ids: list[str]
|
||||
) -> None:
|
||||
"""NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause."""
|
||||
if not document_metadata_batch:
|
||||
logger.info("`document_metadata_batch` is empty. Skipping.")
|
||||
if not document_ids:
|
||||
logger.info("`document_ids` is empty. Skipping.")
|
||||
return
|
||||
|
||||
insert_stmt = insert(DocumentByConnectorCredentialPair).values(
|
||||
[
|
||||
model_to_dict(
|
||||
DocumentByConnectorCredentialPair(
|
||||
id=document_metadata.document_id,
|
||||
connector_id=document_metadata.connector_id,
|
||||
credential_id=document_metadata.credential_id,
|
||||
id=doc_id,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
)
|
||||
for document_metadata in document_metadata_batch
|
||||
for doc_id in document_ids
|
||||
]
|
||||
)
|
||||
# for now, there are no columns to update. If more metadata is added, then this
|
||||
@@ -400,17 +401,6 @@ def mark_document_as_synced(document_id: str, db_session: Session) -> None:
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def upsert_documents_complete(
|
||||
db_session: Session,
|
||||
document_metadata_batch: list[DocumentMetadata],
|
||||
) -> None:
|
||||
upsert_documents(db_session, document_metadata_batch)
|
||||
upsert_document_by_connector_credential_pair(db_session, document_metadata_batch)
|
||||
logger.info(
|
||||
f"Upserted {len(document_metadata_batch)} document store entries into DB"
|
||||
)
|
||||
|
||||
|
||||
def delete_document_by_connector_credential_pair__no_commit(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
@@ -520,7 +510,7 @@ def prepare_to_modify_documents(
|
||||
db_session.commit() # ensure that we're not in a transaction
|
||||
|
||||
lock_acquired = False
|
||||
for _ in range(_NUM_LOCK_ATTEMPTS):
|
||||
for i in range(_NUM_LOCK_ATTEMPTS):
|
||||
try:
|
||||
with db_session.begin() as transaction:
|
||||
lock_acquired = acquire_document_locks(
|
||||
@@ -531,7 +521,7 @@ def prepare_to_modify_documents(
|
||||
break
|
||||
except OperationalError as e:
|
||||
logger.warning(
|
||||
f"Failed to acquire locks for documents, retrying. Error: {e}"
|
||||
f"Failed to acquire locks for documents on attempt {i}, retrying. Error: {e}"
|
||||
)
|
||||
|
||||
time.sleep(retry_delay)
|
||||
|
||||
@@ -312,7 +312,9 @@ async def get_async_session_with_tenant(
|
||||
await session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||
await session.execute(
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
text(
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error setting search_path.")
|
||||
@@ -373,7 +375,9 @@ def get_session_with_tenant(
|
||||
cursor.execute(f'SET search_path = "{tenant_id}"')
|
||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||
cursor.execute(
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
text(
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
)
|
||||
)
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
@@ -67,6 +67,13 @@ def create_index_attempt(
|
||||
return new_attempt.id
|
||||
|
||||
|
||||
def delete_index_attempt(db_session: Session, index_attempt_id: int) -> None:
|
||||
index_attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if index_attempt:
|
||||
db_session.delete(index_attempt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mock_successful_index_attempt(
|
||||
connector_credential_pair_id: int,
|
||||
search_settings_id: int,
|
||||
|
||||
@@ -20,7 +20,8 @@ from danswer.db.document import get_documents_by_ids
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
from danswer.db.document import update_docs_last_modified__no_commit
|
||||
from danswer.db.document import update_docs_updated_at__no_commit
|
||||
from danswer.db.document import upsert_documents_complete
|
||||
from danswer.db.document import upsert_document_by_connector_credential_pair
|
||||
from danswer.db.document import upsert_documents
|
||||
from danswer.db.document_set import fetch_document_sets_for_documents
|
||||
from danswer.db.index_attempt import create_index_attempt_error
|
||||
from danswer.db.models import Document as DBDocument
|
||||
@@ -56,13 +57,13 @@ class IndexingPipelineProtocol(Protocol):
|
||||
...
|
||||
|
||||
|
||||
def upsert_documents_in_db(
|
||||
def _upsert_documents_in_db(
|
||||
documents: list[Document],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
# Metadata here refers to basic document info, not metadata about the actual content
|
||||
doc_m_batch: list[DocumentMetadata] = []
|
||||
document_metadata_list: list[DocumentMetadata] = []
|
||||
for doc in documents:
|
||||
first_link = next(
|
||||
(section.link for section in doc.sections if section.link), ""
|
||||
@@ -77,12 +78,9 @@ def upsert_documents_in_db(
|
||||
secondary_owners=get_experts_stores_representations(doc.secondary_owners),
|
||||
from_ingestion_api=doc.from_ingestion_api,
|
||||
)
|
||||
doc_m_batch.append(db_doc_metadata)
|
||||
document_metadata_list.append(db_doc_metadata)
|
||||
|
||||
upsert_documents_complete(
|
||||
db_session=db_session,
|
||||
document_metadata_batch=doc_m_batch,
|
||||
)
|
||||
upsert_documents(db_session, document_metadata_list)
|
||||
|
||||
# Insert document content metadata
|
||||
for doc in documents:
|
||||
@@ -95,21 +93,25 @@ def upsert_documents_in_db(
|
||||
document_id=doc.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
else:
|
||||
create_or_add_document_tag(
|
||||
tag_key=k,
|
||||
tag_value=v,
|
||||
source=doc.source,
|
||||
document_id=doc.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
continue
|
||||
|
||||
create_or_add_document_tag(
|
||||
tag_key=k,
|
||||
tag_value=v,
|
||||
source=doc.source,
|
||||
document_id=doc.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def get_doc_ids_to_update(
|
||||
documents: list[Document], db_docs: list[DBDocument]
|
||||
) -> list[Document]:
|
||||
"""Figures out which documents actually need to be updated. If a document is already present
|
||||
and the `updated_at` hasn't changed, we shouldn't need to do anything with it."""
|
||||
and the `updated_at` hasn't changed, we shouldn't need to do anything with it.
|
||||
|
||||
NB: Still need to associate the document in the DB if multiple connectors are
|
||||
indexing the same doc."""
|
||||
id_update_time_map = {
|
||||
doc.id: doc.doc_updated_at for doc in db_docs if doc.doc_updated_at
|
||||
}
|
||||
@@ -195,9 +197,9 @@ def index_doc_batch_prepare(
|
||||
db_session: Session,
|
||||
ignore_time_skip: bool = False,
|
||||
) -> DocumentBatchPrepareContext | None:
|
||||
"""This sets up the documents in the relational DB (source of truth) for permissions, metadata, etc.
|
||||
"""Sets up the documents in the relational DB (source of truth) for permissions, metadata, etc.
|
||||
This preceeds indexing it into the actual document index."""
|
||||
documents = []
|
||||
documents: list[Document] = []
|
||||
for document in document_batch:
|
||||
empty_contents = not any(section.text.strip() for section in document.sections)
|
||||
if (
|
||||
@@ -212,43 +214,58 @@ def index_doc_batch_prepare(
|
||||
logger.warning(
|
||||
f"Skipping document with ID {document.id} as it has neither title nor content."
|
||||
)
|
||||
elif (
|
||||
document.title is not None and not document.title.strip() and empty_contents
|
||||
):
|
||||
continue
|
||||
|
||||
if document.title is not None and not document.title.strip() and empty_contents:
|
||||
# The title is explicitly empty ("" and not None) and the document is empty
|
||||
# so when building the chunk text representation, it will be empty and unuseable
|
||||
logger.warning(
|
||||
f"Skipping document with ID {document.id} as the chunks will be empty."
|
||||
)
|
||||
else:
|
||||
documents.append(document)
|
||||
continue
|
||||
|
||||
document_ids = [document.id for document in documents]
|
||||
documents.append(document)
|
||||
|
||||
# Create a trimmed list of docs that don't have a newer updated at
|
||||
# Shortcuts the time-consuming flow on connector index retries
|
||||
document_ids: list[str] = [document.id for document in documents]
|
||||
db_docs: list[DBDocument] = get_documents_by_ids(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
|
||||
# Skip indexing docs that don't have a newer updated at
|
||||
# Shortcuts the time-consuming flow on connector index retries
|
||||
updatable_docs = (
|
||||
get_doc_ids_to_update(documents=documents, db_docs=db_docs)
|
||||
if not ignore_time_skip
|
||||
else documents
|
||||
)
|
||||
|
||||
# No docs to update either because the batch is empty or every doc was already indexed
|
||||
# for all updatable docs, upsert into the DB
|
||||
# Does not include doc_updated_at which is also used to indicate a successful update
|
||||
if updatable_docs:
|
||||
_upsert_documents_in_db(
|
||||
documents=updatable_docs,
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Upserted {len(updatable_docs)} changed docs out of "
|
||||
f"{len(documents)} total docs into the DB"
|
||||
)
|
||||
|
||||
# for all docs, upsert the document to cc pair relationship
|
||||
upsert_document_by_connector_credential_pair(
|
||||
db_session,
|
||||
index_attempt_metadata.connector_id,
|
||||
index_attempt_metadata.credential_id,
|
||||
document_ids,
|
||||
)
|
||||
|
||||
# No docs to process because the batch is empty or every doc was already indexed
|
||||
if not updatable_docs:
|
||||
return None
|
||||
|
||||
# Create records in the source of truth about these documents,
|
||||
# does not include doc_updated_at which is also used to indicate a successful update
|
||||
upsert_documents_in_db(
|
||||
documents=documents,
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
id_to_db_doc_map = {doc.id: doc for doc in db_docs}
|
||||
return DocumentBatchPrepareContext(
|
||||
updatable_docs=updatable_docs, id_to_db_doc_map=id_to_db_doc_map
|
||||
@@ -269,7 +286,10 @@ def index_doc_batch(
|
||||
) -> tuple[int, int]:
|
||||
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
|
||||
Note that the documents should already be batched at this point so that it does not inflate the
|
||||
memory requirements"""
|
||||
memory requirements
|
||||
|
||||
Returns a tuple where the first element is the number of new docs and the
|
||||
second element is the number of chunks."""
|
||||
|
||||
no_access = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
@@ -312,9 +332,9 @@ def index_doc_batch(
|
||||
|
||||
# we're concerned about race conditions where multiple simultaneous indexings might result
|
||||
# in one set of metadata overwriting another one in vespa.
|
||||
# we still write data here for immediate and most likely correct sync, but
|
||||
# we still write data here for the immediate and most likely correct sync, but
|
||||
# to resolve this, an update of the last modified field at the end of this loop
|
||||
# always triggers a final metadata sync
|
||||
# always triggers a final metadata sync via the celery queue
|
||||
access_aware_chunks = [
|
||||
DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
@@ -351,7 +371,8 @@ def index_doc_batch(
|
||||
ids_to_new_updated_at = {}
|
||||
for doc in successful_docs:
|
||||
last_modified_ids.append(doc.id)
|
||||
# doc_updated_at is the connector source's idea of when the doc was last modified
|
||||
# doc_updated_at is the source's idea (on the other end of the connector)
|
||||
# of when the doc was last modified
|
||||
if doc.doc_updated_at is None:
|
||||
continue
|
||||
ids_to_new_updated_at[doc.id] = doc.doc_updated_at
|
||||
@@ -366,10 +387,13 @@ def index_doc_batch(
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return len([r for r in insertion_records if r.already_existed is False]), len(
|
||||
access_aware_chunks
|
||||
result = (
|
||||
len([r for r in insertion_records if r.already_existed is False]),
|
||||
len(access_aware_chunks),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def build_indexing_pipeline(
|
||||
*,
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import time
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
@@ -13,6 +14,7 @@ from danswer.db.connector_credential_pair import get_connector_credential_pair_f
|
||||
from danswer.db.document import (
|
||||
construct_document_select_for_connector_credential_pair_by_needs_sync,
|
||||
)
|
||||
from danswer.db.models import Document
|
||||
from danswer.redis.redis_object_helper import RedisObjectHelper
|
||||
|
||||
|
||||
@@ -30,6 +32,9 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
def __init__(self, tenant_id: str | None, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
|
||||
# documents that should be skipped
|
||||
self.skip_docs: set[str] = set()
|
||||
|
||||
@classmethod
|
||||
def get_fence_key(cls) -> str:
|
||||
return RedisConnectorCredentialPair.FENCE_PREFIX
|
||||
@@ -45,14 +50,19 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
# example: connector_taskset
|
||||
return f"{self.TASKSET_PREFIX}"
|
||||
|
||||
def set_skip_docs(self, skip_docs: set[str]) -> None:
|
||||
# documents that should be skipped. Note that this classes updates
|
||||
# the list on the fly
|
||||
self.skip_docs = skip_docs
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
lock: RedisLock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
) -> tuple[int, int] | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
@@ -63,7 +73,10 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
|
||||
cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
|
||||
num_docs = 0
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
doc = cast(Document, doc)
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
@@ -71,6 +84,12 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
num_docs += 1
|
||||
|
||||
# check if we should skip the document (typically because it's already syncing)
|
||||
if doc.id in self.skip_docs:
|
||||
continue
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
@@ -93,5 +112,6 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
self.skip_docs.add(doc.id)
|
||||
|
||||
return len(async_results)
|
||||
return len(async_results), num_docs
|
||||
|
||||
@@ -6,6 +6,7 @@ from uuid import uuid4
|
||||
import redis
|
||||
from celery import Celery
|
||||
from pydantic import BaseModel
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
@@ -83,7 +84,7 @@ class RedisConnectorDelete:
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
lock: redis.lock.Lock,
|
||||
lock: RedisLock,
|
||||
) -> int | None:
|
||||
"""Returns None if the cc_pair doesn't exist.
|
||||
Otherwise, returns an int with the number of generated tasks."""
|
||||
|
||||
@@ -6,7 +6,7 @@ import redis
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RedisConnectorIndexingFenceData(BaseModel):
|
||||
class RedisConnectorIndexPayload(BaseModel):
|
||||
index_attempt_id: int | None
|
||||
started: datetime | None
|
||||
submitted: datetime
|
||||
@@ -71,22 +71,20 @@ class RedisConnectorIndex:
|
||||
return False
|
||||
|
||||
@property
|
||||
def payload(self) -> RedisConnectorIndexingFenceData | None:
|
||||
def payload(self) -> RedisConnectorIndexPayload | None:
|
||||
# read related data and evaluate/print task progress
|
||||
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
|
||||
if fence_bytes is None:
|
||||
return None
|
||||
|
||||
fence_str = fence_bytes.decode("utf-8")
|
||||
payload = RedisConnectorIndexingFenceData.model_validate_json(
|
||||
cast(str, fence_str)
|
||||
)
|
||||
payload = RedisConnectorIndexPayload.model_validate_json(cast(str, fence_str))
|
||||
|
||||
return payload
|
||||
|
||||
def set_fence(
|
||||
self,
|
||||
payload: RedisConnectorIndexingFenceData | None,
|
||||
payload: RedisConnectorIndexPayload | None,
|
||||
) -> None:
|
||||
if not payload:
|
||||
self.redis.delete(self.fence_key)
|
||||
|
||||
@@ -4,6 +4,7 @@ from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
@@ -105,7 +106,7 @@ class RedisConnectorPrune:
|
||||
documents_to_prune: set[str],
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
lock: redis.lock.Lock | None,
|
||||
lock: RedisLock | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from uuid import uuid4
|
||||
import redis
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
@@ -50,9 +51,9 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
lock: RedisLock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
) -> tuple[int, int] | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
@@ -84,7 +85,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
return len(async_results), len(async_results)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.delete(self.taskset_key)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
@@ -85,7 +85,13 @@ class RedisObjectHelper(ABC):
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
lock: RedisLock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
pass
|
||||
) -> tuple[int, int] | None:
|
||||
"""First element should be the number of actual tasks generated, second should
|
||||
be the number of docs that were candidates to be synced for the cc pair.
|
||||
|
||||
The need for this is when we are syncing stale docs referenced by multiple
|
||||
connectors. In a single pass across multiple cc pairs, we only want a task
|
||||
for be created for a particular document id the first time we see it.
|
||||
The rest can be skipped."""
|
||||
|
||||
@@ -5,6 +5,7 @@ from uuid import uuid4
|
||||
import redis
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
@@ -51,15 +52,15 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
lock: RedisLock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
) -> tuple[int, int] | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
|
||||
if not global_version.is_ee_version():
|
||||
return 0
|
||||
return 0, 0
|
||||
|
||||
try:
|
||||
construct_document_select_by_usergroup = fetch_versioned_implementation(
|
||||
@@ -67,7 +68,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
"construct_document_select_by_usergroup",
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
return 0
|
||||
return 0, 0
|
||||
|
||||
stmt = construct_document_select_by_usergroup(int(self._id))
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
@@ -97,7 +98,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
return len(async_results), len(async_results)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.delete(self.taskset_key)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from typing import Any
|
||||
|
||||
from atlassian import Confluence # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
@@ -19,12 +18,8 @@ def _get_group_members_email_paginated(
|
||||
confluence_client: OnyxConfluence,
|
||||
group_name: str,
|
||||
) -> set[str]:
|
||||
members: list[dict[str, Any]] = []
|
||||
for member_batch in confluence_client.paginated_group_members_retrieval(group_name):
|
||||
members.extend(member_batch)
|
||||
|
||||
group_member_emails: set[str] = set()
|
||||
for member in members:
|
||||
for member in confluence_client.paginated_group_members_retrieval(group_name):
|
||||
email = member.get("email")
|
||||
if not email:
|
||||
user_name = member.get("username")
|
||||
@@ -43,19 +38,33 @@ def confluence_group_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> None:
|
||||
credentials = cc_pair.credential.credential_json
|
||||
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
|
||||
wiki_base = cc_pair.connector.connector_specific_config["wiki_base"]
|
||||
|
||||
# test connection with direct client, no retries
|
||||
confluence_client = Confluence(
|
||||
api_version="cloud" if is_cloud else "latest",
|
||||
url=wiki_base.rstrip("/"),
|
||||
username=credentials["confluence_username"] if is_cloud else None,
|
||||
password=credentials["confluence_access_token"] if is_cloud else None,
|
||||
token=credentials["confluence_access_token"] if not is_cloud else None,
|
||||
)
|
||||
spaces = confluence_client.get_all_spaces(limit=1)
|
||||
if not spaces:
|
||||
raise RuntimeError(f"No spaces found at {wiki_base}!")
|
||||
|
||||
confluence_client = build_confluence_client(
|
||||
credentials_json=cc_pair.credential.credential_json,
|
||||
credentials_json=credentials,
|
||||
is_cloud=is_cloud,
|
||||
wiki_base=cc_pair.connector.connector_specific_config["wiki_base"],
|
||||
wiki_base=wiki_base,
|
||||
)
|
||||
|
||||
# Get all group names
|
||||
group_names: list[str] = []
|
||||
for group_batch in confluence_client.paginated_groups_retrieval():
|
||||
for group in group_batch:
|
||||
if group_name := group.get("name"):
|
||||
group_names.append(group_name)
|
||||
for group in confluence_client.paginated_groups_retrieval():
|
||||
if group_name := group.get("name"):
|
||||
group_names.append(group_name)
|
||||
|
||||
# For each group name, get all members and create a danswer group
|
||||
danswer_groups: list[ExternalUserGroup] = []
|
||||
|
||||
@@ -29,6 +29,78 @@ from tests.integration.common_utils.test_models import DATestUserGroup
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
|
||||
|
||||
# def test_connector_creation(reset: None) -> None:
|
||||
# # Creating an admin user (first user created is automatically an admin)
|
||||
# admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# # create connectors
|
||||
# cc_pair_1 = CCPairManager.create_from_scratch(
|
||||
# source=DocumentSource.INGESTION_API,
|
||||
# user_performing_action=admin_user,
|
||||
# )
|
||||
|
||||
# cc_pair_info = CCPairManager.get_single(
|
||||
# cc_pair_1.id, user_performing_action=admin_user
|
||||
# )
|
||||
# assert cc_pair_info
|
||||
# assert cc_pair_info.creator
|
||||
# assert str(cc_pair_info.creator) == admin_user.id
|
||||
# assert cc_pair_info.creator_email == admin_user.email
|
||||
|
||||
|
||||
# TODO(rkuo): will enable this once i have credentials on github
|
||||
# def test_overlapping_connector_creation(reset: None) -> None:
|
||||
# # Creating an admin user (first user created is automatically an admin)
|
||||
# admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# config = {
|
||||
# "wiki_base": os.environ["CONFLUENCE_TEST_SPACE_URL"],
|
||||
# "space": os.environ["CONFLUENCE_TEST_SPACE"],
|
||||
# "is_cloud": True,
|
||||
# "page_id": "",
|
||||
# }
|
||||
|
||||
# credential = {
|
||||
# "confluence_username": os.environ["CONFLUENCE_USER_NAME"],
|
||||
# "confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"],
|
||||
# }
|
||||
|
||||
# # store the time before we create the connector so that we know after
|
||||
# # when the indexing should have started
|
||||
# now = datetime.now(timezone.utc)
|
||||
|
||||
# # create connector
|
||||
# cc_pair_1 = CCPairManager.create_from_scratch(
|
||||
# source=DocumentSource.CONFLUENCE,
|
||||
# connector_specific_config=config,
|
||||
# credential_json=credential,
|
||||
# user_performing_action=admin_user,
|
||||
# )
|
||||
|
||||
# CCPairManager.wait_for_indexing(
|
||||
# cc_pair_1, now, timeout=60, user_performing_action=admin_user
|
||||
# )
|
||||
|
||||
# cc_pair_2 = CCPairManager.create_from_scratch(
|
||||
# source=DocumentSource.CONFLUENCE,
|
||||
# connector_specific_config=config,
|
||||
# credential_json=credential,
|
||||
# user_performing_action=admin_user,
|
||||
# )
|
||||
|
||||
# CCPairManager.wait_for_indexing(
|
||||
# cc_pair_2, now, timeout=60, user_performing_action=admin_user
|
||||
# )
|
||||
|
||||
# info_1 = CCPairManager.get_single(cc_pair_1.id)
|
||||
# assert info_1
|
||||
|
||||
# info_2 = CCPairManager.get_single(cc_pair_2.id)
|
||||
# assert info_2
|
||||
|
||||
# assert info_1.num_docs_indexed == info_2.num_docs_indexed
|
||||
|
||||
|
||||
def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -104,6 +104,7 @@ const RenderField: FC<RenderFieldProps> = ({
|
||||
type={field.type}
|
||||
label={label}
|
||||
name={field.name}
|
||||
isTextArea={true}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
|
||||
@@ -221,8 +221,11 @@ export const connectorConfigs: Record<
|
||||
},
|
||||
{
|
||||
type: "text",
|
||||
description:
|
||||
"Enter a comma separated list of the URLs of the shared drives to index. Leave blank to index all shared drives.",
|
||||
description: (currentCredential) => {
|
||||
return currentCredential?.credential_json?.google_tokens
|
||||
? "If you are a non-admin user authenticated using Google Oauth, you will need to specify the URLs for the shared drives you would like to index. Leaving this blank will NOT index any shared drives."
|
||||
: "Enter a comma separated list of the URLs for the shared drive you would like to index. Leave this blank to index all shared drives.";
|
||||
},
|
||||
label: "Shared Drive URLs",
|
||||
name: "shared_drive_urls",
|
||||
visibleCondition: (values) => values.include_shared_drives,
|
||||
@@ -230,14 +233,16 @@ export const connectorConfigs: Record<
|
||||
},
|
||||
{
|
||||
type: "checkbox",
|
||||
label: (currentCredential) =>
|
||||
currentCredential?.credential_json?.google_drive_tokens
|
||||
label: (currentCredential) => {
|
||||
return currentCredential?.credential_json?.google_tokens
|
||||
? "Include My Drive?"
|
||||
: "Include Everyone's My Drive?",
|
||||
description: (currentCredential) =>
|
||||
currentCredential?.credential_json?.google_drive_tokens
|
||||
: "Include Everyone's My Drive?";
|
||||
},
|
||||
description: (currentCredential) => {
|
||||
return currentCredential?.credential_json?.google_tokens
|
||||
? "This will allow Danswer to index everything in your My Drive."
|
||||
: "This will allow Danswer to index everything in everyone's My Drives.",
|
||||
: "This will allow Danswer to index everything in everyone's My Drives.";
|
||||
},
|
||||
name: "include_my_drives",
|
||||
optional: true,
|
||||
default: true,
|
||||
@@ -250,7 +255,7 @@ export const connectorConfigs: Record<
|
||||
name: "my_drive_emails",
|
||||
visibleCondition: (values, currentCredential) =>
|
||||
values.include_my_drives &&
|
||||
!currentCredential?.credential_json?.google_drive_tokens,
|
||||
!currentCredential?.credential_json?.google_tokens,
|
||||
optional: true,
|
||||
},
|
||||
],
|
||||
@@ -258,7 +263,7 @@ export const connectorConfigs: Record<
|
||||
{
|
||||
type: "text",
|
||||
description:
|
||||
"Enter a comma separated list of the URLs of the folders located in Shared Drives to index. The files located in these folders (and all subfolders) will be indexed. Note: This will be in addition to the 'Include Shared Drives' and 'Shared Drive URLs' settings, so leave those blank if you only want to index the folders specified here.",
|
||||
"Enter a comma separated list of the URLs of any folders you would like to index. The files located in these folders (and all subfolders) will be indexed. Note: This will be in addition to whatever settings you have selected above, so leave those blank if you only want to index the folders specified here.",
|
||||
label: "Folder URLs",
|
||||
name: "shared_folder_urls",
|
||||
optional: true,
|
||||
|
||||
Reference in New Issue
Block a user