mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
13 Commits
dump-scrip
...
v0.13.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
310732d10f | ||
|
|
322d7cdc90 | ||
|
|
67d943da11 | ||
|
|
9456fef307 | ||
|
|
cc3c0800f0 | ||
|
|
e860f15b64 | ||
|
|
574ef470a4 | ||
|
|
9e391495c2 | ||
|
|
e26d5430fa | ||
|
|
cce0ec2f22 | ||
|
|
a4f09a62a5 | ||
|
|
fd2428d97f | ||
|
|
cfc46812c8 |
@@ -59,7 +59,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=sender.concurrency)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -14,10 +14,14 @@ from celery.signals import worker_shutdown
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.celery_utils import celery_is_worker_primary
|
||||
from danswer.background.celery.tasks.vespa.tasks import get_unfenced_index_attempt_ids
|
||||
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
|
||||
from danswer.db.engine import get_session_with_default_tenant
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndex
|
||||
@@ -134,6 +138,23 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
|
||||
RedisConnectorStop.reset_all(r)
|
||||
|
||||
# mark orphaned index attempts as failed
|
||||
with get_session_with_default_tenant() as db_session:
|
||||
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
|
||||
for attempt_id in unfenced_attempt_ids:
|
||||
attempt = get_index_attempt(db_session, attempt_id)
|
||||
if not attempt:
|
||||
continue
|
||||
|
||||
failure_reason = (
|
||||
f"Orphaned index attempt found on startup: "
|
||||
f"index_attempt={attempt.id} "
|
||||
f"cc_pair={attempt.connector_credential_pair_id} "
|
||||
f"search_settings={attempt.search_settings_id}"
|
||||
)
|
||||
logger.warning(failure_reason)
|
||||
mark_attempt_failed(attempt.id, db_session, failure_reason)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
@@ -87,7 +87,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
cc_pair_id: int,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
|
||||
@@ -3,13 +3,14 @@ from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from time import sleep
|
||||
|
||||
import redis
|
||||
import sentry_sdk
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.exceptions import LockError
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
@@ -44,7 +45,7 @@ from danswer.db.swap_index import check_index_swap
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndexingFenceData
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndexPayload
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
@@ -61,14 +62,18 @@ class RunIndexingCallback(RunIndexingCallbackInterface):
|
||||
self,
|
||||
stop_key: str,
|
||||
generator_progress_key: str,
|
||||
redis_lock: redis.lock.Lock,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
):
|
||||
super().__init__()
|
||||
self.redis_lock: redis.lock.Lock = redis_lock
|
||||
self.redis_lock: RedisLock = redis_lock
|
||||
self.stop_key: str = stop_key
|
||||
self.generator_progress_key: str = generator_progress_key
|
||||
self.redis_client = redis_client
|
||||
self.started: datetime = datetime.now(timezone.utc)
|
||||
self.redis_lock.reacquire()
|
||||
|
||||
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_client.exists(self.stop_key):
|
||||
@@ -76,7 +81,19 @@ class RunIndexingCallback(RunIndexingCallbackInterface):
|
||||
return False
|
||||
|
||||
def progress(self, amount: int) -> None:
|
||||
self.redis_lock.reacquire()
|
||||
try:
|
||||
self.redis_lock.reacquire()
|
||||
self.last_lock_reacquire = datetime.now(timezone.utc)
|
||||
except LockError:
|
||||
logger.exception(
|
||||
f"RunIndexingCallback - lock.reacquire exceptioned. "
|
||||
f"lock_timeout={self.redis_lock.timeout} "
|
||||
f"start={self.started} "
|
||||
f"last_reacquired={self.last_lock_reacquire} "
|
||||
f"now={datetime.now(timezone.utc)}"
|
||||
)
|
||||
raise
|
||||
|
||||
self.redis_client.incrby(self.generator_progress_key, amount)
|
||||
|
||||
|
||||
@@ -325,7 +342,7 @@ def try_creating_indexing_task(
|
||||
redis_connector_index.generator_clear()
|
||||
|
||||
# set a basic fence to start
|
||||
payload = RedisConnectorIndexingFenceData(
|
||||
payload = RedisConnectorIndexPayload(
|
||||
index_attempt_id=None,
|
||||
started=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
@@ -368,7 +385,7 @@ def try_creating_indexing_task(
|
||||
redis_connector_index.set_fence(payload)
|
||||
|
||||
except Exception:
|
||||
redis_connector_index.set_fence(payload)
|
||||
redis_connector_index.set_fence(None)
|
||||
task_logger.exception(
|
||||
f"Unexpected exception: "
|
||||
f"tenant={tenant_id} "
|
||||
|
||||
@@ -13,6 +13,7 @@ from celery.exceptions import SoftTimeLimitExceeded
|
||||
from celery.result import AsyncResult
|
||||
from celery.states import READY_STATES
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
from tenacity import RetryError
|
||||
|
||||
@@ -162,7 +163,7 @@ def try_generate_stale_document_sync_tasks(
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
# the fence is up, do nothing
|
||||
@@ -180,7 +181,12 @@ def try_generate_stale_document_sync_tasks(
|
||||
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair."
|
||||
)
|
||||
|
||||
task_logger.info("RedisConnector.generate_tasks starting by cc_pair.")
|
||||
task_logger.info(
|
||||
"RedisConnector.generate_tasks starting by cc_pair. "
|
||||
"Documents spanning multiple cc_pairs will only be synced once."
|
||||
)
|
||||
|
||||
docs_to_skip: set[str] = set()
|
||||
|
||||
# rkuo: we could technically sync all stale docs in one big pass.
|
||||
# but I feel it's more understandable to group the docs by cc_pair
|
||||
@@ -188,22 +194,21 @@ def try_generate_stale_document_sync_tasks(
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
rc = RedisConnectorCredentialPair(tenant_id, cc_pair.id)
|
||||
tasks_generated = rc.generate_tasks(
|
||||
celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
rc.set_skip_docs(docs_to_skip)
|
||||
result = rc.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
|
||||
|
||||
if tasks_generated is None:
|
||||
if result is None:
|
||||
continue
|
||||
|
||||
if tasks_generated == 0:
|
||||
if result[1] == 0:
|
||||
continue
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.generate_tasks finished for single cc_pair. "
|
||||
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
|
||||
f"cc_pair={cc_pair.id} tasks_generated={result[0]} tasks_possible={result[1]}"
|
||||
)
|
||||
|
||||
total_tasks_generated += tasks_generated
|
||||
total_tasks_generated += result[0]
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
|
||||
@@ -218,7 +223,7 @@ def try_generate_document_set_sync_tasks(
|
||||
document_set_id: int,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
@@ -246,12 +251,11 @@ def try_generate_document_set_sync_tasks(
|
||||
)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
tasks_generated = rds.generate_tasks(
|
||||
celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
if tasks_generated is None:
|
||||
result = rds.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
tasks_generated = result[0]
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
@@ -260,7 +264,7 @@ def try_generate_document_set_sync_tasks(
|
||||
|
||||
task_logger.info(
|
||||
f"RedisDocumentSet.generate_tasks finished. "
|
||||
f"document_set_id={document_set.id} tasks_generated={tasks_generated}"
|
||||
f"document_set={document_set.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
@@ -273,7 +277,7 @@ def try_generate_user_group_sync_tasks(
|
||||
usergroup_id: int,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
@@ -302,12 +306,11 @@ def try_generate_user_group_sync_tasks(
|
||||
task_logger.info(
|
||||
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
|
||||
)
|
||||
tasks_generated = rug.generate_tasks(
|
||||
celery_app, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
if tasks_generated is None:
|
||||
result = rug.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
tasks_generated = result[0]
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
@@ -316,7 +319,7 @@ def try_generate_user_group_sync_tasks(
|
||||
|
||||
task_logger.info(
|
||||
f"RedisUserGroup.generate_tasks finished. "
|
||||
f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}"
|
||||
f"usergroup={usergroup.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
@@ -580,8 +583,8 @@ def monitor_ccpair_indexing_taskset(
|
||||
progress = redis_connector_index.get_progress()
|
||||
if progress is not None:
|
||||
task_logger.info(
|
||||
f"Connector indexing progress: cc_pair_id={cc_pair_id} "
|
||||
f"search_settings_id={search_settings_id} "
|
||||
f"Connector indexing progress: cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
@@ -602,8 +605,8 @@ def monitor_ccpair_indexing_taskset(
|
||||
# if it isn't, then the worker crashed
|
||||
task_logger.info(
|
||||
f"Connector indexing aborted: "
|
||||
f"cc_pair_id={cc_pair_id} "
|
||||
f"search_settings_id={search_settings_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
|
||||
@@ -621,8 +624,8 @@ def monitor_ccpair_indexing_taskset(
|
||||
status_enum = HTTPStatus(status_int)
|
||||
|
||||
task_logger.info(
|
||||
f"Connector indexing finished: cc_pair_id={cc_pair_id} "
|
||||
f"search_settings_id={search_settings_id} "
|
||||
f"Connector indexing finished: cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"status={status_enum.name} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
@@ -630,6 +633,37 @@ def monitor_ccpair_indexing_taskset(
|
||||
redis_connector_index.reset()
|
||||
|
||||
|
||||
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
|
||||
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
|
||||
want to clean them up.
|
||||
|
||||
Unfenced = attempt not in terminal state and fence does not exist.
|
||||
"""
|
||||
unfenced_attempts: list[int] = []
|
||||
|
||||
# do some cleanup before clearing fences
|
||||
# check the db for any outstanding index attempts
|
||||
attempts: list[IndexAttempt] = []
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
|
||||
)
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
|
||||
)
|
||||
|
||||
for attempt in attempts:
|
||||
# if attempts exist in the db but we don't detect them in redis, mark them as failed
|
||||
fence_key = RedisConnectorIndex.fence_key_with_ids(
|
||||
attempt.connector_credential_pair_id, attempt.search_settings_id
|
||||
)
|
||||
if r.exists(fence_key):
|
||||
continue
|
||||
|
||||
unfenced_attempts.append(attempt.id)
|
||||
|
||||
return unfenced_attempts
|
||||
|
||||
|
||||
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
|
||||
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
|
||||
@@ -643,7 +677,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
"""
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: redis.lock.Lock = r.lock(
|
||||
lock_beat: RedisLock = r.lock(
|
||||
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -677,31 +711,24 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
f"pruning={n_pruning}"
|
||||
)
|
||||
|
||||
# do some cleanup before clearing fences
|
||||
# check the db for any outstanding index attempts
|
||||
# Fail any index attempts in the DB that don't have fences
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
attempts: list[IndexAttempt] = []
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
|
||||
)
|
||||
attempts.extend(
|
||||
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
|
||||
)
|
||||
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
|
||||
for attempt_id in unfenced_attempt_ids:
|
||||
attempt = get_index_attempt(db_session, attempt_id)
|
||||
if not attempt:
|
||||
continue
|
||||
|
||||
for a in attempts:
|
||||
# if attempts exist in the db but we don't detect them in redis, mark them as failed
|
||||
fence_key = RedisConnectorIndex.fence_key_with_ids(
|
||||
a.connector_credential_pair_id, a.search_settings_id
|
||||
failure_reason = (
|
||||
f"Unfenced index attempt found in DB: "
|
||||
f"index_attempt={attempt.id} "
|
||||
f"cc_pair={attempt.connector_credential_pair_id} "
|
||||
f"search_settings={attempt.search_settings_id}"
|
||||
)
|
||||
task_logger.warning(failure_reason)
|
||||
mark_attempt_failed(
|
||||
attempt.id, db_session, failure_reason=failure_reason
|
||||
)
|
||||
if not r.exists(fence_key):
|
||||
failure_reason = (
|
||||
f"Unknown index attempt. Might be left over from a process restart: "
|
||||
f"index_attempt={a.id} "
|
||||
f"cc_pair={a.connector_credential_pair_id} "
|
||||
f"search_settings={a.search_settings_id}"
|
||||
)
|
||||
task_logger.warning(failure_reason)
|
||||
mark_attempt_failed(a.id, db_session, failure_reason=failure_reason)
|
||||
|
||||
lock_beat.reacquire()
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
|
||||
@@ -433,11 +433,13 @@ def run_indexing_entrypoint(
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
|
||||
|
||||
tenant_str = ""
|
||||
if tenant_id is not None:
|
||||
tenant_str = f" for tenant {tenant_id}"
|
||||
|
||||
logger.info(
|
||||
f"Indexing starting for tenant {tenant_id}: "
|
||||
if tenant_id is not None
|
||||
else ""
|
||||
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"Indexing starting{tenant_str}: "
|
||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
f"credentials='{attempt.connector_credential_pair.connector_id}'"
|
||||
)
|
||||
@@ -445,10 +447,8 @@ def run_indexing_entrypoint(
|
||||
_run_indexing(db_session, attempt, tenant_id, callback)
|
||||
|
||||
logger.info(
|
||||
f"Indexing finished for tenant {tenant_id}: "
|
||||
if tenant_id is not None
|
||||
else ""
|
||||
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"Indexing finished{tenant_str}: "
|
||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
f"credentials='{attempt.connector_credential_pair.connector_id}'"
|
||||
)
|
||||
|
||||
@@ -74,7 +74,7 @@ CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
|
||||
|
||||
# needs to be long enough to cover the maximum time it takes to download an object
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
CELERY_INDEXING_LOCK_TIMEOUT = 60 * 60 # 60 min
|
||||
CELERY_INDEXING_LOCK_TIMEOUT = 3 * 60 * 60 # 60 min
|
||||
|
||||
# needs to be long enough to cover the maximum time it takes to download an object
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
|
||||
@@ -3,6 +3,8 @@ from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
from atlassian import Confluence # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
@@ -70,7 +72,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self.confluence_client: OnyxConfluence | None = None
|
||||
self._confluence_client: OnyxConfluence | None = None
|
||||
self.is_cloud = is_cloud
|
||||
|
||||
# Remove trailing slash from wiki_base if present
|
||||
@@ -97,39 +99,59 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self.cql_label_filter = ""
|
||||
if labels_to_skip:
|
||||
labels_to_skip = list(set(labels_to_skip))
|
||||
comma_separated_labels = ",".join(f"'{label}'" for label in labels_to_skip)
|
||||
comma_separated_labels = ",".join(
|
||||
f"'{quote(label)}'" for label in labels_to_skip
|
||||
)
|
||||
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
|
||||
|
||||
@property
|
||||
def confluence_client(self) -> OnyxConfluence:
|
||||
if self._confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
return self._confluence_client
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
# see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py
|
||||
# for a list of other hidden constructor args
|
||||
self.confluence_client = build_confluence_client(
|
||||
self._confluence_client = build_confluence_client(
|
||||
credentials_json=credentials,
|
||||
is_cloud=self.is_cloud,
|
||||
wiki_base=self.wiki_base,
|
||||
)
|
||||
|
||||
client_without_retries = Confluence(
|
||||
api_version="cloud" if self.is_cloud else "latest",
|
||||
url=self.wiki_base.rstrip("/"),
|
||||
username=credentials["confluence_username"] if self.is_cloud else None,
|
||||
password=credentials["confluence_access_token"] if self.is_cloud else None,
|
||||
token=credentials["confluence_access_token"] if not self.is_cloud else None,
|
||||
)
|
||||
spaces = client_without_retries.get_all_spaces(limit=1)
|
||||
if not spaces:
|
||||
raise RuntimeError(
|
||||
f"No spaces found at {self.wiki_base}! "
|
||||
"Check your credentials and wiki_base and make sure "
|
||||
"is_cloud is set correctly."
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_comment_string_for_page_id(self, page_id: str) -> str:
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
comment_string = ""
|
||||
|
||||
comment_cql = f"type=comment and container='{page_id}'"
|
||||
comment_cql += self.cql_label_filter
|
||||
|
||||
expand = ",".join(_COMMENT_EXPANSION_FIELDS)
|
||||
for comments in self.confluence_client.paginated_cql_page_retrieval(
|
||||
for comment in self.confluence_client.paginated_cql_retrieval(
|
||||
cql=comment_cql,
|
||||
expand=expand,
|
||||
):
|
||||
for comment in comments:
|
||||
comment_string += "\nComment:\n"
|
||||
comment_string += extract_text_from_confluence_html(
|
||||
confluence_client=self.confluence_client,
|
||||
confluence_object=comment,
|
||||
)
|
||||
comment_string += "\nComment:\n"
|
||||
comment_string += extract_text_from_confluence_html(
|
||||
confluence_client=self.confluence_client,
|
||||
confluence_object=comment,
|
||||
fetched_titles=set(),
|
||||
)
|
||||
|
||||
return comment_string
|
||||
|
||||
@@ -141,28 +163,30 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
If its a page, it extracts the text, adds the comments for the document text.
|
||||
If its an attachment, it just downloads the attachment and converts that into a document.
|
||||
"""
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
# The url and the id are the same
|
||||
object_url = build_confluence_document_id(
|
||||
self.wiki_base, confluence_object["_links"]["webui"]
|
||||
base_url=self.wiki_base,
|
||||
content_url=confluence_object["_links"]["webui"],
|
||||
is_cloud=self.is_cloud,
|
||||
)
|
||||
|
||||
object_text = None
|
||||
# Extract text from page
|
||||
if confluence_object["type"] == "page":
|
||||
object_text = extract_text_from_confluence_html(
|
||||
self.confluence_client, confluence_object
|
||||
confluence_client=self.confluence_client,
|
||||
confluence_object=confluence_object,
|
||||
fetched_titles={confluence_object.get("title", "")},
|
||||
)
|
||||
# Add comments to text
|
||||
object_text += self._get_comment_string_for_page_id(confluence_object["id"])
|
||||
elif confluence_object["type"] == "attachment":
|
||||
object_text = attachment_to_content(
|
||||
self.confluence_client, confluence_object
|
||||
confluence_client=self.confluence_client, attachment=confluence_object
|
||||
)
|
||||
|
||||
if object_text is None:
|
||||
# This only happens for attachments that are not parseable
|
||||
return None
|
||||
|
||||
# Get space name
|
||||
@@ -193,44 +217,39 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
)
|
||||
|
||||
def _fetch_document_batches(self) -> GenerateDocumentsOutput:
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
confluence_page_ids: list[str] = []
|
||||
|
||||
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
|
||||
# Fetch pages as Documents
|
||||
for page_batch in self.confluence_client.paginated_cql_page_retrieval(
|
||||
for page in self.confluence_client.paginated_cql_retrieval(
|
||||
cql=page_query,
|
||||
expand=",".join(_PAGE_EXPANSION_FIELDS),
|
||||
limit=self.batch_size,
|
||||
):
|
||||
for page in page_batch:
|
||||
confluence_page_ids.append(page["id"])
|
||||
doc = self._convert_object_to_document(page)
|
||||
if doc is not None:
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
confluence_page_ids.append(page["id"])
|
||||
doc = self._convert_object_to_document(page)
|
||||
if doc is not None:
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
# Fetch attachments as Documents
|
||||
for confluence_page_id in confluence_page_ids:
|
||||
attachment_cql = f"type=attachment and container='{confluence_page_id}'"
|
||||
attachment_cql += self.cql_label_filter
|
||||
# TODO: maybe should add time filter as well?
|
||||
for attachments in self.confluence_client.paginated_cql_page_retrieval(
|
||||
for attachment in self.confluence_client.paginated_cql_retrieval(
|
||||
cql=attachment_cql,
|
||||
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
|
||||
):
|
||||
for attachment in attachments:
|
||||
doc = self._convert_object_to_document(attachment)
|
||||
if doc is not None:
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
doc = self._convert_object_to_document(attachment)
|
||||
if doc is not None:
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
@@ -255,48 +274,47 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
doc_metadata_list: list[SlimDocument] = []
|
||||
|
||||
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
|
||||
|
||||
page_query = self.cql_page_query + self.cql_label_filter
|
||||
for pages in self.confluence_client.cql_paginate_all_expansions(
|
||||
for page in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=page_query,
|
||||
expand=restrictions_expand,
|
||||
):
|
||||
for page in pages:
|
||||
# If the page has restrictions, add them to the perm_sync_data
|
||||
# These will be used by doc_sync.py to sync permissions
|
||||
perm_sync_data = {
|
||||
"restrictions": page.get("restrictions", {}),
|
||||
"space_key": page.get("space", {}).get("key"),
|
||||
}
|
||||
# If the page has restrictions, add them to the perm_sync_data
|
||||
# These will be used by doc_sync.py to sync permissions
|
||||
perm_sync_data = {
|
||||
"restrictions": page.get("restrictions", {}),
|
||||
"space_key": page.get("space", {}).get("key"),
|
||||
}
|
||||
|
||||
doc_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=build_confluence_document_id(
|
||||
self.wiki_base,
|
||||
page["_links"]["webui"],
|
||||
self.is_cloud,
|
||||
),
|
||||
perm_sync_data=perm_sync_data,
|
||||
)
|
||||
)
|
||||
attachment_cql = f"type=attachment and container='{page['id']}'"
|
||||
attachment_cql += self.cql_label_filter
|
||||
for attachment in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=attachment_cql,
|
||||
expand=restrictions_expand,
|
||||
):
|
||||
doc_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=build_confluence_document_id(
|
||||
self.wiki_base, page["_links"]["webui"]
|
||||
self.wiki_base,
|
||||
attachment["_links"]["webui"],
|
||||
self.is_cloud,
|
||||
),
|
||||
perm_sync_data=perm_sync_data,
|
||||
)
|
||||
)
|
||||
attachment_cql = f"type=attachment and container='{page['id']}'"
|
||||
attachment_cql += self.cql_label_filter
|
||||
for attachments in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=attachment_cql,
|
||||
expand=restrictions_expand,
|
||||
):
|
||||
for attachment in attachments:
|
||||
doc_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=build_confluence_document_id(
|
||||
self.wiki_base, attachment["_links"]["webui"]
|
||||
),
|
||||
perm_sync_data=perm_sync_data,
|
||||
)
|
||||
)
|
||||
yield doc_metadata_list
|
||||
doc_metadata_list = []
|
||||
yield doc_metadata_list
|
||||
doc_metadata_list = []
|
||||
|
||||
@@ -20,6 +20,10 @@ F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
|
||||
|
||||
# https://jira.atlassian.com/browse/CONFCLOUD-76433
|
||||
_PROBLEMATIC_EXPANSIONS = "body.storage.value"
|
||||
_REPLACEMENT_EXPANSIONS = "body.view.value"
|
||||
|
||||
|
||||
class ConfluenceRateLimitError(Exception):
|
||||
pass
|
||||
@@ -80,7 +84,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
MAX_RETRIES = 5
|
||||
|
||||
TIMEOUT = 3600
|
||||
TIMEOUT = 600
|
||||
timeout_at = time.monotonic() + TIMEOUT
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
@@ -88,13 +92,16 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
raise TimeoutError(
|
||||
f"Confluence call attempts took longer than {TIMEOUT} seconds."
|
||||
)
|
||||
|
||||
try:
|
||||
# we're relying more on the client to rate limit itself
|
||||
# and applying our own retries in a more specific set of circumstances
|
||||
return confluence_call(*args, **kwargs)
|
||||
except HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
logger.warning(
|
||||
f"HTTPError in confluence call. "
|
||||
f"Retrying in {delay_until} seconds..."
|
||||
)
|
||||
while time.monotonic() < delay_until:
|
||||
# in the future, check a signal here to exit
|
||||
time.sleep(1)
|
||||
@@ -103,7 +110,6 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
# Users reported it to be intermittent, so just retry
|
||||
if attempt == MAX_RETRIES - 1:
|
||||
raise e
|
||||
|
||||
logger.exception(
|
||||
"Confluence Client raised an AttributeError. Retrying..."
|
||||
)
|
||||
@@ -141,7 +147,7 @@ class OnyxConfluence(Confluence):
|
||||
|
||||
def _paginate_url(
|
||||
self, url_suffix: str, limit: int | None = None
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
This will paginate through the top level query.
|
||||
"""
|
||||
@@ -153,46 +159,43 @@ class OnyxConfluence(Confluence):
|
||||
|
||||
while url_suffix:
|
||||
try:
|
||||
logger.debug(f"Making confluence call to {url_suffix}")
|
||||
next_response = self.get(url_suffix)
|
||||
except Exception as e:
|
||||
logger.exception("Error in danswer_cql: \n")
|
||||
raise e
|
||||
yield next_response.get("results", [])
|
||||
logger.warning(f"Error in confluence call to {url_suffix}")
|
||||
|
||||
# If the problematic expansion is in the url, replace it
|
||||
# with the replacement expansion and try again
|
||||
# If that fails, raise the error
|
||||
if _PROBLEMATIC_EXPANSIONS not in url_suffix:
|
||||
logger.exception(f"Error in confluence call to {url_suffix}")
|
||||
raise e
|
||||
logger.warning(
|
||||
f"Replacing {_PROBLEMATIC_EXPANSIONS} with {_REPLACEMENT_EXPANSIONS}"
|
||||
" and trying again."
|
||||
)
|
||||
url_suffix = url_suffix.replace(
|
||||
_PROBLEMATIC_EXPANSIONS,
|
||||
_REPLACEMENT_EXPANSIONS,
|
||||
)
|
||||
continue
|
||||
|
||||
# yield the results individually
|
||||
yield from next_response.get("results", [])
|
||||
|
||||
url_suffix = next_response.get("_links", {}).get("next")
|
||||
|
||||
def paginated_groups_retrieval(
|
||||
self,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
return self._paginate_url("rest/api/group", limit)
|
||||
|
||||
def paginated_group_members_retrieval(
|
||||
self,
|
||||
group_name: str,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
group_name = quote(group_name)
|
||||
return self._paginate_url(f"rest/api/group/{group_name}/member", limit)
|
||||
|
||||
def paginated_cql_user_retrieval(
|
||||
def paginated_cql_retrieval(
|
||||
self,
|
||||
cql: str,
|
||||
expand: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
The content/search endpoint can be used to fetch pages, attachments, and comments.
|
||||
"""
|
||||
expand_string = f"&expand={expand}" if expand else ""
|
||||
return self._paginate_url(
|
||||
f"rest/api/search/user?cql={cql}{expand_string}", limit
|
||||
)
|
||||
|
||||
def paginated_cql_page_retrieval(
|
||||
self,
|
||||
cql: str,
|
||||
expand: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
expand_string = f"&expand={expand}" if expand else ""
|
||||
return self._paginate_url(
|
||||
yield from self._paginate_url(
|
||||
f"rest/api/content/search?cql={cql}{expand_string}", limit
|
||||
)
|
||||
|
||||
@@ -201,7 +204,7 @@ class OnyxConfluence(Confluence):
|
||||
cql: str,
|
||||
expand: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[list[dict[str, Any]]]:
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
This function will paginate through the top level query first, then
|
||||
paginate through all of the expansions.
|
||||
@@ -221,6 +224,44 @@ class OnyxConfluence(Confluence):
|
||||
for item in data:
|
||||
_traverse_and_update(item)
|
||||
|
||||
for results in self.paginated_cql_page_retrieval(cql, expand, limit):
|
||||
_traverse_and_update(results)
|
||||
yield results
|
||||
for confluence_object in self.paginated_cql_retrieval(cql, expand, limit):
|
||||
_traverse_and_update(confluence_object)
|
||||
yield confluence_object
|
||||
|
||||
def paginated_cql_user_retrieval(
|
||||
self,
|
||||
cql: str,
|
||||
expand: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
The search/user endpoint can be used to fetch users.
|
||||
It's a seperate endpoint from the content/search endpoint used only for users.
|
||||
Otherwise it's very similar to the content/search endpoint.
|
||||
"""
|
||||
expand_string = f"&expand={expand}" if expand else ""
|
||||
yield from self._paginate_url(
|
||||
f"rest/api/search/user?cql={cql}{expand_string}", limit
|
||||
)
|
||||
|
||||
def paginated_groups_retrieval(
|
||||
self,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
This is not an SQL like query.
|
||||
It's a confluence specific endpoint that can be used to fetch groups.
|
||||
"""
|
||||
yield from self._paginate_url("rest/api/group", limit)
|
||||
|
||||
def paginated_group_members_retrieval(
|
||||
self,
|
||||
group_name: str,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
This is not an SQL like query.
|
||||
It's a confluence specific endpoint that can be used to fetch the members of a group.
|
||||
"""
|
||||
group_name = quote(group_name)
|
||||
yield from self._paginate_url(f"rest/api/group/{group_name}/member", limit)
|
||||
|
||||
@@ -2,6 +2,7 @@ import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
import bs4
|
||||
|
||||
@@ -71,7 +72,9 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
|
||||
|
||||
|
||||
def extract_text_from_confluence_html(
|
||||
confluence_client: OnyxConfluence, confluence_object: dict[str, Any]
|
||||
confluence_client: OnyxConfluence,
|
||||
confluence_object: dict[str, Any],
|
||||
fetched_titles: set[str],
|
||||
) -> str:
|
||||
"""Parse a Confluence html page and replace the 'user Id' by the real
|
||||
User Display Name
|
||||
@@ -79,7 +82,7 @@ def extract_text_from_confluence_html(
|
||||
Args:
|
||||
confluence_object (dict): The confluence object as a dict
|
||||
confluence_client (Confluence): Confluence client
|
||||
|
||||
fetched_titles (set[str]): The titles of the pages that have already been fetched
|
||||
Returns:
|
||||
str: loaded and formated Confluence page
|
||||
"""
|
||||
@@ -100,6 +103,73 @@ def extract_text_from_confluence_html(
|
||||
continue
|
||||
# Include @ sign for tagging, more clear for LLM
|
||||
user.replaceWith("@" + _get_user(confluence_client, user_id))
|
||||
|
||||
for html_page_reference in soup.findAll("ac:structured-macro"):
|
||||
# Here, we only want to process page within page macros
|
||||
if html_page_reference.attrs.get("ac:name") != "include":
|
||||
continue
|
||||
|
||||
page_data = html_page_reference.find("ri:page")
|
||||
if not page_data:
|
||||
logger.warning(
|
||||
f"Skipping retrieval of {html_page_reference} because because page data is missing"
|
||||
)
|
||||
continue
|
||||
|
||||
page_title = page_data.attrs.get("ri:content-title")
|
||||
if not page_title:
|
||||
# only fetch pages that have a title
|
||||
logger.warning(
|
||||
f"Skipping retrieval of {html_page_reference} because it has no title"
|
||||
)
|
||||
continue
|
||||
|
||||
if page_title in fetched_titles:
|
||||
# prevent recursive fetching of pages
|
||||
logger.debug(f"Skipping {page_title} because it has already been fetched")
|
||||
continue
|
||||
|
||||
fetched_titles.add(page_title)
|
||||
|
||||
# Wrap this in a try-except because there are some pages that might not exist
|
||||
try:
|
||||
page_query = f"type=page and title='{quote(page_title)}'"
|
||||
|
||||
page_contents: dict[str, Any] | None = None
|
||||
# Confluence enforces title uniqueness, so we should only get one result here
|
||||
for page in confluence_client.paginated_cql_retrieval(
|
||||
cql=page_query,
|
||||
expand="body.storage.value",
|
||||
limit=1,
|
||||
):
|
||||
page_contents = page
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error getting page contents for object {confluence_object}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not page_contents:
|
||||
continue
|
||||
|
||||
text_from_page = extract_text_from_confluence_html(
|
||||
confluence_client=confluence_client,
|
||||
confluence_object=page_contents,
|
||||
fetched_titles=fetched_titles,
|
||||
)
|
||||
|
||||
html_page_reference.replaceWith(text_from_page)
|
||||
|
||||
for html_link_body in soup.findAll("ac:link-body"):
|
||||
# This extracts the text from inline links in the page so they can be
|
||||
# represented in the document text as plain text
|
||||
try:
|
||||
text_from_link = html_link_body.text
|
||||
html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing ac:link-body: {e}")
|
||||
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
@@ -153,7 +223,9 @@ def attachment_to_content(
|
||||
return extracted_text
|
||||
|
||||
|
||||
def build_confluence_document_id(base_url: str, content_url: str) -> str:
|
||||
def build_confluence_document_id(
|
||||
base_url: str, content_url: str, is_cloud: bool
|
||||
) -> str:
|
||||
"""For confluence, the document id is the page url for a page based document
|
||||
or the attachment download url for an attachment based document
|
||||
|
||||
@@ -164,6 +236,8 @@ def build_confluence_document_id(base_url: str, content_url: str) -> str:
|
||||
Returns:
|
||||
str: The document id
|
||||
"""
|
||||
if is_cloud and not base_url.endswith("/wiki"):
|
||||
base_url += "/wiki"
|
||||
return f"{base_url}{content_url}"
|
||||
|
||||
|
||||
@@ -209,6 +283,6 @@ def build_confluence_client(
|
||||
password=credentials_json["confluence_access_token"] if is_cloud else None,
|
||||
token=credentials_json["confluence_access_token"] if not is_cloud else None,
|
||||
backoff_and_retry=True,
|
||||
max_backoff_retries=60,
|
||||
max_backoff_retries=10,
|
||||
max_backoff_seconds=60,
|
||||
)
|
||||
|
||||
@@ -169,6 +169,7 @@ def get_document_connector_counts(
|
||||
def get_document_counts_for_cc_pairs(
|
||||
db_session: Session, cc_pair_identifiers: list[ConnectorCredentialPairIdentifier]
|
||||
) -> Sequence[tuple[int, int, int]]:
|
||||
"""Returns a sequence of tuples of (connector_id, credential_id, document count)"""
|
||||
stmt = (
|
||||
select(
|
||||
DocumentByConnectorCredentialPair.connector_id,
|
||||
@@ -323,23 +324,23 @@ def upsert_documents(
|
||||
|
||||
|
||||
def upsert_document_by_connector_credential_pair(
|
||||
db_session: Session, document_metadata_batch: list[DocumentMetadata]
|
||||
db_session: Session, connector_id: int, credential_id: int, document_ids: list[str]
|
||||
) -> None:
|
||||
"""NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause."""
|
||||
if not document_metadata_batch:
|
||||
logger.info("`document_metadata_batch` is empty. Skipping.")
|
||||
if not document_ids:
|
||||
logger.info("`document_ids` is empty. Skipping.")
|
||||
return
|
||||
|
||||
insert_stmt = insert(DocumentByConnectorCredentialPair).values(
|
||||
[
|
||||
model_to_dict(
|
||||
DocumentByConnectorCredentialPair(
|
||||
id=document_metadata.document_id,
|
||||
connector_id=document_metadata.connector_id,
|
||||
credential_id=document_metadata.credential_id,
|
||||
id=doc_id,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
)
|
||||
for document_metadata in document_metadata_batch
|
||||
for doc_id in document_ids
|
||||
]
|
||||
)
|
||||
# for now, there are no columns to update. If more metadata is added, then this
|
||||
@@ -400,17 +401,6 @@ def mark_document_as_synced(document_id: str, db_session: Session) -> None:
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def upsert_documents_complete(
|
||||
db_session: Session,
|
||||
document_metadata_batch: list[DocumentMetadata],
|
||||
) -> None:
|
||||
upsert_documents(db_session, document_metadata_batch)
|
||||
upsert_document_by_connector_credential_pair(db_session, document_metadata_batch)
|
||||
logger.info(
|
||||
f"Upserted {len(document_metadata_batch)} document store entries into DB"
|
||||
)
|
||||
|
||||
|
||||
def delete_document_by_connector_credential_pair__no_commit(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
@@ -520,7 +510,7 @@ def prepare_to_modify_documents(
|
||||
db_session.commit() # ensure that we're not in a transaction
|
||||
|
||||
lock_acquired = False
|
||||
for _ in range(_NUM_LOCK_ATTEMPTS):
|
||||
for i in range(_NUM_LOCK_ATTEMPTS):
|
||||
try:
|
||||
with db_session.begin() as transaction:
|
||||
lock_acquired = acquire_document_locks(
|
||||
@@ -531,7 +521,7 @@ def prepare_to_modify_documents(
|
||||
break
|
||||
except OperationalError as e:
|
||||
logger.warning(
|
||||
f"Failed to acquire locks for documents, retrying. Error: {e}"
|
||||
f"Failed to acquire locks for documents on attempt {i}, retrying. Error: {e}"
|
||||
)
|
||||
|
||||
time.sleep(retry_delay)
|
||||
|
||||
@@ -312,7 +312,9 @@ async def get_async_session_with_tenant(
|
||||
await session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||
await session.execute(
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
text(
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error setting search_path.")
|
||||
@@ -373,7 +375,9 @@ def get_session_with_tenant(
|
||||
cursor.execute(f'SET search_path = "{tenant_id}"')
|
||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||
cursor.execute(
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
text(
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
)
|
||||
)
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
@@ -20,7 +20,8 @@ from danswer.db.document import get_documents_by_ids
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
from danswer.db.document import update_docs_last_modified__no_commit
|
||||
from danswer.db.document import update_docs_updated_at__no_commit
|
||||
from danswer.db.document import upsert_documents_complete
|
||||
from danswer.db.document import upsert_document_by_connector_credential_pair
|
||||
from danswer.db.document import upsert_documents
|
||||
from danswer.db.document_set import fetch_document_sets_for_documents
|
||||
from danswer.db.index_attempt import create_index_attempt_error
|
||||
from danswer.db.models import Document as DBDocument
|
||||
@@ -56,13 +57,13 @@ class IndexingPipelineProtocol(Protocol):
|
||||
...
|
||||
|
||||
|
||||
def upsert_documents_in_db(
|
||||
def _upsert_documents_in_db(
|
||||
documents: list[Document],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
# Metadata here refers to basic document info, not metadata about the actual content
|
||||
doc_m_batch: list[DocumentMetadata] = []
|
||||
document_metadata_list: list[DocumentMetadata] = []
|
||||
for doc in documents:
|
||||
first_link = next(
|
||||
(section.link for section in doc.sections if section.link), ""
|
||||
@@ -77,12 +78,9 @@ def upsert_documents_in_db(
|
||||
secondary_owners=get_experts_stores_representations(doc.secondary_owners),
|
||||
from_ingestion_api=doc.from_ingestion_api,
|
||||
)
|
||||
doc_m_batch.append(db_doc_metadata)
|
||||
document_metadata_list.append(db_doc_metadata)
|
||||
|
||||
upsert_documents_complete(
|
||||
db_session=db_session,
|
||||
document_metadata_batch=doc_m_batch,
|
||||
)
|
||||
upsert_documents(db_session, document_metadata_list)
|
||||
|
||||
# Insert document content metadata
|
||||
for doc in documents:
|
||||
@@ -95,21 +93,25 @@ def upsert_documents_in_db(
|
||||
document_id=doc.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
else:
|
||||
create_or_add_document_tag(
|
||||
tag_key=k,
|
||||
tag_value=v,
|
||||
source=doc.source,
|
||||
document_id=doc.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
continue
|
||||
|
||||
create_or_add_document_tag(
|
||||
tag_key=k,
|
||||
tag_value=v,
|
||||
source=doc.source,
|
||||
document_id=doc.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def get_doc_ids_to_update(
|
||||
documents: list[Document], db_docs: list[DBDocument]
|
||||
) -> list[Document]:
|
||||
"""Figures out which documents actually need to be updated. If a document is already present
|
||||
and the `updated_at` hasn't changed, we shouldn't need to do anything with it."""
|
||||
and the `updated_at` hasn't changed, we shouldn't need to do anything with it.
|
||||
|
||||
NB: Still need to associate the document in the DB if multiple connectors are
|
||||
indexing the same doc."""
|
||||
id_update_time_map = {
|
||||
doc.id: doc.doc_updated_at for doc in db_docs if doc.doc_updated_at
|
||||
}
|
||||
@@ -195,9 +197,9 @@ def index_doc_batch_prepare(
|
||||
db_session: Session,
|
||||
ignore_time_skip: bool = False,
|
||||
) -> DocumentBatchPrepareContext | None:
|
||||
"""This sets up the documents in the relational DB (source of truth) for permissions, metadata, etc.
|
||||
"""Sets up the documents in the relational DB (source of truth) for permissions, metadata, etc.
|
||||
This preceeds indexing it into the actual document index."""
|
||||
documents = []
|
||||
documents: list[Document] = []
|
||||
for document in document_batch:
|
||||
empty_contents = not any(section.text.strip() for section in document.sections)
|
||||
if (
|
||||
@@ -212,43 +214,58 @@ def index_doc_batch_prepare(
|
||||
logger.warning(
|
||||
f"Skipping document with ID {document.id} as it has neither title nor content."
|
||||
)
|
||||
elif (
|
||||
document.title is not None and not document.title.strip() and empty_contents
|
||||
):
|
||||
continue
|
||||
|
||||
if document.title is not None and not document.title.strip() and empty_contents:
|
||||
# The title is explicitly empty ("" and not None) and the document is empty
|
||||
# so when building the chunk text representation, it will be empty and unuseable
|
||||
logger.warning(
|
||||
f"Skipping document with ID {document.id} as the chunks will be empty."
|
||||
)
|
||||
else:
|
||||
documents.append(document)
|
||||
continue
|
||||
|
||||
document_ids = [document.id for document in documents]
|
||||
documents.append(document)
|
||||
|
||||
# Create a trimmed list of docs that don't have a newer updated at
|
||||
# Shortcuts the time-consuming flow on connector index retries
|
||||
document_ids: list[str] = [document.id for document in documents]
|
||||
db_docs: list[DBDocument] = get_documents_by_ids(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
|
||||
# Skip indexing docs that don't have a newer updated at
|
||||
# Shortcuts the time-consuming flow on connector index retries
|
||||
updatable_docs = (
|
||||
get_doc_ids_to_update(documents=documents, db_docs=db_docs)
|
||||
if not ignore_time_skip
|
||||
else documents
|
||||
)
|
||||
|
||||
# No docs to update either because the batch is empty or every doc was already indexed
|
||||
# for all updatable docs, upsert into the DB
|
||||
# Does not include doc_updated_at which is also used to indicate a successful update
|
||||
if updatable_docs:
|
||||
_upsert_documents_in_db(
|
||||
documents=updatable_docs,
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Upserted {len(updatable_docs)} changed docs out of "
|
||||
f"{len(documents)} total docs into the DB"
|
||||
)
|
||||
|
||||
# for all docs, upsert the document to cc pair relationship
|
||||
upsert_document_by_connector_credential_pair(
|
||||
db_session,
|
||||
index_attempt_metadata.connector_id,
|
||||
index_attempt_metadata.credential_id,
|
||||
document_ids,
|
||||
)
|
||||
|
||||
# No docs to process because the batch is empty or every doc was already indexed
|
||||
if not updatable_docs:
|
||||
return None
|
||||
|
||||
# Create records in the source of truth about these documents,
|
||||
# does not include doc_updated_at which is also used to indicate a successful update
|
||||
upsert_documents_in_db(
|
||||
documents=documents,
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
id_to_db_doc_map = {doc.id: doc for doc in db_docs}
|
||||
return DocumentBatchPrepareContext(
|
||||
updatable_docs=updatable_docs, id_to_db_doc_map=id_to_db_doc_map
|
||||
@@ -269,7 +286,10 @@ def index_doc_batch(
|
||||
) -> tuple[int, int]:
|
||||
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
|
||||
Note that the documents should already be batched at this point so that it does not inflate the
|
||||
memory requirements"""
|
||||
memory requirements
|
||||
|
||||
Returns a tuple where the first element is the number of new docs and the
|
||||
second element is the number of chunks."""
|
||||
|
||||
no_access = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
@@ -312,9 +332,9 @@ def index_doc_batch(
|
||||
|
||||
# we're concerned about race conditions where multiple simultaneous indexings might result
|
||||
# in one set of metadata overwriting another one in vespa.
|
||||
# we still write data here for immediate and most likely correct sync, but
|
||||
# we still write data here for the immediate and most likely correct sync, but
|
||||
# to resolve this, an update of the last modified field at the end of this loop
|
||||
# always triggers a final metadata sync
|
||||
# always triggers a final metadata sync via the celery queue
|
||||
access_aware_chunks = [
|
||||
DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
@@ -351,7 +371,8 @@ def index_doc_batch(
|
||||
ids_to_new_updated_at = {}
|
||||
for doc in successful_docs:
|
||||
last_modified_ids.append(doc.id)
|
||||
# doc_updated_at is the connector source's idea of when the doc was last modified
|
||||
# doc_updated_at is the source's idea (on the other end of the connector)
|
||||
# of when the doc was last modified
|
||||
if doc.doc_updated_at is None:
|
||||
continue
|
||||
ids_to_new_updated_at[doc.id] = doc.doc_updated_at
|
||||
@@ -366,10 +387,13 @@ def index_doc_batch(
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return len([r for r in insertion_records if r.already_existed is False]), len(
|
||||
access_aware_chunks
|
||||
result = (
|
||||
len([r for r in insertion_records if r.already_existed is False]),
|
||||
len(access_aware_chunks),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def build_indexing_pipeline(
|
||||
*,
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import time
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
@@ -13,6 +14,7 @@ from danswer.db.connector_credential_pair import get_connector_credential_pair_f
|
||||
from danswer.db.document import (
|
||||
construct_document_select_for_connector_credential_pair_by_needs_sync,
|
||||
)
|
||||
from danswer.db.models import Document
|
||||
from danswer.redis.redis_object_helper import RedisObjectHelper
|
||||
|
||||
|
||||
@@ -30,6 +32,9 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
def __init__(self, tenant_id: str | None, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
|
||||
# documents that should be skipped
|
||||
self.skip_docs: set[str] = set()
|
||||
|
||||
@classmethod
|
||||
def get_fence_key(cls) -> str:
|
||||
return RedisConnectorCredentialPair.FENCE_PREFIX
|
||||
@@ -45,14 +50,19 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
# example: connector_taskset
|
||||
return f"{self.TASKSET_PREFIX}"
|
||||
|
||||
def set_skip_docs(self, skip_docs: set[str]) -> None:
|
||||
# documents that should be skipped. Note that this classes updates
|
||||
# the list on the fly
|
||||
self.skip_docs = skip_docs
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
lock: RedisLock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
) -> tuple[int, int] | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
@@ -63,7 +73,10 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
|
||||
cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
|
||||
num_docs = 0
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
doc = cast(Document, doc)
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
@@ -71,6 +84,12 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
num_docs += 1
|
||||
|
||||
# check if we should skip the document (typically because it's already syncing)
|
||||
if doc.id in self.skip_docs:
|
||||
continue
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
@@ -93,5 +112,6 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
self.skip_docs.add(doc.id)
|
||||
|
||||
return len(async_results)
|
||||
return len(async_results), num_docs
|
||||
|
||||
@@ -6,6 +6,7 @@ from uuid import uuid4
|
||||
import redis
|
||||
from celery import Celery
|
||||
from pydantic import BaseModel
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
@@ -83,7 +84,7 @@ class RedisConnectorDelete:
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
lock: redis.lock.Lock,
|
||||
lock: RedisLock,
|
||||
) -> int | None:
|
||||
"""Returns None if the cc_pair doesn't exist.
|
||||
Otherwise, returns an int with the number of generated tasks."""
|
||||
|
||||
@@ -6,7 +6,7 @@ import redis
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RedisConnectorIndexingFenceData(BaseModel):
|
||||
class RedisConnectorIndexPayload(BaseModel):
|
||||
index_attempt_id: int | None
|
||||
started: datetime | None
|
||||
submitted: datetime
|
||||
@@ -71,22 +71,20 @@ class RedisConnectorIndex:
|
||||
return False
|
||||
|
||||
@property
|
||||
def payload(self) -> RedisConnectorIndexingFenceData | None:
|
||||
def payload(self) -> RedisConnectorIndexPayload | None:
|
||||
# read related data and evaluate/print task progress
|
||||
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
|
||||
if fence_bytes is None:
|
||||
return None
|
||||
|
||||
fence_str = fence_bytes.decode("utf-8")
|
||||
payload = RedisConnectorIndexingFenceData.model_validate_json(
|
||||
cast(str, fence_str)
|
||||
)
|
||||
payload = RedisConnectorIndexPayload.model_validate_json(cast(str, fence_str))
|
||||
|
||||
return payload
|
||||
|
||||
def set_fence(
|
||||
self,
|
||||
payload: RedisConnectorIndexingFenceData | None,
|
||||
payload: RedisConnectorIndexPayload | None,
|
||||
) -> None:
|
||||
if not payload:
|
||||
self.redis.delete(self.fence_key)
|
||||
|
||||
@@ -4,6 +4,7 @@ from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
@@ -105,7 +106,7 @@ class RedisConnectorPrune:
|
||||
documents_to_prune: set[str],
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
lock: redis.lock.Lock | None,
|
||||
lock: RedisLock | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from uuid import uuid4
|
||||
import redis
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
@@ -50,9 +51,9 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
lock: RedisLock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
) -> tuple[int, int] | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
@@ -84,7 +85,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
return len(async_results), len(async_results)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.delete(self.taskset_key)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
@@ -85,7 +85,13 @@ class RedisObjectHelper(ABC):
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
lock: RedisLock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
pass
|
||||
) -> tuple[int, int] | None:
|
||||
"""First element should be the number of actual tasks generated, second should
|
||||
be the number of docs that were candidates to be synced for the cc pair.
|
||||
|
||||
The need for this is when we are syncing stale docs referenced by multiple
|
||||
connectors. In a single pass across multiple cc pairs, we only want a task
|
||||
for be created for a particular document id the first time we see it.
|
||||
The rest can be skipped."""
|
||||
|
||||
@@ -5,6 +5,7 @@ from uuid import uuid4
|
||||
import redis
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
@@ -51,15 +52,15 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
lock: RedisLock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
) -> tuple[int, int] | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
|
||||
if not global_version.is_ee_version():
|
||||
return 0
|
||||
return 0, 0
|
||||
|
||||
try:
|
||||
construct_document_select_by_usergroup = fetch_versioned_implementation(
|
||||
@@ -67,7 +68,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
"construct_document_select_by_usergroup",
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
return 0
|
||||
return 0, 0
|
||||
|
||||
stmt = construct_document_select_by_usergroup(int(self._id))
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
@@ -97,7 +98,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
return len(async_results), len(async_results)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.delete(self.taskset_key)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from typing import Any
|
||||
|
||||
from atlassian import Confluence # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
@@ -19,12 +18,8 @@ def _get_group_members_email_paginated(
|
||||
confluence_client: OnyxConfluence,
|
||||
group_name: str,
|
||||
) -> set[str]:
|
||||
members: list[dict[str, Any]] = []
|
||||
for member_batch in confluence_client.paginated_group_members_retrieval(group_name):
|
||||
members.extend(member_batch)
|
||||
|
||||
group_member_emails: set[str] = set()
|
||||
for member in members:
|
||||
for member in confluence_client.paginated_group_members_retrieval(group_name):
|
||||
email = member.get("email")
|
||||
if not email:
|
||||
user_name = member.get("username")
|
||||
@@ -43,19 +38,33 @@ def confluence_group_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> None:
|
||||
credentials = cc_pair.credential.credential_json
|
||||
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
|
||||
wiki_base = cc_pair.connector.connector_specific_config["wiki_base"]
|
||||
|
||||
# test connection with direct client, no retries
|
||||
confluence_client = Confluence(
|
||||
api_version="cloud" if is_cloud else "latest",
|
||||
url=wiki_base.rstrip("/"),
|
||||
username=credentials["confluence_username"] if is_cloud else None,
|
||||
password=credentials["confluence_access_token"] if is_cloud else None,
|
||||
token=credentials["confluence_access_token"] if not is_cloud else None,
|
||||
)
|
||||
spaces = confluence_client.get_all_spaces(limit=1)
|
||||
if not spaces:
|
||||
raise RuntimeError(f"No spaces found at {wiki_base}!")
|
||||
|
||||
confluence_client = build_confluence_client(
|
||||
credentials_json=cc_pair.credential.credential_json,
|
||||
credentials_json=credentials,
|
||||
is_cloud=is_cloud,
|
||||
wiki_base=cc_pair.connector.connector_specific_config["wiki_base"],
|
||||
wiki_base=wiki_base,
|
||||
)
|
||||
|
||||
# Get all group names
|
||||
group_names: list[str] = []
|
||||
for group_batch in confluence_client.paginated_groups_retrieval():
|
||||
for group in group_batch:
|
||||
if group_name := group.get("name"):
|
||||
group_names.append(group_name)
|
||||
for group in confluence_client.paginated_groups_retrieval():
|
||||
if group_name := group.get("name"):
|
||||
group_names.append(group_name)
|
||||
|
||||
# For each group name, get all members and create a danswer group
|
||||
danswer_groups: list[ExternalUserGroup] = []
|
||||
|
||||
@@ -29,6 +29,78 @@ from tests.integration.common_utils.test_models import DATestUserGroup
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
|
||||
|
||||
# def test_connector_creation(reset: None) -> None:
|
||||
# # Creating an admin user (first user created is automatically an admin)
|
||||
# admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# # create connectors
|
||||
# cc_pair_1 = CCPairManager.create_from_scratch(
|
||||
# source=DocumentSource.INGESTION_API,
|
||||
# user_performing_action=admin_user,
|
||||
# )
|
||||
|
||||
# cc_pair_info = CCPairManager.get_single(
|
||||
# cc_pair_1.id, user_performing_action=admin_user
|
||||
# )
|
||||
# assert cc_pair_info
|
||||
# assert cc_pair_info.creator
|
||||
# assert str(cc_pair_info.creator) == admin_user.id
|
||||
# assert cc_pair_info.creator_email == admin_user.email
|
||||
|
||||
|
||||
# TODO(rkuo): will enable this once i have credentials on github
|
||||
# def test_overlapping_connector_creation(reset: None) -> None:
|
||||
# # Creating an admin user (first user created is automatically an admin)
|
||||
# admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# config = {
|
||||
# "wiki_base": os.environ["CONFLUENCE_TEST_SPACE_URL"],
|
||||
# "space": os.environ["CONFLUENCE_TEST_SPACE"],
|
||||
# "is_cloud": True,
|
||||
# "page_id": "",
|
||||
# }
|
||||
|
||||
# credential = {
|
||||
# "confluence_username": os.environ["CONFLUENCE_USER_NAME"],
|
||||
# "confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"],
|
||||
# }
|
||||
|
||||
# # store the time before we create the connector so that we know after
|
||||
# # when the indexing should have started
|
||||
# now = datetime.now(timezone.utc)
|
||||
|
||||
# # create connector
|
||||
# cc_pair_1 = CCPairManager.create_from_scratch(
|
||||
# source=DocumentSource.CONFLUENCE,
|
||||
# connector_specific_config=config,
|
||||
# credential_json=credential,
|
||||
# user_performing_action=admin_user,
|
||||
# )
|
||||
|
||||
# CCPairManager.wait_for_indexing(
|
||||
# cc_pair_1, now, timeout=60, user_performing_action=admin_user
|
||||
# )
|
||||
|
||||
# cc_pair_2 = CCPairManager.create_from_scratch(
|
||||
# source=DocumentSource.CONFLUENCE,
|
||||
# connector_specific_config=config,
|
||||
# credential_json=credential,
|
||||
# user_performing_action=admin_user,
|
||||
# )
|
||||
|
||||
# CCPairManager.wait_for_indexing(
|
||||
# cc_pair_2, now, timeout=60, user_performing_action=admin_user
|
||||
# )
|
||||
|
||||
# info_1 = CCPairManager.get_single(cc_pair_1.id)
|
||||
# assert info_1
|
||||
|
||||
# info_2 = CCPairManager.get_single(cc_pair_2.id)
|
||||
# assert info_2
|
||||
|
||||
# assert info_1.num_docs_indexed == info_2.num_docs_indexed
|
||||
|
||||
|
||||
def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
Reference in New Issue
Block a user