Compare commits

..

4 Commits

Author SHA1 Message Date
pablodanswer
8b9e1a07d5 typing 2024-11-11 09:26:46 -08:00
pablodanswer
b6301ffcb9 spacing 2024-11-11 09:05:01 -08:00
pablodanswer
490ce0db18 cleaner approach 2024-11-11 09:03:49 -08:00
pablodanswer
b2ca13eaae treat async values differently 2024-11-11 08:59:16 -08:00
152 changed files with 12613 additions and 6898 deletions

View File

@@ -197,8 +197,7 @@ jobs:
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
danswer/danswer-integration:test \
/app/tests/integration/tests \
/app/tests/integration/connector_job_tests
/app/tests/integration/tests
continue-on-error: true
id: run_tests

View File

@@ -23,6 +23,21 @@ jobs:
with:
version: v3.14.4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
cache: 'pip'
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
- name: Set up chart-testing
uses: helm/chart-testing-action@v2.6.1
@@ -37,22 +52,6 @@ jobs:
echo "changed=true" >> "$GITHUB_OUTPUT"
fi
# rkuo: I don't think we need python?
# - name: Set up Python
# uses: actions/setup-python@v5
# with:
# python-version: '3.11'
# cache: 'pip'
# cache-dependency-path: |
# backend/requirements/default.txt
# backend/requirements/dev.txt
# backend/requirements/model_server.txt
# - run: |
# python -m pip install --upgrade pip
# pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
# pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
# pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
# lint all charts if any changes were detected
- name: Run chart-testing (lint)
if: steps.list-changed.outputs.changed == 'true'

View File

@@ -203,7 +203,7 @@
"--loglevel=INFO",
"--hostname=light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
"vespa_metadata_sync,connector_deletion",
],
"presentation": {
"group": "2",
@@ -232,7 +232,7 @@
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync",
"connector_pruning",
],
"presentation": {
"group": "2",

View File

@@ -1,68 +0,0 @@
"""default chosen assistants to none
Revision ID: 26b931506ecb
Revises: 2daa494a0851
Create Date: 2024-11-12 13:23:29.858995
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "26b931506ecb"
down_revision = "2daa494a0851"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user", sa.Column("chosen_assistants_new", postgresql.JSONB(), nullable=True)
)
op.execute(
"""
UPDATE "user"
SET chosen_assistants_new =
CASE
WHEN chosen_assistants = '[-2, -1, 0]' THEN NULL
ELSE chosen_assistants
END
"""
)
op.drop_column("user", "chosen_assistants")
op.alter_column(
"user", "chosen_assistants_new", new_column_name="chosen_assistants"
)
def downgrade() -> None:
op.add_column(
"user",
sa.Column(
"chosen_assistants_old",
postgresql.JSONB(),
nullable=False,
server_default="[-2, -1, 0]",
),
)
op.execute(
"""
UPDATE "user"
SET chosen_assistants_old =
CASE
WHEN chosen_assistants IS NULL THEN '[-2, -1, 0]'::jsonb
ELSE chosen_assistants
END
"""
)
op.drop_column("user", "chosen_assistants")
op.alter_column(
"user", "chosen_assistants_old", new_column_name="chosen_assistants"
)

View File

@@ -1,30 +0,0 @@
"""add-group-sync-time
Revision ID: 2daa494a0851
Revises: c0fd6e4da83a
Create Date: 2024-11-11 10:57:22.991157
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2daa494a0851"
down_revision = "c0fd6e4da83a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column(
"last_time_external_group_sync",
sa.DateTime(timezone=True),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("connector_credential_pair", "last_time_external_group_sync")

View File

@@ -1,30 +0,0 @@
"""add creator to cc pair
Revision ID: 9cf5c00f72fe
Revises: c0fd6e4da83a
Create Date: 2024-11-12 15:16:42.682902
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "9cf5c00f72fe"
down_revision = "26b931506ecb"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column(
"creator_id",
sa.UUID(as_uuid=True),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("connector_credential_pair", "creator_id")

View File

@@ -16,41 +16,6 @@ class ExternalAccess:
is_public: bool
@dataclass(frozen=True)
class DocExternalAccess:
external_access: ExternalAccess
# The document ID
doc_id: str
def to_dict(self) -> dict:
return {
"external_access": {
"external_user_emails": list(self.external_access.external_user_emails),
"external_user_group_ids": list(
self.external_access.external_user_group_ids
),
"is_public": self.external_access.is_public,
},
"doc_id": self.doc_id,
}
@classmethod
def from_dict(cls, data: dict) -> "DocExternalAccess":
external_access = ExternalAccess(
external_user_emails=set(
data["external_access"].get("external_user_emails", [])
),
external_user_group_ids=set(
data["external_access"].get("external_user_group_ids", [])
),
is_public=data["external_access"]["is_public"],
)
return cls(
external_access=external_access,
doc_id=data["doc_id"],
)
@dataclass(frozen=True)
class DocumentAccess(ExternalAccess):
# User emails for Danswer users, None indicates admin

View File

@@ -15,7 +15,6 @@ class UserRole(str, Enum):
for all groups they are a member of
"""
LIMITED = "limited"
BASIC = "basic"
ADMIN = "admin"
CURATOR = "curator"

View File

@@ -228,17 +228,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
safe: bool = False,
request: Optional[Request] = None,
) -> User:
referral_source = None
if request is not None:
referral_source = request.cookies.get("referral_source", None)
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=user_create.email,
referral_source=referral_source,
)
async with get_async_session_with_tenant(tenant_id) as db_session:
@@ -299,17 +294,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> models.UOAP:
referral_source = None
if request:
referral_source = getattr(request.state, "referral_source", None)
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=account_email,
referral_source=referral_source,
)
if not tenant_id:
@@ -662,24 +652,10 @@ async def current_user_with_expired_token(
return await double_check_user(user, include_expired=True)
async def current_limited_user(
user: User | None = Depends(optional_user),
) -> User | None:
return await double_check_user(user)
async def current_user(
user: User | None = Depends(optional_user),
) -> User | None:
user = await double_check_user(user)
if not user:
return None
if user.role == UserRole.LIMITED:
raise BasicAuthenticationError(
detail="Access denied. User role is LIMITED. BASIC or higher permissions are required.",
)
return user
return await double_check_user(user)
async def current_curator_or_admin_user(
@@ -735,6 +711,8 @@ def generate_state_token(
# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91
def create_danswer_oauth_router(
oauth_client: BaseOAuth2,
backend: AuthenticationBackend,
@@ -784,22 +762,15 @@ def get_oauth_router(
response_model=OAuth2AuthorizeResponse,
)
async def authorize(
request: Request,
scopes: List[str] = Query(None),
request: Request, scopes: List[str] = Query(None)
) -> OAuth2AuthorizeResponse:
referral_source = request.cookies.get("referral_source", None)
if redirect_url is not None:
authorize_redirect_url = redirect_url
else:
authorize_redirect_url = str(request.url_for(callback_route_name))
next_url = request.query_params.get("next", "/")
state_data: Dict[str, str] = {
"next_url": next_url,
"referral_source": referral_source or "default_referral",
}
state_data: Dict[str, str] = {"next_url": next_url}
state = generate_state_token(state_data, state_secret)
authorization_url = await oauth_client.get_authorization_url(
authorize_redirect_url,
@@ -858,11 +829,8 @@ def get_oauth_router(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
next_url = state_data.get("next_url", "/")
referral_source = state_data.get("referral_source", None)
request.state.referral_source = referral_source
# Proceed to authenticate or create the user
# Authenticate user
try:
user = await user_manager.oauth_callback(
oauth_client.name,
@@ -904,6 +872,7 @@ def get_oauth_router(
redirect_response.status_code = response.status_code
if hasattr(response, "media_type"):
redirect_response.media_type = response.media_type
return redirect_response
return router

View File

@@ -24,8 +24,6 @@ from danswer.document_index.vespa_constants import VESPA_CONFIG_SERVER_URL
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_document_set import RedisDocumentSet
from danswer.redis.redis_pool import get_redis_client
@@ -138,22 +136,6 @@ def on_task_postrun(
RedisConnectorPrune.remove_from_taskset(int(cc_pair_id), task_id, r)
return
if task_id.startswith(RedisConnectorPermissionSync.SUBTASK_PREFIX):
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
if cc_pair_id is not None:
RedisConnectorPermissionSync.remove_from_taskset(
int(cc_pair_id), task_id, r
)
return
if task_id.startswith(RedisConnectorExternalGroupSync.SUBTASK_PREFIX):
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
if cc_pair_id is not None:
RedisConnectorExternalGroupSync.remove_from_taskset(
int(cc_pair_id), task_id, r
)
return
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
"""The first signal sent on celery worker startup"""

View File

@@ -12,7 +12,6 @@ from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
from shared_configs.configs import MULTI_TENANT
logger = setup_logger(__name__)
@@ -73,15 +72,6 @@ class DynamicTenantScheduler(PersistentScheduler):
logger.info(f"Found {len(existing_tenants)} existing tenants in schedule")
for tenant_id in tenant_ids:
if (
IGNORED_SYNCING_TENANT_LIST
and tenant_id in IGNORED_SYNCING_TENANT_LIST
):
logger.info(
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
)
continue
if tenant_id not in existing_tenants:
logger.info(f"Processing new tenant: {tenant_id}")

View File

@@ -91,7 +91,5 @@ def on_setup_logging(
celery_app.autodiscover_tasks(
[
"danswer.background.celery.tasks.pruning",
"danswer.background.celery.tasks.doc_permission_syncing",
"danswer.background.celery.tasks.external_group_syncing",
]
)

View File

@@ -6,7 +6,6 @@ from celery import signals
from celery import Task
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_process_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
@@ -60,7 +59,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=sender.concurrency)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
@@ -82,11 +81,6 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@worker_process_init.connect
def init_worker(**kwargs: Any) -> None:
SqlEngine.reset_engine()
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any

View File

@@ -92,6 +92,5 @@ celery_app.autodiscover_tasks(
"danswer.background.celery.tasks.shared",
"danswer.background.celery.tasks.vespa",
"danswer.background.celery.tasks.connector_deletion",
"danswer.background.celery.tasks.doc_permission_syncing",
]
)

View File

@@ -20,8 +20,6 @@ from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_connector_stop import RedisConnectorStop
@@ -136,10 +134,6 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
RedisConnectorStop.reset_all(r)
RedisConnectorPermissionSync.reset_all(r)
RedisConnectorExternalGroupSync.reset_all(r)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
@@ -239,8 +233,6 @@ celery_app.autodiscover_tasks(
"danswer.background.celery.tasks.connector_deletion",
"danswer.background.celery.tasks.indexing",
"danswer.background.celery.tasks.periodic",
"danswer.background.celery.tasks.doc_permission_syncing",
"danswer.background.celery.tasks.external_group_syncing",
"danswer.background.celery.tasks.pruning",
"danswer.background.celery.tasks.shared",
"danswer.background.celery.tasks.vespa",

View File

@@ -0,0 +1,96 @@
from datetime import timedelta
from typing import Any
from celery.beat import PersistentScheduler # type: ignore
from celery.utils.log import get_task_logger
from danswer.db.engine import get_all_tenant_ids
from danswer.utils.variable_functionality import fetch_versioned_implementation
logger = get_task_logger(__name__)
class DynamicTenantScheduler(PersistentScheduler):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._reload_interval = timedelta(minutes=1)
self._last_reload = self.app.now() - self._reload_interval
def setup_schedule(self) -> None:
super().setup_schedule()
def tick(self) -> float:
retval = super().tick()
now = self.app.now()
if (
self._last_reload is None
or (now - self._last_reload) > self._reload_interval
):
logger.info("Reloading schedule to check for new tenants...")
self._update_tenant_tasks()
self._last_reload = now
return retval
def _update_tenant_tasks(self) -> None:
logger.info("Checking for tenant task updates...")
try:
tenant_ids = get_all_tenant_ids()
tasks_to_schedule = fetch_versioned_implementation(
"danswer.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
)
new_beat_schedule: dict[str, dict[str, Any]] = {}
current_schedule = getattr(self, "_store", {"entries": {}}).get(
"entries", {}
)
existing_tenants = set()
for task_name in current_schedule.keys():
if "-" in task_name:
existing_tenants.add(task_name.split("-")[-1])
for tenant_id in tenant_ids:
if tenant_id not in existing_tenants:
logger.info(f"Found new tenant: {tenant_id}")
for task in tasks_to_schedule():
task_name = f"{task['name']}-{tenant_id}"
new_task = {
"task": task["task"],
"schedule": task["schedule"],
"kwargs": {"tenant_id": tenant_id},
}
if options := task.get("options"):
new_task["options"] = options
new_beat_schedule[task_name] = new_task
if self._should_update_schedule(current_schedule, new_beat_schedule):
logger.info(
"Updating schedule",
extra={
"new_tasks": len(new_beat_schedule),
"current_tasks": len(current_schedule),
},
)
if not hasattr(self, "_store"):
self._store: dict[str, dict] = {"entries": {}}
self.update_from_dict(new_beat_schedule)
logger.info(f"New schedule: {new_beat_schedule}")
logger.info("Tenant tasks updated successfully")
else:
logger.debug("No schedule updates needed")
except (AttributeError, KeyError):
logger.exception("Failed to process task configuration")
except Exception:
logger.exception("Unexpected error updating tenant tasks")
def _should_update_schedule(
self, current_schedule: dict, new_schedule: dict
) -> bool:
"""Compare schedules to determine if an update is needed."""
current_tasks = set(current_schedule.keys())
new_tasks = set(new_schedule.keys())
return current_tasks != new_tasks

View File

@@ -81,7 +81,7 @@ def extract_ids_from_runnable_connector(
callback: RunIndexingCallbackInterface | None = None,
) -> set[str]:
"""
If the SlimConnector hasnt been implemented for the given connector, just pull
If the PruneConnector hasnt been implemented for the given connector, just pull
all docs using the load_from_state and grab out the IDs.
Optionally, a callback can be passed to handle the length of each document batch.

View File

@@ -8,7 +8,7 @@ tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": "check_for_vespa_sync_task",
"schedule": timedelta(seconds=20),
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
@@ -20,13 +20,13 @@ tasks_to_schedule = [
{
"name": "check-for-indexing",
"task": "check_for_indexing",
"schedule": timedelta(seconds=15),
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-prune",
"task": "check_for_pruning",
"schedule": timedelta(seconds=15),
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
@@ -41,18 +41,6 @@ tasks_to_schedule = [
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-doc-permissions-sync",
"task": "check_for_doc_permissions_sync",
"schedule": timedelta(seconds=30),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-external-group-sync",
"task": "check_for_external_group_sync",
"schedule": timedelta(seconds=20),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
]

View File

@@ -143,12 +143,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
f"cc_pair={cc_pair_id}"
)
if redis_connector.permissions.fenced:
raise TaskDependencyError(
f"Connector deletion - Delayed (permissions in progress): "
f"cc_pair={cc_pair_id}"
)
# add tasks to celery and build up the task set to monitor in redis
redis_connector.delete.taskset_clear()

View File

@@ -1,321 +0,0 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from uuid import uuid4
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from danswer.access.models import DocExternalAccess
from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import DocumentSource
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_non_web_user_if_not_exists
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncData,
)
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import doc_permission_sync_ctx
from danswer.utils.logger import setup_logger
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.danswer.db.document import upsert_document_external_perms
from ee.danswer.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS
from ee.danswer.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
logger = setup_logger()
DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES = 3
# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT
LIGHT_SOFT_TIME_LIMIT = 105
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if external doc permissions sync is due."""
if cc_pair.access_type != AccessType.SYNC:
return False
# skip doc permissions sync if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return False
# If the last sync is None, it has never been run so we run the sync
last_perm_sync = cc_pair.last_time_perm_sync
if last_perm_sync is None:
return True
source_sync_period = DOC_PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
if not source_sync_period:
return True
# If the last sync is greater than the full fetch period, we run the sync
next_sync = last_perm_sync + timedelta(seconds=source_sync_period)
if datetime.now(timezone.utc) >= next_sync:
return True
return False
@shared_task(
name="check_for_doc_permissions_sync",
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
# get all cc pairs that need to be synced
cc_pair_ids_to_sync: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
for cc_pair in cc_pairs:
if _is_external_doc_permissions_sync_due(cc_pair):
cc_pair_ids_to_sync.append(cc_pair.id)
for cc_pair_id in cc_pair_ids_to_sync:
tasks_created = try_creating_permissions_sync_task(
self.app, cc_pair_id, r, tenant_id
)
if not tasks_created:
continue
task_logger.info(f"Doc permissions sync queued: cc_pair={cc_pair_id}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if lock_beat.owned():
lock_beat.release()
def try_creating_permissions_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
tenant_id: str | None,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Returns None if no syncing is required."""
redis_connector = RedisConnector(tenant_id, cc_pair_id)
LOCK_TIMEOUT = 30
lock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
timeout=LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
return None
try:
if redis_connector.permissions.fenced:
return None
if redis_connector.delete.fenced:
return None
if redis_connector.prune.fenced:
return None
redis_connector.permissions.generator_clear()
redis_connector.permissions.taskset_clear()
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
app.send_task(
"connector_permission_sync_generator_task",
kwargs=dict(
cc_pair_id=cc_pair_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.HIGH,
)
# set a basic fence to start
payload = RedisConnectorPermissionSyncData(
started=None,
)
redis_connector.permissions.set_fence(payload)
except Exception:
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
return None
finally:
if lock.owned():
lock.release()
return 1
@shared_task(
name="connector_permission_sync_generator_task",
acks_late=False,
soft_time_limit=JOB_TIMEOUT,
track_started=True,
trail=False,
bind=True,
)
def connector_permission_sync_generator_task(
self: Task,
cc_pair_id: int,
tenant_id: str | None,
) -> None:
"""
Permission sync task that handles document permission syncing for a given connector credential pair
This task assumes that the task has already been properly fenced
"""
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
doc_permission_sync_ctx_dict["cc_pair_id"] = cc_pair_id
doc_permission_sync_ctx_dict["request_id"] = self.request.id
doc_permission_sync_ctx.set(doc_permission_sync_ctx_dict)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client(tenant_id=tenant_id)
lock = r.lock(
DanswerRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
f"Permission sync task already running, exiting...: cc_pair={cc_pair_id}"
)
return None
try:
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if cc_pair is None:
raise ValueError(
f"No connector credential pair found for id: {cc_pair_id}"
)
source_type = cc_pair.connector.source
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
if doc_sync_func is None:
raise ValueError(f"No doc sync func found for {source_type}")
logger.info(f"Syncing docs for {source_type}")
payload = RedisConnectorPermissionSyncData(
started=datetime.now(timezone.utc),
)
redis_connector.permissions.set_fence(payload)
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
task_logger.info(
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
)
tasks_generated = redis_connector.permissions.generate_tasks(
self.app, lock, document_external_accesses, source_type
)
if tasks_generated is None:
return None
task_logger.info(
f"RedisConnector.permissions.generate_tasks finished. "
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
)
redis_connector.permissions.generator_complete = tasks_generated
except Exception as e:
task_logger.exception(f"Failed to run permission sync: cc_pair={cc_pair_id}")
redis_connector.permissions.generator_clear()
redis_connector.permissions.taskset_clear()
redis_connector.permissions.set_fence(None)
raise e
finally:
if lock.owned():
lock.release()
@shared_task(
name="update_external_document_permissions_task",
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
time_limit=LIGHT_TIME_LIMIT,
max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES,
bind=True,
)
def update_external_document_permissions_task(
self: Task,
tenant_id: str | None,
serialized_doc_external_access: dict,
source_string: str,
) -> bool:
document_external_access = DocExternalAccess.from_dict(
serialized_doc_external_access
)
doc_id = document_external_access.doc_id
external_access = document_external_access.external_access
try:
with get_session_with_tenant(tenant_id) as db_session:
# Then we build the update requests to update vespa
batch_add_non_web_user_if_not_exists(
db_session=db_session,
emails=list(external_access.external_user_emails),
)
upsert_document_external_perms(
db_session=db_session,
doc_id=doc_id,
external_access=external_access,
source_type=DocumentSource(source_string),
)
logger.debug(
f"Successfully synced postgres document permissions for {doc_id}"
)
return True
except Exception:
logger.exception("Error Syncing Document Permissions")
return False

View File

@@ -1,265 +0,0 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from uuid import uuid4
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector import mark_cc_pair_as_external_group_synced
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.danswer.db.external_perm import ExternalUserGroup
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair
from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIOD
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
logger = setup_logger()
EXTERNAL_GROUPS_UPDATE_MAX_RETRIES = 3
# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT
LIGHT_SOFT_TIME_LIMIT = 105
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if external group sync is due."""
if cc_pair.access_type != AccessType.SYNC:
return False
# skip pruning if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return False
# If there is not group sync function for the connector, we don't run the sync
# This is fine because all sources dont necessarily have a concept of groups
if not GROUP_PERMISSIONS_FUNC_MAP.get(cc_pair.connector.source):
return False
# If the last sync is None, it has never been run so we run the sync
last_ext_group_sync = cc_pair.last_time_external_group_sync
if last_ext_group_sync is None:
return True
source_sync_period = EXTERNAL_GROUP_SYNC_PERIOD
# If EXTERNAL_GROUP_SYNC_PERIOD is None, we always run the sync.
if not source_sync_period:
return True
# If the last sync is greater than the full fetch period, we run the sync
next_sync = last_ext_group_sync + timedelta(seconds=source_sync_period)
if datetime.now(timezone.utc) >= next_sync:
return True
return False
@shared_task(
name="check_for_external_group_sync",
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
cc_pair_ids_to_sync: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
for cc_pair in cc_pairs:
if _is_external_group_sync_due(cc_pair):
cc_pair_ids_to_sync.append(cc_pair.id)
for cc_pair_id in cc_pair_ids_to_sync:
tasks_created = try_creating_permissions_sync_task(
self.app, cc_pair_id, r, tenant_id
)
if not tasks_created:
continue
task_logger.info(f"External group sync queued: cc_pair={cc_pair_id}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if lock_beat.owned():
lock_beat.release()
def try_creating_permissions_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
tenant_id: str | None,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Returns None if no syncing is required."""
redis_connector = RedisConnector(tenant_id, cc_pair_id)
LOCK_TIMEOUT = 30
lock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_external_group_sync_tasks",
timeout=LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
return None
try:
# Dont kick off a new sync if the previous one is still running
if redis_connector.external_group_sync.fenced:
return None
redis_connector.external_group_sync.generator_clear()
redis_connector.external_group_sync.taskset_clear()
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
_ = app.send_task(
"connector_external_group_sync_generator_task",
kwargs=dict(
cc_pair_id=cc_pair_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.HIGH,
)
# set a basic fence to start
redis_connector.external_group_sync.set_fence(True)
except Exception:
task_logger.exception(
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
)
return None
finally:
if lock.owned():
lock.release()
return 1
@shared_task(
name="connector_external_group_sync_generator_task",
acks_late=False,
soft_time_limit=JOB_TIMEOUT,
track_started=True,
trail=False,
bind=True,
)
def connector_external_group_sync_generator_task(
self: Task,
cc_pair_id: int,
tenant_id: str | None,
) -> None:
"""
Permission sync task that handles document permission syncing for a given connector credential pair
This task assumes that the task has already been properly fenced
"""
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client(tenant_id=tenant_id)
lock = r.lock(
DanswerRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
)
try:
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
)
return None
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if cc_pair is None:
raise ValueError(
f"No connector credential pair found for id: {cc_pair_id}"
)
source_type = cc_pair.connector.source
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
if ext_group_sync_func is None:
raise ValueError(f"No external group sync func found for {source_type}")
logger.info(f"Syncing docs for {source_type}")
external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)
logger.info(
f"Syncing {len(external_user_groups)} external user groups for {source_type}"
)
replace_user__ext_group_for_cc_pair(
db_session=db_session,
cc_pair_id=cc_pair.id,
group_defs=external_user_groups,
source=cc_pair.connector.source,
)
logger.info(
f"Synced {len(external_user_groups)} external user groups for {source_type}"
)
mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)
except Exception as e:
task_logger.exception(
f"Failed to run external group sync: cc_pair={cc_pair_id}"
)
redis_connector.external_group_sync.generator_clear()
redis_connector.external_group_sync.taskset_clear()
raise e
finally:
# we always want to clear the fence after the task is done or failed so it doesn't get stuck
redis_connector.external_group_sync.set_fence(False)
if lock.owned():
lock.release()

View File

@@ -38,35 +38,6 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if pruning is due."""
# skip pruning if no prune frequency is set
# pruning can still be forced via the API which will run a pruning task directly
if not cc_pair.connector.prune_freq:
return False
# skip pruning if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
# skip pruning if the next scheduled prune time hasn't been reached yet
last_pruned = cc_pair.last_pruned
if not last_pruned:
if not cc_pair.last_successful_index_time:
# if we've never indexed, we can't prune
return False
# if never pruned, use the last time the connector indexed successfully
last_pruned = cc_pair.last_successful_index_time
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
if datetime.now(timezone.utc) < next_prune:
return False
return True
@shared_task(
name="check_for_pruning",
soft_time_limit=JOB_TIMEOUT,
@@ -98,7 +69,7 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
if not cc_pair:
continue
if not _is_pruning_due(cc_pair):
if not is_pruning_due(cc_pair, db_session, r):
continue
tasks_created = try_creating_prune_generator_task(
@@ -119,6 +90,47 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
lock_beat.release()
def is_pruning_due(
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
) -> bool:
"""Returns an int if pruning is triggered.
The int represents the number of prune tasks generated (in this case, only one
because the task is a long running generator task.)
Returns None if no pruning is triggered (due to not being needed or
other reasons such as simultaneous pruning restrictions.
Checks for scheduling related conditions, then delegates the rest of the checks to
try_creating_prune_generator_task.
"""
# skip pruning if no prune frequency is set
# pruning can still be forced via the API which will run a pruning task directly
if not cc_pair.connector.prune_freq:
return False
# skip pruning if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
# skip pruning if the next scheduled prune time hasn't been reached yet
last_pruned = cc_pair.last_pruned
if not last_pruned:
if not cc_pair.last_successful_index_time:
# if we've never indexed, we can't prune
return False
# if never pruned, use the last time the connector indexed successfully
last_pruned = cc_pair.last_successful_index_time
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
if datetime.now(timezone.utc) < next_prune:
return False
return True
def try_creating_prune_generator_task(
celery_app: Celery,
cc_pair: ConnectorCredentialPair,
@@ -154,16 +166,10 @@ def try_creating_prune_generator_task(
return None
try:
# skip pruning if already pruning
if redis_connector.prune.fenced:
if redis_connector.prune.fenced: # skip pruning if already pruning
return None
# skip pruning if the cc_pair is deleting
if redis_connector.delete.fenced:
return None
# skip pruning if doc permissions sync is running
if redis_connector.permissions.fenced:
if redis_connector.delete.fenced: # skip pruning if the cc_pair is deleting
return None
db_session.refresh(cc_pair)

View File

@@ -59,7 +59,7 @@ def document_by_cc_pair_cleanup_task(
connector / credential pair from the access list
(6) delete all relevant entries from postgres
"""
task_logger.debug(f"Task start: tenant={tenant_id} doc={document_id}")
task_logger.info(f"tenant={tenant_id} doc={document_id}")
try:
with get_session_with_tenant(tenant_id) as db_session:
@@ -141,9 +141,7 @@ def document_by_cc_pair_cleanup_task(
return False
except Exception as ex:
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
)
task_logger.info(f"Retry failed: {ex.last_attempt.attempt_number}")
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()
@@ -173,8 +171,8 @@ def document_by_cc_pair_cleanup_task(
else:
# This is the last attempt! mark the document as dirty in the db so that it
# eventually gets fixed out of band via stale document reconciliation
task_logger.warning(
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
task_logger.info(
f"Max retries reached. Marking doc as dirty for reconciliation: "
f"tenant={tenant_id} doc={document_id}"
)
with get_session_with_tenant(tenant_id):

View File

@@ -27,7 +27,6 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector import mark_cc_pair_as_permissions_synced
from danswer.db.connector import mark_ccpair_as_pruned
from danswer.db.connector_credential_pair import add_deletion_failure_message
from danswer.db.connector_credential_pair import (
@@ -59,10 +58,6 @@ from danswer.document_index.interfaces import VespaDocumentFields
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncData,
)
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_document_set import RedisDocumentSet
@@ -551,47 +546,6 @@ def monitor_ccpair_pruning_taskset(
redis_connector.prune.set_fence(False)
def monitor_ccpair_permissions_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"monitor_ccpair_permissions_taskset: could not parse cc_pair_id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if not redis_connector.permissions.fenced:
return
initial = redis_connector.permissions.generator_complete
if initial is None:
return
remaining = redis_connector.permissions.get_remaining()
task_logger.info(
f"Permissions sync progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
)
if remaining > 0:
return
payload: RedisConnectorPermissionSyncData | None = (
redis_connector.permissions.payload
)
start_time: datetime | None = payload.started if payload else None
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
redis_connector.permissions.taskset_clear()
redis_connector.permissions.generator_clear()
redis_connector.permissions.set_fence(None)
def monitor_ccpair_indexing_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
@@ -714,17 +668,13 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
n_pruning = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_PRUNING, r_celery
)
n_permissions_sync = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
)
task_logger.info(
f"Queue lengths: celery={n_celery} "
f"indexing={n_indexing} "
f"sync={n_sync} "
f"deletion={n_deletion} "
f"pruning={n_pruning} "
f"permissions_sync={n_permissions_sync} "
f"pruning={n_pruning}"
)
# do some cleanup before clearing fences
@@ -738,22 +688,20 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for attempt in attempts:
for a in attempts:
# if attempts exist in the db but we don't detect them in redis, mark them as failed
fence_key = RedisConnectorIndex.fence_key_with_ids(
attempt.connector_credential_pair_id, attempt.search_settings_id
a.connector_credential_pair_id, a.search_settings_id
)
if not r.exists(fence_key):
failure_reason = (
f"Unknown index attempt. Might be left over from a process restart: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
f"index_attempt={a.id} "
f"cc_pair={a.connector_credential_pair_id} "
f"search_settings={a.search_settings_id}"
)
task_logger.warning(failure_reason)
mark_attempt_failed(
attempt.id, db_session, failure_reason=failure_reason
)
mark_attempt_failed(a.id, db_session, failure_reason=failure_reason)
lock_beat.reacquire()
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
@@ -793,12 +741,6 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session)
# uncomment for debugging if needed
# r_celery = celery_app.broker_connection().channel().client
# length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery)
@@ -869,9 +811,7 @@ def vespa_metadata_sync_task(
)
except Exception as ex:
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
)
task_logger.warning(f"Retry failed: {ex.last_attempt.attempt_number}")
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()

View File

@@ -29,26 +29,18 @@ JobStatusType = (
def _initializer(
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
) -> Any:
"""Initialize the child process with a fresh SQLAlchemy Engine.
"""Ensure the parent proc's database connections are not touched
in the new connection pool
Based on SQLAlchemy's recommendations to handle multiprocessing:
Based on the recommended approach in the SQLAlchemy docs found:
https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork
"""
if kwargs is None:
kwargs = {}
logger.info("Initializing spawned worker child process.")
# Reset the engine in the child process
SqlEngine.reset_engine()
# Optionally set a custom app name for database logging purposes
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
# Initialize a new engine with desired parameters
SqlEngine.init_engine(pool_size=4, max_overflow=12, pool_recycle=60)
# Proceed with executing the target function
return func(*args, **kwargs)

View File

@@ -33,8 +33,8 @@ from danswer.document_index.factory import get_default_document_index
from danswer.indexing.embedder import DefaultIndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logger import IndexAttemptSingleton
from danswer.utils.logger import setup_logger
from danswer.utils.logger import TaskAttemptSingleton
from danswer.utils.variable_functionality import global_version
logger = setup_logger()
@@ -427,7 +427,7 @@ def run_indexing_entrypoint(
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
TaskAttemptSingleton.set_cc_and_index_id(
IndexAttemptSingleton.set_cc_and_index_id(
index_attempt_id, connector_credential_pair_id
)
with get_session_with_tenant(tenant_id) as db_session:

View File

@@ -0,0 +1,4 @@
def name_sync_external_doc_permissions_task(
cc_pair_id: int, tenant_id: str | None = None
) -> str:
return f"sync_external_doc_permissions_task__{cc_pair_id}"

View File

@@ -14,6 +14,15 @@ from danswer.db.tasks import mark_task_start
from danswer.db.tasks import register_task
def name_cc_prune_task(
connector_id: int | None = None, credential_id: int | None = None
) -> str:
task_name = f"prune_connector_credential_pair_{connector_id}_{credential_id}"
if not connector_id or not credential_id:
task_name = "prune_connector_credential_pair"
return task_name
T = TypeVar("T", bound=Callable)

View File

@@ -503,7 +503,3 @@ _API_KEY_HASH_ROUNDS_RAW = os.environ.get("API_KEY_HASH_ROUNDS")
API_KEY_HASH_ROUNDS = (
int(_API_KEY_HASH_ROUNDS_RAW) if _API_KEY_HASH_ROUNDS_RAW else None
)
POD_NAME = os.environ.get("POD_NAME")
POD_NAMESPACE = os.environ.get("POD_NAMESPACE")

View File

@@ -80,10 +80,6 @@ CELERY_INDEXING_LOCK_TIMEOUT = 60 * 60 # 60 min
# if we can get callbacks as object bytes download, we could lower this a lot.
CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min
CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT = 300 # 5 min
CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
@@ -213,17 +209,9 @@ class PostgresAdvisoryLocks(Enum):
class DanswerCeleryQueues:
# Light queue
VESPA_METADATA_SYNC = "vespa_metadata_sync"
DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert"
CONNECTOR_DELETION = "connector_deletion"
# Heavy queue
CONNECTOR_PRUNING = "connector_pruning"
CONNECTOR_DOC_PERMISSIONS_SYNC = "connector_doc_permissions_sync"
CONNECTOR_EXTERNAL_GROUP_SYNC = "connector_external_group_sync"
# Indexing queue
CONNECTOR_INDEXING = "connector_indexing"
@@ -233,18 +221,8 @@ class DanswerRedisLocks:
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat"
CHECK_INDEXING_BEAT_LOCK = "da_lock:check_indexing_beat"
CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK = (
"da_lock:check_connector_doc_permissions_sync_beat"
)
CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK = (
"da_lock:check_connector_external_group_sync_beat"
)
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX = (
"da_lock:connector_doc_permissions_sync"
)
CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX = "da_lock:connector_external_group_sync"
PRUNING_LOCK_PREFIX = "da_lock:pruning"
INDEXING_METADATA_PREFIX = "da_metadata:indexing"

View File

@@ -119,14 +119,3 @@ if _LITELLM_PASS_THROUGH_HEADERS_RAW:
logger.error(
"Failed to parse LITELLM_PASS_THROUGH_HEADERS, must be a valid JSON object"
)
# if specified, will merge the specified JSON with the existing body of the
# request before sending it to the LLM
LITELLM_EXTRA_BODY: dict | None = None
_LITELLM_EXTRA_BODY_RAW = os.environ.get("LITELLM_EXTRA_BODY")
if _LITELLM_EXTRA_BODY_RAW:
try:
LITELLM_EXTRA_BODY = json.loads(_LITELLM_EXTRA_BODY_RAW)
except Exception:
pass

View File

@@ -146,7 +146,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
# The url and the id are the same
object_url = build_confluence_document_id(
self.wiki_base, confluence_object["_links"]["webui"], self.is_cloud
self.wiki_base, confluence_object["_links"]["webui"]
)
object_text = None
@@ -278,9 +278,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
doc_metadata_list.append(
SlimDocument(
id=build_confluence_document_id(
self.wiki_base,
page["_links"]["webui"],
self.is_cloud,
self.wiki_base, page["_links"]["webui"]
),
perm_sync_data=perm_sync_data,
)
@@ -295,9 +293,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
doc_metadata_list.append(
SlimDocument(
id=build_confluence_document_id(
self.wiki_base,
attachment["_links"]["webui"],
self.is_cloud,
self.wiki_base, attachment["_links"]["webui"]
),
perm_sync_data=perm_sync_data,
)

View File

@@ -100,39 +100,6 @@ def extract_text_from_confluence_html(
continue
# Include @ sign for tagging, more clear for LLM
user.replaceWith("@" + _get_user(confluence_client, user_id))
for html_page_reference in soup.findAll("ri:page"):
# Wrap this in a try-except because there are some pages that might not exist
try:
page_title = html_page_reference.attrs["ri:content-title"]
if not page_title:
continue
page_query = f"type=page and title='{page_title}'"
page_contents: dict[str, Any] | None = None
# Confluence enforces title uniqueness, so we should only get one result here
for page_batch in confluence_client.paginated_cql_page_retrieval(
cql=page_query,
expand="body.storage.value",
limit=1,
):
page_contents = page_batch[0]
break
except Exception:
logger.warning(
f"Error getting page contents for object {confluence_object}"
)
continue
if not page_contents:
continue
text_from_page = extract_text_from_confluence_html(
confluence_client, page_contents
)
html_page_reference.replaceWith(text_from_page)
return format_document_soup(soup)
@@ -186,9 +153,7 @@ def attachment_to_content(
return extracted_text
def build_confluence_document_id(
base_url: str, content_url: str, is_cloud: bool
) -> str:
def build_confluence_document_id(base_url: str, content_url: str) -> str:
"""For confluence, the document id is the page url for a page based document
or the attachment download url for an attachment based document
@@ -199,8 +164,6 @@ def build_confluence_document_id(
Returns:
str: The document id
"""
if is_cloud and not base_url.endswith("/wiki"):
base_url += "/wiki"
return f"{base_url}{content_url}"

View File

@@ -305,7 +305,6 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
query = _build_time_range_query(time_range_start, time_range_end)
doc_batch = []
for user_email in self._get_all_user_emails():
logger.info(f"Fetching slim threads for user: {user_email}")
gmail_service = get_gmail_service(self.creds, user_email)
for thread in execute_paginated_retrieval(
retrieval_function=gmail_service.users().threads().list,

View File

@@ -192,33 +192,23 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
def _update_traversed_parent_ids(self, folder_id: str) -> None:
self._retrieved_ids.add(folder_id)
def _get_all_user_emails(self) -> list[str]:
# Start with primary admin email
user_emails = [self.primary_admin_email]
# Only fetch additional users if using service account
if isinstance(self.creds, OAuthCredentials):
return user_emails
def _get_all_user_emails(self, admins_only: bool) -> list[str]:
admin_service = get_admin_service(
creds=self.creds,
user_email=self.primary_admin_email,
)
# Get admins first since they're more likely to have access to most files
for is_admin in [True, False]:
query = "isAdmin=true" if is_admin else "isAdmin=false"
for user in execute_paginated_retrieval(
retrieval_function=admin_service.users().list,
list_key="users",
fields=USER_FIELDS,
domain=self.google_domain,
query=query,
):
if email := user.get("primaryEmail"):
if email not in user_emails:
user_emails.append(email)
return user_emails
query = "isAdmin=true" if admins_only else "isAdmin=false"
emails = []
for user in execute_paginated_retrieval(
retrieval_function=admin_service.users().list,
list_key="users",
fields=USER_FIELDS,
domain=self.google_domain,
query=query,
):
if email := user.get("primaryEmail"):
emails.append(email)
return emails
def _get_all_drive_ids(self) -> set[str]:
primary_drive_service = get_drive_service(
@@ -226,48 +216,55 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
user_email=self.primary_admin_email,
)
all_drive_ids = set()
# We don't want to fail if we're using OAuth because you can
# access your my drive as a non admin user in an org still
ignore_fetch_failure = isinstance(self.creds, OAuthCredentials)
for drive in execute_paginated_retrieval(
retrieval_function=primary_drive_service.drives().list,
list_key="drives",
continue_on_404_or_403=ignore_fetch_failure,
useDomainAdminAccess=True,
fields="drives(id)",
):
all_drive_ids.add(drive["id"])
if not all_drive_ids:
logger.warning(
"No drives found. This is likely because oauth user "
"is not an admin and cannot view all drive IDs. "
"Continuing with only the shared drive IDs specified in the config."
)
all_drive_ids = set(self._requested_shared_drive_ids)
return all_drive_ids
def _initialize_all_class_variables(self) -> None:
# Get all user emails
# Get admins first becuase they are more likely to have access to the most files
user_emails = [self.primary_admin_email]
for admins_only in [True, False]:
for email in self._get_all_user_emails(admins_only=admins_only):
if email not in user_emails:
user_emails.append(email)
self._all_org_emails = user_emails
self._all_drive_ids: set[str] = self._get_all_drive_ids()
# remove drive ids from the folder ids because they are queried differently
self._requested_folder_ids -= self._all_drive_ids
# Remove drive_ids that are not in the all_drive_ids and check them as folders instead
invalid_drive_ids = self._requested_shared_drive_ids - self._all_drive_ids
if invalid_drive_ids:
logger.warning(
f"Some shared drive IDs were not found. IDs: {invalid_drive_ids}"
)
logger.warning("Checking for folder access instead...")
self._requested_folder_ids.update(invalid_drive_ids)
if not self.include_shared_drives:
self._requested_shared_drive_ids = set()
elif not self._requested_shared_drive_ids:
self._requested_shared_drive_ids = self._all_drive_ids
def _impersonate_user_for_retrieval(
self,
user_email: str,
is_slim: bool,
filtered_drive_ids: set[str],
filtered_folder_ids: set[str],
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
drive_service = get_drive_service(self.creds, user_email)
# if we are including my drives, try to get the current user's my
# drive if any of the following are true:
# - no specific emails were requested
# - the current user's email is in the requested emails
# - we are using OAuth (in which case we assume that is the only email we will try)
if self.include_my_drives and (
not self._requested_my_drive_emails
or user_email in self._requested_my_drive_emails
or isinstance(self.creds, OAuthCredentials)
):
yield from get_all_files_in_my_drive(
service=drive_service,
@@ -277,7 +274,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
end=end,
)
remaining_drive_ids = filtered_drive_ids - self._retrieved_ids
remaining_drive_ids = self._requested_shared_drive_ids - self._retrieved_ids
for drive_id in remaining_drive_ids:
yield from get_files_in_shared_drive(
service=drive_service,
@@ -288,7 +285,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
end=end,
)
remaining_folders = filtered_folder_ids - self._retrieved_ids
remaining_folders = self._requested_folder_ids - self._retrieved_ids
for folder_id in remaining_folders:
yield from crawl_folders_for_files(
service=drive_service,
@@ -305,56 +302,22 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
all_org_emails: list[str] = self._get_all_user_emails()
all_drive_ids: set[str] = self._get_all_drive_ids()
# remove drive ids from the folder ids because they are queried differently
filtered_folder_ids = self._requested_folder_ids - all_drive_ids
# Remove drive_ids that are not in the all_drive_ids and check them as folders instead
invalid_drive_ids = self._requested_shared_drive_ids - all_drive_ids
if invalid_drive_ids:
logger.warning(
f"Some shared drive IDs were not found. IDs: {invalid_drive_ids}"
)
logger.warning("Checking for folder access instead...")
filtered_folder_ids.update(invalid_drive_ids)
# If including shared drives, use the requested IDs if provided,
# otherwise use all drive IDs
filtered_drive_ids = set()
if self.include_shared_drives:
if self._requested_shared_drive_ids:
# Remove invalid drive IDs from requested IDs
filtered_drive_ids = (
self._requested_shared_drive_ids - invalid_drive_ids
)
else:
filtered_drive_ids = all_drive_ids
self._initialize_all_class_variables()
# Process users in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=10) as executor:
future_to_email = {
executor.submit(
self._impersonate_user_for_retrieval,
email,
is_slim,
filtered_drive_ids,
filtered_folder_ids,
start,
end,
self._impersonate_user_for_retrieval, email, is_slim, start, end
): email
for email in all_org_emails
for email in self._all_org_emails
}
# Yield results as they complete
for future in as_completed(future_to_email):
yield from future.result()
remaining_folders = (
filtered_drive_ids | filtered_folder_ids
) - self._retrieved_ids
remaining_folders = self._requested_folder_ids - self._retrieved_ids
if remaining_folders:
logger.warning(
f"Some folders/drives were not retrieved. IDs: {remaining_folders}"

View File

@@ -105,7 +105,7 @@ def execute_paginated_retrieval(
)()
elif e.resp.status == 404 or e.resp.status == 403:
if continue_on_404_or_403:
logger.debug(f"Error executing request: {e}")
logger.warning(f"Error executing request: {e}")
results = {}
else:
raise e

View File

@@ -55,11 +55,11 @@ def validate_channel_names(
# Scaling configurations for multi-tenant Slack bot handling
TENANT_LOCK_EXPIRATION = 1800 # How long a pod can hold exclusive access to a tenant before other pods can acquire it
TENANT_HEARTBEAT_INTERVAL = (
15 # How often pods send heartbeats to indicate they are still processing a tenant
60 # How often pods send heartbeats to indicate they are still processing a tenant
)
TENANT_HEARTBEAT_EXPIRATION = (
30 # How long before a tenant's heartbeat expires, allowing other pods to take over
TENANT_HEARTBEAT_EXPIRATION = 180 # How long before a tenant's heartbeat expires, allowing other pods to take over
TENANT_ACQUISITION_INTERVAL = (
60 # How often pods attempt to acquire unprocessed tenants
)
TENANT_ACQUISITION_INTERVAL = 60 # How often pods attempt to acquire unprocessed tenants and checks for new tokens
MAX_TENANTS_PER_POD = int(os.getenv("MAX_TENANTS_PER_POD", 50))

View File

@@ -17,8 +17,6 @@ from slack_sdk import WebClient
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from danswer.configs.app_configs import POD_NAME
from danswer.configs.app_configs import POD_NAMESPACE
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import MessageType
from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
@@ -77,7 +75,6 @@ from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import DISALLOWED_SLACK_BOT_TENANT_LIST
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
@@ -87,9 +84,7 @@ logger = setup_logger()
# Prometheus metric for HPA
active_tenants_gauge = Gauge(
"active_tenants",
"Number of active tenants handled by this pod",
["namespace", "pod"],
"active_tenants", "Number of active tenants handled by this pod"
)
# In rare cases, some users have been experiencing a massive amount of trivial messages coming through
@@ -152,9 +147,7 @@ class SlackbotHandler:
while not self._shutdown_event.is_set():
try:
self.acquire_tenants()
active_tenants_gauge.labels(namespace=POD_NAMESPACE, pod=POD_NAME).set(
len(self.tenant_ids)
)
active_tenants_gauge.set(len(self.tenant_ids))
logger.debug(f"Current active tenants: {len(self.tenant_ids)}")
except Exception as e:
logger.exception(f"Error in Slack acquisition: {e}")
@@ -171,15 +164,9 @@ class SlackbotHandler:
def acquire_tenants(self) -> None:
tenant_ids = get_all_tenant_ids()
logger.debug(f"Found {len(tenant_ids)} total tenants in Postgres")
for tenant_id in tenant_ids:
if (
DISALLOWED_SLACK_BOT_TENANT_LIST is not None
and tenant_id in DISALLOWED_SLACK_BOT_TENANT_LIST
):
logger.debug(f"Tenant {tenant_id} is in the disallowed list, skipping")
continue
if tenant_id in self.tenant_ids:
logger.debug(f"Tenant {tenant_id} already in self.tenant_ids")
continue
@@ -203,9 +190,6 @@ class SlackbotHandler:
continue
logger.debug(f"Acquired lock for tenant {tenant_id}")
self.tenant_ids.add(tenant_id)
for tenant_id in self.tenant_ids:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(
tenant_id or POSTGRES_DEFAULT_SCHEMA
)
@@ -252,14 +236,14 @@ class SlackbotHandler:
self.slack_bot_tokens[tenant_id] = slack_bot_tokens
if self.socket_clients.get(tenant_id):
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
self.start_socket_client(tenant_id, slack_bot_tokens)
except KvKeyNotFoundError:
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
if self.socket_clients.get(tenant_id):
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
del self.socket_clients[tenant_id]
del self.slack_bot_tokens[tenant_id]
@@ -293,14 +277,14 @@ class SlackbotHandler:
logger.info(f"Connecting socket client for tenant {tenant_id}")
socket_client.connect()
self.socket_clients[tenant_id] = socket_client
self.tenant_ids.add(tenant_id)
logger.info(f"Started SocketModeClient for tenant {tenant_id}")
def stop_socket_clients(self) -> None:
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
for tenant_id, client in self.socket_clients.items():
if client:
asyncio.run(client.close())
logger.info(f"Stopped SocketModeClient for tenant {tenant_id}")
asyncio.run(client.close())
logger.info(f"Stopped SocketModeClient for tenant {tenant_id}")
def shutdown(self, signum: int | None, frame: FrameType | None) -> None:
if not self.running:
@@ -314,16 +298,6 @@ class SlackbotHandler:
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
self.stop_socket_clients()
# Release locks for all tenants
logger.info(f"Releasing locks for {len(self.tenant_ids)} tenants")
for tenant_id in self.tenant_ids:
try:
redis_client = get_redis_client(tenant_id=tenant_id)
redis_client.delete(DanswerRedisLocks.SLACK_BOT_LOCK)
logger.info(f"Released lock for tenant {tenant_id}")
except Exception as e:
logger.error(f"Error releasing lock for tenant {tenant_id}: {e}")
# Wait for background threads to finish (with timeout)
logger.info("Waiting for background threads to finish...")
self.acquire_thread.join(timeout=5)

View File

@@ -282,32 +282,3 @@ def mark_ccpair_as_pruned(cc_pair_id: int, db_session: Session) -> None:
cc_pair.last_pruned = datetime.now(timezone.utc)
db_session.commit()
def mark_cc_pair_as_permissions_synced(
db_session: Session, cc_pair_id: int, start_time: datetime | None
) -> None:
stmt = select(ConnectorCredentialPair).where(
ConnectorCredentialPair.id == cc_pair_id
)
cc_pair = db_session.scalar(stmt)
if cc_pair is None:
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")
cc_pair.last_time_perm_sync = start_time
db_session.commit()
def mark_cc_pair_as_external_group_synced(db_session: Session, cc_pair_id: int) -> None:
stmt = select(ConnectorCredentialPair).where(
ConnectorCredentialPair.id == cc_pair_id
)
cc_pair = db_session.scalar(stmt)
if cc_pair is None:
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")
# The sync time can be marked after it ran because all group syncs
# are run in full, not polling for changes.
# If this changes, we need to update this function.
cc_pair.last_time_external_group_sync = datetime.now(timezone.utc)
db_session.commit()

View File

@@ -76,10 +76,8 @@ def _add_user_filters(
.where(~UG__CCpair.user_group_id.in_(user_groups))
.correlate(ConnectorCredentialPair)
)
where_clause |= ConnectorCredentialPair.creator_id == user.id
else:
where_clause |= ConnectorCredentialPair.access_type == AccessType.PUBLIC
where_clause |= ConnectorCredentialPair.access_type == AccessType.SYNC
return stmt.where(where_clause)
@@ -389,7 +387,6 @@ def add_credential_to_connector(
)
association = ConnectorCredentialPair(
creator_id=user.id if user else None,
connector_id=connector_id,
credential_id=credential_id,
name=cc_pair_name,

View File

@@ -19,7 +19,6 @@ from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import null
from danswer.configs.constants import DEFAULT_BOOST
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.feedback import delete_document_feedback_for_documents__no_commit
@@ -47,21 +46,13 @@ def count_documents_by_needs_sync(session: Session) -> int:
"""Get the count of all documents where:
1. last_modified is newer than last_synced
2. last_synced is null (meaning we've never synced)
AND the document has a relationship with a connector/credential pair
TODO: The documents without a relationship with a connector/credential pair
should be cleaned up somehow eventually.
This function executes the query and returns the count of
documents matching the criteria."""
count = (
session.query(func.count(DbDocument.id.distinct()))
session.query(func.count())
.select_from(DbDocument)
.join(
DocumentByConnectorCredentialPair,
DbDocument.id == DocumentByConnectorCredentialPair.id,
)
.filter(
or_(
DbDocument.last_modified > DbDocument.last_synced,
@@ -100,22 +91,6 @@ def construct_document_select_for_connector_credential_pair_by_needs_sync(
return stmt
def get_all_documents_needing_vespa_sync_for_cc_pair(
db_session: Session, cc_pair_id: int
) -> list[DbDocument]:
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id, db_session=db_session
)
if not cc_pair:
raise ValueError(f"No CC pair found with ID: {cc_pair_id}")
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
cc_pair.connector_id, cc_pair.credential_id
)
return list(db_session.scalars(stmt).all())
def construct_document_select_for_connector_credential_pair(
connector_id: int, credential_id: int | None = None
) -> Select:
@@ -129,21 +104,6 @@ def construct_document_select_for_connector_credential_pair(
return stmt
def get_documents_for_cc_pair(
db_session: Session,
cc_pair_id: int,
) -> list[DbDocument]:
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id, db_session=db_session
)
if not cc_pair:
raise ValueError(f"No CC pair found with ID: {cc_pair_id}")
stmt = construct_document_select_for_connector_credential_pair(
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
)
return list(db_session.scalars(stmt).all())
def get_document_ids_for_connector_credential_pair(
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
) -> list[str]:
@@ -308,7 +268,7 @@ def get_access_info_for_documents(
return db_session.execute(stmt).all() # type: ignore
def _upsert_documents(
def upsert_documents(
db_session: Session,
document_metadata_batch: list[DocumentMetadata],
initial_boost: int = DEFAULT_BOOST,
@@ -346,8 +306,6 @@ def _upsert_documents(
]
)
# This does not update the permissions of the document if
# the document already exists.
on_conflict_stmt = insert_stmt.on_conflict_do_update(
index_elements=["id"], # Conflict target
set_={
@@ -364,7 +322,7 @@ def _upsert_documents(
db_session.commit()
def _upsert_document_by_connector_credential_pair(
def upsert_document_by_connector_credential_pair(
db_session: Session, document_metadata_batch: list[DocumentMetadata]
) -> None:
"""NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause."""
@@ -446,8 +404,8 @@ def upsert_documents_complete(
db_session: Session,
document_metadata_batch: list[DocumentMetadata],
) -> None:
_upsert_documents(db_session, document_metadata_batch)
_upsert_document_by_connector_credential_pair(db_session, document_metadata_batch)
upsert_documents(db_session, document_metadata_batch)
upsert_document_by_connector_credential_pair(db_session, document_metadata_batch)
logger.info(
f"Upserted {len(document_metadata_batch)} document store entries into DB"
)
@@ -505,6 +463,7 @@ def delete_documents_complete__no_commit(
db_session: Session, document_ids: list[str]
) -> None:
"""This completely deletes the documents from the db, including all foreign key relationships"""
logger.info(f"Deleting {len(document_ids)} documents from the DB")
delete_documents_by_connector_credential_pair__no_commit(db_session, document_ids)
delete_document_feedback_for_documents__no_commit(
document_ids=document_ids, db_session=db_session

View File

@@ -189,13 +189,6 @@ class SqlEngine:
return ""
return cls._app_name
@classmethod
def reset_engine(cls) -> None:
with cls._lock:
if cls._engine:
cls._engine.dispose()
cls._engine = None
def get_all_tenant_ids() -> list[str] | list[None]:
if not MULTI_TENANT:
@@ -319,9 +312,7 @@ async def get_async_session_with_tenant(
await session.execute(text(f'SET search_path = "{tenant_id}"'))
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
await session.execute(
text(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
except Exception:
logger.exception("Error setting search_path.")
@@ -382,9 +373,7 @@ def get_session_with_tenant(
cursor.execute(f'SET search_path = "{tenant_id}"')
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
cursor.execute(
text(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
finally:
cursor.close()

View File

@@ -126,8 +126,8 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
# if specified, controls the assistants that are shown to the user + their order
# if not specified, all assistants are shown
chosen_assistants: Mapped[list[int] | None] = mapped_column(
postgresql.JSONB(), nullable=True, default=None
chosen_assistants: Mapped[list[int]] = mapped_column(
postgresql.JSONB(), nullable=False, default=[-2, -1, 0]
)
visible_assistants: Mapped[list[int]] = mapped_column(
postgresql.JSONB(), nullable=False, default=[]
@@ -173,11 +173,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
)
# Whether the user has logged in via web. False if user has only used Danswer through Slack bot
has_web_login: Mapped[bool] = mapped_column(Boolean, default=True)
cc_pairs: Mapped[list["ConnectorCredentialPair"]] = relationship(
"ConnectorCredentialPair",
back_populates="creator",
primaryjoin="User.id == foreign(ConnectorCredentialPair.creator_id)",
)
class InputPrompt(Base):
@@ -425,9 +420,6 @@ class ConnectorCredentialPair(Base):
last_time_perm_sync: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
last_time_external_group_sync: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
# Time finished, not used for calculating backend jobs which uses time started (created)
last_successful_index_time: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), default=None
@@ -460,14 +452,6 @@ class ConnectorCredentialPair(Base):
"IndexAttempt", back_populates="connector_credential_pair"
)
# the user id of the user that created this cc pair
creator_id: Mapped[UUID | None] = mapped_column(nullable=True)
creator: Mapped["User"] = relationship(
"User",
back_populates="cc_pairs",
primaryjoin="foreign(ConnectorCredentialPair.creator_id) == remote(User.id)",
)
class Document(Base):
__tablename__ = "document"

View File

@@ -743,4 +743,5 @@ def delete_persona_by_name(
)
db_session.execute(stmt)
db_session.commit()

View File

@@ -97,18 +97,3 @@ def batch_add_non_web_user_if_not_exists__no_commit(
db_session.flush() # generate ids
return found_users + new_users
def batch_add_non_web_user_if_not_exists(
db_session: Session, emails: list[str]
) -> list[User]:
found_users, missing_user_emails = get_users_by_emails(db_session, emails)
new_users: list[User] = []
for email in missing_user_emails:
new_users.append(_generate_non_web_user(email=email))
db_session.add_all(new_users)
db_session.commit()
return found_users + new_users

View File

@@ -56,7 +56,7 @@ class IndexingPipelineProtocol(Protocol):
...
def _upsert_documents_in_db(
def upsert_documents_in_db(
documents: list[Document],
index_attempt_metadata: IndexAttemptMetadata,
db_session: Session,
@@ -243,7 +243,7 @@ def index_doc_batch_prepare(
# Create records in the source of truth about these documents,
# does not include doc_updated_at which is also used to indicate a successful update
_upsert_documents_in_db(
upsert_documents_in_db(
documents=documents,
index_attempt_metadata=index_attempt_metadata,
db_session=db_session,
@@ -255,7 +255,7 @@ def index_doc_batch_prepare(
)
@log_function_time(debug_only=True)
@log_function_time()
def index_doc_batch(
*,
chunker: Chunker,

View File

@@ -26,7 +26,6 @@ from danswer.configs.app_configs import LOG_ALL_MODEL_INTERACTIONS
from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
from danswer.configs.model_configs import DISABLE_LITELLM_STREAMING
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.configs.model_configs import LITELLM_EXTRA_BODY
from danswer.llm.interfaces import LLM
from danswer.llm.interfaces import LLMConfig
from danswer.llm.interfaces import ToolChoiceOptions
@@ -214,7 +213,6 @@ class DefaultMultiLLM(LLM):
temperature: float = GEN_AI_TEMPERATURE,
custom_config: dict[str, str] | None = None,
extra_headers: dict[str, str] | None = None,
extra_body: dict | None = LITELLM_EXTRA_BODY,
):
self._timeout = timeout
self._model_provider = model_provider
@@ -248,8 +246,6 @@ class DefaultMultiLLM(LLM):
model_kwargs: dict[str, Any] = {}
if extra_headers:
model_kwargs.update({"extra_headers": extra_headers})
if extra_body:
model_kwargs.update({"extra_body": extra_body})
self._model_kwargs = model_kwargs

View File

@@ -315,7 +315,7 @@ def get_application() -> FastAPI:
tags=["users"],
)
if AUTH_TYPE == AuthType.GOOGLE_OAUTH:
if AUTH_TYPE == AuthType.GOOGLE_OAUTH or AUTH_TYPE == AuthType.CLOUD:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_router_with_global_prefix_prepended(
application,

View File

@@ -1,8 +1,6 @@
import redis
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_connector_stop import RedisConnectorStop
@@ -21,10 +19,6 @@ class RedisConnector:
self.stop = RedisConnectorStop(tenant_id, id, self.redis)
self.prune = RedisConnectorPrune(tenant_id, id, self.redis)
self.delete = RedisConnectorDelete(tenant_id, id, self.redis)
self.permissions = RedisConnectorPermissionSync(tenant_id, id, self.redis)
self.external_group_sync = RedisConnectorExternalGroupSync(
tenant_id, id, self.redis
)
def new_index(self, search_settings_id: int) -> RedisConnectorIndex:
return RedisConnectorIndex(

View File

@@ -63,7 +63,6 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
cc_pair.connector_id, cc_pair.credential_id
)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (

View File

@@ -1,187 +0,0 @@
import time
from datetime import datetime
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from pydantic import BaseModel
from danswer.access.models import DocExternalAccess
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
class RedisConnectorPermissionSyncData(BaseModel):
started: datetime | None
class RedisConnectorPermissionSync:
"""Manages interactions with redis for doc permission sync tasks. Should only be accessed
through RedisConnector."""
PREFIX = "connectordocpermissionsync"
FENCE_PREFIX = f"{PREFIX}_fence"
# phase 1 - geneartor task and progress signals
GENERATORTASK_PREFIX = f"{PREFIX}+generator" # connectorpermissions+generator
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # connectorpermissions_generator_progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # connectorpermissions_generator_complete
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpermissions_taskset
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpermissions+sub
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
self.id = id
self.redis = redis
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}"
self.generator_task_key = f"{self.GENERATORTASK_PREFIX}_{id}"
self.generator_progress_key = f"{self.GENERATOR_PROGRESS_PREFIX}_{id}"
self.generator_complete_key = f"{self.GENERATOR_COMPLETE_PREFIX}_{id}"
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
def taskset_clear(self) -> None:
self.redis.delete(self.taskset_key)
def generator_clear(self) -> None:
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
def get_remaining(self) -> int:
remaining = cast(int, self.redis.scard(self.taskset_key))
return remaining
def get_active_task_count(self) -> int:
"""Count of active permission sync tasks"""
count = 0
for _ in self.redis.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
count += 1
return count
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
@property
def payload(self) -> RedisConnectorPermissionSyncData | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
if fence_bytes is None:
return None
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorPermissionSyncData.model_validate_json(
cast(str, fence_str)
)
return payload
def set_fence(
self,
payload: RedisConnectorPermissionSyncData | None,
) -> None:
if not payload:
self.redis.delete(self.fence_key)
return
self.redis.set(self.fence_key, payload.model_dump_json())
@property
def generator_complete(self) -> int | None:
"""the fence payload is an int representing the starting number of
permission sync tasks to be processed ... just after the generator completes."""
fence_bytes = self.redis.get(self.generator_complete_key)
if fence_bytes is None:
return None
if fence_bytes == b"None":
return None
fence_int = int(cast(bytes, fence_bytes).decode())
return fence_int
@generator_complete.setter
def generator_complete(self, payload: int | None) -> None:
"""Set the payload to an int to set the fence, otherwise if None it will
be deleted"""
if payload is None:
self.redis.delete(self.generator_complete_key)
return
self.redis.set(self.generator_complete_key, payload)
def generate_tasks(
self,
celery_app: Celery,
lock: redis.lock.Lock | None,
new_permissions: list[DocExternalAccess],
source_string: str,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
# Create a task for each document permission sync
for doc_perm in new_permissions:
current_time = time.monotonic()
if lock and current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# Add task for document permissions sync
custom_task_id = f"{self.subtask_prefix}_{uuid4()}"
self.redis.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
"update_external_document_permissions_task",
kwargs=dict(
tenant_id=self.tenant_id,
serialized_doc_external_access=doc_perm.to_dict(),
source_string=source_string,
),
queue=DanswerCeleryQueues.DOC_PERMISSIONS_UPSERT,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
@staticmethod
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
taskset_key = f"{RedisConnectorPermissionSync.TASKSET_PREFIX}_{id}"
r.srem(taskset_key, task_id)
return
@staticmethod
def reset_all(r: redis.Redis) -> None:
"""Deletes all redis values for all connectors"""
for key in r.scan_iter(RedisConnectorPermissionSync.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(
RedisConnectorPermissionSync.GENERATOR_COMPLETE_PREFIX + "*"
):
r.delete(key)
for key in r.scan_iter(
RedisConnectorPermissionSync.GENERATOR_PROGRESS_PREFIX + "*"
):
r.delete(key)
for key in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
r.delete(key)

View File

@@ -1,133 +0,0 @@
from typing import cast
import redis
from celery import Celery
from sqlalchemy.orm import Session
class RedisConnectorExternalGroupSync:
"""Manages interactions with redis for external group syncing tasks. Should only be accessed
through RedisConnector."""
PREFIX = "connectorexternalgroupsync"
FENCE_PREFIX = f"{PREFIX}_fence"
# phase 1 - geneartor task and progress signals
GENERATORTASK_PREFIX = f"{PREFIX}+generator" # connectorexternalgroupsync+generator
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # connectorexternalgroupsync_generator_progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # connectorexternalgroupsync_generator_complete
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorexternalgroupsync_taskset
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorexternalgroupsync+sub
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
self.id = id
self.redis = redis
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}"
self.generator_task_key = f"{self.GENERATORTASK_PREFIX}_{id}"
self.generator_progress_key = f"{self.GENERATOR_PROGRESS_PREFIX}_{id}"
self.generator_complete_key = f"{self.GENERATOR_COMPLETE_PREFIX}_{id}"
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
def taskset_clear(self) -> None:
self.redis.delete(self.taskset_key)
def generator_clear(self) -> None:
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
def get_remaining(self) -> int:
# todo: move into fence
remaining = cast(int, self.redis.scard(self.taskset_key))
return remaining
def get_active_task_count(self) -> int:
"""Count of active external group syncing tasks"""
count = 0
for _ in self.redis.scan_iter(
RedisConnectorExternalGroupSync.FENCE_PREFIX + "*"
):
count += 1
return count
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
def set_fence(self, value: bool) -> None:
if not value:
self.redis.delete(self.fence_key)
return
self.redis.set(self.fence_key, 0)
@property
def generator_complete(self) -> int | None:
"""the fence payload is an int representing the starting number of
external group syncing tasks to be processed ... just after the generator completes.
"""
fence_bytes = self.redis.get(self.generator_complete_key)
if fence_bytes is None:
return None
if fence_bytes == b"None":
return None
fence_int = int(cast(bytes, fence_bytes).decode())
return fence_int
@generator_complete.setter
def generator_complete(self, payload: int | None) -> None:
"""Set the payload to an int to set the fence, otherwise if None it will
be deleted"""
if payload is None:
self.redis.delete(self.generator_complete_key)
return
self.redis.set(self.generator_complete_key, payload)
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
lock: redis.lock.Lock | None,
) -> int | None:
pass
@staticmethod
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
taskset_key = f"{RedisConnectorExternalGroupSync.TASKSET_PREFIX}_{id}"
r.srem(taskset_key, task_id)
return
@staticmethod
def reset_all(r: redis.Redis) -> None:
"""Deletes all redis values for all connectors"""
for key in r.scan_iter(RedisConnectorExternalGroupSync.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(
RedisConnectorExternalGroupSync.GENERATOR_COMPLETE_PREFIX + "*"
):
r.delete(key)
for key in r.scan_iter(
RedisConnectorExternalGroupSync.GENERATOR_PROGRESS_PREFIX + "*"
):
r.delete(key)
for key in r.scan_iter(RedisConnectorExternalGroupSync.FENCE_PREFIX + "*"):
r.delete(key)

File diff suppressed because it is too large Load Diff

View File

@@ -1,44 +0,0 @@
[
{
"url": "https://docs.danswer.dev/more/use_cases/overview",
"title": "Use Cases Overview",
"content": "How to leverage Danswer in your organization\n\nDanswer Overview\nDanswer is the AI Assistant connected to your organization's docs, apps, and people. Danswer makes Generative AI more versatile for work by enabling new types of questions like \"What is the most common feature request we've heard from customers this month\". Whereas other AI systems have no context of your team and are generally unhelpful with work related questions, Danswer makes it possible to ask these questions in natural language and get back answers in seconds.\n\nDanswer can connect to +30 different tools and the use cases are not limited to the ones in the following pages. The highlighted use cases are for inspiration and come from feedback gathered from our users and customers.\n\n\nCommon Getting Started Questions:\n\nWhy are these docs connected in my Danswer deployment?\nAnswer: This is just an example of how connectors work in Danswer. You can connect up your own team's knowledge and you will be able to ask questions unique to your organization. Danswer will keep all of the knowledge up to date and in sync with your connected applications.\n\nIs my data being sent anywhere when I connect it up to Danswer?\nAnswer: No! Danswer is built with data security as our highest priority. We open sourced it so our users can know exactly what is going on with their data. By default all of the document processing happens within Danswer. The only time it is sent outward is for the GenAI call to generate answers.\n\nWhere is the feature for auto sync-ing document level access permissions from all connected sources?\nAnswer: This falls under the Enterprise Edition set of Danswer features built on top of the MIT/community edition. If you are on Danswer Cloud, you have access to them by default. If you're running it yourself, reach out to the Danswer team to receive access.",
"chunk_ind": 0
},
{
"url": "https://docs.danswer.dev/more/use_cases/enterprise_search",
"title": "Enterprise Search",
"content": "Value of Enterprise Search with Danswer\n\nWhat is Enterprise Search and why is it Important?\nAn Enterprise Search system gives team members a single place to access all of the disparate knowledge of an organization. Critical information is saved across a host of channels like call transcripts with prospects, engineering design docs, IT runbooks, customer support email exchanges, project management tickets, and more. As fast moving teams scale up, information gets spread out and more disorganized.\n\nSince it quickly becomes infeasible to check across every source, decisions get made on incomplete information, employee satisfaction decreases, and the most valuable members of your team are tied up with constant distractions as junior teammates are unable to unblock themselves. Danswer solves this problem by letting anyone on the team access all of the knowledge across your organization in a permissioned and secure way. Users can ask questions in natural language and get back answers and documents across all of the connected sources instantly.\n\nWhat's the real cost?\nA typical knowledge worker spends over 2 hours a week on search, but more than that, the cost of incomplete or incorrect information can be extremely high. Customer support/success that isn't able to find the reference to similar cases could cause hours or even days of delay leading to lower customer satisfaction or in the worst case - churn. An account exec not realizing that a prospect had previously mentioned a specific need could lead to lost deals. An engineer not realizing a similar feature had previously been built could result in weeks of wasted development time and tech debt with duplicate implementation. With a lack of knowledge, your whole organization is navigating in the dark - inefficient and mistake prone.",
"chunk_ind": 0
},
{
"url": "https://docs.danswer.dev/more/use_cases/enterprise_search",
"title": "Enterprise Search",
"content": "More than Search\nWhen analyzing the entire corpus of knowledge within your company is as easy as asking a question in a search bar, your entire team can stay informed and up to date. Danswer also makes it trivial to identify where knowledge is well documented and where it is lacking. Team members who are centers of knowledge can begin to effectively document their expertise since it is no longer being thrown into a black hole. All of this allows the organization to achieve higher efficiency and drive business outcomes.\n\nWith Generative AI, the entire user experience has evolved as well. For example, instead of just finding similar cases for your customer support team to reference, Danswer breaks down the issue and explains it so that even the most junior members can understand it. This in turn lets them give the most holistic and technically accurate response possible to your customers. On the other end, even the super stars of your sales team will not be able to review 10 hours of transcripts before hopping on that critical call, but Danswer can easily parse through it in mere seconds and give crucial context to help your team close.",
"chunk_ind": 0
},
{
"url": "https://docs.danswer.dev/more/use_cases/ai_platform",
"title": "AI Platform",
"content": "Build AI Agents powered by the knowledge and workflows specific to your organization.\n\nBeyond Answers\nAgents enabled by generative AI and reasoning capable models are helping teams to automate their work. Danswer is helping teams make it happen. Danswer provides out of the box user chat sessions, attaching custom tools, handling LLM reasoning, code execution, data analysis, referencing internal knowledge, and much more.\n\nDanswer as a platform is not a no-code agent builder. We are made by developers for developers and this gives your team the full flexibility and power to create agents not constrained by blocks and simple logic paths.\n\nFlexibility and Extensibility\nDanswer is open source and completely whitebox. This not only gives transparency to what happens within the system but also means that your team can directly modify the source code to suit your unique needs.",
"chunk_ind": 0
},
{
"url": "https://docs.danswer.dev/more/use_cases/customer_support",
"title": "Customer Support",
"content": "Help your customer support team instantly answer any question across your entire product.\n\nAI Enabled Support\nCustomer support agents have one of the highest breadth jobs. They field requests that cover the entire surface area of the product and need to help your users find success on extremely short timelines. Because they're not the same people who designed or built the system, they often lack the depth of understanding needed - resulting in delays and escalations to other teams. Modern teams are leveraging AI to help their CS team optimize the speed and quality of these critical customer-facing interactions.\n\nThe Importance of Context\nThere are two critical components of AI copilots for customer support. The first is that the AI system needs to be connected with as much information as possible (not just support tools like Zendesk or Intercom) and that the knowledge needs to be as fresh as possible. Sometimes a fix might even be in places rarely checked by CS such as pull requests in a code repository. The second critical component is the ability of the AI system to break down difficult concepts and convoluted processes into more digestible descriptions and for your team members to be able to chat back and forth with the system to build a better understanding.\n\nDanswer takes care of both of these. The system connects up to over 30+ different applications and the knowledge is pulled in constantly so that the information access is always up to date.",
"chunk_ind": 0
},
{
"url": "https://docs.danswer.dev/more/use_cases/sales",
"title": "Sales",
"content": "Keep your team up to date on every conversation and update so they can close.\n\nRecall Every Detail\nBeing able to instantly revisit every detail of any call without reading transcripts is helping Sales teams provide more tailored pitches, build stronger relationships, and close more deals. Instead of searching and reading through hours of transcripts in preparation for a call, your team can now ask Danswer \"What specific features was ACME interested in seeing for the demo\". Since your team doesn't have time to read every transcript prior to a call, Danswer provides a more thorough summary because it can instantly parse hundreds of pages and distill out the relevant information. Even for fast lookups it becomes much more convenient - for example to brush up on connection building topics by asking \"What rapport building topic did we chat about in the last call with ACME\".\n\nKnow Every Product Update\nIt is impossible for Sales teams to keep up with every product update. Because of this, when a prospect has a question that the Sales team does not know, they have no choice but to rely on the Product and Engineering orgs to get an authoritative answer. Not only is this distracting to the other teams, it also slows down the time to respond to the prospect (and as we know, time is the biggest killer of deals). With Danswer, it is even possible to get answers live on call because of how fast accessing information becomes. A question like \"Have we shipped the Microsoft AD integration yet?\" can now be answered in seconds meaning that prospects can get answers while on the call instead of asynchronously and sales cycles are reduced as a result.",
"chunk_ind": 0
},
{
"url": "https://docs.danswer.dev/more/use_cases/operations",
"title": "Operations",
"content": "Double the productivity of your Ops teams like IT, HR, etc.\n\nAutomatically Resolve Tickets\nModern teams are leveraging AI to auto-resolve up to 50% of tickets. Whether it is an employee asking about benefits details or how to set up the VPN for remote work, Danswer can help your team help themselves. This frees up your team to do the real impactful work of landing star candidates or improving your internal processes.\n\nAI Aided Onboarding\nOne of the periods where your team needs the most help is when they're just ramping up. Instead of feeling lost in dozens of new tools, Danswer gives them a single place where they can ask about anything in natural language. Whether it's how to set up their work environment or what their onboarding goals are, Danswer can walk them through every step with the help of Generative AI. This lets your team feel more empowered and gives time back to the more seasoned members of your team to focus on moving the needle.",
"chunk_ind": 0
}
]

View File

@@ -32,7 +32,7 @@ from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.server.documents.models import ConnectorBase
from danswer.utils.logger import setup_logger
from danswer.utils.retry_wrapper import retry_builder
from danswer.utils.variable_functionality import fetch_versioned_implementation
logger = setup_logger()
@@ -91,21 +91,7 @@ def _create_indexable_chunks(
return list(ids_to_documents.values()), chunks
# Cohere is used in EE version
def load_processed_docs(cohere_enabled: bool) -> list[dict]:
initial_docs_path = os.path.join(
os.getcwd(),
"danswer",
"seeding",
"initial_docs.json",
)
processed_docs = json.load(open(initial_docs_path))
return processed_docs
def seed_initial_documents(
db_session: Session, tenant_id: str | None, cohere_enabled: bool = False
) -> None:
def seed_initial_documents(db_session: Session, tenant_id: str | None) -> None:
"""
Seed initial documents so users don't have an empty index to start
@@ -146,9 +132,7 @@ def seed_initial_documents(
return
search_settings = get_current_search_settings(db_session)
if search_settings.model_name != DEFAULT_DOCUMENT_ENCODER_MODEL and not (
search_settings.model_name == "embed-english-v3.0" and cohere_enabled
):
if search_settings.model_name != DEFAULT_DOCUMENT_ENCODER_MODEL:
logger.info("Embedding model has been updated, skipping")
return
@@ -188,10 +172,11 @@ def seed_initial_documents(
last_successful_index_time=last_index_time,
)
cc_pair_id = cast(int, result.data)
processed_docs = fetch_versioned_implementation(
"danswer.seeding.load_docs",
"load_processed_docs",
)(cohere_enabled)
initial_docs_path = os.path.join(
os.getcwd(), "danswer", "seeding", "initial_docs.json"
)
processed_docs = json.load(open(initial_docs_path))
docs, chunks = _create_indexable_chunks(processed_docs, tenant_id)

View File

@@ -6,7 +6,6 @@ from starlette.routing import BaseRoute
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_limited_user
from danswer.auth.users import current_user
from danswer.auth.users import current_user_with_expired_token
from danswer.configs.app_configs import APP_API_PREFIX
@@ -103,8 +102,7 @@ def check_router_auth(
for dependency in route_dependant_obj.dependencies:
depends_fn = dependency.cache_key[0]
if (
depends_fn == current_limited_user
or depends_fn == current_user
depends_fn == current_user
or depends_fn == current_admin_user
or depends_fn == current_curator_or_admin_user
or depends_fn == api_key_dep
@@ -120,5 +118,5 @@ def check_router_auth(
# print(f"(\"{route.path}\", {set(route.methods)}),")
raise RuntimeError(
f"Did not find user dependency in private route - {route}"
f"Did not find current_user or current_admin_user dependency in route - {route}"
)

View File

@@ -12,13 +12,13 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot
from danswer.background.celery.tasks.doc_permission_syncing.tasks import (
try_creating_permissions_sync_task,
)
from danswer.background.celery.tasks.pruning.tasks import (
try_creating_prune_generator_task,
)
from danswer.background.celery.versioned_apps.primary import app as primary_app
from danswer.background.task_name_builders import (
name_sync_external_doc_permissions_task,
)
from danswer.db.connector_credential_pair import add_credential_to_connector
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.connector_credential_pair import remove_credential_from_connector
@@ -26,7 +26,6 @@ from danswer.db.connector_credential_pair import (
update_connector_credential_pair_from_id,
)
from danswer.db.document import get_document_counts_for_cc_pairs
from danswer.db.document import get_documents_for_cc_pair
from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
from danswer.db.engine import get_current_tenant_id
from danswer.db.engine import get_session
@@ -39,13 +38,15 @@ from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
from danswer.db.models import User
from danswer.db.search_settings import get_current_search_settings
from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_pool import get_redis_client
from danswer.server.documents.models import CCPairFullInfo
from danswer.server.documents.models import CCStatusUpdateRequest
from danswer.server.documents.models import CeleryTaskStatus
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.server.documents.models import ConnectorCredentialPairMetadata
from danswer.server.documents.models import DocumentSyncStatus
from danswer.server.documents.models import PaginatedIndexAttempts
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
@@ -287,12 +288,12 @@ def prune_cc_pair(
)
@router.get("/admin/cc-pair/{cc_pair_id}/sync-permissions")
@router.get("/admin/cc-pair/{cc_pair_id}/sync")
def get_cc_pair_latest_sync(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> datetime | None:
) -> CeleryTaskStatus:
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
@@ -302,20 +303,34 @@ def get_cc_pair_latest_sync(
if not cc_pair:
raise HTTPException(
status_code=400,
detail="cc_pair not found for current user's permissions",
detail="Connection not found for current user's permissions",
)
return cc_pair.last_time_perm_sync
# look up the last sync task for this connector (if it exists)
sync_task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair_id)
last_sync_task = get_latest_task(sync_task_name, db_session)
if not last_sync_task:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="No sync task found.",
)
return CeleryTaskStatus(
id=last_sync_task.task_id,
name=last_sync_task.task_name,
status=last_sync_task.status,
start_time=last_sync_task.start_time,
register_time=last_sync_task.register_time,
)
@router.post("/admin/cc-pair/{cc_pair_id}/sync-permissions")
@router.post("/admin/cc-pair/{cc_pair_id}/sync")
def sync_cc_pair(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> StatusResponse[list[int]]:
"""Triggers permissions sync on a particular cc_pair immediately"""
# avoiding circular refs
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
@@ -329,49 +344,37 @@ def sync_cc_pair(
detail="Connection not found for current user's permissions",
)
r = get_redis_client(tenant_id=tenant_id)
sync_task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair_id)
last_sync_task = get_latest_task(sync_task_name, db_session)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if redis_connector.permissions.fenced:
if last_sync_task and check_task_is_live_and_not_timed_out(
last_sync_task, db_session
):
raise HTTPException(
status_code=HTTPStatus.CONFLICT,
detail="Doc permissions sync task already in progress.",
detail="Sync task already in progress.",
)
logger.info(
f"Doc permissions sync cc_pair={cc_pair_id} "
f"connector_id={cc_pair.connector_id} "
f"credential_id={cc_pair.credential_id} "
f"{cc_pair.connector.name} connector."
logger.info(f"Syncing the {cc_pair.connector.name} connector.")
sync_external_doc_permissions_task = fetch_ee_implementation_or_noop(
"danswer.background.celery.apps.primary",
"sync_external_doc_permissions_task",
None,
)
tasks_created = try_creating_permissions_sync_task(
primary_app, cc_pair_id, r, CURRENT_TENANT_ID_CONTEXTVAR.get()
)
if not tasks_created:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="Doc permissions sync task creation failed.",
if sync_external_doc_permissions_task:
sync_external_doc_permissions_task.apply_async(
kwargs=dict(
cc_pair_id=cc_pair_id, tenant_id=CURRENT_TENANT_ID_CONTEXTVAR.get()
),
)
return StatusResponse(
success=True,
message="Successfully created the doc permissions sync task.",
message="Successfully created the sync task.",
)
@router.get("/admin/cc-pair/{cc_pair_id}/get-docs-sync-status")
def get_docs_sync_status(
cc_pair_id: int,
_: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> list[DocumentSyncStatus]:
all_docs_for_cc_pair = get_documents_for_cc_pair(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
return [DocumentSyncStatus.from_model(doc) for doc in all_docs_for_cc_pair]
@router.put("/connector/{connector_id}/credential/{credential_id}")
def associate_credential_to_connector(
connector_id: int,
@@ -387,7 +390,6 @@ def associate_credential_to_connector(
user=user,
target_group_ids=metadata.groups,
object_is_public=metadata.access_type == AccessType.PUBLIC,
object_is_perm_sync=metadata.access_type == AccessType.SYNC,
)
try:

View File

@@ -81,6 +81,7 @@ from danswer.db.index_attempt import get_latest_index_attempts_by_status
from danswer.db.models import IndexingStatus
from danswer.db.models import SearchSettings
from danswer.db.models import User
from danswer.db.models import UserRole
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.file_store.file_store import get_default_file_store
@@ -664,8 +665,7 @@ def create_connector_from_model(
db_session=db_session,
user=user,
target_group_ids=connector_data.groups,
object_is_public=connector_data.access_type == AccessType.PUBLIC,
object_is_perm_sync=connector_data.access_type == AccessType.SYNC,
object_is_public=connector_data.is_public,
)
connector_base = connector_data.to_connector_base()
return create_connector(
@@ -683,31 +683,32 @@ def create_connector_with_mock_credential(
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> StatusResponse:
fetch_ee_implementation_or_noop(
"danswer.db.user_group", "validate_user_creation_permissions", None
)(
db_session=db_session,
user=user,
target_group_ids=connector_data.groups,
object_is_public=connector_data.access_type == AccessType.PUBLIC,
object_is_perm_sync=connector_data.access_type == AccessType.SYNC,
)
if user and user.role != UserRole.ADMIN:
if connector_data.is_public:
raise HTTPException(
status_code=401,
detail="User does not have permission to create public credentials",
)
if not connector_data.groups:
raise HTTPException(
status_code=401,
detail="Curators must specify 1+ groups",
)
try:
_validate_connector_allowed(connector_data.source)
connector_response = create_connector(
db_session=db_session,
connector_data=connector_data,
db_session=db_session, connector_data=connector_data
)
mock_credential = CredentialBase(
credential_json={},
admin_public=True,
source=connector_data.source,
credential_json={}, admin_public=True, source=connector_data.source
)
credential = create_credential(
credential_data=mock_credential,
user=user,
db_session=db_session,
mock_credential, user=user, db_session=db_session
)
access_type = (
AccessType.PUBLIC if connector_data.is_public else AccessType.PRIVATE
)
response = add_credential_to_connector(
@@ -715,7 +716,7 @@ def create_connector_with_mock_credential(
user=user,
connector_id=cast(int, connector_response.id), # will aways be an int
credential_id=credential.id,
access_type=connector_data.access_type,
access_type=access_type,
cc_pair_name=connector_data.name,
groups=connector_data.groups,
)
@@ -740,8 +741,7 @@ def update_connector_from_model(
db_session=db_session,
user=user,
target_group_ids=connector_data.groups,
object_is_public=connector_data.access_type == AccessType.PUBLIC,
object_is_perm_sync=connector_data.access_type == AccessType.SYNC,
object_is_public=connector_data.is_public,
)
connector_base = connector_data.to_connector_base()
except ValueError as e:

View File

@@ -14,7 +14,6 @@ from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import Connector
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential
from danswer.db.models import Document as DbDocument
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexAttemptError as DbIndexAttemptError
from danswer.db.models import IndexingStatus
@@ -22,20 +21,6 @@ from danswer.db.models import TaskStatus
from danswer.server.utils import mask_credential_dict
class DocumentSyncStatus(BaseModel):
doc_id: str
last_synced: datetime | None
last_modified: datetime | None
@classmethod
def from_model(cls, doc: DbDocument) -> "DocumentSyncStatus":
return DocumentSyncStatus(
doc_id=doc.id,
last_synced=doc.last_synced,
last_modified=doc.last_modified,
)
class DocumentInfo(BaseModel):
num_chunks: int
num_tokens: int
@@ -64,11 +49,11 @@ class ConnectorBase(BaseModel):
class ConnectorUpdateRequest(ConnectorBase):
access_type: AccessType
is_public: bool = True
groups: list[int] = Field(default_factory=list)
def to_connector_base(self) -> ConnectorBase:
return ConnectorBase(**self.model_dump(exclude={"access_type", "groups"}))
return ConnectorBase(**self.model_dump(exclude={"is_public", "groups"}))
class ConnectorSnapshot(ConnectorBase):
@@ -237,8 +222,6 @@ class CCPairFullInfo(BaseModel):
is_editable_for_current_user: bool
deletion_failure_message: str | None
indexing: bool
creator: UUID | None
creator_email: str | None
@classmethod
def from_models(
@@ -284,10 +267,6 @@ class CCPairFullInfo(BaseModel):
is_editable_for_current_user=is_editable_for_current_user,
deletion_failure_message=cc_pair_model.deletion_failure_message,
indexing=indexing,
creator=cc_pair_model.creator_id,
creator_email=cc_pair_model.creator.email
if cc_pair_model.creator
else None,
)

View File

@@ -11,7 +11,6 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_limited_user
from danswer.auth.users import current_user
from danswer.configs.constants import FileOrigin
from danswer.configs.constants import NotificationType
@@ -273,7 +272,7 @@ def list_personas(
@basic_router.get("/{persona_id}")
def get_persona(
persona_id: int,
user: User | None = Depends(current_limited_user),
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> PersonaSnapshot:
return PersonaSnapshot.from_model(

View File

@@ -630,25 +630,31 @@ def update_user_assistant_list(
db_session.commit()
def update_assistant_visibility(
def update_assistant_list(
preferences: UserPreferences, assistant_id: int, show: bool
) -> UserPreferences:
visible_assistants = preferences.visible_assistants or []
hidden_assistants = preferences.hidden_assistants or []
chosen_assistants = preferences.chosen_assistants or []
if show:
if assistant_id not in visible_assistants:
visible_assistants.append(assistant_id)
if assistant_id in hidden_assistants:
hidden_assistants.remove(assistant_id)
if assistant_id not in chosen_assistants:
chosen_assistants.append(assistant_id)
else:
if assistant_id in visible_assistants:
visible_assistants.remove(assistant_id)
if assistant_id not in hidden_assistants:
hidden_assistants.append(assistant_id)
if assistant_id in chosen_assistants:
chosen_assistants.remove(assistant_id)
preferences.visible_assistants = visible_assistants
preferences.hidden_assistants = hidden_assistants
preferences.chosen_assistants = chosen_assistants
return preferences
@@ -664,23 +670,15 @@ def update_user_assistant_visibility(
store = get_kv_store()
no_auth_user = fetch_no_auth_user(store)
preferences = no_auth_user.preferences
updated_preferences = update_assistant_visibility(
preferences, assistant_id, show
)
if updated_preferences.chosen_assistants is not None:
updated_preferences.chosen_assistants.append(assistant_id)
updated_preferences = update_assistant_list(preferences, assistant_id, show)
set_no_auth_user_preferences(store, updated_preferences)
return
else:
raise RuntimeError("This should never happen")
user_preferences = UserInfo.from_model(user).preferences
updated_preferences = update_assistant_visibility(
user_preferences, assistant_id, show
)
if updated_preferences.chosen_assistants is not None:
updated_preferences.chosen_assistants.append(assistant_id)
updated_preferences = update_assistant_list(user_preferences, assistant_id, show)
db_session.execute(
update(User)
.where(User.id == user.id) # type: ignore

View File

@@ -18,7 +18,6 @@ from PIL import Image
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.auth.users import current_limited_user
from danswer.auth.users import current_user
from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.chat_utils import extract_headers
@@ -310,7 +309,7 @@ async def is_connected(request: Request) -> Callable[[], bool]:
def handle_new_chat_message(
chat_message_req: CreateChatMessageRequest,
request: Request,
user: User | None = Depends(current_limited_user),
user: User | None = Depends(current_user),
_: None = Depends(check_token_rate_limits),
is_connected_func: Callable[[], bool] = Depends(is_connected),
) -> StreamingResponse:
@@ -392,7 +391,7 @@ def set_message_as_latest(
@router.post("/create-chat-message-feedback")
def create_chat_feedback(
feedback: ChatFeedbackRequest,
user: User | None = Depends(current_limited_user),
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
user_id = user.id if user else None

View File

@@ -9,7 +9,6 @@ from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_limited_user
from danswer.auth.users import current_user
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import MessageType
@@ -263,7 +262,7 @@ def stream_query_validation(
@basic_router.post("/stream-answer-with-quote")
def get_answer_with_quote(
query_request: DirectQARequest,
user: User = Depends(current_limited_user),
user: User = Depends(current_user),
_: None = Depends(check_token_rate_limits),
) -> StreamingResponse:
query = query_request.messages[0].message

View File

@@ -59,9 +59,7 @@ from shared_configs.model_server_models import SupportedEmbeddingModel
logger = setup_logger()
def setup_danswer(
db_session: Session, tenant_id: str | None, cohere_enabled: bool = False
) -> None:
def setup_danswer(db_session: Session, tenant_id: str | None) -> None:
"""
Setup Danswer for a particular tenant. In the Single Tenant case, it will set it up for the default schema
on server startup. In the MT case, it will be called when the tenant is created.
@@ -150,7 +148,7 @@ def setup_danswer(
# update multipass indexing setting based on GPU availability
update_default_multipass_indexing(db_session)
seed_initial_documents(db_session, tenant_id, cohere_enabled)
seed_initial_documents(db_session, tenant_id)
def translate_saved_search_settings(db_session: Session) -> None:

View File

@@ -21,12 +21,8 @@ pruning_ctx: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar(
"pruning_ctx", default=dict()
)
doc_permission_sync_ctx: contextvars.ContextVar[
dict[str, Any]
] = contextvars.ContextVar("doc_permission_sync_ctx", default=dict())
class TaskAttemptSingleton:
class IndexAttemptSingleton:
"""Used to tell if this process is an indexing job, and if so what is the
unique identifier for this indexing attempt. For things like the API server,
main background job (scheduler), etc. this will not be used."""
@@ -70,10 +66,9 @@ class DanswerLoggingAdapter(logging.LoggerAdapter):
) -> tuple[str, MutableMapping[str, Any]]:
# If this is an indexing job, add the attempt ID to the log message
# This helps filter the logs for this specific indexing
index_attempt_id = TaskAttemptSingleton.get_index_attempt_id()
cc_pair_id = TaskAttemptSingleton.get_connector_credential_pair_id()
index_attempt_id = IndexAttemptSingleton.get_index_attempt_id()
cc_pair_id = IndexAttemptSingleton.get_connector_credential_pair_id()
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
pruning_ctx_dict = pruning_ctx.get()
if len(pruning_ctx_dict) > 0:
if "request_id" in pruning_ctx_dict:
@@ -81,9 +76,6 @@ class DanswerLoggingAdapter(logging.LoggerAdapter):
if "cc_pair_id" in pruning_ctx_dict:
msg = f"[CC Pair: {pruning_ctx_dict['cc_pair_id']}] {msg}"
elif len(doc_permission_sync_ctx_dict) > 0:
if "request_id" in doc_permission_sync_ctx_dict:
msg = f"[Doc Permissions Sync: {doc_permission_sync_ctx_dict['request_id']}] {msg}"
else:
if index_attempt_id is not None:
msg = f"[Index Attempt: {index_attempt_id}] {msg}"

View File

@@ -1,12 +1,32 @@
from danswer.background.celery.apps.primary import celery_app
from danswer.background.task_name_builders import (
name_sync_external_doc_permissions_task,
)
from danswer.background.task_utils import build_celery_task_wrapper
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.db.chat import delete_chat_sessions_older_than
from danswer.db.engine import get_session_with_tenant
from danswer.server.settings.store import load_settings
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from ee.danswer.background.celery_utils import should_perform_chat_ttl_check
from ee.danswer.background.celery_utils import (
should_perform_external_doc_permissions_check,
)
from ee.danswer.background.celery_utils import (
should_perform_external_group_permissions_check,
)
from ee.danswer.background.task_name_builders import name_chat_ttl_task
from ee.danswer.background.task_name_builders import (
name_sync_external_group_permissions_task,
)
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.danswer.external_permissions.permission_sync import (
run_external_doc_permission_sync,
)
from ee.danswer.external_permissions.permission_sync import (
run_external_group_permission_sync,
)
from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
@@ -14,6 +34,25 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
# mark as EE for all tasks in this file
global_version.set_ee()
@build_celery_task_wrapper(name_sync_external_doc_permissions_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_external_doc_permissions_task(
cc_pair_id: int, *, tenant_id: str | None
) -> None:
with get_session_with_tenant(tenant_id) as db_session:
run_external_doc_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
@build_celery_task_wrapper(name_sync_external_group_permissions_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_external_group_permissions_task(
cc_pair_id: int, *, tenant_id: str | None
) -> None:
with get_session_with_tenant(tenant_id) as db_session:
run_external_group_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
@build_celery_task_wrapper(name_chat_ttl_task)
@@ -28,6 +67,38 @@ def perform_ttl_management_task(
#####
# Periodic Tasks
#####
@celery_app.task(
name="check_sync_external_doc_permissions_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_sync_external_doc_permissions_task(*, tenant_id: str | None) -> None:
"""Runs periodically to sync external permissions"""
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
for cc_pair in cc_pairs:
if should_perform_external_doc_permissions_check(
cc_pair=cc_pair, db_session=db_session
):
sync_external_doc_permissions_task.apply_async(
kwargs=dict(cc_pair_id=cc_pair.id, tenant_id=tenant_id),
)
@celery_app.task(
name="check_sync_external_group_permissions_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_sync_external_group_permissions_task(*, tenant_id: str | None) -> None:
"""Runs periodically to sync external group permissions"""
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
for cc_pair in cc_pairs:
if should_perform_external_group_permissions_check(
cc_pair=cc_pair, db_session=db_session
):
sync_external_group_permissions_task.apply_async(
kwargs=dict(cc_pair_id=cc_pair.id, tenant_id=tenant_id),
)
@celery_app.task(

View File

@@ -6,6 +6,16 @@ from danswer.background.celery.tasks.beat_schedule import (
)
ee_tasks_to_schedule = [
{
"name": "sync-external-doc-permissions",
"task": "check_sync_external_doc_permissions_task",
"schedule": timedelta(seconds=30), # TODO: optimize this
},
{
"name": "sync-external-group-permissions",
"task": "check_sync_external_group_permissions_task",
"schedule": timedelta(seconds=60), # TODO: optimize this
},
{
"name": "autogenerate_usage_report",
"task": "autogenerate_usage_report_task",

View File

@@ -1,13 +1,46 @@
from datetime import datetime
from datetime import timezone
from sqlalchemy.orm import Session
from danswer.background.task_name_builders import (
name_sync_external_doc_permissions_task,
)
from danswer.db.enums import AccessType
from danswer.db.models import ConnectorCredentialPair
from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.utils.logger import setup_logger
from ee.danswer.background.task_name_builders import name_chat_ttl_task
from ee.danswer.background.task_name_builders import (
name_sync_external_group_permissions_task,
)
from ee.danswer.external_permissions.sync_params import PERMISSION_SYNC_PERIODS
logger = setup_logger()
def _is_time_to_run_sync(cc_pair: ConnectorCredentialPair) -> bool:
source_sync_period = PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
if not source_sync_period:
return True
# If the last sync is None, it has never been run so we run the sync
if cc_pair.last_time_perm_sync is None:
return True
last_sync = cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc)
current_time = datetime.now(timezone.utc)
# If the last sync is greater than the full fetch period, we run the sync
if (current_time - last_sync).total_seconds() > source_sync_period:
return True
return False
def should_perform_chat_ttl_check(
retention_limit_days: int | None, db_session: Session
) -> bool:
@@ -24,3 +57,47 @@ def should_perform_chat_ttl_check(
logger.debug(f"{task_name} is already being performed. Skipping.")
return False
return True
def should_perform_external_doc_permissions_check(
cc_pair: ConnectorCredentialPair, db_session: Session
) -> bool:
if cc_pair.access_type != AccessType.SYNC:
return False
task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair.id)
latest_task = get_latest_task(task_name, db_session)
if not latest_task:
return True
if check_task_is_live_and_not_timed_out(latest_task, db_session):
logger.debug(f"{task_name} is already being performed. Skipping.")
return False
if not _is_time_to_run_sync(cc_pair):
return False
return True
def should_perform_external_group_permissions_check(
cc_pair: ConnectorCredentialPair, db_session: Session
) -> bool:
if cc_pair.access_type != AccessType.SYNC:
return False
task_name = name_sync_external_group_permissions_task(cc_pair_id=cc_pair.id)
latest_task = get_latest_task(task_name, db_session)
if not latest_task:
return True
if check_task_is_live_and_not_timed_out(latest_task, db_session):
logger.debug(f"{task_name} is already being performed. Skipping.")
return False
if not _is_time_to_run_sync(cc_pair):
return False
return True

View File

@@ -1,2 +1,8 @@
def name_chat_ttl_task(retention_limit_days: int, tenant_id: str | None = None) -> str:
return f"chat_ttl_{retention_limit_days}_days"
def name_sync_external_group_permissions_task(
cc_pair_id: int, tenant_id: str | None = None
) -> str:
return f"sync_external_group_permissions_task__{cc_pair_id}"

View File

@@ -1,6 +1,3 @@
from datetime import datetime
from datetime import timezone
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -48,53 +45,3 @@ def upsert_document_external_perms__no_commit(
document.external_user_emails = list(external_access.external_user_emails)
document.external_user_group_ids = prefixed_external_groups
document.is_public = external_access.is_public
def upsert_document_external_perms(
db_session: Session,
doc_id: str,
external_access: ExternalAccess,
source_type: DocumentSource,
) -> None:
"""
This sets the permissions for a document in postgres.
NOTE: this will replace any existing external access, it will not do a union
"""
document = db_session.scalars(
select(DbDocument).where(DbDocument.id == doc_id)
).first()
prefixed_external_groups: set[str] = {
prefix_group_w_source(
ext_group_name=group_id,
source=source_type,
)
for group_id in external_access.external_user_group_ids
}
if not document:
# If the document does not exist, still store the external access
# So that if the document is added later, the external access is already stored
# The upsert function in the indexing pipeline does not overwrite the permissions fields
document = DbDocument(
id=doc_id,
semantic_id="",
external_user_emails=external_access.external_user_emails,
external_user_group_ids=prefixed_external_groups,
is_public=external_access.is_public,
)
db_session.add(document)
db_session.commit()
return
# If the document exists, we need to check if the external access has changed
if (
external_access.external_user_emails != set(document.external_user_emails or [])
or prefixed_external_groups != set(document.external_user_group_ids or [])
or external_access.is_public != document.is_public
):
document.external_user_emails = list(external_access.external_user_emails)
document.external_user_group_ids = list(prefixed_external_groups)
document.is_public = external_access.is_public
document.last_modified = datetime.now(timezone.utc)
db_session.commit()

View File

@@ -9,12 +9,11 @@ from sqlalchemy.orm import Session
from danswer.access.utils import prefix_group_w_source
from danswer.configs.constants import DocumentSource
from danswer.db.models import User__ExternalUserGroupId
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
class ExternalUserGroup(BaseModel):
id: str
user_emails: list[str]
user_ids: list[UUID]
def delete_user__ext_group_for_user__no_commit(
@@ -39,7 +38,7 @@ def delete_user__ext_group_for_cc_pair__no_commit(
)
def replace_user__ext_group_for_cc_pair(
def replace_user__ext_group_for_cc_pair__no_commit(
db_session: Session,
cc_pair_id: int,
group_defs: list[ExternalUserGroup],
@@ -47,44 +46,24 @@ def replace_user__ext_group_for_cc_pair(
) -> None:
"""
This function clears all existing external user group relations for a given cc_pair_id
and replaces them with the new group definitions and commits the changes.
and replaces them with the new group definitions.
"""
delete_user__ext_group_for_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
# collect all emails from all groups to batch add all users at once for efficiency
all_group_member_emails = set()
for external_group in group_defs:
for user_email in external_group.user_emails:
all_group_member_emails.add(user_email)
# batch add users if they don't exist and get their ids
all_group_members = batch_add_non_web_user_if_not_exists__no_commit(
db_session=db_session, emails=list(all_group_member_emails)
)
# map emails to ids
email_id_map = {user.email: user.id for user in all_group_members}
# use these ids to create new external user group relations relating group_id to user_ids
new_external_permissions = []
for external_group in group_defs:
for user_email in external_group.user_emails:
user_id = email_id_map[user_email]
new_external_permissions.append(
User__ExternalUserGroupId(
user_id=user_id,
external_user_group_id=prefix_group_w_source(
external_group.id, source
),
cc_pair_id=cc_pair_id,
)
)
new_external_permissions = [
User__ExternalUserGroupId(
user_id=user_id,
external_user_group_id=prefix_group_w_source(external_group.id, source),
cc_pair_id=cc_pair_id,
)
for external_group in group_defs
for user_id in external_group.user_ids
]
db_session.add_all(new_external_permissions)
db_session.commit()
def fetch_external_groups_for_user(

View File

@@ -124,21 +124,16 @@ def _cleanup_document_set__user_group_relationships__no_commit(
def validate_user_creation_permissions(
db_session: Session,
user: User | None,
target_group_ids: list[int] | None = None,
object_is_public: bool | None = None,
object_is_perm_sync: bool | None = None,
target_group_ids: list[int] | None,
object_is_public: bool | None,
) -> None:
"""
All users can create/edit permission synced objects if they don't specify a group
All admin actions are allowed.
Prevents non-admins from creating/editing:
- public objects
- objects with no groups
- objects that belong to a group they don't curate
"""
if object_is_perm_sync and not target_group_ids:
return
if not user or user.role == UserRole.ADMIN:
return

View File

@@ -4,14 +4,17 @@ https://confluence.atlassian.com/conf85/check-who-can-view-a-page-1283360557.htm
"""
from typing import Any
from danswer.access.models import DocExternalAccess
from sqlalchemy.orm import Session
from danswer.access.models import ExternalAccess
from danswer.connectors.confluence.connector import ConfluenceConnector
from danswer.connectors.confluence.onyx_confluence import OnyxConfluence
from danswer.connectors.confluence.utils import get_user_email_from_username__server
from danswer.connectors.models import SlimDocument
from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
from danswer.utils.logger import setup_logger
from ee.danswer.db.document import upsert_document_external_perms__no_commit
logger = setup_logger()
@@ -160,13 +163,7 @@ def _extract_read_access_restrictions(
f"Email for user {user['username']} not found in Confluence"
)
else:
if user.get("email") is not None:
logger.warning(f"Cant find email for user {user.get('displayName')}")
logger.warning(
"This user needs to make their email accessible in Confluence Settings"
)
logger.warning(f"no user email or username for {user}")
logger.warning(f"User {user} does not have an email or username")
# Extract the groups with read access
read_access_group = read_access_restrictions.get("group", {})
@@ -193,12 +190,12 @@ def _fetch_all_page_restrictions_for_space(
confluence_client: OnyxConfluence,
slim_docs: list[SlimDocument],
space_permissions_by_space_key: dict[str, ExternalAccess],
) -> list[DocExternalAccess]:
) -> dict[str, ExternalAccess]:
"""
For all pages, if a page has restrictions, then use those restrictions.
Otherwise, use the space's restrictions.
"""
document_restrictions: list[DocExternalAccess] = []
document_restrictions: dict[str, ExternalAccess] = {}
for slim_doc in slim_docs:
if slim_doc.perm_sync_data is None:
@@ -210,34 +207,21 @@ def _fetch_all_page_restrictions_for_space(
restrictions=slim_doc.perm_sync_data.get("restrictions", {}),
)
if restrictions:
document_restrictions.append(
DocExternalAccess(
doc_id=slim_doc.id,
external_access=restrictions,
)
)
# If there are restrictions, then we don't need to use the space's restrictions
continue
space_key = slim_doc.perm_sync_data.get("space_key")
if space_permissions := space_permissions_by_space_key.get(space_key):
# If there are no restrictions, then use the space's restrictions
document_restrictions.append(
DocExternalAccess(
doc_id=slim_doc.id,
external_access=space_permissions,
)
)
continue
logger.warning(f"No permissions found for document {slim_doc.id}")
document_restrictions[slim_doc.id] = restrictions
else:
space_key = slim_doc.perm_sync_data.get("space_key")
if space_permissions := space_permissions_by_space_key.get(space_key):
document_restrictions[slim_doc.id] = space_permissions
else:
logger.warning(f"No permissions found for document {slim_doc.id}")
return document_restrictions
def confluence_doc_sync(
db_session: Session,
cc_pair: ConnectorCredentialPair,
) -> list[DocExternalAccess]:
) -> None:
"""
Adds the external permissions to the documents in postgres
if the document doesn't already exists in postgres, we create
@@ -263,8 +247,20 @@ def confluence_doc_sync(
for doc_batch in confluence_connector.retrieve_all_slim_documents():
slim_docs.extend(doc_batch)
return _fetch_all_page_restrictions_for_space(
permissions_by_doc_id = _fetch_all_page_restrictions_for_space(
confluence_client=confluence_client,
slim_docs=slim_docs,
space_permissions_by_space_key=space_permissions_by_space_key,
)
all_emails = set()
for doc_id, page_specific_access in permissions_by_doc_id.items():
upsert_document_external_perms__no_commit(
db_session=db_session,
doc_id=doc_id,
external_access=page_specific_access,
source_type=cc_pair.connector.source,
)
all_emails.update(page_specific_access.external_user_emails)
batch_add_non_web_user_if_not_exists__no_commit(db_session, list(all_emails))

View File

@@ -1,11 +1,15 @@
from typing import Any
from sqlalchemy.orm import Session
from danswer.connectors.confluence.onyx_confluence import OnyxConfluence
from danswer.connectors.confluence.utils import build_confluence_client
from danswer.connectors.confluence.utils import get_user_email_from_username__server
from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
from danswer.utils.logger import setup_logger
from ee.danswer.db.external_perm import ExternalUserGroup
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit
logger = setup_logger()
@@ -36,8 +40,9 @@ def _get_group_members_email_paginated(
def confluence_group_sync(
db_session: Session,
cc_pair: ConnectorCredentialPair,
) -> list[ExternalUserGroup]:
) -> None:
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
confluence_client = build_confluence_client(
credentials_json=cc_pair.credential.credential_json,
@@ -58,13 +63,20 @@ def confluence_group_sync(
group_member_emails = _get_group_members_email_paginated(
confluence_client, group_name
)
if not group_member_emails:
continue
danswer_groups.append(
ExternalUserGroup(
id=group_name,
user_emails=list(group_member_emails),
)
group_members = batch_add_non_web_user_if_not_exists__no_commit(
db_session=db_session, emails=list(group_member_emails)
)
if group_members:
danswer_groups.append(
ExternalUserGroup(
id=group_name,
user_ids=[user.id for user in group_members],
)
)
return danswer_groups
replace_user__ext_group_for_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=cc_pair.id,
group_defs=danswer_groups,
source=cc_pair.connector.source,
)

View File

@@ -1,12 +1,15 @@
from datetime import datetime
from datetime import timezone
from danswer.access.models import DocExternalAccess
from sqlalchemy.orm import Session
from danswer.access.models import ExternalAccess
from danswer.connectors.gmail.connector import GmailConnector
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
from danswer.utils.logger import setup_logger
from ee.danswer.db.document import upsert_document_external_perms__no_commit
logger = setup_logger()
@@ -28,8 +31,9 @@ def _get_slim_doc_generator(
def gmail_doc_sync(
db_session: Session,
cc_pair: ConnectorCredentialPair,
) -> list[DocExternalAccess]:
) -> None:
"""
Adds the external permissions to the documents in postgres
if the document doesn't already exists in postgres, we create
@@ -41,7 +45,6 @@ def gmail_doc_sync(
slim_doc_generator = _get_slim_doc_generator(cc_pair, gmail_connector)
document_external_access: list[DocExternalAccess] = []
for slim_doc_batch in slim_doc_generator:
for slim_doc in slim_doc_batch:
if slim_doc.perm_sync_data is None:
@@ -53,11 +56,13 @@ def gmail_doc_sync(
external_user_group_ids=set(),
is_public=False,
)
document_external_access.append(
DocExternalAccess(
doc_id=slim_doc.id,
external_access=ext_access,
)
batch_add_non_web_user_if_not_exists__no_commit(
db_session=db_session,
emails=list(ext_access.external_user_emails),
)
upsert_document_external_perms__no_commit(
db_session=db_session,
doc_id=slim_doc.id,
external_access=ext_access,
source_type=cc_pair.connector.source,
)
return document_external_access

View File

@@ -2,7 +2,8 @@ from datetime import datetime
from datetime import timezone
from typing import Any
from danswer.access.models import DocExternalAccess
from sqlalchemy.orm import Session
from danswer.access.models import ExternalAccess
from danswer.connectors.google_drive.connector import GoogleDriveConnector
from danswer.connectors.google_utils.google_utils import execute_paginated_retrieval
@@ -10,7 +11,9 @@ from danswer.connectors.google_utils.resources import get_drive_service
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.models import SlimDocument
from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
from danswer.utils.logger import setup_logger
from ee.danswer.db.document import upsert_document_external_perms__no_commit
logger = setup_logger()
@@ -110,13 +113,8 @@ def _get_permissions_from_slim_doc(
elif permission_type == "group":
group_emails.add(permission["emailAddress"])
elif permission_type == "domain" and company_domain:
if permission.get("domain") == company_domain:
if permission["domain"] == company_domain:
public = True
else:
logger.warning(
"Permission is type domain but does not match company domain:"
f"\n {permission}"
)
elif permission_type == "anyone":
public = True
@@ -128,8 +126,9 @@ def _get_permissions_from_slim_doc(
def gdrive_doc_sync(
db_session: Session,
cc_pair: ConnectorCredentialPair,
) -> list[DocExternalAccess]:
) -> None:
"""
Adds the external permissions to the documents in postgres
if the document doesn't already exists in postgres, we create
@@ -143,17 +142,19 @@ def gdrive_doc_sync(
slim_doc_generator = _get_slim_doc_generator(cc_pair, google_drive_connector)
document_external_accesses = []
for slim_doc_batch in slim_doc_generator:
for slim_doc in slim_doc_batch:
ext_access = _get_permissions_from_slim_doc(
google_drive_connector=google_drive_connector,
slim_doc=slim_doc,
)
document_external_accesses.append(
DocExternalAccess(
external_access=ext_access,
doc_id=slim_doc.id,
)
batch_add_non_web_user_if_not_exists__no_commit(
db_session=db_session,
emails=list(ext_access.external_user_emails),
)
upsert_document_external_perms__no_commit(
db_session=db_session,
doc_id=slim_doc.id,
external_access=ext_access,
source_type=cc_pair.connector.source,
)
return document_external_accesses

View File

@@ -1,16 +1,21 @@
from sqlalchemy.orm import Session
from danswer.connectors.google_drive.connector import GoogleDriveConnector
from danswer.connectors.google_utils.google_utils import execute_paginated_retrieval
from danswer.connectors.google_utils.resources import get_admin_service
from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
from danswer.utils.logger import setup_logger
from ee.danswer.db.external_perm import ExternalUserGroup
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit
logger = setup_logger()
def gdrive_group_sync(
db_session: Session,
cc_pair: ConnectorCredentialPair,
) -> list[ExternalUserGroup]:
) -> None:
google_drive_connector = GoogleDriveConnector(
**cc_pair.connector.connector_specific_config
)
@@ -39,14 +44,20 @@ def gdrive_group_sync(
):
group_member_emails.append(member["email"])
if not group_member_emails:
continue
danswer_groups.append(
ExternalUserGroup(
id=group_email,
user_emails=list(group_member_emails),
)
# Add group members to DB and get their IDs
group_members = batch_add_non_web_user_if_not_exists__no_commit(
db_session=db_session, emails=group_member_emails
)
if group_members:
danswer_groups.append(
ExternalUserGroup(
id=group_email, user_ids=[user.id for user in group_members]
)
)
return danswer_groups
replace_user__ext_group_for_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=cc_pair.id,
group_defs=danswer_groups,
source=cc_pair.connector.source,
)

View File

@@ -0,0 +1,115 @@
from datetime import datetime
from datetime import timezone
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_documents
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import get_document_ids_for_connector_credential_pair
from danswer.document_index.factory import get_current_primary_default_document_index
from danswer.document_index.interfaces import UpdateRequest
from danswer.utils.logger import setup_logger
from ee.danswer.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
logger = setup_logger()
def run_external_group_permission_sync(
db_session: Session,
cc_pair_id: int,
) -> None:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if cc_pair is None:
raise ValueError(f"No connector credential pair found for id: {cc_pair_id}")
source_type = cc_pair.connector.source
group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
if group_sync_func is None:
# Not all sync connectors support group permissions so this is fine
return
try:
# This function updates:
# - the user_email <-> external_user_group_id mapping
# in postgres without committing
logger.debug(f"Syncing groups for {source_type}")
if group_sync_func is not None:
group_sync_func(
db_session,
cc_pair,
)
# update postgres
db_session.commit()
except Exception:
logger.exception("Error Syncing Group Permissions")
db_session.rollback()
def run_external_doc_permission_sync(
db_session: Session,
cc_pair_id: int,
) -> None:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if cc_pair is None:
raise ValueError(f"No connector credential pair found for id: {cc_pair_id}")
source_type = cc_pair.connector.source
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
last_time_perm_sync = cc_pair.last_time_perm_sync
if doc_sync_func is None:
raise ValueError(
f"No permission sync function found for source type: {source_type}"
)
try:
# This function updates:
# - the user_email <-> document mapping
# - the external_user_group_id <-> document mapping
# in postgres without committing
logger.info(f"Syncing docs for {source_type}")
doc_sync_func(
db_session,
cc_pair,
)
# Get the document ids for the cc pair
document_ids_for_cc_pair = get_document_ids_for_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# This function fetches the updated access for the documents
# and returns a dictionary of document_ids and access
# This is the access we want to update vespa with
docs_access = get_access_for_documents(
document_ids=document_ids_for_cc_pair,
db_session=db_session,
)
# Then we build the update requests to update vespa
update_reqs = [
UpdateRequest(document_ids=[doc_id], access=doc_access)
for doc_id, doc_access in docs_access.items()
]
# Don't bother sync-ing secondary, it will be sync-ed after switch anyway
document_index = get_current_primary_default_document_index(db_session)
# update vespa
document_index.update(update_reqs)
cc_pair.last_time_perm_sync = datetime.now(timezone.utc)
# update postgres
db_session.commit()
logger.info(f"Successfully synced docs for {source_type}")
except Exception:
logger.exception("Error Syncing Document Permissions")
cc_pair.last_time_perm_sync = last_time_perm_sync
db_session.rollback()

View File

@@ -1,12 +1,16 @@
from slack_sdk import WebClient
from sqlalchemy.orm import Session
from danswer.access.models import DocExternalAccess
from danswer.access.models import ExternalAccess
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.interfaces import SlimConnector
from danswer.connectors.models import InputType
from danswer.connectors.slack.connector import get_channels
from danswer.connectors.slack.connector import make_paginated_slack_api_call_w_retries
from danswer.connectors.slack.connector import SlackPollConnector
from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
from danswer.utils.logger import setup_logger
from ee.danswer.db.document import upsert_document_external_perms__no_commit
from ee.danswer.external_permissions.slack.utils import fetch_user_id_to_email_map
@@ -14,15 +18,22 @@ logger = setup_logger()
def _get_slack_document_ids_and_channels(
db_session: Session,
cc_pair: ConnectorCredentialPair,
) -> dict[str, list[str]]:
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
slack_connector.load_credentials(cc_pair.credential.credential_json)
# Get all document ids that need their permissions updated
runnable_connector = instantiate_connector(
db_session=db_session,
source=cc_pair.connector.source,
input_type=InputType.SLIM_RETRIEVAL,
connector_specific_config=cc_pair.connector.connector_specific_config,
credential=cc_pair.credential,
)
slim_doc_generator = slack_connector.retrieve_all_slim_documents()
assert isinstance(runnable_connector, SlimConnector)
channel_doc_map: dict[str, list[str]] = {}
for doc_metadata_batch in slim_doc_generator:
for doc_metadata_batch in runnable_connector.retrieve_all_slim_documents():
for doc_metadata in doc_metadata_batch:
if doc_metadata.perm_sync_data is None:
continue
@@ -35,11 +46,13 @@ def _get_slack_document_ids_and_channels(
def _fetch_workspace_permissions(
db_session: Session,
user_id_to_email_map: dict[str, str],
) -> ExternalAccess:
user_emails = set()
for email in user_id_to_email_map.values():
user_emails.add(email)
batch_add_non_web_user_if_not_exists__no_commit(db_session, list(user_emails))
return ExternalAccess(
external_user_emails=user_emails,
# No group<->document mapping for slack
@@ -50,6 +63,7 @@ def _fetch_workspace_permissions(
def _fetch_channel_permissions(
db_session: Session,
slack_client: WebClient,
workspace_permissions: ExternalAccess,
user_id_to_email_map: dict[str, str],
@@ -99,6 +113,9 @@ def _fetch_channel_permissions(
# If no email is found, we skip the user
continue
user_id_to_email_map[member_id] = member_email
batch_add_non_web_user_if_not_exists__no_commit(
db_session, [member_email]
)
member_emails.add(member_email)
@@ -114,8 +131,9 @@ def _fetch_channel_permissions(
def slack_doc_sync(
db_session: Session,
cc_pair: ConnectorCredentialPair,
) -> list[DocExternalAccess]:
) -> None:
"""
Adds the external permissions to the documents in postgres
if the document doesn't already exists in postgres, we create
@@ -127,18 +145,19 @@ def slack_doc_sync(
)
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
channel_doc_map = _get_slack_document_ids_and_channels(
db_session=db_session,
cc_pair=cc_pair,
)
workspace_permissions = _fetch_workspace_permissions(
db_session=db_session,
user_id_to_email_map=user_id_to_email_map,
)
channel_permissions = _fetch_channel_permissions(
db_session=db_session,
slack_client=slack_client,
workspace_permissions=workspace_permissions,
user_id_to_email_map=user_id_to_email_map,
)
document_external_accesses = []
for channel_id, ext_access in channel_permissions.items():
doc_ids = channel_doc_map.get(channel_id)
if not doc_ids:
@@ -146,10 +165,9 @@ def slack_doc_sync(
continue
for doc_id in doc_ids:
document_external_accesses.append(
DocExternalAccess(
external_access=ext_access,
doc_id=doc_id,
)
upsert_document_external_perms__no_commit(
db_session=db_session,
doc_id=doc_id,
external_access=ext_access,
source_type=cc_pair.connector.source,
)
return document_external_accesses

View File

@@ -5,11 +5,14 @@ SO WHEN CHECKING IF A USER CAN ACCESS A DOCUMENT, WE ONLY NEED TO CHECK THEIR EM
THERE IS NO USERGROUP <-> DOCUMENT PERMISSION MAPPING
"""
from slack_sdk import WebClient
from sqlalchemy.orm import Session
from danswer.connectors.slack.connector import make_paginated_slack_api_call_w_retries
from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
from danswer.utils.logger import setup_logger
from ee.danswer.db.external_perm import ExternalUserGroup
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit
from ee.danswer.external_permissions.slack.utils import fetch_user_id_to_email_map
logger = setup_logger()
@@ -26,6 +29,7 @@ def _get_slack_group_ids(
def _get_slack_group_members_email(
db_session: Session,
slack_client: WebClient,
group_name: str,
user_id_to_email_map: dict[str, str],
@@ -45,14 +49,18 @@ def _get_slack_group_members_email(
# If no email is found, we skip the user
continue
user_id_to_email_map[member_id] = member_email
batch_add_non_web_user_if_not_exists__no_commit(
db_session, [member_email]
)
group_member_emails.append(member_email)
return group_member_emails
def slack_group_sync(
db_session: Session,
cc_pair: ConnectorCredentialPair,
) -> list[ExternalUserGroup]:
) -> None:
slack_client = WebClient(
token=cc_pair.credential.credential_json["slack_bot_token"]
)
@@ -61,13 +69,24 @@ def slack_group_sync(
danswer_groups: list[ExternalUserGroup] = []
for group_name in _get_slack_group_ids(slack_client):
group_member_emails = _get_slack_group_members_email(
db_session=db_session,
slack_client=slack_client,
group_name=group_name,
user_id_to_email_map=user_id_to_email_map,
)
if not group_member_emails:
continue
danswer_groups.append(
ExternalUserGroup(id=group_name, user_emails=group_member_emails)
group_members = batch_add_non_web_user_if_not_exists__no_commit(
db_session=db_session, emails=group_member_emails
)
return danswer_groups
if group_members:
danswer_groups.append(
ExternalUserGroup(
id=group_name, user_ids=[user.id for user in group_members]
)
)
replace_user__ext_group_for_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=cc_pair.id,
group_defs=danswer_groups,
source=cc_pair.connector.source,
)

View File

@@ -1,9 +1,9 @@
from collections.abc import Callable
from danswer.access.models import DocExternalAccess
from sqlalchemy.orm import Session
from danswer.configs.constants import DocumentSource
from danswer.db.models import ConnectorCredentialPair
from ee.danswer.db.external_perm import ExternalUserGroup
from ee.danswer.external_permissions.confluence.doc_sync import confluence_doc_sync
from ee.danswer.external_permissions.confluence.group_sync import confluence_group_sync
from ee.danswer.external_permissions.gmail.doc_sync import gmail_doc_sync
@@ -12,18 +12,12 @@ from ee.danswer.external_permissions.google_drive.group_sync import gdrive_group
from ee.danswer.external_permissions.slack.doc_sync import slack_doc_sync
# Defining the input/output types for the sync functions
DocSyncFuncType = Callable[
SyncFuncType = Callable[
[
Session,
ConnectorCredentialPair,
],
list[DocExternalAccess],
]
GroupSyncFuncType = Callable[
[
ConnectorCredentialPair,
],
list[ExternalUserGroup],
None,
]
# These functions update:
@@ -31,7 +25,7 @@ GroupSyncFuncType = Callable[
# - the external_user_group_id <-> document mapping
# in postgres without committing
# THIS ONE IS NECESSARY FOR AUTO SYNC TO WORK
DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, DocSyncFuncType] = {
DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, SyncFuncType] = {
DocumentSource.GOOGLE_DRIVE: gdrive_doc_sync,
DocumentSource.CONFLUENCE: confluence_doc_sync,
DocumentSource.SLACK: slack_doc_sync,
@@ -42,21 +36,19 @@ DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, DocSyncFuncType] = {
# - the user_email <-> external_user_group_id mapping
# in postgres without committing
# THIS ONE IS OPTIONAL ON AN APP BY APP BASIS
GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, GroupSyncFuncType] = {
GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, SyncFuncType] = {
DocumentSource.GOOGLE_DRIVE: gdrive_group_sync,
DocumentSource.CONFLUENCE: confluence_group_sync,
}
# If nothing is specified here, we run the doc_sync every time the celery beat runs
DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all doc permissions every 5 minutes
DocumentSource.CONFLUENCE: 5 * 60,
DocumentSource.SLACK: 5 * 60,
}
EXTERNAL_GROUP_SYNC_PERIOD: int = 30 # 30 seconds
def check_if_valid_sync_source(source_type: DocumentSource) -> bool:
return source_type in DOC_PERMISSIONS_FUNC_MAP

View File

@@ -1,5 +1,4 @@
from fastapi import FastAPI
from httpx_oauth.clients.google import GoogleOAuth2
from httpx_oauth.clients.openid import OpenID
from danswer.auth.users import auth_backend
@@ -60,31 +59,6 @@ def get_application() -> FastAPI:
if MULTI_TENANT:
add_tenant_id_middleware(application, logger)
if AUTH_TYPE == AuthType.CLOUD:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_router_with_global_prefix_prepended(
application,
create_danswer_oauth_router(
oauth_client,
auth_backend,
USER_AUTH_SECRET,
associate_by_email=True,
is_verified_by_default=True,
# Points the user back to the login page
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
),
prefix="/auth/oauth",
tags=["auth"],
)
# Need basic auth router for `logout` endpoint
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_logout_router(auth_backend),
prefix="/auth",
tags=["auth"],
)
if AUTH_TYPE == AuthType.OIDC:
include_router_with_global_prefix_prepended(
application,
@@ -99,7 +73,6 @@ def get_application() -> FastAPI:
prefix="/auth/oidc",
tags=["auth"],
)
# need basic auth router for `logout` endpoint
include_router_with_global_prefix_prepended(
application,

View File

@@ -1,45 +0,0 @@
import json
import os
from typing import cast
from typing import List
from cohere import Client
from ee.danswer.configs.app_configs import COHERE_DEFAULT_API_KEY
Embedding = List[float]
def load_processed_docs(cohere_enabled: bool) -> list[dict]:
base_path = os.path.join(os.getcwd(), "danswer", "seeding")
if cohere_enabled and COHERE_DEFAULT_API_KEY:
initial_docs_path = os.path.join(base_path, "initial_docs_cohere.json")
processed_docs = json.load(open(initial_docs_path))
cohere_client = Client(api_key=COHERE_DEFAULT_API_KEY)
embed_model = "embed-english-v3.0"
for doc in processed_docs:
title_embed_response = cohere_client.embed(
texts=[doc["title"]],
model=embed_model,
input_type="search_document",
)
content_embed_response = cohere_client.embed(
texts=[doc["content"]],
model=embed_model,
input_type="search_document",
)
doc["title_embedding"] = cast(
List[Embedding], title_embed_response.embeddings
)[0]
doc["content_embedding"] = cast(
List[Embedding], content_embed_response.embeddings
)[0]
else:
initial_docs_path = os.path.join(base_path, "initial_docs.json")
processed_docs = json.load(open(initial_docs_path))
return processed_docs

View File

@@ -38,4 +38,3 @@ class ImpersonateRequest(BaseModel):
class TenantCreationPayload(BaseModel):
tenant_id: str
email: str
referral_source: str | None = None

View File

@@ -4,7 +4,6 @@ import uuid
import aiohttp # Async HTTP client
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.auth.users import exceptions
@@ -14,8 +13,6 @@ from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.llm import update_default_provider
from danswer.db.llm import upsert_cloud_embedding_provider
from danswer.db.llm import upsert_llm_provider
from danswer.db.models import IndexModelStatus
from danswer.db.models import SearchSettings
from danswer.db.models import UserTenantMapping
from danswer.llm.llm_provider_options import ANTHROPIC_MODEL_NAMES
from danswer.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
@@ -44,9 +41,7 @@ from shared_configs.enums import EmbeddingProvider
logger = logging.getLogger(__name__)
async def get_or_create_tenant_id(
email: str, referral_source: str | None = None
) -> str:
async def get_or_create_tenant_id(email: str) -> str:
"""Get existing tenant ID for an email or create a new tenant if none exists."""
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
@@ -56,7 +51,7 @@ async def get_or_create_tenant_id(
except exceptions.UserNotExists:
# If tenant does not exist and in Multi tenant mode, provision a new tenant
try:
tenant_id = await create_tenant(email, referral_source)
tenant_id = await create_tenant(email)
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
@@ -69,13 +64,13 @@ async def get_or_create_tenant_id(
return tenant_id
async def create_tenant(email: str, referral_source: str | None = None) -> str:
async def create_tenant(email: str) -> str:
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
try:
# Provision tenant on data plane
await provision_tenant(tenant_id, email)
# Notify control plane
await notify_control_plane(tenant_id, email, referral_source)
await notify_control_plane(tenant_id, email)
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
await rollback_tenant_provisioning(tenant_id)
@@ -107,19 +102,9 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
await asyncio.to_thread(run_alembic_migrations, tenant_id)
with get_session_with_tenant(tenant_id) as db_session:
setup_danswer(db_session, tenant_id)
configure_default_api_keys(db_session)
current_search_settings = (
db_session.query(SearchSettings)
.filter_by(status=IndexModelStatus.FUTURE)
.first()
)
cohere_enabled = (
current_search_settings is not None
and current_search_settings.provider_type == EmbeddingProvider.COHERE
)
setup_danswer(db_session, tenant_id, cohere_enabled=cohere_enabled)
add_users_to_tenant([email], tenant_id)
except Exception as e:
@@ -132,18 +117,14 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
async def notify_control_plane(
tenant_id: str, email: str, referral_source: str | None = None
) -> None:
async def notify_control_plane(tenant_id: str, email: str) -> None:
logger.info("Fetching billing information")
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
payload = TenantCreationPayload(
tenant_id=tenant_id, email=email, referral_source=referral_source
)
payload = TenantCreationPayload(tenant_id=tenant_id, email=email)
async with aiohttp.ClientSession() as session:
async with session.post(
@@ -219,51 +200,11 @@ def configure_default_api_keys(db_session: Session) -> None:
provider_type=EmbeddingProvider.COHERE,
api_key=COHERE_DEFAULT_API_KEY,
)
try:
logger.info("Attempting to upsert Cohere cloud embedding provider")
upsert_cloud_embedding_provider(db_session, cloud_embedding_provider)
logger.info("Successfully upserted Cohere cloud embedding provider")
logger.info("Updating search settings with Cohere embedding model details")
query = (
select(SearchSettings)
.where(SearchSettings.status == IndexModelStatus.FUTURE)
.order_by(SearchSettings.id.desc())
)
result = db_session.execute(query)
current_search_settings = result.scalars().first()
if current_search_settings:
current_search_settings.model_name = (
"embed-english-v3.0" # Cohere's latest model as of now
)
current_search_settings.model_dim = (
1024 # Cohere's embed-english-v3.0 dimension
)
current_search_settings.provider_type = EmbeddingProvider.COHERE
current_search_settings.index_name = (
"danswer_chunk_cohere_embed_english_v3_0"
)
current_search_settings.query_prefix = ""
current_search_settings.passage_prefix = ""
db_session.commit()
else:
raise RuntimeError(
"No search settings specified, DB is not in a valid state"
)
logger.info("Fetching updated search settings to verify changes")
updated_query = (
select(SearchSettings)
.where(SearchSettings.status == IndexModelStatus.PRESENT)
.order_by(SearchSettings.id.desc())
)
updated_result = db_session.execute(updated_query)
updated_result.scalars().first()
except Exception:
logger.exception("Failed to configure Cohere embedding provider")
except Exception as e:
logger.error(f"Failed to configure Cohere embedding provider: {e}")
else:
logger.info(
logger.error(
"COHERE_DEFAULT_API_KEY not set, skipping Cohere embedding provider configuration"
)

View File

@@ -26,5 +26,4 @@ lxml==5.3.0
lxml_html_clean==0.2.2
boto3-stubs[s3]==1.34.133
pandas==2.2.3
pandas-stubs==2.2.3.241009
cohere==5.6.1
pandas-stubs==2.2.3.241009

View File

@@ -1,2 +1 @@
python3-saml==1.15.0
cohere==5.6.1
python3-saml==1.15.0

View File

@@ -42,7 +42,7 @@ def run_jobs() -> None:
"--loglevel=INFO",
"--hostname=light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
"vespa_metadata_sync,connector_deletion",
]
cmd_worker_heavy = [
@@ -56,7 +56,7 @@ def run_jobs() -> None:
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync",
"connector_pruning",
]
cmd_worker_indexing = [

View File

@@ -142,20 +142,6 @@ async def async_return_default_schema(*args: Any, **kwargs: Any) -> str:
# Prefix used for all tenant ids
TENANT_ID_PREFIX = "tenant_"
DISALLOWED_SLACK_BOT_TENANT_IDS = os.environ.get("DISALLOWED_SLACK_BOT_TENANT_IDS")
DISALLOWED_SLACK_BOT_TENANT_LIST = (
[tenant.strip() for tenant in DISALLOWED_SLACK_BOT_TENANT_IDS.split(",")]
if DISALLOWED_SLACK_BOT_TENANT_IDS
else None
)
IGNORED_SYNCING_TENANT_IDS = os.environ.get("IGNORED_SYNCING_TENANT_IDS")
IGNORED_SYNCING_TENANT_LIST = (
[tenant.strip() for tenant in IGNORED_SYNCING_TENANT_IDS.split(",")]
if IGNORED_SYNCING_TENANT_IDS
else None
)
SUPPORTED_EMBEDDING_MODELS = [
# Cloud-based models
SupportedEmbeddingModel(

View File

@@ -33,7 +33,7 @@ stopasgroup=true
command=celery -A danswer.background.celery.versioned_apps.light worker
--loglevel=INFO
--hostname=light@%%n
-Q vespa_metadata_sync,connector_deletion,doc_permissions_upsert
-Q vespa_metadata_sync,connector_deletion
stdout_logfile=/var/log/celery_worker_light.log
stdout_logfile_maxbytes=16MB
redirect_stderr=true
@@ -45,7 +45,7 @@ stopasgroup=true
command=celery -A danswer.background.celery.versioned_apps.heavy worker
--loglevel=INFO
--hostname=heavy@%%n
-Q connector_pruning,connector_doc_permissions_sync,connector_external_group_sync
-Q connector_pruning
stdout_logfile=/var/log/celery_worker_heavy.log
stdout_logfile_maxbytes=16MB
redirect_stderr=true

View File

@@ -39,39 +39,24 @@ def test_confluence_connector_basic(
with pytest.raises(StopIteration):
next(doc_batch_generator)
assert len(doc_batch) == 3
assert len(doc_batch) == 2
for doc in doc_batch:
if doc.semantic_identifier == "DailyConnectorTestSpace Home":
page_doc = doc
elif ".txt" in doc.semantic_identifier:
txt_doc = doc
elif doc.semantic_identifier == "Page Within A Page":
page_within_a_page_doc = doc
assert page_within_a_page_doc.semantic_identifier == "Page Within A Page"
assert page_within_a_page_doc.primary_owners
assert page_within_a_page_doc.primary_owners[0].email == "hagen@danswer.ai"
assert len(page_within_a_page_doc.sections) == 1
page_within_a_page_section = page_within_a_page_doc.sections[0]
page_within_a_page_text = "@Chris Weaver loves cherry pie"
assert page_within_a_page_section.text == page_within_a_page_text
assert (
page_within_a_page_section.link
== "https://danswerai.atlassian.net/wiki/spaces/DailyConne/pages/200769540/Page+Within+A+Page"
)
assert page_doc.semantic_identifier == "DailyConnectorTestSpace Home"
assert page_doc.metadata["labels"] == ["testlabel"]
assert page_doc.primary_owners
assert page_doc.primary_owners[0].email == "hagen@danswer.ai"
assert page_doc.primary_owners[0].email == "chris@danswer.ai"
assert len(page_doc.sections) == 1
page_section = page_doc.sections[0]
assert page_section.text == "test123 " + page_within_a_page_text
section = page_doc.sections[0]
assert section.text == "test123"
assert (
page_section.link
section.link
== "https://danswerai.atlassian.net/wiki/spaces/DailyConne/overview"
)

View File

@@ -8,11 +8,11 @@ import requests
from danswer.connectors.models import InputType
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.server.documents.models import CCPairFullInfo
from danswer.db.enums import TaskStatus
from danswer.server.documents.models import CeleryTaskStatus
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.server.documents.models import ConnectorIndexingStatus
from danswer.server.documents.models import DocumentSource
from danswer.server.documents.models import DocumentSyncStatus
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import MAX_DELAY
@@ -73,7 +73,7 @@ class CCPairManager:
source=source,
input_type=input_type,
connector_specific_config=connector_specific_config,
access_type=access_type,
is_public=(access_type == AccessType.PUBLIC),
groups=groups,
user_performing_action=user_performing_action,
)
@@ -147,22 +147,7 @@ class CCPairManager:
result.raise_for_status()
@staticmethod
def get_single(
cc_pair_id: int,
user_performing_action: DATestUser | None = None,
) -> CCPairFullInfo | None:
response = requests.get(
f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
cc_pair_json = response.json()
return CCPairFullInfo(**cc_pair_json)
@staticmethod
def get_indexing_status_by_id(
def get_one(
cc_pair_id: int,
user_performing_action: DATestUser | None = None,
) -> ConnectorIndexingStatus | None:
@@ -181,7 +166,7 @@ class CCPairManager:
return None
@staticmethod
def get_indexing_statuses(
def get_all(
user_performing_action: DATestUser | None = None,
) -> list[ConnectorIndexingStatus]:
response = requests.get(
@@ -199,7 +184,7 @@ class CCPairManager:
verify_deleted: bool = False,
user_performing_action: DATestUser | None = None,
) -> None:
all_cc_pairs = CCPairManager.get_indexing_statuses(user_performing_action)
all_cc_pairs = CCPairManager.get_all(user_performing_action)
for retrieved_cc_pair in all_cc_pairs:
if retrieved_cc_pair.cc_pair_id == cc_pair.id:
if verify_deleted:
@@ -249,9 +234,7 @@ class CCPairManager:
"""after: Wait for an indexing success time after this time"""
start = time.monotonic()
while True:
fetched_cc_pairs = CCPairManager.get_indexing_statuses(
user_performing_action
)
fetched_cc_pairs = CCPairManager.get_all(user_performing_action)
for fetched_cc_pair in fetched_cc_pairs:
if fetched_cc_pair.cc_pair_id != cc_pair.id:
continue
@@ -344,133 +327,57 @@ class CCPairManager:
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
) -> None:
"""This function triggers a permission sync.
Naming / intent of this function probably could use improvement, but currently it's letting
409 Conflict pass through since if it's running that's what we were trying to do anyway.
"""
result = requests.post(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync-permissions",
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
#
if result.status_code != 409:
result.raise_for_status()
result.raise_for_status()
@staticmethod
def get_sync_task(
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
) -> datetime | None:
) -> CeleryTaskStatus:
response = requests.get(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync-permissions",
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
response_str = response.json()
# If the response itself is a datetime string, parse it
if not isinstance(response_str, str):
return None
try:
return datetime.fromisoformat(response_str)
except ValueError:
return None
@staticmethod
def get_doc_sync_statuses(
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
) -> list[DocumentSyncStatus]:
response = requests.get(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/get-docs-sync-status",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
doc_sync_statuses: list[DocumentSyncStatus] = []
for doc_sync_status in response.json():
doc_sync_statuses.append(
DocumentSyncStatus(
doc_id=doc_sync_status["doc_id"],
last_synced=datetime.fromisoformat(doc_sync_status["last_synced"]),
last_modified=datetime.fromisoformat(
doc_sync_status["last_modified"]
),
)
)
return doc_sync_statuses
return CeleryTaskStatus(**response.json())
@staticmethod
def wait_for_sync(
cc_pair: DATestCCPair,
after: datetime,
timeout: float = MAX_DELAY,
number_of_updated_docs: int = 0,
user_performing_action: DATestUser | None = None,
) -> None:
"""after: The task register time must be after this time."""
start = time.monotonic()
while True:
last_synced = CCPairManager.get_sync_task(cc_pair, user_performing_action)
if last_synced and last_synced > after:
print(f"last_synced: {last_synced}")
print(f"sync command start time: {after}")
print(f"permission sync complete: cc_pair={cc_pair.id}")
break
task = CCPairManager.get_sync_task(cc_pair, user_performing_action)
if not task:
raise ValueError("Sync task not found.")
if not task.register_time or task.register_time < after:
raise ValueError("Sync task register time is too early.")
if task.status == TaskStatus.SUCCESS:
# Sync succeeded
return
elapsed = time.monotonic() - start
if elapsed > timeout:
raise TimeoutError(
f"Permission sync was not completed within {timeout} seconds"
f"CC pair syncing was not completed within {timeout} seconds"
)
print(
f"Waiting for CC sync to complete. elapsed={elapsed:.2f} timeout={timeout}"
)
time.sleep(5)
# TODO: remove this sleep,
# this shouldnt be necessary but something is off with the timing for the sync jobs
time.sleep(5)
print("waiting for vespa sync")
# wait for the vespa sync to complete once the permission sync is complete
start = time.monotonic()
while True:
doc_sync_statuses = CCPairManager.get_doc_sync_statuses(
cc_pair=cc_pair,
user_performing_action=user_performing_action,
)
synced_docs = 0
for doc_sync_status in doc_sync_statuses:
if (
doc_sync_status.last_synced is not None
and doc_sync_status.last_modified is not None
and doc_sync_status.last_synced >= doc_sync_status.last_modified
and doc_sync_status.last_synced >= after
and doc_sync_status.last_modified >= after
):
synced_docs += 1
if synced_docs >= number_of_updated_docs:
print(f"all docs synced: cc_pair={cc_pair.id}")
break
elapsed = time.monotonic() - start
if elapsed > timeout:
raise TimeoutError(
f"Vespa sync was not completed within {timeout} seconds"
)
print(
f"Waiting for vespa sync to complete. elapsed={elapsed:.2f} timeout={timeout}"
f"Waiting for CC syncing to complete. elapsed={elapsed:.2f} timeout={timeout}"
)
time.sleep(5)
@@ -485,7 +392,7 @@ class CCPairManager:
cc_pair_id is good to do."""
start = time.monotonic()
while True:
cc_pairs = CCPairManager.get_indexing_statuses(user_performing_action)
cc_pairs = CCPairManager.get_all(user_performing_action)
if cc_pair_id:
found = False
for cc_pair in cc_pairs:

View File

@@ -4,7 +4,6 @@ from uuid import uuid4
import requests
from danswer.connectors.models import InputType
from danswer.db.enums import AccessType
from danswer.server.documents.models import ConnectorUpdateRequest
from danswer.server.documents.models import DocumentSource
from tests.integration.common_utils.constants import API_SERVER_URL
@@ -20,7 +19,7 @@ class ConnectorManager:
source: DocumentSource = DocumentSource.FILE,
input_type: InputType = InputType.LOAD_STATE,
connector_specific_config: dict[str, Any] | None = None,
access_type: AccessType = AccessType.PUBLIC,
is_public: bool = True,
groups: list[int] | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestConnector:
@@ -31,7 +30,7 @@ class ConnectorManager:
source=source,
input_type=input_type,
connector_specific_config=connector_specific_config or {},
access_type=access_type,
is_public=is_public,
groups=groups or [],
)
@@ -52,7 +51,7 @@ class ConnectorManager:
input_type=input_type,
connector_specific_config=connector_specific_config or {},
groups=groups,
access_type=access_type,
is_public=is_public,
)
@staticmethod

View File

@@ -28,12 +28,10 @@ class TenantManager:
def create(
tenant_id: str | None = None,
initial_admin_email: str | None = None,
referral_source: str | None = None,
) -> dict[str, str]:
body = {
"tenant_id": tenant_id,
"initial_admin_email": initial_admin_email,
"referral_source": referral_source,
}
token = generate_auth_token()

View File

@@ -55,7 +55,7 @@ class DATestConnector(BaseModel):
input_type: InputType
connector_specific_config: dict[str, Any]
groups: list[int] | None = None
access_type: AccessType | None = None
is_public: bool | None = None
class SimpleTestDocument(BaseModel):

View File

@@ -6,10 +6,6 @@ import pytest
from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager
# from tests.load_env_vars import load_env_vars
# load_env_vars()
@pytest.fixture()
def slack_test_setup() -> Generator[tuple[dict[str, Any], dict[str, Any]], None, None]:

View File

@@ -67,7 +67,7 @@ def test_slack_permission_sync(
"workspace": "onyx-test-workspace",
"channels": [public_channel["name"], private_channel["name"]],
},
access_type=AccessType.SYNC,
is_public=True,
groups=[],
user_performing_action=admin_user,
)
@@ -96,13 +96,11 @@ def test_slack_permission_sync(
private_message = "Sara's favorite number is 346794"
# Add messages to channels
print(f"\n Adding public message to channel: {public_message}")
SlackManager.add_message_to_channel(
slack_client=slack_client,
channel=public_channel,
message=public_message,
)
print(f"\n Adding private message to channel: {private_message}")
SlackManager.add_message_to_channel(
slack_client=slack_client,
channel=private_channel,
@@ -119,6 +117,7 @@ def test_slack_permission_sync(
)
# Run permission sync
before = datetime.now(timezone.utc)
CCPairManager.sync(
cc_pair=cc_pair,
user_performing_action=admin_user,
@@ -126,33 +125,26 @@ def test_slack_permission_sync(
CCPairManager.wait_for_sync(
cc_pair=cc_pair,
after=before,
number_of_updated_docs=2,
user_performing_action=admin_user,
)
# Search as admin with access to both channels
print("\nSearching as admin user")
danswer_doc_message_strings = DocumentSearchManager.search_documents(
query="favorite number",
user_performing_action=admin_user,
)
print(
"\n documents retrieved by admin user: ",
danswer_doc_message_strings,
)
# Ensure admin user can see messages from both channels
assert public_message in danswer_doc_message_strings
assert private_message in danswer_doc_message_strings
# Search as test_user_2 with access to only the public channel
print("\n Searching as test_user_2")
danswer_doc_message_strings = DocumentSearchManager.search_documents(
query="favorite number",
user_performing_action=test_user_2,
)
print(
"\n documents retrieved by test_user_2: ",
"\ntop_documents content before removing from private channel for test_user_2: ",
danswer_doc_message_strings,
)
@@ -161,13 +153,12 @@ def test_slack_permission_sync(
assert private_message not in danswer_doc_message_strings
# Search as test_user_1 with access to both channels
print("\n Searching as test_user_1")
danswer_doc_message_strings = DocumentSearchManager.search_documents(
query="favorite number",
user_performing_action=test_user_1,
)
print(
"\n documents retrieved by test_user_1 before being removed from private channel: ",
"\ntop_documents content before removing from private channel for test_user_1: ",
danswer_doc_message_strings,
)
@@ -176,8 +167,7 @@ def test_slack_permission_sync(
assert private_message in danswer_doc_message_strings
# ----------------------MAKE THE CHANGES--------------------------
print("\n Removing test_user_1 from the private channel")
before = datetime.now(timezone.utc)
print("\nRemoving test_user_1 from the private channel")
# Remove test_user_1 from the private channel
desired_channel_members = [admin_user]
SlackManager.set_channel_members(
@@ -195,20 +185,18 @@ def test_slack_permission_sync(
CCPairManager.wait_for_sync(
cc_pair=cc_pair,
after=before,
number_of_updated_docs=1,
user_performing_action=admin_user,
)
# ----------------------------VERIFY THE CHANGES---------------------------
# Ensure test_user_1 can no longer see messages from the private channel
# Search as test_user_1 with access to only the public channel
danswer_doc_message_strings = DocumentSearchManager.search_documents(
query="favorite number",
user_performing_action=test_user_1,
)
print(
"\n documents retrieved by test_user_1 after being removed from private channel: ",
"\ntop_documents content after removing from private channel for test_user_1: ",
danswer_doc_message_strings,
)

View File

@@ -3,8 +3,6 @@ from datetime import datetime
from datetime import timezone
from typing import Any
import pytest
from danswer.connectors.models import InputType
from danswer.db.enums import AccessType
from danswer.server.documents.models import DocumentSource
@@ -24,7 +22,7 @@ from tests.integration.common_utils.vespa import vespa_fixture
from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager
@pytest.mark.xfail(reason="flaky - see DAN-986 for details", strict=False)
# @pytest.mark.xfail(reason="flaky - see DAN-835 for example", strict=False)
def test_slack_prune(
reset: None,
vespa_client: vespa_fixture,
@@ -64,7 +62,7 @@ def test_slack_prune(
"workspace": "onyx-test-workspace",
"channels": [public_channel["name"], private_channel["name"]],
},
access_type=AccessType.PUBLIC,
is_public=True,
groups=[],
user_performing_action=admin_user,
)

View File

@@ -14,12 +14,12 @@ from tests.integration.common_utils.test_models import DATestUser
def test_multi_tenant_access_control(reset_multitenant: None) -> None:
# Create Tenant 1 and its Admin User
TenantManager.create("tenant_dev1", "test1@test.com", "Data Plane Registration")
TenantManager.create("tenant_dev1", "test1@test.com")
test_user1: DATestUser = UserManager.create(name="test1", email="test1@test.com")
assert UserManager.verify_role(test_user1, UserRole.ADMIN)
# Create Tenant 2 and its Admin User
TenantManager.create("tenant_dev2", "test2@test.com", "Data Plane Registration")
TenantManager.create("tenant_dev2", "test2@test.com")
test_user2: DATestUser = UserManager.create(name="test2", email="test2@test.com")
assert UserManager.verify_role(test_user2, UserRole.ADMIN)

View File

@@ -11,7 +11,7 @@ from tests.integration.common_utils.test_models import DATestUser
# Test flow from creating tenant to registering as a user
def test_tenant_creation(reset_multitenant: None) -> None:
TenantManager.create("tenant_dev", "test@test.com", "Data Plane Registration")
TenantManager.create("tenant_dev", "test@test.com")
test_user: DATestUser = UserManager.create(name="test", email="test@test.com")
assert UserManager.verify_role(test_user, UserRole.ADMIN)
@@ -26,7 +26,7 @@ def test_tenant_creation(reset_multitenant: None) -> None:
test_connector = ConnectorManager.create(
name="admin_test_connector",
source=DocumentSource.FILE,
access_type=AccessType.PRIVATE,
is_public=False,
user_performing_action=test_user,
)

Some files were not shown because too many files have changed in this diff Show More