Compare commits

...

18 Commits

Author SHA1 Message Date
rkuo-danswer
4e4c48291a Merge pull request #3207 from danswer-ai/hotfix/v0.13-double-check
Hotfix/v0.13 double check
2024-11-22 00:53:03 -08:00
Richard Kuo
871918738f remove unnecessary backported tests 2024-11-22 00:24:48 -08:00
Richard Kuo
b1780732b9 backport double check of error conditions 2024-11-22 00:20:18 -08:00
rkuo-danswer
e92f48c17d Merge pull request #3193 from danswer-ai/hotfix/v0.13-google-oauth
Merge hotfix/v0.13-google-oauth into release/v0.13
2024-11-21 14:12:21 -08:00
hagen-danswer
8fad5f7e5e Updated google copy and added non admin oauth support (#3120)
* Updated google copy and added non admin oauth support

* backend update

* accounted for oauth

* further removed class variables

* updated sets
2024-11-21 18:28:30 +00:00
rkuo-danswer
310732d10f Merge pull request #3185 from danswer-ai/hotfix/conf
added logging and bugfixing to conf (#3167)
2024-11-20 16:30:37 -08:00
hagen-danswer
322d7cdc90 brought the timeout changes too 2024-11-20 15:58:28 -08:00
hagen-danswer
67d943da11 added logging and bugfixing to conf (#3167)
* standardized escaping of CQL strings

* think i found it

* fix

* should be fixed

* added handling for special linking behavior in confluence

* Update onyx_confluence.py

* Update onyx_confluence.py

---------

Co-authored-by: rkuo-danswer <rkuo@danswer.ai>
2024-11-20 15:53:25 -08:00
rkuo-danswer
9456fef307 Merge pull request #3161 from danswer-ai/hotfix/v0.13-indexing-redux
enhanced logging for indexing and increased indexing timeouts
2024-11-18 19:16:39 -08:00
Richard Kuo (Danswer)
cc3c0800f0 no idea how those files got into the merge 2024-11-18 18:38:29 -08:00
Richard Kuo (Danswer)
e860f15b64 hotfix merge 2024-11-18 18:14:21 -08:00
rkuo-danswer
574ef470a4 Merge pull request #3149 from danswer-ai/hotfix/v0.13-overlapping-connectors
merge overlapping connector hotfix
2024-11-16 22:34:02 -08:00
Richard Kuo
9e391495c2 fix unused stuff for hotfix 2024-11-16 21:11:39 -08:00
Richard Kuo
e26d5430fa merge overlapping connector hotfix 2024-11-16 20:59:00 -08:00
rkuo-danswer
cce0ec2f22 Merge pull request #3141 from danswer-ai/hotfix/v0.13-indexing-concurrency
Merge hotfix/v0.13-indexing-concurrency into release/v0.13
2024-11-15 12:51:41 -08:00
rkuo-danswer
a4f09a62a5 Merge pull request #3142 from danswer-ai/hotfix/v0.13-session-text
Merge hotfix/v0.13-session-text into release/v0.13
2024-11-15 12:51:23 -08:00
rkuo-danswer
fd2428d97f Merge pull request #3131 from danswer-ai/bugfix/session_text
use text()
2024-11-15 20:23:18 +00:00
rkuo-danswer
cfc46812c8 scale indexing sql pool based on concurrency (#3130) 2024-11-15 20:21:43 +00:00
27 changed files with 838 additions and 374 deletions

View File

@@ -59,7 +59,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=sender.concurrency)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:

View File

@@ -14,10 +14,16 @@ from celery.signals import worker_shutdown
import danswer.background.celery.apps.app_base as app_base
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.background.celery.tasks.indexing.tasks import (
get_unfenced_index_attempt_ids,
)
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from danswer.db.engine import get_session_with_default_tenant
from danswer.db.engine import SqlEngine
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_index import RedisConnectorIndex
@@ -134,6 +140,23 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
RedisConnectorStop.reset_all(r)
# mark orphaned index attempts as failed
with get_session_with_default_tenant() as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
attempt = get_index_attempt(db_session, attempt_id)
if not attempt:
continue
failure_reason = (
f"Orphaned index attempt found on startup: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
logger.warning(failure_reason)
mark_attempt_failed(attempt.id, db_session, failure_reason)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:

View File

@@ -1,12 +1,12 @@
from datetime import datetime
from datetime import timezone
import redis
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
@@ -87,7 +87,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
cc_pair_id: int,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
lock_beat: RedisLock,
tenant_id: str | None,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.

View File

@@ -10,6 +10,8 @@ from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.exceptions import LockError
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
@@ -32,6 +34,8 @@ from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.enums import IndexingStatus
from danswer.db.enums import IndexModelStatus
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import delete_index_attempt
from danswer.db.index_attempt import get_all_index_attempts_by_status
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
from danswer.db.index_attempt import mark_attempt_failed
@@ -44,7 +48,8 @@ from danswer.db.swap_index import check_index_swap
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_index import RedisConnectorIndexingFenceData
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_index import RedisConnectorIndexPayload
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
@@ -61,14 +66,18 @@ class RunIndexingCallback(RunIndexingCallbackInterface):
self,
stop_key: str,
generator_progress_key: str,
redis_lock: redis.lock.Lock,
redis_lock: RedisLock,
redis_client: Redis,
):
super().__init__()
self.redis_lock: redis.lock.Lock = redis_lock
self.redis_lock: RedisLock = redis_lock
self.stop_key: str = stop_key
self.generator_progress_key: str = generator_progress_key
self.redis_client = redis_client
self.started: datetime = datetime.now(timezone.utc)
self.redis_lock.reacquire()
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
def should_stop(self) -> bool:
if self.redis_client.exists(self.stop_key):
@@ -76,10 +85,70 @@ class RunIndexingCallback(RunIndexingCallbackInterface):
return False
def progress(self, amount: int) -> None:
self.redis_lock.reacquire()
try:
self.redis_lock.reacquire()
self.last_lock_reacquire = datetime.now(timezone.utc)
except LockError:
logger.exception(
f"RunIndexingCallback - lock.reacquire exceptioned. "
f"lock_timeout={self.redis_lock.timeout} "
f"start={self.started} "
f"last_reacquired={self.last_lock_reacquire} "
f"now={datetime.now(timezone.utc)}"
)
raise
self.redis_client.incrby(self.generator_progress_key, amount)
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
want to clean them up.
Unfenced = attempt not in terminal state and fence does not exist.
"""
unfenced_attempts: list[int] = []
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# inner = index_attempt in non terminal state
# outer = r.fence_key down
# check the db for index attempts in a non terminal state
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for attempt in attempts:
fence_key = RedisConnectorIndex.fence_key_with_ids(
attempt.connector_credential_pair_id, attempt.search_settings_id
)
# if the fence is down / doesn't exist, possible error but not confirmed
if r.exists(fence_key):
continue
# Between the time the attempts are first looked up and the time we see the fence down,
# the attempt may have completed and taken down the fence normally.
# We need to double check that the index attempt is still in a non terminal state
# and matches the original state, which confirms we are really in a bad state.
attempt_2 = get_index_attempt(db_session, attempt.id)
if not attempt_2:
continue
if attempt.status != attempt_2.status:
continue
unfenced_attempts.append(attempt.id)
return unfenced_attempts
@shared_task(
name="check_for_indexing",
soft_time_limit=300,
@@ -90,7 +159,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
lock_beat: RedisLock = r.lock(
DanswerRedisLocks.CHECK_INDEXING_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -100,6 +169,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
if not lock_beat.acquire(blocking=False):
return None
# check for search settings swap
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
old_search_settings = check_index_swap(db_session=db_session)
current_search_settings = get_current_search_settings(db_session)
@@ -118,13 +188,18 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
embedding_model=embedding_model,
)
# gather cc_pair_ids
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
lock_beat.reacquire()
cc_pairs = fetch_connector_credential_pairs(db_session)
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
# kick off index attempts
for cc_pair_id in cc_pair_ids:
lock_beat.reacquire()
redis_connector = RedisConnector(tenant_id, cc_pair_id)
with get_session_with_tenant(tenant_id) as db_session:
# Get the primary search settings
@@ -180,6 +255,29 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
f"search_settings={search_settings_instance.id} "
)
tasks_created += 1
# Fail any index attempts in the DB that don't have fences
# This shouldn't ever happen!
with get_session_with_tenant(tenant_id) as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
lock_beat.reacquire()
attempt = get_index_attempt(db_session, attempt_id)
if not attempt:
continue
failure_reason = (
f"Unfenced index attempt found in DB: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
task_logger.error(failure_reason)
mark_attempt_failed(
attempt.id, db_session, failure_reason=failure_reason
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -189,6 +287,11 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
finally:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error(
"check_for_indexing - Lock not owned on completion: "
f"tenant={tenant_id}"
)
return tasks_created
@@ -293,10 +396,11 @@ def try_creating_indexing_task(
"""
LOCK_TIMEOUT = 30
index_attempt_id: int | None = None
# we need to serialize any attempt to trigger indexing since it can be triggered
# either via celery beat or manually (API call)
lock = r.lock(
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
timeout=LOCK_TIMEOUT,
)
@@ -325,7 +429,7 @@ def try_creating_indexing_task(
redis_connector_index.generator_clear()
# set a basic fence to start
payload = RedisConnectorIndexingFenceData(
payload = RedisConnectorIndexPayload(
index_attempt_id=None,
started=None,
submitted=datetime.now(timezone.utc),
@@ -347,6 +451,8 @@ def try_creating_indexing_task(
custom_task_id = redis_connector_index.generate_generator_task_id()
# when the task is sent, we have yet to finish setting up the fence
# therefore, the task must contain code that blocks until the fence is ready
result = celery_app.send_task(
"connector_indexing_proxy_task",
kwargs=dict(
@@ -368,13 +474,16 @@ def try_creating_indexing_task(
redis_connector_index.set_fence(payload)
except Exception:
redis_connector_index.set_fence(payload)
task_logger.exception(
f"Unexpected exception: "
f"try_creating_indexing_task - Unexpected exception: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings.id}"
)
if index_attempt_id is not None:
delete_index_attempt(db_session, index_attempt_id)
redis_connector_index.set_fence(None)
return None
finally:
if lock.owned():
@@ -392,7 +501,7 @@ def connector_indexing_proxy_task(
) -> None:
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
task_logger.info(
f"Indexing proxy - starting: attempt={index_attempt_id} "
f"Indexing watchdog - starting: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
@@ -400,7 +509,7 @@ def connector_indexing_proxy_task(
client = SimpleJobClient()
job = client.submit(
connector_indexing_task,
connector_indexing_task_wrapper,
index_attempt_id,
cc_pair_id,
search_settings_id,
@@ -411,7 +520,7 @@ def connector_indexing_proxy_task(
if not job:
task_logger.info(
f"Indexing proxy - spawn failed: attempt={index_attempt_id} "
f"Indexing watchdog - spawn failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
@@ -419,7 +528,7 @@ def connector_indexing_proxy_task(
return
task_logger.info(
f"Indexing proxy - spawn succeeded: attempt={index_attempt_id} "
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
@@ -443,7 +552,7 @@ def connector_indexing_proxy_task(
if job.status == "error":
task_logger.error(
f"Indexing proxy - spawned task exceptioned: "
f"Indexing watchdog - spawned task exceptioned: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
@@ -455,7 +564,7 @@ def connector_indexing_proxy_task(
break
task_logger.info(
f"Indexing proxy - finished: attempt={index_attempt_id} "
f"Indexing watchdog - finished: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
@@ -463,6 +572,38 @@ def connector_indexing_proxy_task(
return
def connector_indexing_task_wrapper(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
is_ee: bool,
) -> int | None:
"""Just wraps connector_indexing_task so we can log any exceptions before
re-raising it."""
result: int | None = None
try:
result = connector_indexing_task(
index_attempt_id,
cc_pair_id,
search_settings_id,
tenant_id,
is_ee,
)
except:
logger.exception(
f"connector_indexing_task exceptioned: "
f"tenant={tenant_id} "
f"index_attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
raise
return result
def connector_indexing_task(
index_attempt_id: int,
cc_pair_id: int,
@@ -516,6 +657,7 @@ def connector_indexing_task(
if redis_connector.delete.fenced:
raise RuntimeError(
f"Indexing will not start because connector deletion is in progress: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.delete.fence_key}"
)
@@ -523,18 +665,18 @@ def connector_indexing_task(
if redis_connector.stop.fenced:
raise RuntimeError(
f"Indexing will not start because a connector stop signal was detected: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.stop.fence_key}"
)
while True:
# wait for the fence to come up
if not redis_connector_index.fenced:
if not redis_connector_index.fenced: # The fence must exist
raise ValueError(
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
)
payload = redis_connector_index.payload
payload = redis_connector_index.payload # The payload must exist
if not payload:
raise ValueError("connector_indexing_task: payload invalid or not found")
@@ -557,7 +699,7 @@ def connector_indexing_task(
)
break
lock = r.lock(
lock: RedisLock = r.lock(
redis_connector_index.generator_lock_key,
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
)
@@ -566,7 +708,7 @@ def connector_indexing_task(
if not acquired:
logger.warning(
f"Indexing task already running, exiting...: "
f"cc_pair={cc_pair_id} search_settings={search_settings_id}"
f"index_attempt={index_attempt_id} cc_pair={cc_pair_id} search_settings={search_settings_id}"
)
return None

View File

@@ -5,7 +5,6 @@ from http import HTTPStatus
from typing import cast
import httpx
import redis
from celery import Celery
from celery import shared_task
from celery import Task
@@ -13,6 +12,7 @@ from celery.exceptions import SoftTimeLimitExceeded
from celery.result import AsyncResult
from celery.states import READY_STATES
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from tenacity import RetryError
@@ -45,13 +45,10 @@ from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import IndexingStatus
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.index_attempt import get_all_index_attempts_by_status
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import DocumentSet
from danswer.db.models import IndexAttempt
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaDocumentFields
@@ -162,7 +159,7 @@ def try_generate_stale_document_sync_tasks(
celery_app: Celery,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
lock_beat: RedisLock,
tenant_id: str | None,
) -> int | None:
# the fence is up, do nothing
@@ -180,7 +177,12 @@ def try_generate_stale_document_sync_tasks(
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair."
)
task_logger.info("RedisConnector.generate_tasks starting by cc_pair.")
task_logger.info(
"RedisConnector.generate_tasks starting by cc_pair. "
"Documents spanning multiple cc_pairs will only be synced once."
)
docs_to_skip: set[str] = set()
# rkuo: we could technically sync all stale docs in one big pass.
# but I feel it's more understandable to group the docs by cc_pair
@@ -188,22 +190,21 @@ def try_generate_stale_document_sync_tasks(
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
rc = RedisConnectorCredentialPair(tenant_id, cc_pair.id)
tasks_generated = rc.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
rc.set_skip_docs(docs_to_skip)
result = rc.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
if tasks_generated is None:
if result is None:
continue
if tasks_generated == 0:
if result[1] == 0:
continue
task_logger.info(
f"RedisConnector.generate_tasks finished for single cc_pair. "
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
f"cc_pair={cc_pair.id} tasks_generated={result[0]} tasks_possible={result[1]}"
)
total_tasks_generated += tasks_generated
total_tasks_generated += result[0]
task_logger.info(
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
@@ -218,7 +219,7 @@ def try_generate_document_set_sync_tasks(
document_set_id: int,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
lock_beat: RedisLock,
tenant_id: str | None,
) -> int | None:
lock_beat.reacquire()
@@ -246,12 +247,11 @@ def try_generate_document_set_sync_tasks(
)
# Add all documents that need to be updated into the queue
tasks_generated = rds.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
if tasks_generated is None:
result = rds.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
if result is None:
return None
tasks_generated = result[0]
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
@@ -260,7 +260,7 @@ def try_generate_document_set_sync_tasks(
task_logger.info(
f"RedisDocumentSet.generate_tasks finished. "
f"document_set_id={document_set.id} tasks_generated={tasks_generated}"
f"document_set={document_set.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
@@ -273,7 +273,7 @@ def try_generate_user_group_sync_tasks(
usergroup_id: int,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
lock_beat: RedisLock,
tenant_id: str | None,
) -> int | None:
lock_beat.reacquire()
@@ -302,12 +302,11 @@ def try_generate_user_group_sync_tasks(
task_logger.info(
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
)
tasks_generated = rug.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
if tasks_generated is None:
result = rug.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
if result is None:
return None
tasks_generated = result[0]
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
@@ -316,7 +315,7 @@ def try_generate_user_group_sync_tasks(
task_logger.info(
f"RedisUserGroup.generate_tasks finished. "
f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}"
f"usergroup={usergroup.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
@@ -580,8 +579,8 @@ def monitor_ccpair_indexing_taskset(
progress = redis_connector_index.get_progress()
if progress is not None:
task_logger.info(
f"Connector indexing progress: cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"Connector indexing progress: cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
@@ -590,30 +589,46 @@ def monitor_ccpair_indexing_taskset(
# the task is still setting up
return
# Read result state BEFORE generator_complete_key to avoid a race condition
# never use any blocking methods on the result from inside a task!
result: AsyncResult = AsyncResult(payload.celery_task_id)
result_state = result.state
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# inner = get_completion / generator_complete not signaled
# outer = result.state in READY state
status_int = redis_connector_index.get_completion()
if status_int is None:
if result_state in READY_STATES:
# IF the task state is READY, THEN generator_complete should be set
# if it isn't, then the worker crashed
task_logger.info(
f"Connector indexing aborted: "
f"cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
if status_int is None: # inner signal not set ... possible error
result_state = result.state
if (
result_state in READY_STATES
): # outer signal in terminal state ... possible error
# Now double check!
if redis_connector_index.get_completion() is None:
# inner signal still not set (and cannot change when outer result_state is READY)
# Task is finished but generator complete isn't set.
# We have a problem! Worker may have crashed.
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
if index_attempt:
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason="Connector indexing aborted or exceptioned.",
msg = (
f"Connector indexing aborted or exceptioned: "
f"attempt={payload.index_attempt_id} "
f"celery_task={payload.celery_task_id} "
f"result_state={result_state} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
task_logger.warning(msg)
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
if index_attempt:
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason=msg,
)
redis_connector_index.reset()
redis_connector_index.reset()
return
@@ -621,8 +636,8 @@ def monitor_ccpair_indexing_taskset(
status_enum = HTTPStatus(status_int)
task_logger.info(
f"Connector indexing finished: cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"Connector indexing finished: cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"status={status_enum.name} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
@@ -643,7 +658,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
"""
r = get_redis_client(tenant_id=tenant_id)
lock_beat: redis.lock.Lock = r.lock(
lock_beat: RedisLock = r.lock(
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -677,32 +692,6 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
f"pruning={n_pruning}"
)
# do some cleanup before clearing fences
# check the db for any outstanding index attempts
with get_session_with_tenant(tenant_id) as db_session:
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for a in attempts:
# if attempts exist in the db but we don't detect them in redis, mark them as failed
fence_key = RedisConnectorIndex.fence_key_with_ids(
a.connector_credential_pair_id, a.search_settings_id
)
if not r.exists(fence_key):
failure_reason = (
f"Unknown index attempt. Might be left over from a process restart: "
f"index_attempt={a.id} "
f"cc_pair={a.connector_credential_pair_id} "
f"search_settings={a.search_settings_id}"
)
task_logger.warning(failure_reason)
mark_attempt_failed(a.id, db_session, failure_reason=failure_reason)
lock_beat.reacquire()
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)

View File

@@ -433,11 +433,13 @@ def run_indexing_entrypoint(
with get_session_with_tenant(tenant_id) as db_session:
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
tenant_str = ""
if tenant_id is not None:
tenant_str = f" for tenant {tenant_id}"
logger.info(
f"Indexing starting for tenant {tenant_id}: "
if tenant_id is not None
else ""
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
f"Indexing starting{tenant_str}: "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)
@@ -445,10 +447,8 @@ def run_indexing_entrypoint(
_run_indexing(db_session, attempt, tenant_id, callback)
logger.info(
f"Indexing finished for tenant {tenant_id}: "
if tenant_id is not None
else ""
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
f"Indexing finished{tenant_str}: "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)

View File

@@ -74,7 +74,7 @@ CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
# needs to be long enough to cover the maximum time it takes to download an object
# if we can get callbacks as object bytes download, we could lower this a lot.
CELERY_INDEXING_LOCK_TIMEOUT = 60 * 60 # 60 min
CELERY_INDEXING_LOCK_TIMEOUT = 3 * 60 * 60 # 60 min
# needs to be long enough to cover the maximum time it takes to download an object
# if we can get callbacks as object bytes download, we could lower this a lot.

View File

@@ -3,6 +3,8 @@ from datetime import timezone
from typing import Any
from urllib.parse import quote
from atlassian import Confluence # type: ignore
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.app_configs import INDEX_BATCH_SIZE
@@ -70,7 +72,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
) -> None:
self.batch_size = batch_size
self.continue_on_failure = continue_on_failure
self.confluence_client: OnyxConfluence | None = None
self._confluence_client: OnyxConfluence | None = None
self.is_cloud = is_cloud
# Remove trailing slash from wiki_base if present
@@ -97,39 +99,59 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
self.cql_label_filter = ""
if labels_to_skip:
labels_to_skip = list(set(labels_to_skip))
comma_separated_labels = ",".join(f"'{label}'" for label in labels_to_skip)
comma_separated_labels = ",".join(
f"'{quote(label)}'" for label in labels_to_skip
)
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
@property
def confluence_client(self) -> OnyxConfluence:
if self._confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
return self._confluence_client
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
# see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py
# for a list of other hidden constructor args
self.confluence_client = build_confluence_client(
self._confluence_client = build_confluence_client(
credentials_json=credentials,
is_cloud=self.is_cloud,
wiki_base=self.wiki_base,
)
client_without_retries = Confluence(
api_version="cloud" if self.is_cloud else "latest",
url=self.wiki_base.rstrip("/"),
username=credentials["confluence_username"] if self.is_cloud else None,
password=credentials["confluence_access_token"] if self.is_cloud else None,
token=credentials["confluence_access_token"] if not self.is_cloud else None,
)
spaces = client_without_retries.get_all_spaces(limit=1)
if not spaces:
raise RuntimeError(
f"No spaces found at {self.wiki_base}! "
"Check your credentials and wiki_base and make sure "
"is_cloud is set correctly."
)
return None
def _get_comment_string_for_page_id(self, page_id: str) -> str:
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
comment_string = ""
comment_cql = f"type=comment and container='{page_id}'"
comment_cql += self.cql_label_filter
expand = ",".join(_COMMENT_EXPANSION_FIELDS)
for comments in self.confluence_client.paginated_cql_page_retrieval(
for comment in self.confluence_client.paginated_cql_retrieval(
cql=comment_cql,
expand=expand,
):
for comment in comments:
comment_string += "\nComment:\n"
comment_string += extract_text_from_confluence_html(
confluence_client=self.confluence_client,
confluence_object=comment,
)
comment_string += "\nComment:\n"
comment_string += extract_text_from_confluence_html(
confluence_client=self.confluence_client,
confluence_object=comment,
fetched_titles=set(),
)
return comment_string
@@ -141,28 +163,30 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
If its a page, it extracts the text, adds the comments for the document text.
If its an attachment, it just downloads the attachment and converts that into a document.
"""
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
# The url and the id are the same
object_url = build_confluence_document_id(
self.wiki_base, confluence_object["_links"]["webui"]
base_url=self.wiki_base,
content_url=confluence_object["_links"]["webui"],
is_cloud=self.is_cloud,
)
object_text = None
# Extract text from page
if confluence_object["type"] == "page":
object_text = extract_text_from_confluence_html(
self.confluence_client, confluence_object
confluence_client=self.confluence_client,
confluence_object=confluence_object,
fetched_titles={confluence_object.get("title", "")},
)
# Add comments to text
object_text += self._get_comment_string_for_page_id(confluence_object["id"])
elif confluence_object["type"] == "attachment":
object_text = attachment_to_content(
self.confluence_client, confluence_object
confluence_client=self.confluence_client, attachment=confluence_object
)
if object_text is None:
# This only happens for attachments that are not parseable
return None
# Get space name
@@ -193,44 +217,39 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
)
def _fetch_document_batches(self) -> GenerateDocumentsOutput:
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
doc_batch: list[Document] = []
confluence_page_ids: list[str] = []
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
# Fetch pages as Documents
for page_batch in self.confluence_client.paginated_cql_page_retrieval(
for page in self.confluence_client.paginated_cql_retrieval(
cql=page_query,
expand=",".join(_PAGE_EXPANSION_FIELDS),
limit=self.batch_size,
):
for page in page_batch:
confluence_page_ids.append(page["id"])
doc = self._convert_object_to_document(page)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
confluence_page_ids.append(page["id"])
doc = self._convert_object_to_document(page)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
# Fetch attachments as Documents
for confluence_page_id in confluence_page_ids:
attachment_cql = f"type=attachment and container='{confluence_page_id}'"
attachment_cql += self.cql_label_filter
# TODO: maybe should add time filter as well?
for attachments in self.confluence_client.paginated_cql_page_retrieval(
for attachment in self.confluence_client.paginated_cql_retrieval(
cql=attachment_cql,
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
):
for attachment in attachments:
doc = self._convert_object_to_document(attachment)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
doc = self._convert_object_to_document(attachment)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
@@ -255,48 +274,47 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
doc_metadata_list: list[SlimDocument] = []
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
page_query = self.cql_page_query + self.cql_label_filter
for pages in self.confluence_client.cql_paginate_all_expansions(
for page in self.confluence_client.cql_paginate_all_expansions(
cql=page_query,
expand=restrictions_expand,
):
for page in pages:
# If the page has restrictions, add them to the perm_sync_data
# These will be used by doc_sync.py to sync permissions
perm_sync_data = {
"restrictions": page.get("restrictions", {}),
"space_key": page.get("space", {}).get("key"),
}
# If the page has restrictions, add them to the perm_sync_data
# These will be used by doc_sync.py to sync permissions
perm_sync_data = {
"restrictions": page.get("restrictions", {}),
"space_key": page.get("space", {}).get("key"),
}
doc_metadata_list.append(
SlimDocument(
id=build_confluence_document_id(
self.wiki_base,
page["_links"]["webui"],
self.is_cloud,
),
perm_sync_data=perm_sync_data,
)
)
attachment_cql = f"type=attachment and container='{page['id']}'"
attachment_cql += self.cql_label_filter
for attachment in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_cql,
expand=restrictions_expand,
):
doc_metadata_list.append(
SlimDocument(
id=build_confluence_document_id(
self.wiki_base, page["_links"]["webui"]
self.wiki_base,
attachment["_links"]["webui"],
self.is_cloud,
),
perm_sync_data=perm_sync_data,
)
)
attachment_cql = f"type=attachment and container='{page['id']}'"
attachment_cql += self.cql_label_filter
for attachments in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_cql,
expand=restrictions_expand,
):
for attachment in attachments:
doc_metadata_list.append(
SlimDocument(
id=build_confluence_document_id(
self.wiki_base, attachment["_links"]["webui"]
),
perm_sync_data=perm_sync_data,
)
)
yield doc_metadata_list
doc_metadata_list = []
yield doc_metadata_list
doc_metadata_list = []

View File

@@ -20,6 +20,10 @@ F = TypeVar("F", bound=Callable[..., Any])
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
# https://jira.atlassian.com/browse/CONFCLOUD-76433
_PROBLEMATIC_EXPANSIONS = "body.storage.value"
_REPLACEMENT_EXPANSIONS = "body.view.value"
class ConfluenceRateLimitError(Exception):
pass
@@ -80,7 +84,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
MAX_RETRIES = 5
TIMEOUT = 3600
TIMEOUT = 600
timeout_at = time.monotonic() + TIMEOUT
for attempt in range(MAX_RETRIES):
@@ -88,13 +92,16 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
raise TimeoutError(
f"Confluence call attempts took longer than {TIMEOUT} seconds."
)
try:
# we're relying more on the client to rate limit itself
# and applying our own retries in a more specific set of circumstances
return confluence_call(*args, **kwargs)
except HTTPError as e:
delay_until = _handle_http_error(e, attempt)
logger.warning(
f"HTTPError in confluence call. "
f"Retrying in {delay_until} seconds..."
)
while time.monotonic() < delay_until:
# in the future, check a signal here to exit
time.sleep(1)
@@ -103,7 +110,6 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
# Users reported it to be intermittent, so just retry
if attempt == MAX_RETRIES - 1:
raise e
logger.exception(
"Confluence Client raised an AttributeError. Retrying..."
)
@@ -141,7 +147,7 @@ class OnyxConfluence(Confluence):
def _paginate_url(
self, url_suffix: str, limit: int | None = None
) -> Iterator[list[dict[str, Any]]]:
) -> Iterator[dict[str, Any]]:
"""
This will paginate through the top level query.
"""
@@ -153,46 +159,43 @@ class OnyxConfluence(Confluence):
while url_suffix:
try:
logger.debug(f"Making confluence call to {url_suffix}")
next_response = self.get(url_suffix)
except Exception as e:
logger.exception("Error in danswer_cql: \n")
raise e
yield next_response.get("results", [])
logger.warning(f"Error in confluence call to {url_suffix}")
# If the problematic expansion is in the url, replace it
# with the replacement expansion and try again
# If that fails, raise the error
if _PROBLEMATIC_EXPANSIONS not in url_suffix:
logger.exception(f"Error in confluence call to {url_suffix}")
raise e
logger.warning(
f"Replacing {_PROBLEMATIC_EXPANSIONS} with {_REPLACEMENT_EXPANSIONS}"
" and trying again."
)
url_suffix = url_suffix.replace(
_PROBLEMATIC_EXPANSIONS,
_REPLACEMENT_EXPANSIONS,
)
continue
# yield the results individually
yield from next_response.get("results", [])
url_suffix = next_response.get("_links", {}).get("next")
def paginated_groups_retrieval(
self,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
return self._paginate_url("rest/api/group", limit)
def paginated_group_members_retrieval(
self,
group_name: str,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
group_name = quote(group_name)
return self._paginate_url(f"rest/api/group/{group_name}/member", limit)
def paginated_cql_user_retrieval(
def paginated_cql_retrieval(
self,
cql: str,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
) -> Iterator[dict[str, Any]]:
"""
The content/search endpoint can be used to fetch pages, attachments, and comments.
"""
expand_string = f"&expand={expand}" if expand else ""
return self._paginate_url(
f"rest/api/search/user?cql={cql}{expand_string}", limit
)
def paginated_cql_page_retrieval(
self,
cql: str,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
expand_string = f"&expand={expand}" if expand else ""
return self._paginate_url(
yield from self._paginate_url(
f"rest/api/content/search?cql={cql}{expand_string}", limit
)
@@ -201,7 +204,7 @@ class OnyxConfluence(Confluence):
cql: str,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
) -> Iterator[dict[str, Any]]:
"""
This function will paginate through the top level query first, then
paginate through all of the expansions.
@@ -221,6 +224,44 @@ class OnyxConfluence(Confluence):
for item in data:
_traverse_and_update(item)
for results in self.paginated_cql_page_retrieval(cql, expand, limit):
_traverse_and_update(results)
yield results
for confluence_object in self.paginated_cql_retrieval(cql, expand, limit):
_traverse_and_update(confluence_object)
yield confluence_object
def paginated_cql_user_retrieval(
self,
cql: str,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
The search/user endpoint can be used to fetch users.
It's a seperate endpoint from the content/search endpoint used only for users.
Otherwise it's very similar to the content/search endpoint.
"""
expand_string = f"&expand={expand}" if expand else ""
yield from self._paginate_url(
f"rest/api/search/user?cql={cql}{expand_string}", limit
)
def paginated_groups_retrieval(
self,
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
This is not an SQL like query.
It's a confluence specific endpoint that can be used to fetch groups.
"""
yield from self._paginate_url("rest/api/group", limit)
def paginated_group_members_retrieval(
self,
group_name: str,
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
This is not an SQL like query.
It's a confluence specific endpoint that can be used to fetch the members of a group.
"""
group_name = quote(group_name)
yield from self._paginate_url(f"rest/api/group/{group_name}/member", limit)

View File

@@ -2,6 +2,7 @@ import io
from datetime import datetime
from datetime import timezone
from typing import Any
from urllib.parse import quote
import bs4
@@ -71,7 +72,9 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
def extract_text_from_confluence_html(
confluence_client: OnyxConfluence, confluence_object: dict[str, Any]
confluence_client: OnyxConfluence,
confluence_object: dict[str, Any],
fetched_titles: set[str],
) -> str:
"""Parse a Confluence html page and replace the 'user Id' by the real
User Display Name
@@ -79,7 +82,7 @@ def extract_text_from_confluence_html(
Args:
confluence_object (dict): The confluence object as a dict
confluence_client (Confluence): Confluence client
fetched_titles (set[str]): The titles of the pages that have already been fetched
Returns:
str: loaded and formated Confluence page
"""
@@ -100,6 +103,73 @@ def extract_text_from_confluence_html(
continue
# Include @ sign for tagging, more clear for LLM
user.replaceWith("@" + _get_user(confluence_client, user_id))
for html_page_reference in soup.findAll("ac:structured-macro"):
# Here, we only want to process page within page macros
if html_page_reference.attrs.get("ac:name") != "include":
continue
page_data = html_page_reference.find("ri:page")
if not page_data:
logger.warning(
f"Skipping retrieval of {html_page_reference} because because page data is missing"
)
continue
page_title = page_data.attrs.get("ri:content-title")
if not page_title:
# only fetch pages that have a title
logger.warning(
f"Skipping retrieval of {html_page_reference} because it has no title"
)
continue
if page_title in fetched_titles:
# prevent recursive fetching of pages
logger.debug(f"Skipping {page_title} because it has already been fetched")
continue
fetched_titles.add(page_title)
# Wrap this in a try-except because there are some pages that might not exist
try:
page_query = f"type=page and title='{quote(page_title)}'"
page_contents: dict[str, Any] | None = None
# Confluence enforces title uniqueness, so we should only get one result here
for page in confluence_client.paginated_cql_retrieval(
cql=page_query,
expand="body.storage.value",
limit=1,
):
page_contents = page
break
except Exception as e:
logger.warning(
f"Error getting page contents for object {confluence_object}: {e}"
)
continue
if not page_contents:
continue
text_from_page = extract_text_from_confluence_html(
confluence_client=confluence_client,
confluence_object=page_contents,
fetched_titles=fetched_titles,
)
html_page_reference.replaceWith(text_from_page)
for html_link_body in soup.findAll("ac:link-body"):
# This extracts the text from inline links in the page so they can be
# represented in the document text as plain text
try:
text_from_link = html_link_body.text
html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})")
except Exception as e:
logger.warning(f"Error processing ac:link-body: {e}")
return format_document_soup(soup)
@@ -153,7 +223,9 @@ def attachment_to_content(
return extracted_text
def build_confluence_document_id(base_url: str, content_url: str) -> str:
def build_confluence_document_id(
base_url: str, content_url: str, is_cloud: bool
) -> str:
"""For confluence, the document id is the page url for a page based document
or the attachment download url for an attachment based document
@@ -164,6 +236,8 @@ def build_confluence_document_id(base_url: str, content_url: str) -> str:
Returns:
str: The document id
"""
if is_cloud and not base_url.endswith("/wiki"):
base_url += "/wiki"
return f"{base_url}{content_url}"
@@ -209,6 +283,6 @@ def build_confluence_client(
password=credentials_json["confluence_access_token"] if is_cloud else None,
token=credentials_json["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=60,
max_backoff_retries=10,
max_backoff_seconds=60,
)

View File

@@ -192,23 +192,33 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
def _update_traversed_parent_ids(self, folder_id: str) -> None:
self._retrieved_ids.add(folder_id)
def _get_all_user_emails(self, admins_only: bool) -> list[str]:
def _get_all_user_emails(self) -> list[str]:
# Start with primary admin email
user_emails = [self.primary_admin_email]
# Only fetch additional users if using service account
if isinstance(self.creds, OAuthCredentials):
return user_emails
admin_service = get_admin_service(
creds=self.creds,
user_email=self.primary_admin_email,
)
query = "isAdmin=true" if admins_only else "isAdmin=false"
emails = []
for user in execute_paginated_retrieval(
retrieval_function=admin_service.users().list,
list_key="users",
fields=USER_FIELDS,
domain=self.google_domain,
query=query,
):
if email := user.get("primaryEmail"):
emails.append(email)
return emails
# Get admins first since they're more likely to have access to most files
for is_admin in [True, False]:
query = "isAdmin=true" if is_admin else "isAdmin=false"
for user in execute_paginated_retrieval(
retrieval_function=admin_service.users().list,
list_key="users",
fields=USER_FIELDS,
domain=self.google_domain,
query=query,
):
if email := user.get("primaryEmail"):
if email not in user_emails:
user_emails.append(email)
return user_emails
def _get_all_drive_ids(self) -> set[str]:
primary_drive_service = get_drive_service(
@@ -216,55 +226,48 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
user_email=self.primary_admin_email,
)
all_drive_ids = set()
# We don't want to fail if we're using OAuth because you can
# access your my drive as a non admin user in an org still
ignore_fetch_failure = isinstance(self.creds, OAuthCredentials)
for drive in execute_paginated_retrieval(
retrieval_function=primary_drive_service.drives().list,
list_key="drives",
continue_on_404_or_403=ignore_fetch_failure,
useDomainAdminAccess=True,
fields="drives(id)",
):
all_drive_ids.add(drive["id"])
return all_drive_ids
def _initialize_all_class_variables(self) -> None:
# Get all user emails
# Get admins first becuase they are more likely to have access to the most files
user_emails = [self.primary_admin_email]
for admins_only in [True, False]:
for email in self._get_all_user_emails(admins_only=admins_only):
if email not in user_emails:
user_emails.append(email)
self._all_org_emails = user_emails
self._all_drive_ids: set[str] = self._get_all_drive_ids()
# remove drive ids from the folder ids because they are queried differently
self._requested_folder_ids -= self._all_drive_ids
# Remove drive_ids that are not in the all_drive_ids and check them as folders instead
invalid_drive_ids = self._requested_shared_drive_ids - self._all_drive_ids
if invalid_drive_ids:
if not all_drive_ids:
logger.warning(
f"Some shared drive IDs were not found. IDs: {invalid_drive_ids}"
"No drives found. This is likely because oauth user "
"is not an admin and cannot view all drive IDs. "
"Continuing with only the shared drive IDs specified in the config."
)
logger.warning("Checking for folder access instead...")
self._requested_folder_ids.update(invalid_drive_ids)
all_drive_ids = set(self._requested_shared_drive_ids)
if not self.include_shared_drives:
self._requested_shared_drive_ids = set()
elif not self._requested_shared_drive_ids:
self._requested_shared_drive_ids = self._all_drive_ids
return all_drive_ids
def _impersonate_user_for_retrieval(
self,
user_email: str,
is_slim: bool,
filtered_drive_ids: set[str],
filtered_folder_ids: set[str],
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
drive_service = get_drive_service(self.creds, user_email)
# if we are including my drives, try to get the current user's my
# drive if any of the following are true:
# - no specific emails were requested
# - the current user's email is in the requested emails
# - we are using OAuth (in which case we assume that is the only email we will try)
if self.include_my_drives and (
not self._requested_my_drive_emails
or user_email in self._requested_my_drive_emails
or isinstance(self.creds, OAuthCredentials)
):
yield from get_all_files_in_my_drive(
service=drive_service,
@@ -274,7 +277,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
end=end,
)
remaining_drive_ids = self._requested_shared_drive_ids - self._retrieved_ids
remaining_drive_ids = filtered_drive_ids - self._retrieved_ids
for drive_id in remaining_drive_ids:
yield from get_files_in_shared_drive(
service=drive_service,
@@ -285,7 +288,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
end=end,
)
remaining_folders = self._requested_folder_ids - self._retrieved_ids
remaining_folders = filtered_folder_ids - self._retrieved_ids
for folder_id in remaining_folders:
yield from crawl_folders_for_files(
service=drive_service,
@@ -302,22 +305,56 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
self._initialize_all_class_variables()
all_org_emails: list[str] = self._get_all_user_emails()
all_drive_ids: set[str] = self._get_all_drive_ids()
# remove drive ids from the folder ids because they are queried differently
filtered_folder_ids = self._requested_folder_ids - all_drive_ids
# Remove drive_ids that are not in the all_drive_ids and check them as folders instead
invalid_drive_ids = self._requested_shared_drive_ids - all_drive_ids
if invalid_drive_ids:
logger.warning(
f"Some shared drive IDs were not found. IDs: {invalid_drive_ids}"
)
logger.warning("Checking for folder access instead...")
filtered_folder_ids.update(invalid_drive_ids)
# If including shared drives, use the requested IDs if provided,
# otherwise use all drive IDs
filtered_drive_ids = set()
if self.include_shared_drives:
if self._requested_shared_drive_ids:
# Remove invalid drive IDs from requested IDs
filtered_drive_ids = (
self._requested_shared_drive_ids - invalid_drive_ids
)
else:
filtered_drive_ids = all_drive_ids
# Process users in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=10) as executor:
future_to_email = {
executor.submit(
self._impersonate_user_for_retrieval, email, is_slim, start, end
self._impersonate_user_for_retrieval,
email,
is_slim,
filtered_drive_ids,
filtered_folder_ids,
start,
end,
): email
for email in self._all_org_emails
for email in all_org_emails
}
# Yield results as they complete
for future in as_completed(future_to_email):
yield from future.result()
remaining_folders = self._requested_folder_ids - self._retrieved_ids
remaining_folders = (
filtered_drive_ids | filtered_folder_ids
) - self._retrieved_ids
if remaining_folders:
logger.warning(
f"Some folders/drives were not retrieved. IDs: {remaining_folders}"

View File

@@ -105,7 +105,7 @@ def execute_paginated_retrieval(
)()
elif e.resp.status == 404 or e.resp.status == 403:
if continue_on_404_or_403:
logger.warning(f"Error executing request: {e}")
logger.debug(f"Error executing request: {e}")
results = {}
else:
raise e

View File

@@ -169,6 +169,7 @@ def get_document_connector_counts(
def get_document_counts_for_cc_pairs(
db_session: Session, cc_pair_identifiers: list[ConnectorCredentialPairIdentifier]
) -> Sequence[tuple[int, int, int]]:
"""Returns a sequence of tuples of (connector_id, credential_id, document count)"""
stmt = (
select(
DocumentByConnectorCredentialPair.connector_id,
@@ -323,23 +324,23 @@ def upsert_documents(
def upsert_document_by_connector_credential_pair(
db_session: Session, document_metadata_batch: list[DocumentMetadata]
db_session: Session, connector_id: int, credential_id: int, document_ids: list[str]
) -> None:
"""NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause."""
if not document_metadata_batch:
logger.info("`document_metadata_batch` is empty. Skipping.")
if not document_ids:
logger.info("`document_ids` is empty. Skipping.")
return
insert_stmt = insert(DocumentByConnectorCredentialPair).values(
[
model_to_dict(
DocumentByConnectorCredentialPair(
id=document_metadata.document_id,
connector_id=document_metadata.connector_id,
credential_id=document_metadata.credential_id,
id=doc_id,
connector_id=connector_id,
credential_id=credential_id,
)
)
for document_metadata in document_metadata_batch
for doc_id in document_ids
]
)
# for now, there are no columns to update. If more metadata is added, then this
@@ -400,17 +401,6 @@ def mark_document_as_synced(document_id: str, db_session: Session) -> None:
db_session.commit()
def upsert_documents_complete(
db_session: Session,
document_metadata_batch: list[DocumentMetadata],
) -> None:
upsert_documents(db_session, document_metadata_batch)
upsert_document_by_connector_credential_pair(db_session, document_metadata_batch)
logger.info(
f"Upserted {len(document_metadata_batch)} document store entries into DB"
)
def delete_document_by_connector_credential_pair__no_commit(
db_session: Session,
document_id: str,
@@ -520,7 +510,7 @@ def prepare_to_modify_documents(
db_session.commit() # ensure that we're not in a transaction
lock_acquired = False
for _ in range(_NUM_LOCK_ATTEMPTS):
for i in range(_NUM_LOCK_ATTEMPTS):
try:
with db_session.begin() as transaction:
lock_acquired = acquire_document_locks(
@@ -531,7 +521,7 @@ def prepare_to_modify_documents(
break
except OperationalError as e:
logger.warning(
f"Failed to acquire locks for documents, retrying. Error: {e}"
f"Failed to acquire locks for documents on attempt {i}, retrying. Error: {e}"
)
time.sleep(retry_delay)

View File

@@ -312,7 +312,9 @@ async def get_async_session_with_tenant(
await session.execute(text(f'SET search_path = "{tenant_id}"'))
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
await session.execute(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
text(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
)
except Exception:
logger.exception("Error setting search_path.")
@@ -373,7 +375,9 @@ def get_session_with_tenant(
cursor.execute(f'SET search_path = "{tenant_id}"')
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
cursor.execute(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
text(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
)
finally:
cursor.close()

View File

@@ -67,6 +67,13 @@ def create_index_attempt(
return new_attempt.id
def delete_index_attempt(db_session: Session, index_attempt_id: int) -> None:
index_attempt = get_index_attempt(db_session, index_attempt_id)
if index_attempt:
db_session.delete(index_attempt)
db_session.commit()
def mock_successful_index_attempt(
connector_credential_pair_id: int,
search_settings_id: int,

View File

@@ -20,7 +20,8 @@ from danswer.db.document import get_documents_by_ids
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document import update_docs_last_modified__no_commit
from danswer.db.document import update_docs_updated_at__no_commit
from danswer.db.document import upsert_documents_complete
from danswer.db.document import upsert_document_by_connector_credential_pair
from danswer.db.document import upsert_documents
from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.index_attempt import create_index_attempt_error
from danswer.db.models import Document as DBDocument
@@ -56,13 +57,13 @@ class IndexingPipelineProtocol(Protocol):
...
def upsert_documents_in_db(
def _upsert_documents_in_db(
documents: list[Document],
index_attempt_metadata: IndexAttemptMetadata,
db_session: Session,
) -> None:
# Metadata here refers to basic document info, not metadata about the actual content
doc_m_batch: list[DocumentMetadata] = []
document_metadata_list: list[DocumentMetadata] = []
for doc in documents:
first_link = next(
(section.link for section in doc.sections if section.link), ""
@@ -77,12 +78,9 @@ def upsert_documents_in_db(
secondary_owners=get_experts_stores_representations(doc.secondary_owners),
from_ingestion_api=doc.from_ingestion_api,
)
doc_m_batch.append(db_doc_metadata)
document_metadata_list.append(db_doc_metadata)
upsert_documents_complete(
db_session=db_session,
document_metadata_batch=doc_m_batch,
)
upsert_documents(db_session, document_metadata_list)
# Insert document content metadata
for doc in documents:
@@ -95,21 +93,25 @@ def upsert_documents_in_db(
document_id=doc.id,
db_session=db_session,
)
else:
create_or_add_document_tag(
tag_key=k,
tag_value=v,
source=doc.source,
document_id=doc.id,
db_session=db_session,
)
continue
create_or_add_document_tag(
tag_key=k,
tag_value=v,
source=doc.source,
document_id=doc.id,
db_session=db_session,
)
def get_doc_ids_to_update(
documents: list[Document], db_docs: list[DBDocument]
) -> list[Document]:
"""Figures out which documents actually need to be updated. If a document is already present
and the `updated_at` hasn't changed, we shouldn't need to do anything with it."""
and the `updated_at` hasn't changed, we shouldn't need to do anything with it.
NB: Still need to associate the document in the DB if multiple connectors are
indexing the same doc."""
id_update_time_map = {
doc.id: doc.doc_updated_at for doc in db_docs if doc.doc_updated_at
}
@@ -195,9 +197,9 @@ def index_doc_batch_prepare(
db_session: Session,
ignore_time_skip: bool = False,
) -> DocumentBatchPrepareContext | None:
"""This sets up the documents in the relational DB (source of truth) for permissions, metadata, etc.
"""Sets up the documents in the relational DB (source of truth) for permissions, metadata, etc.
This preceeds indexing it into the actual document index."""
documents = []
documents: list[Document] = []
for document in document_batch:
empty_contents = not any(section.text.strip() for section in document.sections)
if (
@@ -212,43 +214,58 @@ def index_doc_batch_prepare(
logger.warning(
f"Skipping document with ID {document.id} as it has neither title nor content."
)
elif (
document.title is not None and not document.title.strip() and empty_contents
):
continue
if document.title is not None and not document.title.strip() and empty_contents:
# The title is explicitly empty ("" and not None) and the document is empty
# so when building the chunk text representation, it will be empty and unuseable
logger.warning(
f"Skipping document with ID {document.id} as the chunks will be empty."
)
else:
documents.append(document)
continue
document_ids = [document.id for document in documents]
documents.append(document)
# Create a trimmed list of docs that don't have a newer updated at
# Shortcuts the time-consuming flow on connector index retries
document_ids: list[str] = [document.id for document in documents]
db_docs: list[DBDocument] = get_documents_by_ids(
db_session=db_session,
document_ids=document_ids,
)
# Skip indexing docs that don't have a newer updated at
# Shortcuts the time-consuming flow on connector index retries
updatable_docs = (
get_doc_ids_to_update(documents=documents, db_docs=db_docs)
if not ignore_time_skip
else documents
)
# No docs to update either because the batch is empty or every doc was already indexed
# for all updatable docs, upsert into the DB
# Does not include doc_updated_at which is also used to indicate a successful update
if updatable_docs:
_upsert_documents_in_db(
documents=updatable_docs,
index_attempt_metadata=index_attempt_metadata,
db_session=db_session,
)
logger.info(
f"Upserted {len(updatable_docs)} changed docs out of "
f"{len(documents)} total docs into the DB"
)
# for all docs, upsert the document to cc pair relationship
upsert_document_by_connector_credential_pair(
db_session,
index_attempt_metadata.connector_id,
index_attempt_metadata.credential_id,
document_ids,
)
# No docs to process because the batch is empty or every doc was already indexed
if not updatable_docs:
return None
# Create records in the source of truth about these documents,
# does not include doc_updated_at which is also used to indicate a successful update
upsert_documents_in_db(
documents=documents,
index_attempt_metadata=index_attempt_metadata,
db_session=db_session,
)
id_to_db_doc_map = {doc.id: doc for doc in db_docs}
return DocumentBatchPrepareContext(
updatable_docs=updatable_docs, id_to_db_doc_map=id_to_db_doc_map
@@ -269,7 +286,10 @@ def index_doc_batch(
) -> tuple[int, int]:
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
Note that the documents should already be batched at this point so that it does not inflate the
memory requirements"""
memory requirements
Returns a tuple where the first element is the number of new docs and the
second element is the number of chunks."""
no_access = DocumentAccess.build(
user_emails=[],
@@ -312,9 +332,9 @@ def index_doc_batch(
# we're concerned about race conditions where multiple simultaneous indexings might result
# in one set of metadata overwriting another one in vespa.
# we still write data here for immediate and most likely correct sync, but
# we still write data here for the immediate and most likely correct sync, but
# to resolve this, an update of the last modified field at the end of this loop
# always triggers a final metadata sync
# always triggers a final metadata sync via the celery queue
access_aware_chunks = [
DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
@@ -351,7 +371,8 @@ def index_doc_batch(
ids_to_new_updated_at = {}
for doc in successful_docs:
last_modified_ids.append(doc.id)
# doc_updated_at is the connector source's idea of when the doc was last modified
# doc_updated_at is the source's idea (on the other end of the connector)
# of when the doc was last modified
if doc.doc_updated_at is None:
continue
ids_to_new_updated_at[doc.id] = doc.doc_updated_at
@@ -366,10 +387,13 @@ def index_doc_batch(
db_session.commit()
return len([r for r in insertion_records if r.already_existed is False]), len(
access_aware_chunks
result = (
len([r for r in insertion_records if r.already_existed is False]),
len(access_aware_chunks),
)
return result
def build_indexing_pipeline(
*,

View File

@@ -1,9 +1,10 @@
import time
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
@@ -13,6 +14,7 @@ from danswer.db.connector_credential_pair import get_connector_credential_pair_f
from danswer.db.document import (
construct_document_select_for_connector_credential_pair_by_needs_sync,
)
from danswer.db.models import Document
from danswer.redis.redis_object_helper import RedisObjectHelper
@@ -30,6 +32,9 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
def __init__(self, tenant_id: str | None, id: int) -> None:
super().__init__(tenant_id, str(id))
# documents that should be skipped
self.skip_docs: set[str] = set()
@classmethod
def get_fence_key(cls) -> str:
return RedisConnectorCredentialPair.FENCE_PREFIX
@@ -45,14 +50,19 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
# example: connector_taskset
return f"{self.TASKSET_PREFIX}"
def set_skip_docs(self, skip_docs: set[str]) -> None:
# documents that should be skipped. Note that this classes updates
# the list on the fly
self.skip_docs = skip_docs
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
lock: RedisLock,
tenant_id: str | None,
) -> int | None:
) -> tuple[int, int] | None:
last_lock_time = time.monotonic()
async_results = []
@@ -63,7 +73,10 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
cc_pair.connector_id, cc_pair.credential_id
)
num_docs = 0
for doc in db_session.scalars(stmt).yield_per(1):
doc = cast(Document, doc)
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
@@ -71,6 +84,12 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
lock.reacquire()
last_lock_time = current_time
num_docs += 1
# check if we should skip the document (typically because it's already syncing)
if doc.id in self.skip_docs:
continue
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
@@ -93,5 +112,6 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
)
async_results.append(result)
self.skip_docs.add(doc.id)
return len(async_results)
return len(async_results), num_docs

View File

@@ -6,6 +6,7 @@ from uuid import uuid4
import redis
from celery import Celery
from pydantic import BaseModel
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
@@ -83,7 +84,7 @@ class RedisConnectorDelete:
self,
celery_app: Celery,
db_session: Session,
lock: redis.lock.Lock,
lock: RedisLock,
) -> int | None:
"""Returns None if the cc_pair doesn't exist.
Otherwise, returns an int with the number of generated tasks."""

View File

@@ -6,7 +6,7 @@ import redis
from pydantic import BaseModel
class RedisConnectorIndexingFenceData(BaseModel):
class RedisConnectorIndexPayload(BaseModel):
index_attempt_id: int | None
started: datetime | None
submitted: datetime
@@ -71,22 +71,20 @@ class RedisConnectorIndex:
return False
@property
def payload(self) -> RedisConnectorIndexingFenceData | None:
def payload(self) -> RedisConnectorIndexPayload | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
if fence_bytes is None:
return None
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorIndexingFenceData.model_validate_json(
cast(str, fence_str)
)
payload = RedisConnectorIndexPayload.model_validate_json(cast(str, fence_str))
return payload
def set_fence(
self,
payload: RedisConnectorIndexingFenceData | None,
payload: RedisConnectorIndexPayload | None,
) -> None:
if not payload:
self.redis.delete(self.fence_key)

View File

@@ -4,6 +4,7 @@ from uuid import uuid4
import redis
from celery import Celery
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
@@ -105,7 +106,7 @@ class RedisConnectorPrune:
documents_to_prune: set[str],
celery_app: Celery,
db_session: Session,
lock: redis.lock.Lock | None,
lock: RedisLock | None,
) -> int | None:
last_lock_time = time.monotonic()

View File

@@ -5,6 +5,7 @@ from uuid import uuid4
import redis
from celery import Celery
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
@@ -50,9 +51,9 @@ class RedisDocumentSet(RedisObjectHelper):
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
lock: RedisLock,
tenant_id: str | None,
) -> int | None:
) -> tuple[int, int] | None:
last_lock_time = time.monotonic()
async_results = []
@@ -84,7 +85,7 @@ class RedisDocumentSet(RedisObjectHelper):
async_results.append(result)
return len(async_results)
return len(async_results), len(async_results)
def reset(self) -> None:
self.redis.delete(self.taskset_key)

View File

@@ -1,9 +1,9 @@
from abc import ABC
from abc import abstractmethod
import redis
from celery import Celery
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.redis.redis_pool import get_redis_client
@@ -85,7 +85,13 @@ class RedisObjectHelper(ABC):
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
lock: RedisLock,
tenant_id: str | None,
) -> int | None:
pass
) -> tuple[int, int] | None:
"""First element should be the number of actual tasks generated, second should
be the number of docs that were candidates to be synced for the cc pair.
The need for this is when we are syncing stale docs referenced by multiple
connectors. In a single pass across multiple cc pairs, we only want a task
for be created for a particular document id the first time we see it.
The rest can be skipped."""

View File

@@ -5,6 +5,7 @@ from uuid import uuid4
import redis
from celery import Celery
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
@@ -51,15 +52,15 @@ class RedisUserGroup(RedisObjectHelper):
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
lock: RedisLock,
tenant_id: str | None,
) -> int | None:
) -> tuple[int, int] | None:
last_lock_time = time.monotonic()
async_results = []
if not global_version.is_ee_version():
return 0
return 0, 0
try:
construct_document_select_by_usergroup = fetch_versioned_implementation(
@@ -67,7 +68,7 @@ class RedisUserGroup(RedisObjectHelper):
"construct_document_select_by_usergroup",
)
except ModuleNotFoundError:
return 0
return 0, 0
stmt = construct_document_select_by_usergroup(int(self._id))
for doc in db_session.scalars(stmt).yield_per(1):
@@ -97,7 +98,7 @@ class RedisUserGroup(RedisObjectHelper):
async_results.append(result)
return len(async_results)
return len(async_results), len(async_results)
def reset(self) -> None:
self.redis.delete(self.taskset_key)

View File

@@ -1,5 +1,4 @@
from typing import Any
from atlassian import Confluence # type: ignore
from sqlalchemy.orm import Session
from danswer.connectors.confluence.onyx_confluence import OnyxConfluence
@@ -19,12 +18,8 @@ def _get_group_members_email_paginated(
confluence_client: OnyxConfluence,
group_name: str,
) -> set[str]:
members: list[dict[str, Any]] = []
for member_batch in confluence_client.paginated_group_members_retrieval(group_name):
members.extend(member_batch)
group_member_emails: set[str] = set()
for member in members:
for member in confluence_client.paginated_group_members_retrieval(group_name):
email = member.get("email")
if not email:
user_name = member.get("username")
@@ -43,19 +38,33 @@ def confluence_group_sync(
db_session: Session,
cc_pair: ConnectorCredentialPair,
) -> None:
credentials = cc_pair.credential.credential_json
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
wiki_base = cc_pair.connector.connector_specific_config["wiki_base"]
# test connection with direct client, no retries
confluence_client = Confluence(
api_version="cloud" if is_cloud else "latest",
url=wiki_base.rstrip("/"),
username=credentials["confluence_username"] if is_cloud else None,
password=credentials["confluence_access_token"] if is_cloud else None,
token=credentials["confluence_access_token"] if not is_cloud else None,
)
spaces = confluence_client.get_all_spaces(limit=1)
if not spaces:
raise RuntimeError(f"No spaces found at {wiki_base}!")
confluence_client = build_confluence_client(
credentials_json=cc_pair.credential.credential_json,
credentials_json=credentials,
is_cloud=is_cloud,
wiki_base=cc_pair.connector.connector_specific_config["wiki_base"],
wiki_base=wiki_base,
)
# Get all group names
group_names: list[str] = []
for group_batch in confluence_client.paginated_groups_retrieval():
for group in group_batch:
if group_name := group.get("name"):
group_names.append(group_name)
for group in confluence_client.paginated_groups_retrieval():
if group_name := group.get("name"):
group_names.append(group_name)
# For each group name, get all members and create a danswer group
danswer_groups: list[ExternalUserGroup] = []

View File

@@ -29,6 +29,78 @@ from tests.integration.common_utils.test_models import DATestUserGroup
from tests.integration.common_utils.vespa import vespa_fixture
# def test_connector_creation(reset: None) -> None:
# # Creating an admin user (first user created is automatically an admin)
# admin_user: DATestUser = UserManager.create(name="admin_user")
# # create connectors
# cc_pair_1 = CCPairManager.create_from_scratch(
# source=DocumentSource.INGESTION_API,
# user_performing_action=admin_user,
# )
# cc_pair_info = CCPairManager.get_single(
# cc_pair_1.id, user_performing_action=admin_user
# )
# assert cc_pair_info
# assert cc_pair_info.creator
# assert str(cc_pair_info.creator) == admin_user.id
# assert cc_pair_info.creator_email == admin_user.email
# TODO(rkuo): will enable this once i have credentials on github
# def test_overlapping_connector_creation(reset: None) -> None:
# # Creating an admin user (first user created is automatically an admin)
# admin_user: DATestUser = UserManager.create(name="admin_user")
# config = {
# "wiki_base": os.environ["CONFLUENCE_TEST_SPACE_URL"],
# "space": os.environ["CONFLUENCE_TEST_SPACE"],
# "is_cloud": True,
# "page_id": "",
# }
# credential = {
# "confluence_username": os.environ["CONFLUENCE_USER_NAME"],
# "confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"],
# }
# # store the time before we create the connector so that we know after
# # when the indexing should have started
# now = datetime.now(timezone.utc)
# # create connector
# cc_pair_1 = CCPairManager.create_from_scratch(
# source=DocumentSource.CONFLUENCE,
# connector_specific_config=config,
# credential_json=credential,
# user_performing_action=admin_user,
# )
# CCPairManager.wait_for_indexing(
# cc_pair_1, now, timeout=60, user_performing_action=admin_user
# )
# cc_pair_2 = CCPairManager.create_from_scratch(
# source=DocumentSource.CONFLUENCE,
# connector_specific_config=config,
# credential_json=credential,
# user_performing_action=admin_user,
# )
# CCPairManager.wait_for_indexing(
# cc_pair_2, now, timeout=60, user_performing_action=admin_user
# )
# info_1 = CCPairManager.get_single(cc_pair_1.id)
# assert info_1
# info_2 = CCPairManager.get_single(cc_pair_2.id)
# assert info_2
# assert info_1.num_docs_indexed == info_2.num_docs_indexed
def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")

View File

@@ -104,6 +104,7 @@ const RenderField: FC<RenderFieldProps> = ({
type={field.type}
label={label}
name={field.name}
isTextArea={true}
/>
)}
</>

View File

@@ -221,8 +221,11 @@ export const connectorConfigs: Record<
},
{
type: "text",
description:
"Enter a comma separated list of the URLs of the shared drives to index. Leave blank to index all shared drives.",
description: (currentCredential) => {
return currentCredential?.credential_json?.google_tokens
? "If you are a non-admin user authenticated using Google Oauth, you will need to specify the URLs for the shared drives you would like to index. Leaving this blank will NOT index any shared drives."
: "Enter a comma separated list of the URLs for the shared drive you would like to index. Leave this blank to index all shared drives.";
},
label: "Shared Drive URLs",
name: "shared_drive_urls",
visibleCondition: (values) => values.include_shared_drives,
@@ -230,14 +233,16 @@ export const connectorConfigs: Record<
},
{
type: "checkbox",
label: (currentCredential) =>
currentCredential?.credential_json?.google_drive_tokens
label: (currentCredential) => {
return currentCredential?.credential_json?.google_tokens
? "Include My Drive?"
: "Include Everyone's My Drive?",
description: (currentCredential) =>
currentCredential?.credential_json?.google_drive_tokens
: "Include Everyone's My Drive?";
},
description: (currentCredential) => {
return currentCredential?.credential_json?.google_tokens
? "This will allow Danswer to index everything in your My Drive."
: "This will allow Danswer to index everything in everyone's My Drives.",
: "This will allow Danswer to index everything in everyone's My Drives.";
},
name: "include_my_drives",
optional: true,
default: true,
@@ -250,7 +255,7 @@ export const connectorConfigs: Record<
name: "my_drive_emails",
visibleCondition: (values, currentCredential) =>
values.include_my_drives &&
!currentCredential?.credential_json?.google_drive_tokens,
!currentCredential?.credential_json?.google_tokens,
optional: true,
},
],
@@ -258,7 +263,7 @@ export const connectorConfigs: Record<
{
type: "text",
description:
"Enter a comma separated list of the URLs of the folders located in Shared Drives to index. The files located in these folders (and all subfolders) will be indexed. Note: This will be in addition to the 'Include Shared Drives' and 'Shared Drive URLs' settings, so leave those blank if you only want to index the folders specified here.",
"Enter a comma separated list of the URLs of any folders you would like to index. The files located in these folders (and all subfolders) will be indexed. Note: This will be in addition to whatever settings you have selected above, so leave those blank if you only want to index the folders specified here.",
label: "Folder URLs",
name: "shared_folder_urls",
optional: true,