mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-21 17:55:45 +00:00
Compare commits
14 Commits
dropdown
...
cohere_def
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
72730a5ba3 | ||
|
|
df8bd6daf4 | ||
|
|
6b78ab0a99 | ||
|
|
e97bf1d4e2 | ||
|
|
293dbfb8eb | ||
|
|
f4a61202a7 | ||
|
|
53f9d94ceb | ||
|
|
5058d898b8 | ||
|
|
bc7de4ec1b | ||
|
|
3ad98078f5 | ||
|
|
0fb12b42f1 | ||
|
|
158329a3cc | ||
|
|
7f1a50823b | ||
|
|
0e76bcef45 |
3
.github/workflows/pr-Integration-tests.yml
vendored
3
.github/workflows/pr-Integration-tests.yml
vendored
@@ -197,8 +197,7 @@ jobs:
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
danswer/danswer-integration:test \
|
||||
/app/tests/integration/tests \
|
||||
/app/tests/integration/connector_job_tests
|
||||
/app/tests/integration/tests
|
||||
continue-on-error: true
|
||||
id: run_tests
|
||||
|
||||
|
||||
31
.github/workflows/pr-helm-chart-testing.yml
vendored
31
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -23,6 +23,21 @@ jobs:
|
||||
with:
|
||||
version: v3.14.4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
- run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@v2.6.1
|
||||
|
||||
@@ -37,22 +52,6 @@ jobs:
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
# rkuo: I don't think we need python?
|
||||
# - name: Set up Python
|
||||
# uses: actions/setup-python@v5
|
||||
# with:
|
||||
# python-version: '3.11'
|
||||
# cache: 'pip'
|
||||
# cache-dependency-path: |
|
||||
# backend/requirements/default.txt
|
||||
# backend/requirements/dev.txt
|
||||
# backend/requirements/model_server.txt
|
||||
# - run: |
|
||||
# python -m pip install --upgrade pip
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
# lint all charts if any changes were detected
|
||||
- name: Run chart-testing (lint)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
|
||||
4
.vscode/launch.template.jsonc
vendored
4
.vscode/launch.template.jsonc
vendored
@@ -203,7 +203,7 @@
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
|
||||
"vespa_metadata_sync,connector_deletion",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
@@ -232,7 +232,7 @@
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync",
|
||||
"connector_pruning",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
"""default chosen assistants to none
|
||||
|
||||
Revision ID: 26b931506ecb
|
||||
Revises: 2daa494a0851
|
||||
Create Date: 2024-11-12 13:23:29.858995
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "26b931506ecb"
|
||||
down_revision = "2daa494a0851"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user", sa.Column("chosen_assistants_new", postgresql.JSONB(), nullable=True)
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET chosen_assistants_new =
|
||||
CASE
|
||||
WHEN chosen_assistants = '[-2, -1, 0]' THEN NULL
|
||||
ELSE chosen_assistants
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
op.drop_column("user", "chosen_assistants")
|
||||
|
||||
op.alter_column(
|
||||
"user", "chosen_assistants_new", new_column_name="chosen_assistants"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"chosen_assistants_old",
|
||||
postgresql.JSONB(),
|
||||
nullable=False,
|
||||
server_default="[-2, -1, 0]",
|
||||
),
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET chosen_assistants_old =
|
||||
CASE
|
||||
WHEN chosen_assistants IS NULL THEN '[-2, -1, 0]'::jsonb
|
||||
ELSE chosen_assistants
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
op.drop_column("user", "chosen_assistants")
|
||||
|
||||
op.alter_column(
|
||||
"user", "chosen_assistants_old", new_column_name="chosen_assistants"
|
||||
)
|
||||
@@ -1,30 +0,0 @@
|
||||
"""add-group-sync-time
|
||||
|
||||
Revision ID: 2daa494a0851
|
||||
Revises: c0fd6e4da83a
|
||||
Create Date: 2024-11-11 10:57:22.991157
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2daa494a0851"
|
||||
down_revision = "c0fd6e4da83a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"last_time_external_group_sync",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector_credential_pair", "last_time_external_group_sync")
|
||||
@@ -1,30 +0,0 @@
|
||||
"""add creator to cc pair
|
||||
|
||||
Revision ID: 9cf5c00f72fe
|
||||
Revises: c0fd6e4da83a
|
||||
Create Date: 2024-11-12 15:16:42.682902
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9cf5c00f72fe"
|
||||
down_revision = "26b931506ecb"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"creator_id",
|
||||
sa.UUID(as_uuid=True),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector_credential_pair", "creator_id")
|
||||
@@ -16,41 +16,6 @@ class ExternalAccess:
|
||||
is_public: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocExternalAccess:
|
||||
external_access: ExternalAccess
|
||||
# The document ID
|
||||
doc_id: str
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"external_access": {
|
||||
"external_user_emails": list(self.external_access.external_user_emails),
|
||||
"external_user_group_ids": list(
|
||||
self.external_access.external_user_group_ids
|
||||
),
|
||||
"is_public": self.external_access.is_public,
|
||||
},
|
||||
"doc_id": self.doc_id,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "DocExternalAccess":
|
||||
external_access = ExternalAccess(
|
||||
external_user_emails=set(
|
||||
data["external_access"].get("external_user_emails", [])
|
||||
),
|
||||
external_user_group_ids=set(
|
||||
data["external_access"].get("external_user_group_ids", [])
|
||||
),
|
||||
is_public=data["external_access"]["is_public"],
|
||||
)
|
||||
return cls(
|
||||
external_access=external_access,
|
||||
doc_id=data["doc_id"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocumentAccess(ExternalAccess):
|
||||
# User emails for Danswer users, None indicates admin
|
||||
|
||||
@@ -15,7 +15,6 @@ class UserRole(str, Enum):
|
||||
for all groups they are a member of
|
||||
"""
|
||||
|
||||
LIMITED = "limited"
|
||||
BASIC = "basic"
|
||||
ADMIN = "admin"
|
||||
CURATOR = "curator"
|
||||
|
||||
@@ -228,17 +228,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
safe: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
) -> User:
|
||||
referral_source = None
|
||||
if request is not None:
|
||||
referral_source = request.cookies.get("referral_source", None)
|
||||
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=user_create.email,
|
||||
referral_source=referral_source,
|
||||
)
|
||||
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -299,17 +294,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
) -> models.UOAP:
|
||||
referral_source = None
|
||||
if request:
|
||||
referral_source = getattr(request.state, "referral_source", None)
|
||||
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=account_email,
|
||||
referral_source=referral_source,
|
||||
)
|
||||
|
||||
if not tenant_id:
|
||||
@@ -662,24 +652,10 @@ async def current_user_with_expired_token(
|
||||
return await double_check_user(user, include_expired=True)
|
||||
|
||||
|
||||
async def current_limited_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
return await double_check_user(user)
|
||||
|
||||
|
||||
async def current_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
user = await double_check_user(user)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
if user.role == UserRole.LIMITED:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User role is LIMITED. BASIC or higher permissions are required.",
|
||||
)
|
||||
return user
|
||||
return await double_check_user(user)
|
||||
|
||||
|
||||
async def current_curator_or_admin_user(
|
||||
@@ -735,6 +711,8 @@ def generate_state_token(
|
||||
|
||||
|
||||
# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91
|
||||
|
||||
|
||||
def create_danswer_oauth_router(
|
||||
oauth_client: BaseOAuth2,
|
||||
backend: AuthenticationBackend,
|
||||
@@ -784,22 +762,15 @@ def get_oauth_router(
|
||||
response_model=OAuth2AuthorizeResponse,
|
||||
)
|
||||
async def authorize(
|
||||
request: Request,
|
||||
scopes: List[str] = Query(None),
|
||||
request: Request, scopes: List[str] = Query(None)
|
||||
) -> OAuth2AuthorizeResponse:
|
||||
referral_source = request.cookies.get("referral_source", None)
|
||||
|
||||
if redirect_url is not None:
|
||||
authorize_redirect_url = redirect_url
|
||||
else:
|
||||
authorize_redirect_url = str(request.url_for(callback_route_name))
|
||||
|
||||
next_url = request.query_params.get("next", "/")
|
||||
|
||||
state_data: Dict[str, str] = {
|
||||
"next_url": next_url,
|
||||
"referral_source": referral_source or "default_referral",
|
||||
}
|
||||
state_data: Dict[str, str] = {"next_url": next_url}
|
||||
state = generate_state_token(state_data, state_secret)
|
||||
authorization_url = await oauth_client.get_authorization_url(
|
||||
authorize_redirect_url,
|
||||
@@ -858,11 +829,8 @@ def get_oauth_router(
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
next_url = state_data.get("next_url", "/")
|
||||
referral_source = state_data.get("referral_source", None)
|
||||
|
||||
request.state.referral_source = referral_source
|
||||
|
||||
# Proceed to authenticate or create the user
|
||||
# Authenticate user
|
||||
try:
|
||||
user = await user_manager.oauth_callback(
|
||||
oauth_client.name,
|
||||
@@ -904,6 +872,7 @@ def get_oauth_router(
|
||||
redirect_response.status_code = response.status_code
|
||||
if hasattr(response, "media_type"):
|
||||
redirect_response.media_type = response.media_type
|
||||
|
||||
return redirect_response
|
||||
|
||||
return router
|
||||
|
||||
@@ -24,8 +24,6 @@ from danswer.document_index.vespa_constants import VESPA_CONFIG_SERVER_URL
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
|
||||
from danswer.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from danswer.redis.redis_document_set import RedisDocumentSet
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
@@ -138,22 +136,6 @@ def on_task_postrun(
|
||||
RedisConnectorPrune.remove_from_taskset(int(cc_pair_id), task_id, r)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorPermissionSync.SUBTASK_PREFIX):
|
||||
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
RedisConnectorPermissionSync.remove_from_taskset(
|
||||
int(cc_pair_id), task_id, r
|
||||
)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorExternalGroupSync.SUBTASK_PREFIX):
|
||||
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
RedisConnectorExternalGroupSync.remove_from_taskset(
|
||||
int(cc_pair_id), task_id, r
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
"""The first signal sent on celery worker startup"""
|
||||
|
||||
@@ -91,7 +91,5 @@ def on_setup_logging(
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"danswer.background.celery.tasks.pruning",
|
||||
"danswer.background.celery.tasks.doc_permission_syncing",
|
||||
"danswer.background.celery.tasks.external_group_syncing",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -60,7 +60,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=sender.concurrency)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -92,6 +92,5 @@ celery_app.autodiscover_tasks(
|
||||
"danswer.background.celery.tasks.shared",
|
||||
"danswer.background.celery.tasks.vespa",
|
||||
"danswer.background.celery.tasks.connector_deletion",
|
||||
"danswer.background.celery.tasks.doc_permission_syncing",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -20,8 +20,6 @@ from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndex
|
||||
from danswer.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from danswer.redis.redis_connector_stop import RedisConnectorStop
|
||||
@@ -136,10 +134,6 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
|
||||
RedisConnectorStop.reset_all(r)
|
||||
|
||||
RedisConnectorPermissionSync.reset_all(r)
|
||||
|
||||
RedisConnectorExternalGroupSync.reset_all(r)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
@@ -239,8 +233,6 @@ celery_app.autodiscover_tasks(
|
||||
"danswer.background.celery.tasks.connector_deletion",
|
||||
"danswer.background.celery.tasks.indexing",
|
||||
"danswer.background.celery.tasks.periodic",
|
||||
"danswer.background.celery.tasks.doc_permission_syncing",
|
||||
"danswer.background.celery.tasks.external_group_syncing",
|
||||
"danswer.background.celery.tasks.pruning",
|
||||
"danswer.background.celery.tasks.shared",
|
||||
"danswer.background.celery.tasks.vespa",
|
||||
|
||||
@@ -81,7 +81,7 @@ def extract_ids_from_runnable_connector(
|
||||
callback: RunIndexingCallbackInterface | None = None,
|
||||
) -> set[str]:
|
||||
"""
|
||||
If the SlimConnector hasnt been implemented for the given connector, just pull
|
||||
If the PruneConnector hasnt been implemented for the given connector, just pull
|
||||
all docs using the load_from_state and grab out the IDs.
|
||||
|
||||
Optionally, a callback can be passed to handle the length of each document batch.
|
||||
|
||||
@@ -41,18 +41,6 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-doc-permissions-sync",
|
||||
"task": "check_for_doc_permissions_sync",
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-external-group-sync",
|
||||
"task": "check_for_external_group_sync",
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -143,12 +143,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
f"cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
if redis_connector.permissions.fenced:
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (permissions in progress): "
|
||||
f"cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
redis_connector.delete.taskset_clear()
|
||||
|
||||
|
||||
@@ -1,321 +0,0 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
|
||||
from danswer.access.models import DocExternalAccess
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_doc_perm_sync import (
|
||||
RedisConnectorPermissionSyncData,
|
||||
)
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import doc_permission_sync_ctx
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.danswer.db.document import upsert_document_external_perms
|
||||
from ee.danswer.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS
|
||||
from ee.danswer.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES = 3
|
||||
|
||||
|
||||
# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT
|
||||
LIGHT_SOFT_TIME_LIMIT = 105
|
||||
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
|
||||
|
||||
|
||||
def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
"""Returns boolean indicating if external doc permissions sync is due."""
|
||||
|
||||
if cc_pair.access_type != AccessType.SYNC:
|
||||
return False
|
||||
|
||||
# skip doc permissions sync if not active
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
||||
return False
|
||||
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return False
|
||||
|
||||
# If the last sync is None, it has never been run so we run the sync
|
||||
last_perm_sync = cc_pair.last_time_perm_sync
|
||||
if last_perm_sync is None:
|
||||
return True
|
||||
|
||||
source_sync_period = DOC_PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
|
||||
|
||||
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
|
||||
if not source_sync_period:
|
||||
return True
|
||||
|
||||
# If the last sync is greater than the full fetch period, we run the sync
|
||||
next_sync = last_perm_sync + timedelta(seconds=source_sync_period)
|
||||
if datetime.now(timezone.utc) >= next_sync:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_doc_permissions_sync",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
# get all cc pairs that need to be synced
|
||||
cc_pair_ids_to_sync: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
|
||||
for cc_pair in cc_pairs:
|
||||
if _is_external_doc_permissions_sync_due(cc_pair):
|
||||
cc_pair_ids_to_sync.append(cc_pair.id)
|
||||
|
||||
for cc_pair_id in cc_pair_ids_to_sync:
|
||||
tasks_created = try_creating_permissions_sync_task(
|
||||
self.app, cc_pair_id, r, tenant_id
|
||||
)
|
||||
if not tasks_created:
|
||||
continue
|
||||
|
||||
task_logger.info(f"Doc permissions sync queued: cc_pair={cc_pair_id}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def try_creating_permissions_sync_task(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Returns None if no syncing is required."""
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
lock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
|
||||
if not acquired:
|
||||
return None
|
||||
|
||||
try:
|
||||
if redis_connector.permissions.fenced:
|
||||
return None
|
||||
|
||||
if redis_connector.delete.fenced:
|
||||
return None
|
||||
|
||||
if redis_connector.prune.fenced:
|
||||
return None
|
||||
|
||||
redis_connector.permissions.generator_clear()
|
||||
redis_connector.permissions.taskset_clear()
|
||||
|
||||
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
|
||||
|
||||
app.send_task(
|
||||
"connector_permission_sync_generator_task",
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair_id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.HIGH,
|
||||
)
|
||||
|
||||
# set a basic fence to start
|
||||
payload = RedisConnectorPermissionSyncData(
|
||||
started=None,
|
||||
)
|
||||
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return 1
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="connector_permission_sync_generator_task",
|
||||
acks_late=False,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def connector_permission_sync_generator_task(
|
||||
self: Task,
|
||||
cc_pair_id: int,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
Permission sync task that handles document permission syncing for a given connector credential pair
|
||||
This task assumes that the task has already been properly fenced
|
||||
"""
|
||||
|
||||
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
|
||||
doc_permission_sync_ctx_dict["cc_pair_id"] = cc_pair_id
|
||||
doc_permission_sync_ctx_dict["request_id"] = self.request.id
|
||||
doc_permission_sync_ctx.set(doc_permission_sync_ctx_dict)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
f"Permission sync task already running, exiting...: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if cc_pair is None:
|
||||
raise ValueError(
|
||||
f"No connector credential pair found for id: {cc_pair_id}"
|
||||
)
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
|
||||
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
if doc_sync_func is None:
|
||||
raise ValueError(f"No doc sync func found for {source_type}")
|
||||
|
||||
logger.info(f"Syncing docs for {source_type}")
|
||||
|
||||
payload = RedisConnectorPermissionSyncData(
|
||||
started=datetime.now(timezone.utc),
|
||||
)
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
|
||||
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
)
|
||||
tasks_generated = redis_connector.permissions.generate_tasks(
|
||||
self.app, lock, document_external_accesses, source_type
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.permissions.generate_tasks finished. "
|
||||
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
redis_connector.permissions.generator_complete = tasks_generated
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(f"Failed to run permission sync: cc_pair={cc_pair_id}")
|
||||
|
||||
redis_connector.permissions.generator_clear()
|
||||
redis_connector.permissions.taskset_clear()
|
||||
redis_connector.permissions.set_fence(None)
|
||||
raise e
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="update_external_document_permissions_task",
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES,
|
||||
bind=True,
|
||||
)
|
||||
def update_external_document_permissions_task(
|
||||
self: Task,
|
||||
tenant_id: str | None,
|
||||
serialized_doc_external_access: dict,
|
||||
source_string: str,
|
||||
) -> bool:
|
||||
document_external_access = DocExternalAccess.from_dict(
|
||||
serialized_doc_external_access
|
||||
)
|
||||
doc_id = document_external_access.doc_id
|
||||
external_access = document_external_access.external_access
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Then we build the update requests to update vespa
|
||||
batch_add_non_web_user_if_not_exists(
|
||||
db_session=db_session,
|
||||
emails=list(external_access.external_user_emails),
|
||||
)
|
||||
upsert_document_external_perms(
|
||||
db_session=db_session,
|
||||
doc_id=doc_id,
|
||||
external_access=external_access,
|
||||
source_type=DocumentSource(source_string),
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Successfully synced postgres document permissions for {doc_id}"
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Error Syncing Document Permissions")
|
||||
return False
|
||||
@@ -1,265 +0,0 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector import mark_cc_pair_as_external_group_synced
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.danswer.db.external_perm import ExternalUserGroup
|
||||
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair
|
||||
from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIOD
|
||||
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
EXTERNAL_GROUPS_UPDATE_MAX_RETRIES = 3
|
||||
|
||||
|
||||
# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT
|
||||
LIGHT_SOFT_TIME_LIMIT = 105
|
||||
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
|
||||
|
||||
|
||||
def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
"""Returns boolean indicating if external group sync is due."""
|
||||
|
||||
if cc_pair.access_type != AccessType.SYNC:
|
||||
return False
|
||||
|
||||
# skip pruning if not active
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
||||
return False
|
||||
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return False
|
||||
|
||||
# If there is not group sync function for the connector, we don't run the sync
|
||||
# This is fine because all sources dont necessarily have a concept of groups
|
||||
if not GROUP_PERMISSIONS_FUNC_MAP.get(cc_pair.connector.source):
|
||||
return False
|
||||
|
||||
# If the last sync is None, it has never been run so we run the sync
|
||||
last_ext_group_sync = cc_pair.last_time_external_group_sync
|
||||
if last_ext_group_sync is None:
|
||||
return True
|
||||
|
||||
source_sync_period = EXTERNAL_GROUP_SYNC_PERIOD
|
||||
|
||||
# If EXTERNAL_GROUP_SYNC_PERIOD is None, we always run the sync.
|
||||
if not source_sync_period:
|
||||
return True
|
||||
|
||||
# If the last sync is greater than the full fetch period, we run the sync
|
||||
next_sync = last_ext_group_sync + timedelta(seconds=source_sync_period)
|
||||
if datetime.now(timezone.utc) >= next_sync:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_external_group_sync",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
cc_pair_ids_to_sync: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
|
||||
for cc_pair in cc_pairs:
|
||||
if _is_external_group_sync_due(cc_pair):
|
||||
cc_pair_ids_to_sync.append(cc_pair.id)
|
||||
|
||||
for cc_pair_id in cc_pair_ids_to_sync:
|
||||
tasks_created = try_creating_permissions_sync_task(
|
||||
self.app, cc_pair_id, r, tenant_id
|
||||
)
|
||||
if not tasks_created:
|
||||
continue
|
||||
|
||||
task_logger.info(f"External group sync queued: cc_pair={cc_pair_id}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def try_creating_permissions_sync_task(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Returns None if no syncing is required."""
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
lock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_external_group_sync_tasks",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
|
||||
if not acquired:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Dont kick off a new sync if the previous one is still running
|
||||
if redis_connector.external_group_sync.fenced:
|
||||
return None
|
||||
|
||||
redis_connector.external_group_sync.generator_clear()
|
||||
redis_connector.external_group_sync.taskset_clear()
|
||||
|
||||
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
|
||||
|
||||
_ = app.send_task(
|
||||
"connector_external_group_sync_generator_task",
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair_id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.HIGH,
|
||||
)
|
||||
# set a basic fence to start
|
||||
redis_connector.external_group_sync.set_fence(True)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return 1
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="connector_external_group_sync_generator_task",
|
||||
acks_late=False,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def connector_external_group_sync_generator_task(
|
||||
self: Task,
|
||||
cc_pair_id: int,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
Permission sync task that handles document permission syncing for a given connector credential pair
|
||||
This task assumes that the task has already been properly fenced
|
||||
"""
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if cc_pair is None:
|
||||
raise ValueError(
|
||||
f"No connector credential pair found for id: {cc_pair_id}"
|
||||
)
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
|
||||
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
if ext_group_sync_func is None:
|
||||
raise ValueError(f"No external group sync func found for {source_type}")
|
||||
|
||||
logger.info(f"Syncing docs for {source_type}")
|
||||
|
||||
external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)
|
||||
|
||||
logger.info(
|
||||
f"Syncing {len(external_user_groups)} external user groups for {source_type}"
|
||||
)
|
||||
|
||||
replace_user__ext_group_for_cc_pair(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
group_defs=external_user_groups,
|
||||
source=cc_pair.connector.source,
|
||||
)
|
||||
logger.info(
|
||||
f"Synced {len(external_user_groups)} external user groups for {source_type}"
|
||||
)
|
||||
|
||||
mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Failed to run external group sync: cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
redis_connector.external_group_sync.generator_clear()
|
||||
redis_connector.external_group_sync.taskset_clear()
|
||||
raise e
|
||||
finally:
|
||||
# we always want to clear the fence after the task is done or failed so it doesn't get stuck
|
||||
redis_connector.external_group_sync.set_fence(False)
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
@@ -38,35 +38,6 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
"""Returns boolean indicating if pruning is due."""
|
||||
|
||||
# skip pruning if no prune frequency is set
|
||||
# pruning can still be forced via the API which will run a pruning task directly
|
||||
if not cc_pair.connector.prune_freq:
|
||||
return False
|
||||
|
||||
# skip pruning if not active
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
||||
return False
|
||||
|
||||
# skip pruning if the next scheduled prune time hasn't been reached yet
|
||||
last_pruned = cc_pair.last_pruned
|
||||
if not last_pruned:
|
||||
if not cc_pair.last_successful_index_time:
|
||||
# if we've never indexed, we can't prune
|
||||
return False
|
||||
|
||||
# if never pruned, use the last time the connector indexed successfully
|
||||
last_pruned = cc_pair.last_successful_index_time
|
||||
|
||||
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
|
||||
if datetime.now(timezone.utc) < next_prune:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_pruning",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
@@ -98,7 +69,7 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
|
||||
if not cc_pair:
|
||||
continue
|
||||
|
||||
if not _is_pruning_due(cc_pair):
|
||||
if not is_pruning_due(cc_pair, db_session, r):
|
||||
continue
|
||||
|
||||
tasks_created = try_creating_prune_generator_task(
|
||||
@@ -119,6 +90,47 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def is_pruning_due(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
) -> bool:
|
||||
"""Returns an int if pruning is triggered.
|
||||
The int represents the number of prune tasks generated (in this case, only one
|
||||
because the task is a long running generator task.)
|
||||
Returns None if no pruning is triggered (due to not being needed or
|
||||
other reasons such as simultaneous pruning restrictions.
|
||||
|
||||
Checks for scheduling related conditions, then delegates the rest of the checks to
|
||||
try_creating_prune_generator_task.
|
||||
"""
|
||||
|
||||
# skip pruning if no prune frequency is set
|
||||
# pruning can still be forced via the API which will run a pruning task directly
|
||||
if not cc_pair.connector.prune_freq:
|
||||
return False
|
||||
|
||||
# skip pruning if not active
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
||||
return False
|
||||
|
||||
# skip pruning if the next scheduled prune time hasn't been reached yet
|
||||
last_pruned = cc_pair.last_pruned
|
||||
if not last_pruned:
|
||||
if not cc_pair.last_successful_index_time:
|
||||
# if we've never indexed, we can't prune
|
||||
return False
|
||||
|
||||
# if never pruned, use the last time the connector indexed successfully
|
||||
last_pruned = cc_pair.last_successful_index_time
|
||||
|
||||
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
|
||||
if datetime.now(timezone.utc) < next_prune:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_prune_generator_task(
|
||||
celery_app: Celery,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
@@ -154,16 +166,10 @@ def try_creating_prune_generator_task(
|
||||
return None
|
||||
|
||||
try:
|
||||
# skip pruning if already pruning
|
||||
if redis_connector.prune.fenced:
|
||||
if redis_connector.prune.fenced: # skip pruning if already pruning
|
||||
return None
|
||||
|
||||
# skip pruning if the cc_pair is deleting
|
||||
if redis_connector.delete.fenced:
|
||||
return None
|
||||
|
||||
# skip pruning if doc permissions sync is running
|
||||
if redis_connector.permissions.fenced:
|
||||
if redis_connector.delete.fenced: # skip pruning if the cc_pair is deleting
|
||||
return None
|
||||
|
||||
db_session.refresh(cc_pair)
|
||||
|
||||
@@ -59,7 +59,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
connector / credential pair from the access list
|
||||
(6) delete all relevant entries from postgres
|
||||
"""
|
||||
task_logger.debug(f"Task start: tenant={tenant_id} doc={document_id}")
|
||||
task_logger.info(f"tenant={tenant_id} doc={document_id}")
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -141,9 +141,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
return False
|
||||
except Exception as ex:
|
||||
if isinstance(ex, RetryError):
|
||||
task_logger.warning(
|
||||
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
|
||||
)
|
||||
task_logger.info(f"Retry failed: {ex.last_attempt.attempt_number}")
|
||||
|
||||
# only set the inner exception if it is of type Exception
|
||||
e_temp = ex.last_attempt.exception()
|
||||
@@ -173,8 +171,8 @@ def document_by_cc_pair_cleanup_task(
|
||||
else:
|
||||
# This is the last attempt! mark the document as dirty in the db so that it
|
||||
# eventually gets fixed out of band via stale document reconciliation
|
||||
task_logger.warning(
|
||||
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
|
||||
task_logger.info(
|
||||
f"Max retries reached. Marking doc as dirty for reconciliation: "
|
||||
f"tenant={tenant_id} doc={document_id}"
|
||||
)
|
||||
with get_session_with_tenant(tenant_id):
|
||||
|
||||
@@ -27,7 +27,6 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector import mark_cc_pair_as_permissions_synced
|
||||
from danswer.db.connector import mark_ccpair_as_pruned
|
||||
from danswer.db.connector_credential_pair import add_deletion_failure_message
|
||||
from danswer.db.connector_credential_pair import (
|
||||
@@ -59,10 +58,6 @@ from danswer.document_index.interfaces import VespaDocumentFields
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from danswer.redis.redis_connector_doc_perm_sync import (
|
||||
RedisConnectorPermissionSyncData,
|
||||
)
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndex
|
||||
from danswer.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from danswer.redis.redis_document_set import RedisDocumentSet
|
||||
@@ -551,47 +546,6 @@ def monitor_ccpair_pruning_taskset(
|
||||
redis_connector.prune.set_fence(False)
|
||||
|
||||
|
||||
def monitor_ccpair_permissions_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"monitor_ccpair_permissions_taskset: could not parse cc_pair_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if not redis_connector.permissions.fenced:
|
||||
return
|
||||
|
||||
initial = redis_connector.permissions.generator_complete
|
||||
if initial is None:
|
||||
return
|
||||
|
||||
remaining = redis_connector.permissions.get_remaining()
|
||||
task_logger.info(
|
||||
f"Permissions sync progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
|
||||
)
|
||||
if remaining > 0:
|
||||
return
|
||||
|
||||
payload: RedisConnectorPermissionSyncData | None = (
|
||||
redis_connector.permissions.payload
|
||||
)
|
||||
start_time: datetime | None = payload.started if payload else None
|
||||
|
||||
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
|
||||
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
|
||||
|
||||
redis_connector.permissions.taskset_clear()
|
||||
redis_connector.permissions.generator_clear()
|
||||
redis_connector.permissions.set_fence(None)
|
||||
|
||||
|
||||
def monitor_ccpair_indexing_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
@@ -714,17 +668,13 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
n_pruning = celery_get_queue_length(
|
||||
DanswerCeleryQueues.CONNECTOR_PRUNING, r_celery
|
||||
)
|
||||
n_permissions_sync = celery_get_queue_length(
|
||||
DanswerCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Queue lengths: celery={n_celery} "
|
||||
f"indexing={n_indexing} "
|
||||
f"sync={n_sync} "
|
||||
f"deletion={n_deletion} "
|
||||
f"pruning={n_pruning} "
|
||||
f"permissions_sync={n_permissions_sync} "
|
||||
f"pruning={n_pruning}"
|
||||
)
|
||||
|
||||
# do some cleanup before clearing fences
|
||||
@@ -738,22 +688,20 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
|
||||
)
|
||||
|
||||
for attempt in attempts:
|
||||
for a in attempts:
|
||||
# if attempts exist in the db but we don't detect them in redis, mark them as failed
|
||||
fence_key = RedisConnectorIndex.fence_key_with_ids(
|
||||
attempt.connector_credential_pair_id, attempt.search_settings_id
|
||||
a.connector_credential_pair_id, a.search_settings_id
|
||||
)
|
||||
if not r.exists(fence_key):
|
||||
failure_reason = (
|
||||
f"Unknown index attempt. Might be left over from a process restart: "
|
||||
f"index_attempt={attempt.id} "
|
||||
f"cc_pair={attempt.connector_credential_pair_id} "
|
||||
f"search_settings={attempt.search_settings_id}"
|
||||
f"index_attempt={a.id} "
|
||||
f"cc_pair={a.connector_credential_pair_id} "
|
||||
f"search_settings={a.search_settings_id}"
|
||||
)
|
||||
task_logger.warning(failure_reason)
|
||||
mark_attempt_failed(
|
||||
attempt.id, db_session, failure_reason=failure_reason
|
||||
)
|
||||
mark_attempt_failed(a.id, db_session, failure_reason=failure_reason)
|
||||
|
||||
lock_beat.reacquire()
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
@@ -793,12 +741,6 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
# uncomment for debugging if needed
|
||||
# r_celery = celery_app.broker_connection().channel().client
|
||||
# length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery)
|
||||
@@ -869,9 +811,7 @@ def vespa_metadata_sync_task(
|
||||
)
|
||||
except Exception as ex:
|
||||
if isinstance(ex, RetryError):
|
||||
task_logger.warning(
|
||||
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
|
||||
)
|
||||
task_logger.warning(f"Retry failed: {ex.last_attempt.attempt_number}")
|
||||
|
||||
# only set the inner exception if it is of type Exception
|
||||
e_temp = ex.last_attempt.exception()
|
||||
|
||||
@@ -33,8 +33,8 @@ from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.indexing.embedder import DefaultIndexingEmbedder
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
|
||||
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from danswer.utils.logger import IndexAttemptSingleton
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.logger import TaskAttemptSingleton
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -427,7 +427,7 @@ def run_indexing_entrypoint(
|
||||
|
||||
# set the indexing attempt ID so that all log messages from this process
|
||||
# will have it added as a prefix
|
||||
TaskAttemptSingleton.set_cc_and_index_id(
|
||||
IndexAttemptSingleton.set_cc_and_index_id(
|
||||
index_attempt_id, connector_credential_pair_id
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
|
||||
4
backend/danswer/background/task_name_builders.py
Normal file
4
backend/danswer/background/task_name_builders.py
Normal file
@@ -0,0 +1,4 @@
|
||||
def name_sync_external_doc_permissions_task(
|
||||
cc_pair_id: int, tenant_id: str | None = None
|
||||
) -> str:
|
||||
return f"sync_external_doc_permissions_task__{cc_pair_id}"
|
||||
@@ -14,6 +14,15 @@ from danswer.db.tasks import mark_task_start
|
||||
from danswer.db.tasks import register_task
|
||||
|
||||
|
||||
def name_cc_prune_task(
|
||||
connector_id: int | None = None, credential_id: int | None = None
|
||||
) -> str:
|
||||
task_name = f"prune_connector_credential_pair_{connector_id}_{credential_id}"
|
||||
if not connector_id or not credential_id:
|
||||
task_name = "prune_connector_credential_pair"
|
||||
return task_name
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Callable)
|
||||
|
||||
|
||||
|
||||
@@ -503,7 +503,3 @@ _API_KEY_HASH_ROUNDS_RAW = os.environ.get("API_KEY_HASH_ROUNDS")
|
||||
API_KEY_HASH_ROUNDS = (
|
||||
int(_API_KEY_HASH_ROUNDS_RAW) if _API_KEY_HASH_ROUNDS_RAW else None
|
||||
)
|
||||
|
||||
|
||||
POD_NAME = os.environ.get("POD_NAME")
|
||||
POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
|
||||
|
||||
@@ -80,10 +80,6 @@ CELERY_INDEXING_LOCK_TIMEOUT = 60 * 60 # 60 min
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
|
||||
|
||||
|
||||
@@ -213,17 +209,9 @@ class PostgresAdvisoryLocks(Enum):
|
||||
|
||||
|
||||
class DanswerCeleryQueues:
|
||||
# Light queue
|
||||
VESPA_METADATA_SYNC = "vespa_metadata_sync"
|
||||
DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert"
|
||||
CONNECTOR_DELETION = "connector_deletion"
|
||||
|
||||
# Heavy queue
|
||||
CONNECTOR_PRUNING = "connector_pruning"
|
||||
CONNECTOR_DOC_PERMISSIONS_SYNC = "connector_doc_permissions_sync"
|
||||
CONNECTOR_EXTERNAL_GROUP_SYNC = "connector_external_group_sync"
|
||||
|
||||
# Indexing queue
|
||||
CONNECTOR_INDEXING = "connector_indexing"
|
||||
|
||||
|
||||
@@ -233,18 +221,8 @@ class DanswerRedisLocks:
|
||||
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
|
||||
CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat"
|
||||
CHECK_INDEXING_BEAT_LOCK = "da_lock:check_indexing_beat"
|
||||
CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK = (
|
||||
"da_lock:check_connector_doc_permissions_sync_beat"
|
||||
)
|
||||
CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK = (
|
||||
"da_lock:check_connector_external_group_sync_beat"
|
||||
)
|
||||
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
|
||||
|
||||
CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX = (
|
||||
"da_lock:connector_doc_permissions_sync"
|
||||
)
|
||||
CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX = "da_lock:connector_external_group_sync"
|
||||
PRUNING_LOCK_PREFIX = "da_lock:pruning"
|
||||
INDEXING_METADATA_PREFIX = "da_metadata:indexing"
|
||||
|
||||
|
||||
@@ -119,14 +119,3 @@ if _LITELLM_PASS_THROUGH_HEADERS_RAW:
|
||||
logger.error(
|
||||
"Failed to parse LITELLM_PASS_THROUGH_HEADERS, must be a valid JSON object"
|
||||
)
|
||||
|
||||
|
||||
# if specified, will merge the specified JSON with the existing body of the
|
||||
# request before sending it to the LLM
|
||||
LITELLM_EXTRA_BODY: dict | None = None
|
||||
_LITELLM_EXTRA_BODY_RAW = os.environ.get("LITELLM_EXTRA_BODY")
|
||||
if _LITELLM_EXTRA_BODY_RAW:
|
||||
try:
|
||||
LITELLM_EXTRA_BODY = json.loads(_LITELLM_EXTRA_BODY_RAW)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -146,7 +146,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
# The url and the id are the same
|
||||
object_url = build_confluence_document_id(
|
||||
self.wiki_base, confluence_object["_links"]["webui"], self.is_cloud
|
||||
self.wiki_base, confluence_object["_links"]["webui"]
|
||||
)
|
||||
|
||||
object_text = None
|
||||
@@ -278,9 +278,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
doc_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=build_confluence_document_id(
|
||||
self.wiki_base,
|
||||
page["_links"]["webui"],
|
||||
self.is_cloud,
|
||||
self.wiki_base, page["_links"]["webui"]
|
||||
),
|
||||
perm_sync_data=perm_sync_data,
|
||||
)
|
||||
@@ -295,9 +293,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
doc_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=build_confluence_document_id(
|
||||
self.wiki_base,
|
||||
attachment["_links"]["webui"],
|
||||
self.is_cloud,
|
||||
self.wiki_base, attachment["_links"]["webui"]
|
||||
),
|
||||
perm_sync_data=perm_sync_data,
|
||||
)
|
||||
|
||||
@@ -100,39 +100,6 @@ def extract_text_from_confluence_html(
|
||||
continue
|
||||
# Include @ sign for tagging, more clear for LLM
|
||||
user.replaceWith("@" + _get_user(confluence_client, user_id))
|
||||
|
||||
for html_page_reference in soup.findAll("ri:page"):
|
||||
# Wrap this in a try-except because there are some pages that might not exist
|
||||
try:
|
||||
page_title = html_page_reference.attrs["ri:content-title"]
|
||||
if not page_title:
|
||||
continue
|
||||
|
||||
page_query = f"type=page and title='{page_title}'"
|
||||
|
||||
page_contents: dict[str, Any] | None = None
|
||||
# Confluence enforces title uniqueness, so we should only get one result here
|
||||
for page_batch in confluence_client.paginated_cql_page_retrieval(
|
||||
cql=page_query,
|
||||
expand="body.storage.value",
|
||||
limit=1,
|
||||
):
|
||||
page_contents = page_batch[0]
|
||||
break
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Error getting page contents for object {confluence_object}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not page_contents:
|
||||
continue
|
||||
text_from_page = extract_text_from_confluence_html(
|
||||
confluence_client, page_contents
|
||||
)
|
||||
|
||||
html_page_reference.replaceWith(text_from_page)
|
||||
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
@@ -186,9 +153,7 @@ def attachment_to_content(
|
||||
return extracted_text
|
||||
|
||||
|
||||
def build_confluence_document_id(
|
||||
base_url: str, content_url: str, is_cloud: bool
|
||||
) -> str:
|
||||
def build_confluence_document_id(base_url: str, content_url: str) -> str:
|
||||
"""For confluence, the document id is the page url for a page based document
|
||||
or the attachment download url for an attachment based document
|
||||
|
||||
@@ -199,8 +164,6 @@ def build_confluence_document_id(
|
||||
Returns:
|
||||
str: The document id
|
||||
"""
|
||||
if is_cloud and not base_url.endswith("/wiki"):
|
||||
base_url += "/wiki"
|
||||
return f"{base_url}{content_url}"
|
||||
|
||||
|
||||
|
||||
@@ -305,7 +305,6 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
query = _build_time_range_query(time_range_start, time_range_end)
|
||||
doc_batch = []
|
||||
for user_email in self._get_all_user_emails():
|
||||
logger.info(f"Fetching slim threads for user: {user_email}")
|
||||
gmail_service = get_gmail_service(self.creds, user_email)
|
||||
for thread in execute_paginated_retrieval(
|
||||
retrieval_function=gmail_service.users().threads().list,
|
||||
|
||||
@@ -192,33 +192,23 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def _update_traversed_parent_ids(self, folder_id: str) -> None:
|
||||
self._retrieved_ids.add(folder_id)
|
||||
|
||||
def _get_all_user_emails(self) -> list[str]:
|
||||
# Start with primary admin email
|
||||
user_emails = [self.primary_admin_email]
|
||||
|
||||
# Only fetch additional users if using service account
|
||||
if isinstance(self.creds, OAuthCredentials):
|
||||
return user_emails
|
||||
|
||||
def _get_all_user_emails(self, admins_only: bool) -> list[str]:
|
||||
admin_service = get_admin_service(
|
||||
creds=self.creds,
|
||||
user_email=self.primary_admin_email,
|
||||
)
|
||||
|
||||
# Get admins first since they're more likely to have access to most files
|
||||
for is_admin in [True, False]:
|
||||
query = "isAdmin=true" if is_admin else "isAdmin=false"
|
||||
for user in execute_paginated_retrieval(
|
||||
retrieval_function=admin_service.users().list,
|
||||
list_key="users",
|
||||
fields=USER_FIELDS,
|
||||
domain=self.google_domain,
|
||||
query=query,
|
||||
):
|
||||
if email := user.get("primaryEmail"):
|
||||
if email not in user_emails:
|
||||
user_emails.append(email)
|
||||
return user_emails
|
||||
query = "isAdmin=true" if admins_only else "isAdmin=false"
|
||||
emails = []
|
||||
for user in execute_paginated_retrieval(
|
||||
retrieval_function=admin_service.users().list,
|
||||
list_key="users",
|
||||
fields=USER_FIELDS,
|
||||
domain=self.google_domain,
|
||||
query=query,
|
||||
):
|
||||
if email := user.get("primaryEmail"):
|
||||
emails.append(email)
|
||||
return emails
|
||||
|
||||
def _get_all_drive_ids(self) -> set[str]:
|
||||
primary_drive_service = get_drive_service(
|
||||
@@ -226,48 +216,55 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
user_email=self.primary_admin_email,
|
||||
)
|
||||
all_drive_ids = set()
|
||||
# We don't want to fail if we're using OAuth because you can
|
||||
# access your my drive as a non admin user in an org still
|
||||
ignore_fetch_failure = isinstance(self.creds, OAuthCredentials)
|
||||
for drive in execute_paginated_retrieval(
|
||||
retrieval_function=primary_drive_service.drives().list,
|
||||
list_key="drives",
|
||||
continue_on_404_or_403=ignore_fetch_failure,
|
||||
useDomainAdminAccess=True,
|
||||
fields="drives(id)",
|
||||
):
|
||||
all_drive_ids.add(drive["id"])
|
||||
|
||||
if not all_drive_ids:
|
||||
logger.warning(
|
||||
"No drives found. This is likely because oauth user "
|
||||
"is not an admin and cannot view all drive IDs. "
|
||||
"Continuing with only the shared drive IDs specified in the config."
|
||||
)
|
||||
all_drive_ids = set(self._requested_shared_drive_ids)
|
||||
|
||||
return all_drive_ids
|
||||
|
||||
def _initialize_all_class_variables(self) -> None:
|
||||
# Get all user emails
|
||||
# Get admins first becuase they are more likely to have access to the most files
|
||||
user_emails = [self.primary_admin_email]
|
||||
for admins_only in [True, False]:
|
||||
for email in self._get_all_user_emails(admins_only=admins_only):
|
||||
if email not in user_emails:
|
||||
user_emails.append(email)
|
||||
self._all_org_emails = user_emails
|
||||
|
||||
self._all_drive_ids: set[str] = self._get_all_drive_ids()
|
||||
|
||||
# remove drive ids from the folder ids because they are queried differently
|
||||
self._requested_folder_ids -= self._all_drive_ids
|
||||
|
||||
# Remove drive_ids that are not in the all_drive_ids and check them as folders instead
|
||||
invalid_drive_ids = self._requested_shared_drive_ids - self._all_drive_ids
|
||||
if invalid_drive_ids:
|
||||
logger.warning(
|
||||
f"Some shared drive IDs were not found. IDs: {invalid_drive_ids}"
|
||||
)
|
||||
logger.warning("Checking for folder access instead...")
|
||||
self._requested_folder_ids.update(invalid_drive_ids)
|
||||
|
||||
if not self.include_shared_drives:
|
||||
self._requested_shared_drive_ids = set()
|
||||
elif not self._requested_shared_drive_ids:
|
||||
self._requested_shared_drive_ids = self._all_drive_ids
|
||||
|
||||
def _impersonate_user_for_retrieval(
|
||||
self,
|
||||
user_email: str,
|
||||
is_slim: bool,
|
||||
filtered_drive_ids: set[str],
|
||||
filtered_folder_ids: set[str],
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
drive_service = get_drive_service(self.creds, user_email)
|
||||
|
||||
# if we are including my drives, try to get the current user's my
|
||||
# drive if any of the following are true:
|
||||
# - no specific emails were requested
|
||||
# - the current user's email is in the requested emails
|
||||
# - we are using OAuth (in which case we assume that is the only email we will try)
|
||||
if self.include_my_drives and (
|
||||
not self._requested_my_drive_emails
|
||||
or user_email in self._requested_my_drive_emails
|
||||
or isinstance(self.creds, OAuthCredentials)
|
||||
):
|
||||
yield from get_all_files_in_my_drive(
|
||||
service=drive_service,
|
||||
@@ -277,7 +274,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
end=end,
|
||||
)
|
||||
|
||||
remaining_drive_ids = filtered_drive_ids - self._retrieved_ids
|
||||
remaining_drive_ids = self._requested_shared_drive_ids - self._retrieved_ids
|
||||
for drive_id in remaining_drive_ids:
|
||||
yield from get_files_in_shared_drive(
|
||||
service=drive_service,
|
||||
@@ -288,7 +285,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
end=end,
|
||||
)
|
||||
|
||||
remaining_folders = filtered_folder_ids - self._retrieved_ids
|
||||
remaining_folders = self._requested_folder_ids - self._retrieved_ids
|
||||
for folder_id in remaining_folders:
|
||||
yield from crawl_folders_for_files(
|
||||
service=drive_service,
|
||||
@@ -305,56 +302,22 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
all_org_emails: list[str] = self._get_all_user_emails()
|
||||
|
||||
all_drive_ids: set[str] = self._get_all_drive_ids()
|
||||
|
||||
# remove drive ids from the folder ids because they are queried differently
|
||||
filtered_folder_ids = self._requested_folder_ids - all_drive_ids
|
||||
|
||||
# Remove drive_ids that are not in the all_drive_ids and check them as folders instead
|
||||
invalid_drive_ids = self._requested_shared_drive_ids - all_drive_ids
|
||||
if invalid_drive_ids:
|
||||
logger.warning(
|
||||
f"Some shared drive IDs were not found. IDs: {invalid_drive_ids}"
|
||||
)
|
||||
logger.warning("Checking for folder access instead...")
|
||||
filtered_folder_ids.update(invalid_drive_ids)
|
||||
|
||||
# If including shared drives, use the requested IDs if provided,
|
||||
# otherwise use all drive IDs
|
||||
filtered_drive_ids = set()
|
||||
if self.include_shared_drives:
|
||||
if self._requested_shared_drive_ids:
|
||||
# Remove invalid drive IDs from requested IDs
|
||||
filtered_drive_ids = (
|
||||
self._requested_shared_drive_ids - invalid_drive_ids
|
||||
)
|
||||
else:
|
||||
filtered_drive_ids = all_drive_ids
|
||||
self._initialize_all_class_variables()
|
||||
|
||||
# Process users in parallel using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
future_to_email = {
|
||||
executor.submit(
|
||||
self._impersonate_user_for_retrieval,
|
||||
email,
|
||||
is_slim,
|
||||
filtered_drive_ids,
|
||||
filtered_folder_ids,
|
||||
start,
|
||||
end,
|
||||
self._impersonate_user_for_retrieval, email, is_slim, start, end
|
||||
): email
|
||||
for email in all_org_emails
|
||||
for email in self._all_org_emails
|
||||
}
|
||||
|
||||
# Yield results as they complete
|
||||
for future in as_completed(future_to_email):
|
||||
yield from future.result()
|
||||
|
||||
remaining_folders = (
|
||||
filtered_drive_ids | filtered_folder_ids
|
||||
) - self._retrieved_ids
|
||||
remaining_folders = self._requested_folder_ids - self._retrieved_ids
|
||||
if remaining_folders:
|
||||
logger.warning(
|
||||
f"Some folders/drives were not retrieved. IDs: {remaining_folders}"
|
||||
|
||||
@@ -105,7 +105,7 @@ def execute_paginated_retrieval(
|
||||
)()
|
||||
elif e.resp.status == 404 or e.resp.status == 403:
|
||||
if continue_on_404_or_403:
|
||||
logger.debug(f"Error executing request: {e}")
|
||||
logger.warning(f"Error executing request: {e}")
|
||||
results = {}
|
||||
else:
|
||||
raise e
|
||||
|
||||
@@ -17,8 +17,6 @@ from slack_sdk import WebClient
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
|
||||
from danswer.configs.app_configs import POD_NAME
|
||||
from danswer.configs.app_configs import POD_NAMESPACE
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
|
||||
@@ -87,9 +85,7 @@ logger = setup_logger()
|
||||
|
||||
# Prometheus metric for HPA
|
||||
active_tenants_gauge = Gauge(
|
||||
"active_tenants",
|
||||
"Number of active tenants handled by this pod",
|
||||
["namespace", "pod"],
|
||||
"active_tenants", "Number of active tenants handled by this pod"
|
||||
)
|
||||
|
||||
# In rare cases, some users have been experiencing a massive amount of trivial messages coming through
|
||||
@@ -152,9 +148,7 @@ class SlackbotHandler:
|
||||
while not self._shutdown_event.is_set():
|
||||
try:
|
||||
self.acquire_tenants()
|
||||
active_tenants_gauge.labels(namespace=POD_NAMESPACE, pod=POD_NAME).set(
|
||||
len(self.tenant_ids)
|
||||
)
|
||||
active_tenants_gauge.set(len(self.tenant_ids))
|
||||
logger.debug(f"Current active tenants: {len(self.tenant_ids)}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in Slack acquisition: {e}")
|
||||
|
||||
@@ -282,32 +282,3 @@ def mark_ccpair_as_pruned(cc_pair_id: int, db_session: Session) -> None:
|
||||
|
||||
cc_pair.last_pruned = datetime.now(timezone.utc)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_cc_pair_as_permissions_synced(
|
||||
db_session: Session, cc_pair_id: int, start_time: datetime | None
|
||||
) -> None:
|
||||
stmt = select(ConnectorCredentialPair).where(
|
||||
ConnectorCredentialPair.id == cc_pair_id
|
||||
)
|
||||
cc_pair = db_session.scalar(stmt)
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")
|
||||
|
||||
cc_pair.last_time_perm_sync = start_time
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_cc_pair_as_external_group_synced(db_session: Session, cc_pair_id: int) -> None:
|
||||
stmt = select(ConnectorCredentialPair).where(
|
||||
ConnectorCredentialPair.id == cc_pair_id
|
||||
)
|
||||
cc_pair = db_session.scalar(stmt)
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")
|
||||
|
||||
# The sync time can be marked after it ran because all group syncs
|
||||
# are run in full, not polling for changes.
|
||||
# If this changes, we need to update this function.
|
||||
cc_pair.last_time_external_group_sync = datetime.now(timezone.utc)
|
||||
db_session.commit()
|
||||
|
||||
@@ -76,10 +76,8 @@ def _add_user_filters(
|
||||
.where(~UG__CCpair.user_group_id.in_(user_groups))
|
||||
.correlate(ConnectorCredentialPair)
|
||||
)
|
||||
where_clause |= ConnectorCredentialPair.creator_id == user.id
|
||||
else:
|
||||
where_clause |= ConnectorCredentialPair.access_type == AccessType.PUBLIC
|
||||
where_clause |= ConnectorCredentialPair.access_type == AccessType.SYNC
|
||||
|
||||
return stmt.where(where_clause)
|
||||
|
||||
@@ -389,7 +387,6 @@ def add_credential_to_connector(
|
||||
)
|
||||
|
||||
association = ConnectorCredentialPair(
|
||||
creator_id=user.id if user else None,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
name=cc_pair_name,
|
||||
|
||||
@@ -19,7 +19,6 @@ from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import null
|
||||
|
||||
from danswer.configs.constants import DEFAULT_BOOST
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.feedback import delete_document_feedback_for_documents__no_commit
|
||||
@@ -47,21 +46,13 @@ def count_documents_by_needs_sync(session: Session) -> int:
|
||||
"""Get the count of all documents where:
|
||||
1. last_modified is newer than last_synced
|
||||
2. last_synced is null (meaning we've never synced)
|
||||
AND the document has a relationship with a connector/credential pair
|
||||
|
||||
TODO: The documents without a relationship with a connector/credential pair
|
||||
should be cleaned up somehow eventually.
|
||||
|
||||
This function executes the query and returns the count of
|
||||
documents matching the criteria."""
|
||||
|
||||
count = (
|
||||
session.query(func.count(DbDocument.id.distinct()))
|
||||
session.query(func.count())
|
||||
.select_from(DbDocument)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
DbDocument.id == DocumentByConnectorCredentialPair.id,
|
||||
)
|
||||
.filter(
|
||||
or_(
|
||||
DbDocument.last_modified > DbDocument.last_synced,
|
||||
@@ -100,22 +91,6 @@ def construct_document_select_for_connector_credential_pair_by_needs_sync(
|
||||
return stmt
|
||||
|
||||
|
||||
def get_all_documents_needing_vespa_sync_for_cc_pair(
|
||||
db_session: Session, cc_pair_id: int
|
||||
) -> list[DbDocument]:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id, db_session=db_session
|
||||
)
|
||||
if not cc_pair:
|
||||
raise ValueError(f"No CC pair found with ID: {cc_pair_id}")
|
||||
|
||||
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
|
||||
cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def construct_document_select_for_connector_credential_pair(
|
||||
connector_id: int, credential_id: int | None = None
|
||||
) -> Select:
|
||||
@@ -129,21 +104,6 @@ def construct_document_select_for_connector_credential_pair(
|
||||
return stmt
|
||||
|
||||
|
||||
def get_documents_for_cc_pair(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
) -> list[DbDocument]:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id, db_session=db_session
|
||||
)
|
||||
if not cc_pair:
|
||||
raise ValueError(f"No CC pair found with ID: {cc_pair_id}")
|
||||
stmt = construct_document_select_for_connector_credential_pair(
|
||||
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
|
||||
)
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def get_document_ids_for_connector_credential_pair(
|
||||
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
|
||||
) -> list[str]:
|
||||
@@ -308,7 +268,7 @@ def get_access_info_for_documents(
|
||||
return db_session.execute(stmt).all() # type: ignore
|
||||
|
||||
|
||||
def _upsert_documents(
|
||||
def upsert_documents(
|
||||
db_session: Session,
|
||||
document_metadata_batch: list[DocumentMetadata],
|
||||
initial_boost: int = DEFAULT_BOOST,
|
||||
@@ -346,8 +306,6 @@ def _upsert_documents(
|
||||
]
|
||||
)
|
||||
|
||||
# This does not update the permissions of the document if
|
||||
# the document already exists.
|
||||
on_conflict_stmt = insert_stmt.on_conflict_do_update(
|
||||
index_elements=["id"], # Conflict target
|
||||
set_={
|
||||
@@ -364,7 +322,7 @@ def _upsert_documents(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _upsert_document_by_connector_credential_pair(
|
||||
def upsert_document_by_connector_credential_pair(
|
||||
db_session: Session, document_metadata_batch: list[DocumentMetadata]
|
||||
) -> None:
|
||||
"""NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause."""
|
||||
@@ -446,8 +404,8 @@ def upsert_documents_complete(
|
||||
db_session: Session,
|
||||
document_metadata_batch: list[DocumentMetadata],
|
||||
) -> None:
|
||||
_upsert_documents(db_session, document_metadata_batch)
|
||||
_upsert_document_by_connector_credential_pair(db_session, document_metadata_batch)
|
||||
upsert_documents(db_session, document_metadata_batch)
|
||||
upsert_document_by_connector_credential_pair(db_session, document_metadata_batch)
|
||||
logger.info(
|
||||
f"Upserted {len(document_metadata_batch)} document store entries into DB"
|
||||
)
|
||||
@@ -505,6 +463,7 @@ def delete_documents_complete__no_commit(
|
||||
db_session: Session, document_ids: list[str]
|
||||
) -> None:
|
||||
"""This completely deletes the documents from the db, including all foreign key relationships"""
|
||||
logger.info(f"Deleting {len(document_ids)} documents from the DB")
|
||||
delete_documents_by_connector_credential_pair__no_commit(db_session, document_ids)
|
||||
delete_document_feedback_for_documents__no_commit(
|
||||
document_ids=document_ids, db_session=db_session
|
||||
|
||||
@@ -189,13 +189,6 @@ class SqlEngine:
|
||||
return ""
|
||||
return cls._app_name
|
||||
|
||||
@classmethod
|
||||
def reset_engine(cls) -> None:
|
||||
with cls._lock:
|
||||
if cls._engine:
|
||||
cls._engine.dispose()
|
||||
cls._engine = None
|
||||
|
||||
|
||||
def get_all_tenant_ids() -> list[str] | list[None]:
|
||||
if not MULTI_TENANT:
|
||||
@@ -319,9 +312,7 @@ async def get_async_session_with_tenant(
|
||||
await session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||
await session.execute(
|
||||
text(
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
)
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error setting search_path.")
|
||||
@@ -382,9 +373,7 @@ def get_session_with_tenant(
|
||||
cursor.execute(f'SET search_path = "{tenant_id}"')
|
||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||
cursor.execute(
|
||||
text(
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
)
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
)
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
@@ -126,8 +126,8 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
|
||||
# if specified, controls the assistants that are shown to the user + their order
|
||||
# if not specified, all assistants are shown
|
||||
chosen_assistants: Mapped[list[int] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True, default=None
|
||||
chosen_assistants: Mapped[list[int]] = mapped_column(
|
||||
postgresql.JSONB(), nullable=False, default=[-2, -1, 0]
|
||||
)
|
||||
visible_assistants: Mapped[list[int]] = mapped_column(
|
||||
postgresql.JSONB(), nullable=False, default=[]
|
||||
@@ -173,11 +173,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
)
|
||||
# Whether the user has logged in via web. False if user has only used Danswer through Slack bot
|
||||
has_web_login: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
cc_pairs: Mapped[list["ConnectorCredentialPair"]] = relationship(
|
||||
"ConnectorCredentialPair",
|
||||
back_populates="creator",
|
||||
primaryjoin="User.id == foreign(ConnectorCredentialPair.creator_id)",
|
||||
)
|
||||
|
||||
|
||||
class InputPrompt(Base):
|
||||
@@ -425,9 +420,6 @@ class ConnectorCredentialPair(Base):
|
||||
last_time_perm_sync: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
last_time_external_group_sync: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
# Time finished, not used for calculating backend jobs which uses time started (created)
|
||||
last_successful_index_time: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), default=None
|
||||
@@ -460,14 +452,6 @@ class ConnectorCredentialPair(Base):
|
||||
"IndexAttempt", back_populates="connector_credential_pair"
|
||||
)
|
||||
|
||||
# the user id of the user that created this cc pair
|
||||
creator_id: Mapped[UUID | None] = mapped_column(nullable=True)
|
||||
creator: Mapped["User"] = relationship(
|
||||
"User",
|
||||
back_populates="cc_pairs",
|
||||
primaryjoin="foreign(ConnectorCredentialPair.creator_id) == remote(User.id)",
|
||||
)
|
||||
|
||||
|
||||
class Document(Base):
|
||||
__tablename__ = "document"
|
||||
|
||||
@@ -743,4 +743,5 @@ def delete_persona_by_name(
|
||||
)
|
||||
|
||||
db_session.execute(stmt)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
@@ -97,18 +97,3 @@ def batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session.flush() # generate ids
|
||||
|
||||
return found_users + new_users
|
||||
|
||||
|
||||
def batch_add_non_web_user_if_not_exists(
|
||||
db_session: Session, emails: list[str]
|
||||
) -> list[User]:
|
||||
found_users, missing_user_emails = get_users_by_emails(db_session, emails)
|
||||
|
||||
new_users: list[User] = []
|
||||
for email in missing_user_emails:
|
||||
new_users.append(_generate_non_web_user(email=email))
|
||||
|
||||
db_session.add_all(new_users)
|
||||
db_session.commit()
|
||||
|
||||
return found_users + new_users
|
||||
|
||||
@@ -56,7 +56,7 @@ class IndexingPipelineProtocol(Protocol):
|
||||
...
|
||||
|
||||
|
||||
def _upsert_documents_in_db(
|
||||
def upsert_documents_in_db(
|
||||
documents: list[Document],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
db_session: Session,
|
||||
@@ -243,7 +243,7 @@ def index_doc_batch_prepare(
|
||||
|
||||
# Create records in the source of truth about these documents,
|
||||
# does not include doc_updated_at which is also used to indicate a successful update
|
||||
_upsert_documents_in_db(
|
||||
upsert_documents_in_db(
|
||||
documents=documents,
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
db_session=db_session,
|
||||
@@ -255,7 +255,7 @@ def index_doc_batch_prepare(
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(debug_only=True)
|
||||
@log_function_time()
|
||||
def index_doc_batch(
|
||||
*,
|
||||
chunker: Chunker,
|
||||
|
||||
@@ -26,7 +26,6 @@ from danswer.configs.app_configs import LOG_ALL_MODEL_INTERACTIONS
|
||||
from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
|
||||
from danswer.configs.model_configs import DISABLE_LITELLM_STREAMING
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.configs.model_configs import LITELLM_EXTRA_BODY
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.interfaces import ToolChoiceOptions
|
||||
@@ -214,7 +213,6 @@ class DefaultMultiLLM(LLM):
|
||||
temperature: float = GEN_AI_TEMPERATURE,
|
||||
custom_config: dict[str, str] | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
extra_body: dict | None = LITELLM_EXTRA_BODY,
|
||||
):
|
||||
self._timeout = timeout
|
||||
self._model_provider = model_provider
|
||||
@@ -248,8 +246,6 @@ class DefaultMultiLLM(LLM):
|
||||
model_kwargs: dict[str, Any] = {}
|
||||
if extra_headers:
|
||||
model_kwargs.update({"extra_headers": extra_headers})
|
||||
if extra_body:
|
||||
model_kwargs.update({"extra_body": extra_body})
|
||||
|
||||
self._model_kwargs = model_kwargs
|
||||
|
||||
|
||||
@@ -315,7 +315,7 @@ def get_application() -> FastAPI:
|
||||
tags=["users"],
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.GOOGLE_OAUTH:
|
||||
if AUTH_TYPE == AuthType.GOOGLE_OAUTH or AUTH_TYPE == AuthType.CLOUD:
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import redis
|
||||
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndex
|
||||
from danswer.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from danswer.redis.redis_connector_stop import RedisConnectorStop
|
||||
@@ -21,10 +19,6 @@ class RedisConnector:
|
||||
self.stop = RedisConnectorStop(tenant_id, id, self.redis)
|
||||
self.prune = RedisConnectorPrune(tenant_id, id, self.redis)
|
||||
self.delete = RedisConnectorDelete(tenant_id, id, self.redis)
|
||||
self.permissions = RedisConnectorPermissionSync(tenant_id, id, self.redis)
|
||||
self.external_group_sync = RedisConnectorExternalGroupSync(
|
||||
tenant_id, id, self.redis
|
||||
)
|
||||
|
||||
def new_index(self, search_settings_id: int) -> RedisConnectorIndex:
|
||||
return RedisConnectorIndex(
|
||||
|
||||
@@ -1,187 +0,0 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.access.models import DocExternalAccess
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
|
||||
|
||||
class RedisConnectorPermissionSyncData(BaseModel):
|
||||
started: datetime | None
|
||||
|
||||
|
||||
class RedisConnectorPermissionSync:
|
||||
"""Manages interactions with redis for doc permission sync tasks. Should only be accessed
|
||||
through RedisConnector."""
|
||||
|
||||
PREFIX = "connectordocpermissionsync"
|
||||
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
|
||||
# phase 1 - geneartor task and progress signals
|
||||
GENERATORTASK_PREFIX = f"{PREFIX}+generator" # connectorpermissions+generator
|
||||
GENERATOR_PROGRESS_PREFIX = (
|
||||
PREFIX + "_generator_progress"
|
||||
) # connectorpermissions_generator_progress
|
||||
GENERATOR_COMPLETE_PREFIX = (
|
||||
PREFIX + "_generator_complete"
|
||||
) # connectorpermissions_generator_complete
|
||||
|
||||
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpermissions_taskset
|
||||
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpermissions+sub
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id = id
|
||||
self.redis = redis
|
||||
|
||||
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}"
|
||||
self.generator_task_key = f"{self.GENERATORTASK_PREFIX}_{id}"
|
||||
self.generator_progress_key = f"{self.GENERATOR_PROGRESS_PREFIX}_{id}"
|
||||
self.generator_complete_key = f"{self.GENERATOR_COMPLETE_PREFIX}_{id}"
|
||||
|
||||
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
|
||||
|
||||
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
|
||||
|
||||
def taskset_clear(self) -> None:
|
||||
self.redis.delete(self.taskset_key)
|
||||
|
||||
def generator_clear(self) -> None:
|
||||
self.redis.delete(self.generator_progress_key)
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
|
||||
def get_remaining(self) -> int:
|
||||
remaining = cast(int, self.redis.scard(self.taskset_key))
|
||||
return remaining
|
||||
|
||||
def get_active_task_count(self) -> int:
|
||||
"""Count of active permission sync tasks"""
|
||||
count = 0
|
||||
for _ in self.redis.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@property
|
||||
def fenced(self) -> bool:
|
||||
if self.redis.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def payload(self) -> RedisConnectorPermissionSyncData | None:
|
||||
# read related data and evaluate/print task progress
|
||||
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
|
||||
if fence_bytes is None:
|
||||
return None
|
||||
|
||||
fence_str = fence_bytes.decode("utf-8")
|
||||
payload = RedisConnectorPermissionSyncData.model_validate_json(
|
||||
cast(str, fence_str)
|
||||
)
|
||||
|
||||
return payload
|
||||
|
||||
def set_fence(
|
||||
self,
|
||||
payload: RedisConnectorPermissionSyncData | None,
|
||||
) -> None:
|
||||
if not payload:
|
||||
self.redis.delete(self.fence_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.fence_key, payload.model_dump_json())
|
||||
|
||||
@property
|
||||
def generator_complete(self) -> int | None:
|
||||
"""the fence payload is an int representing the starting number of
|
||||
permission sync tasks to be processed ... just after the generator completes."""
|
||||
fence_bytes = self.redis.get(self.generator_complete_key)
|
||||
if fence_bytes is None:
|
||||
return None
|
||||
|
||||
if fence_bytes == b"None":
|
||||
return None
|
||||
|
||||
fence_int = int(cast(bytes, fence_bytes).decode())
|
||||
return fence_int
|
||||
|
||||
@generator_complete.setter
|
||||
def generator_complete(self, payload: int | None) -> None:
|
||||
"""Set the payload to an int to set the fence, otherwise if None it will
|
||||
be deleted"""
|
||||
if payload is None:
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.generator_complete_key, payload)
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
lock: redis.lock.Lock | None,
|
||||
new_permissions: list[DocExternalAccess],
|
||||
source_string: str,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
async_results = []
|
||||
|
||||
# Create a task for each document permission sync
|
||||
for doc_perm in new_permissions:
|
||||
current_time = time.monotonic()
|
||||
if lock and current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
# Add task for document permissions sync
|
||||
custom_task_id = f"{self.subtask_prefix}_{uuid4()}"
|
||||
self.redis.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
"update_external_document_permissions_task",
|
||||
kwargs=dict(
|
||||
tenant_id=self.tenant_id,
|
||||
serialized_doc_external_access=doc_perm.to_dict(),
|
||||
source_string=source_string,
|
||||
),
|
||||
queue=DanswerCeleryQueues.DOC_PERMISSIONS_UPSERT,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
@staticmethod
|
||||
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
|
||||
taskset_key = f"{RedisConnectorPermissionSync.TASKSET_PREFIX}_{id}"
|
||||
r.srem(taskset_key, task_id)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def reset_all(r: redis.Redis) -> None:
|
||||
"""Deletes all redis values for all connectors"""
|
||||
for key in r.scan_iter(RedisConnectorPermissionSync.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(
|
||||
RedisConnectorPermissionSync.GENERATOR_COMPLETE_PREFIX + "*"
|
||||
):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(
|
||||
RedisConnectorPermissionSync.GENERATOR_PROGRESS_PREFIX + "*"
|
||||
):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
@@ -1,133 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
class RedisConnectorExternalGroupSync:
|
||||
"""Manages interactions with redis for external group syncing tasks. Should only be accessed
|
||||
through RedisConnector."""
|
||||
|
||||
PREFIX = "connectorexternalgroupsync"
|
||||
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
|
||||
# phase 1 - geneartor task and progress signals
|
||||
GENERATORTASK_PREFIX = f"{PREFIX}+generator" # connectorexternalgroupsync+generator
|
||||
GENERATOR_PROGRESS_PREFIX = (
|
||||
PREFIX + "_generator_progress"
|
||||
) # connectorexternalgroupsync_generator_progress
|
||||
GENERATOR_COMPLETE_PREFIX = (
|
||||
PREFIX + "_generator_complete"
|
||||
) # connectorexternalgroupsync_generator_complete
|
||||
|
||||
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorexternalgroupsync_taskset
|
||||
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorexternalgroupsync+sub
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id = id
|
||||
self.redis = redis
|
||||
|
||||
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}"
|
||||
self.generator_task_key = f"{self.GENERATORTASK_PREFIX}_{id}"
|
||||
self.generator_progress_key = f"{self.GENERATOR_PROGRESS_PREFIX}_{id}"
|
||||
self.generator_complete_key = f"{self.GENERATOR_COMPLETE_PREFIX}_{id}"
|
||||
|
||||
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
|
||||
|
||||
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
|
||||
|
||||
def taskset_clear(self) -> None:
|
||||
self.redis.delete(self.taskset_key)
|
||||
|
||||
def generator_clear(self) -> None:
|
||||
self.redis.delete(self.generator_progress_key)
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
|
||||
def get_remaining(self) -> int:
|
||||
# todo: move into fence
|
||||
remaining = cast(int, self.redis.scard(self.taskset_key))
|
||||
return remaining
|
||||
|
||||
def get_active_task_count(self) -> int:
|
||||
"""Count of active external group syncing tasks"""
|
||||
count = 0
|
||||
for _ in self.redis.scan_iter(
|
||||
RedisConnectorExternalGroupSync.FENCE_PREFIX + "*"
|
||||
):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@property
|
||||
def fenced(self) -> bool:
|
||||
if self.redis.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def set_fence(self, value: bool) -> None:
|
||||
if not value:
|
||||
self.redis.delete(self.fence_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.fence_key, 0)
|
||||
|
||||
@property
|
||||
def generator_complete(self) -> int | None:
|
||||
"""the fence payload is an int representing the starting number of
|
||||
external group syncing tasks to be processed ... just after the generator completes.
|
||||
"""
|
||||
fence_bytes = self.redis.get(self.generator_complete_key)
|
||||
if fence_bytes is None:
|
||||
return None
|
||||
|
||||
if fence_bytes == b"None":
|
||||
return None
|
||||
|
||||
fence_int = int(cast(bytes, fence_bytes).decode())
|
||||
return fence_int
|
||||
|
||||
@generator_complete.setter
|
||||
def generator_complete(self, payload: int | None) -> None:
|
||||
"""Set the payload to an int to set the fence, otherwise if None it will
|
||||
be deleted"""
|
||||
if payload is None:
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.generator_complete_key, payload)
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
lock: redis.lock.Lock | None,
|
||||
) -> int | None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
|
||||
taskset_key = f"{RedisConnectorExternalGroupSync.TASKSET_PREFIX}_{id}"
|
||||
r.srem(taskset_key, task_id)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def reset_all(r: redis.Redis) -> None:
|
||||
"""Deletes all redis values for all connectors"""
|
||||
for key in r.scan_iter(RedisConnectorExternalGroupSync.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(
|
||||
RedisConnectorExternalGroupSync.GENERATOR_COMPLETE_PREFIX + "*"
|
||||
):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(
|
||||
RedisConnectorExternalGroupSync.GENERATOR_PROGRESS_PREFIX + "*"
|
||||
):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorExternalGroupSync.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
@@ -3,6 +3,7 @@ import json
|
||||
import os
|
||||
from typing import cast
|
||||
|
||||
from cohere import Client
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import default_public_access
|
||||
@@ -32,7 +33,7 @@ from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.server.documents.models import ConnectorBase
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.retry_wrapper import retry_builder
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from ee.danswer.configs.app_configs import COHERE_DEFAULT_API_KEY
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -91,18 +92,6 @@ def _create_indexable_chunks(
|
||||
return list(ids_to_documents.values()), chunks
|
||||
|
||||
|
||||
# Cohere is used in EE version
|
||||
def load_processed_docs(cohere_enabled: bool) -> list[dict]:
|
||||
initial_docs_path = os.path.join(
|
||||
os.getcwd(),
|
||||
"danswer",
|
||||
"seeding",
|
||||
"initial_docs.json",
|
||||
)
|
||||
processed_docs = json.load(open(initial_docs_path))
|
||||
return processed_docs
|
||||
|
||||
|
||||
def seed_initial_documents(
|
||||
db_session: Session, tenant_id: str | None, cohere_enabled: bool = False
|
||||
) -> None:
|
||||
@@ -188,10 +177,32 @@ def seed_initial_documents(
|
||||
last_successful_index_time=last_index_time,
|
||||
)
|
||||
cc_pair_id = cast(int, result.data)
|
||||
processed_docs = fetch_versioned_implementation(
|
||||
"danswer.seeding.load_docs",
|
||||
"load_processed_docs",
|
||||
)(cohere_enabled)
|
||||
|
||||
if cohere_enabled:
|
||||
initial_docs_path = os.path.join(
|
||||
os.getcwd(), "danswer", "seeding", "initial_docs_cohere.json"
|
||||
)
|
||||
|
||||
cohere_client = Client(COHERE_DEFAULT_API_KEY)
|
||||
processed_docs = json.load(open(initial_docs_path))
|
||||
for doc in processed_docs:
|
||||
title_embedding = cohere_client.embed(
|
||||
texts=[doc["title"]], model="embed-english-v3.0"
|
||||
).embeddings[0]
|
||||
content_embedding = cohere_client.embed(
|
||||
texts=[doc["content"]], model="embed-english-v3.0"
|
||||
).embeddings[0]
|
||||
doc["title_embedding"] = title_embedding
|
||||
doc["content_embedding"] = content_embedding
|
||||
|
||||
else:
|
||||
initial_docs_path = os.path.join(
|
||||
os.getcwd(),
|
||||
"danswer",
|
||||
"seeding",
|
||||
"initial_docs.json",
|
||||
)
|
||||
processed_docs = json.load(open(initial_docs_path))
|
||||
|
||||
docs, chunks = _create_indexable_chunks(processed_docs, tenant_id)
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from starlette.routing import BaseRoute
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_limited_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.auth.users import current_user_with_expired_token
|
||||
from danswer.configs.app_configs import APP_API_PREFIX
|
||||
@@ -103,8 +102,7 @@ def check_router_auth(
|
||||
for dependency in route_dependant_obj.dependencies:
|
||||
depends_fn = dependency.cache_key[0]
|
||||
if (
|
||||
depends_fn == current_limited_user
|
||||
or depends_fn == current_user
|
||||
depends_fn == current_user
|
||||
or depends_fn == current_admin_user
|
||||
or depends_fn == current_curator_or_admin_user
|
||||
or depends_fn == api_key_dep
|
||||
@@ -120,5 +118,5 @@ def check_router_auth(
|
||||
# print(f"(\"{route.path}\", {set(route.methods)}),")
|
||||
|
||||
raise RuntimeError(
|
||||
f"Did not find user dependency in private route - {route}"
|
||||
f"Did not find current_user or current_admin_user dependency in route - {route}"
|
||||
)
|
||||
|
||||
@@ -12,13 +12,13 @@ from sqlalchemy.orm import Session
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot
|
||||
from danswer.background.celery.tasks.doc_permission_syncing.tasks import (
|
||||
try_creating_permissions_sync_task,
|
||||
)
|
||||
from danswer.background.celery.tasks.pruning.tasks import (
|
||||
try_creating_prune_generator_task,
|
||||
)
|
||||
from danswer.background.celery.versioned_apps.primary import app as primary_app
|
||||
from danswer.background.task_name_builders import (
|
||||
name_sync_external_doc_permissions_task,
|
||||
)
|
||||
from danswer.db.connector_credential_pair import add_credential_to_connector
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.connector_credential_pair import remove_credential_from_connector
|
||||
@@ -26,7 +26,6 @@ from danswer.db.connector_credential_pair import (
|
||||
update_connector_credential_pair_from_id,
|
||||
)
|
||||
from danswer.db.document import get_document_counts_for_cc_pairs
|
||||
from danswer.db.document import get_documents_for_cc_pair
|
||||
from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from danswer.db.engine import get_current_tenant_id
|
||||
from danswer.db.engine import get_session
|
||||
@@ -39,13 +38,15 @@ from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
|
||||
from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
|
||||
from danswer.db.models import User
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.server.documents.models import CCPairFullInfo
|
||||
from danswer.server.documents.models import CCStatusUpdateRequest
|
||||
from danswer.server.documents.models import CeleryTaskStatus
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.server.documents.models import ConnectorCredentialPairMetadata
|
||||
from danswer.server.documents.models import DocumentSyncStatus
|
||||
from danswer.server.documents.models import PaginatedIndexAttempts
|
||||
from danswer.server.models import StatusResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -287,12 +288,12 @@ def prune_cc_pair(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/admin/cc-pair/{cc_pair_id}/sync-permissions")
|
||||
@router.get("/admin/cc-pair/{cc_pair_id}/sync")
|
||||
def get_cc_pair_latest_sync(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> datetime | None:
|
||||
) -> CeleryTaskStatus:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
@@ -302,20 +303,34 @@ def get_cc_pair_latest_sync(
|
||||
if not cc_pair:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="cc_pair not found for current user's permissions",
|
||||
detail="Connection not found for current user's permissions",
|
||||
)
|
||||
|
||||
return cc_pair.last_time_perm_sync
|
||||
# look up the last sync task for this connector (if it exists)
|
||||
sync_task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair_id)
|
||||
last_sync_task = get_latest_task(sync_task_name, db_session)
|
||||
if not last_sync_task:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
detail="No sync task found.",
|
||||
)
|
||||
|
||||
return CeleryTaskStatus(
|
||||
id=last_sync_task.task_id,
|
||||
name=last_sync_task.task_name,
|
||||
status=last_sync_task.status,
|
||||
start_time=last_sync_task.start_time,
|
||||
register_time=last_sync_task.register_time,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/admin/cc-pair/{cc_pair_id}/sync-permissions")
|
||||
@router.post("/admin/cc-pair/{cc_pair_id}/sync")
|
||||
def sync_cc_pair(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> StatusResponse[list[int]]:
|
||||
"""Triggers permissions sync on a particular cc_pair immediately"""
|
||||
# avoiding circular refs
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id,
|
||||
@@ -329,49 +344,37 @@ def sync_cc_pair(
|
||||
detail="Connection not found for current user's permissions",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
sync_task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair_id)
|
||||
last_sync_task = get_latest_task(sync_task_name, db_session)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if redis_connector.permissions.fenced:
|
||||
if last_sync_task and check_task_is_live_and_not_timed_out(
|
||||
last_sync_task, db_session
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.CONFLICT,
|
||||
detail="Doc permissions sync task already in progress.",
|
||||
detail="Sync task already in progress.",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Doc permissions sync cc_pair={cc_pair_id} "
|
||||
f"connector_id={cc_pair.connector_id} "
|
||||
f"credential_id={cc_pair.credential_id} "
|
||||
f"{cc_pair.connector.name} connector."
|
||||
logger.info(f"Syncing the {cc_pair.connector.name} connector.")
|
||||
sync_external_doc_permissions_task = fetch_ee_implementation_or_noop(
|
||||
"danswer.background.celery.apps.primary",
|
||||
"sync_external_doc_permissions_task",
|
||||
None,
|
||||
)
|
||||
tasks_created = try_creating_permissions_sync_task(
|
||||
primary_app, cc_pair_id, r, CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
)
|
||||
if not tasks_created:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
detail="Doc permissions sync task creation failed.",
|
||||
|
||||
if sync_external_doc_permissions_task:
|
||||
sync_external_doc_permissions_task.apply_async(
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair_id, tenant_id=CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
),
|
||||
)
|
||||
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
message="Successfully created the doc permissions sync task.",
|
||||
message="Successfully created the sync task.",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/admin/cc-pair/{cc_pair_id}/get-docs-sync-status")
|
||||
def get_docs_sync_status(
|
||||
cc_pair_id: int,
|
||||
_: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[DocumentSyncStatus]:
|
||||
all_docs_for_cc_pair = get_documents_for_cc_pair(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
return [DocumentSyncStatus.from_model(doc) for doc in all_docs_for_cc_pair]
|
||||
|
||||
|
||||
@router.put("/connector/{connector_id}/credential/{credential_id}")
|
||||
def associate_credential_to_connector(
|
||||
connector_id: int,
|
||||
@@ -387,7 +390,6 @@ def associate_credential_to_connector(
|
||||
user=user,
|
||||
target_group_ids=metadata.groups,
|
||||
object_is_public=metadata.access_type == AccessType.PUBLIC,
|
||||
object_is_perm_sync=metadata.access_type == AccessType.SYNC,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -81,6 +81,7 @@ from danswer.db.index_attempt import get_latest_index_attempts_by_status
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import UserRole
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
@@ -664,8 +665,7 @@ def create_connector_from_model(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
target_group_ids=connector_data.groups,
|
||||
object_is_public=connector_data.access_type == AccessType.PUBLIC,
|
||||
object_is_perm_sync=connector_data.access_type == AccessType.SYNC,
|
||||
object_is_public=connector_data.is_public,
|
||||
)
|
||||
connector_base = connector_data.to_connector_base()
|
||||
return create_connector(
|
||||
@@ -683,31 +683,32 @@ def create_connector_with_mock_credential(
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"danswer.db.user_group", "validate_user_creation_permissions", None
|
||||
)(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
target_group_ids=connector_data.groups,
|
||||
object_is_public=connector_data.access_type == AccessType.PUBLIC,
|
||||
object_is_perm_sync=connector_data.access_type == AccessType.SYNC,
|
||||
)
|
||||
if user and user.role != UserRole.ADMIN:
|
||||
if connector_data.is_public:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="User does not have permission to create public credentials",
|
||||
)
|
||||
if not connector_data.groups:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Curators must specify 1+ groups",
|
||||
)
|
||||
try:
|
||||
_validate_connector_allowed(connector_data.source)
|
||||
connector_response = create_connector(
|
||||
db_session=db_session,
|
||||
connector_data=connector_data,
|
||||
db_session=db_session, connector_data=connector_data
|
||||
)
|
||||
|
||||
mock_credential = CredentialBase(
|
||||
credential_json={},
|
||||
admin_public=True,
|
||||
source=connector_data.source,
|
||||
credential_json={}, admin_public=True, source=connector_data.source
|
||||
)
|
||||
credential = create_credential(
|
||||
credential_data=mock_credential,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
mock_credential, user=user, db_session=db_session
|
||||
)
|
||||
|
||||
access_type = (
|
||||
AccessType.PUBLIC if connector_data.is_public else AccessType.PRIVATE
|
||||
)
|
||||
|
||||
response = add_credential_to_connector(
|
||||
@@ -715,7 +716,7 @@ def create_connector_with_mock_credential(
|
||||
user=user,
|
||||
connector_id=cast(int, connector_response.id), # will aways be an int
|
||||
credential_id=credential.id,
|
||||
access_type=connector_data.access_type,
|
||||
access_type=access_type,
|
||||
cc_pair_name=connector_data.name,
|
||||
groups=connector_data.groups,
|
||||
)
|
||||
@@ -740,8 +741,7 @@ def update_connector_from_model(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
target_group_ids=connector_data.groups,
|
||||
object_is_public=connector_data.access_type == AccessType.PUBLIC,
|
||||
object_is_perm_sync=connector_data.access_type == AccessType.SYNC,
|
||||
object_is_public=connector_data.is_public,
|
||||
)
|
||||
connector_base = connector_data.to_connector_base()
|
||||
except ValueError as e:
|
||||
|
||||
@@ -14,7 +14,6 @@ from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import Document as DbDocument
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexAttemptError as DbIndexAttemptError
|
||||
from danswer.db.models import IndexingStatus
|
||||
@@ -22,20 +21,6 @@ from danswer.db.models import TaskStatus
|
||||
from danswer.server.utils import mask_credential_dict
|
||||
|
||||
|
||||
class DocumentSyncStatus(BaseModel):
|
||||
doc_id: str
|
||||
last_synced: datetime | None
|
||||
last_modified: datetime | None
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, doc: DbDocument) -> "DocumentSyncStatus":
|
||||
return DocumentSyncStatus(
|
||||
doc_id=doc.id,
|
||||
last_synced=doc.last_synced,
|
||||
last_modified=doc.last_modified,
|
||||
)
|
||||
|
||||
|
||||
class DocumentInfo(BaseModel):
|
||||
num_chunks: int
|
||||
num_tokens: int
|
||||
@@ -64,11 +49,11 @@ class ConnectorBase(BaseModel):
|
||||
|
||||
|
||||
class ConnectorUpdateRequest(ConnectorBase):
|
||||
access_type: AccessType
|
||||
is_public: bool = True
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
|
||||
def to_connector_base(self) -> ConnectorBase:
|
||||
return ConnectorBase(**self.model_dump(exclude={"access_type", "groups"}))
|
||||
return ConnectorBase(**self.model_dump(exclude={"is_public", "groups"}))
|
||||
|
||||
|
||||
class ConnectorSnapshot(ConnectorBase):
|
||||
@@ -237,8 +222,6 @@ class CCPairFullInfo(BaseModel):
|
||||
is_editable_for_current_user: bool
|
||||
deletion_failure_message: str | None
|
||||
indexing: bool
|
||||
creator: UUID | None
|
||||
creator_email: str | None
|
||||
|
||||
@classmethod
|
||||
def from_models(
|
||||
@@ -284,10 +267,6 @@ class CCPairFullInfo(BaseModel):
|
||||
is_editable_for_current_user=is_editable_for_current_user,
|
||||
deletion_failure_message=cc_pair_model.deletion_failure_message,
|
||||
indexing=indexing,
|
||||
creator=cc_pair_model.creator_id,
|
||||
creator_email=cc_pair_model.creator.email
|
||||
if cc_pair_model.creator
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_limited_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.configs.constants import NotificationType
|
||||
@@ -273,7 +272,7 @@ def list_personas(
|
||||
@basic_router.get("/{persona_id}")
|
||||
def get_persona(
|
||||
persona_id: int,
|
||||
user: User | None = Depends(current_limited_user),
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PersonaSnapshot:
|
||||
return PersonaSnapshot.from_model(
|
||||
|
||||
@@ -630,25 +630,31 @@ def update_user_assistant_list(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_assistant_visibility(
|
||||
def update_assistant_list(
|
||||
preferences: UserPreferences, assistant_id: int, show: bool
|
||||
) -> UserPreferences:
|
||||
visible_assistants = preferences.visible_assistants or []
|
||||
hidden_assistants = preferences.hidden_assistants or []
|
||||
chosen_assistants = preferences.chosen_assistants or []
|
||||
|
||||
if show:
|
||||
if assistant_id not in visible_assistants:
|
||||
visible_assistants.append(assistant_id)
|
||||
if assistant_id in hidden_assistants:
|
||||
hidden_assistants.remove(assistant_id)
|
||||
if assistant_id not in chosen_assistants:
|
||||
chosen_assistants.append(assistant_id)
|
||||
else:
|
||||
if assistant_id in visible_assistants:
|
||||
visible_assistants.remove(assistant_id)
|
||||
if assistant_id not in hidden_assistants:
|
||||
hidden_assistants.append(assistant_id)
|
||||
if assistant_id in chosen_assistants:
|
||||
chosen_assistants.remove(assistant_id)
|
||||
|
||||
preferences.visible_assistants = visible_assistants
|
||||
preferences.hidden_assistants = hidden_assistants
|
||||
preferences.chosen_assistants = chosen_assistants
|
||||
return preferences
|
||||
|
||||
|
||||
@@ -664,23 +670,15 @@ def update_user_assistant_visibility(
|
||||
store = get_kv_store()
|
||||
no_auth_user = fetch_no_auth_user(store)
|
||||
preferences = no_auth_user.preferences
|
||||
updated_preferences = update_assistant_visibility(
|
||||
preferences, assistant_id, show
|
||||
)
|
||||
if updated_preferences.chosen_assistants is not None:
|
||||
updated_preferences.chosen_assistants.append(assistant_id)
|
||||
|
||||
updated_preferences = update_assistant_list(preferences, assistant_id, show)
|
||||
set_no_auth_user_preferences(store, updated_preferences)
|
||||
return
|
||||
else:
|
||||
raise RuntimeError("This should never happen")
|
||||
|
||||
user_preferences = UserInfo.from_model(user).preferences
|
||||
updated_preferences = update_assistant_visibility(
|
||||
user_preferences, assistant_id, show
|
||||
)
|
||||
if updated_preferences.chosen_assistants is not None:
|
||||
updated_preferences.chosen_assistants.append(assistant_id)
|
||||
updated_preferences = update_assistant_list(user_preferences, assistant_id, show)
|
||||
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == user.id) # type: ignore
|
||||
|
||||
@@ -18,7 +18,6 @@ from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_limited_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.chat.chat_utils import create_chat_chain
|
||||
from danswer.chat.chat_utils import extract_headers
|
||||
@@ -310,7 +309,7 @@ async def is_connected(request: Request) -> Callable[[], bool]:
|
||||
def handle_new_chat_message(
|
||||
chat_message_req: CreateChatMessageRequest,
|
||||
request: Request,
|
||||
user: User | None = Depends(current_limited_user),
|
||||
user: User | None = Depends(current_user),
|
||||
_: None = Depends(check_token_rate_limits),
|
||||
is_connected_func: Callable[[], bool] = Depends(is_connected),
|
||||
) -> StreamingResponse:
|
||||
@@ -392,7 +391,7 @@ def set_message_as_latest(
|
||||
@router.post("/create-chat-message-feedback")
|
||||
def create_chat_feedback(
|
||||
feedback: ChatFeedbackRequest,
|
||||
user: User | None = Depends(current_limited_user),
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
user_id = user.id if user else None
|
||||
|
||||
@@ -9,7 +9,6 @@ from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_limited_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MessageType
|
||||
@@ -263,7 +262,7 @@ def stream_query_validation(
|
||||
@basic_router.post("/stream-answer-with-quote")
|
||||
def get_answer_with_quote(
|
||||
query_request: DirectQARequest,
|
||||
user: User = Depends(current_limited_user),
|
||||
user: User = Depends(current_user),
|
||||
_: None = Depends(check_token_rate_limits),
|
||||
) -> StreamingResponse:
|
||||
query = query_request.messages[0].message
|
||||
|
||||
@@ -21,12 +21,8 @@ pruning_ctx: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar(
|
||||
"pruning_ctx", default=dict()
|
||||
)
|
||||
|
||||
doc_permission_sync_ctx: contextvars.ContextVar[
|
||||
dict[str, Any]
|
||||
] = contextvars.ContextVar("doc_permission_sync_ctx", default=dict())
|
||||
|
||||
|
||||
class TaskAttemptSingleton:
|
||||
class IndexAttemptSingleton:
|
||||
"""Used to tell if this process is an indexing job, and if so what is the
|
||||
unique identifier for this indexing attempt. For things like the API server,
|
||||
main background job (scheduler), etc. this will not be used."""
|
||||
@@ -70,10 +66,9 @@ class DanswerLoggingAdapter(logging.LoggerAdapter):
|
||||
) -> tuple[str, MutableMapping[str, Any]]:
|
||||
# If this is an indexing job, add the attempt ID to the log message
|
||||
# This helps filter the logs for this specific indexing
|
||||
index_attempt_id = TaskAttemptSingleton.get_index_attempt_id()
|
||||
cc_pair_id = TaskAttemptSingleton.get_connector_credential_pair_id()
|
||||
index_attempt_id = IndexAttemptSingleton.get_index_attempt_id()
|
||||
cc_pair_id = IndexAttemptSingleton.get_connector_credential_pair_id()
|
||||
|
||||
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
|
||||
pruning_ctx_dict = pruning_ctx.get()
|
||||
if len(pruning_ctx_dict) > 0:
|
||||
if "request_id" in pruning_ctx_dict:
|
||||
@@ -81,9 +76,6 @@ class DanswerLoggingAdapter(logging.LoggerAdapter):
|
||||
|
||||
if "cc_pair_id" in pruning_ctx_dict:
|
||||
msg = f"[CC Pair: {pruning_ctx_dict['cc_pair_id']}] {msg}"
|
||||
elif len(doc_permission_sync_ctx_dict) > 0:
|
||||
if "request_id" in doc_permission_sync_ctx_dict:
|
||||
msg = f"[Doc Permissions Sync: {doc_permission_sync_ctx_dict['request_id']}] {msg}"
|
||||
else:
|
||||
if index_attempt_id is not None:
|
||||
msg = f"[Index Attempt: {index_attempt_id}] {msg}"
|
||||
|
||||
@@ -1,12 +1,32 @@
|
||||
from danswer.background.celery.apps.primary import celery_app
|
||||
from danswer.background.task_name_builders import (
|
||||
name_sync_external_doc_permissions_task,
|
||||
)
|
||||
from danswer.background.task_utils import build_celery_task_wrapper
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.db.chat import delete_chat_sessions_older_than
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.server.settings.store import load_settings
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from ee.danswer.background.celery_utils import should_perform_chat_ttl_check
|
||||
from ee.danswer.background.celery_utils import (
|
||||
should_perform_external_doc_permissions_check,
|
||||
)
|
||||
from ee.danswer.background.celery_utils import (
|
||||
should_perform_external_group_permissions_check,
|
||||
)
|
||||
from ee.danswer.background.task_name_builders import name_chat_ttl_task
|
||||
from ee.danswer.background.task_name_builders import (
|
||||
name_sync_external_group_permissions_task,
|
||||
)
|
||||
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.danswer.external_permissions.permission_sync import (
|
||||
run_external_doc_permission_sync,
|
||||
)
|
||||
from ee.danswer.external_permissions.permission_sync import (
|
||||
run_external_group_permission_sync,
|
||||
)
|
||||
from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
@@ -14,6 +34,25 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
logger = setup_logger()
|
||||
|
||||
# mark as EE for all tasks in this file
|
||||
global_version.set_ee()
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_sync_external_doc_permissions_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def sync_external_doc_permissions_task(
|
||||
cc_pair_id: int, *, tenant_id: str | None
|
||||
) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
run_external_doc_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_sync_external_group_permissions_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def sync_external_group_permissions_task(
|
||||
cc_pair_id: int, *, tenant_id: str | None
|
||||
) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
run_external_group_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_chat_ttl_task)
|
||||
@@ -28,6 +67,38 @@ def perform_ttl_management_task(
|
||||
#####
|
||||
# Periodic Tasks
|
||||
#####
|
||||
@celery_app.task(
|
||||
name="check_sync_external_doc_permissions_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_sync_external_doc_permissions_task(*, tenant_id: str | None) -> None:
|
||||
"""Runs periodically to sync external permissions"""
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
if should_perform_external_doc_permissions_check(
|
||||
cc_pair=cc_pair, db_session=db_session
|
||||
):
|
||||
sync_external_doc_permissions_task.apply_async(
|
||||
kwargs=dict(cc_pair_id=cc_pair.id, tenant_id=tenant_id),
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="check_sync_external_group_permissions_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_sync_external_group_permissions_task(*, tenant_id: str | None) -> None:
|
||||
"""Runs periodically to sync external group permissions"""
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
if should_perform_external_group_permissions_check(
|
||||
cc_pair=cc_pair, db_session=db_session
|
||||
):
|
||||
sync_external_group_permissions_task.apply_async(
|
||||
kwargs=dict(cc_pair_id=cc_pair.id, tenant_id=tenant_id),
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
|
||||
@@ -6,6 +6,16 @@ from danswer.background.celery.tasks.beat_schedule import (
|
||||
)
|
||||
|
||||
ee_tasks_to_schedule = [
|
||||
{
|
||||
"name": "sync-external-doc-permissions",
|
||||
"task": "check_sync_external_doc_permissions_task",
|
||||
"schedule": timedelta(seconds=30), # TODO: optimize this
|
||||
},
|
||||
{
|
||||
"name": "sync-external-group-permissions",
|
||||
"task": "check_sync_external_group_permissions_task",
|
||||
"schedule": timedelta(seconds=60), # TODO: optimize this
|
||||
},
|
||||
{
|
||||
"name": "autogenerate_usage_report",
|
||||
"task": "autogenerate_usage_report_task",
|
||||
|
||||
@@ -1,13 +1,46 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.task_name_builders import (
|
||||
name_sync_external_doc_permissions_task,
|
||||
)
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.background.task_name_builders import name_chat_ttl_task
|
||||
from ee.danswer.background.task_name_builders import (
|
||||
name_sync_external_group_permissions_task,
|
||||
)
|
||||
from ee.danswer.external_permissions.sync_params import PERMISSION_SYNC_PERIODS
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _is_time_to_run_sync(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
source_sync_period = PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
|
||||
|
||||
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
|
||||
if not source_sync_period:
|
||||
return True
|
||||
|
||||
# If the last sync is None, it has never been run so we run the sync
|
||||
if cc_pair.last_time_perm_sync is None:
|
||||
return True
|
||||
|
||||
last_sync = cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc)
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
# If the last sync is greater than the full fetch period, we run the sync
|
||||
if (current_time - last_sync).total_seconds() > source_sync_period:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def should_perform_chat_ttl_check(
|
||||
retention_limit_days: int | None, db_session: Session
|
||||
) -> bool:
|
||||
@@ -24,3 +57,47 @@ def should_perform_chat_ttl_check(
|
||||
logger.debug(f"{task_name} is already being performed. Skipping.")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def should_perform_external_doc_permissions_check(
|
||||
cc_pair: ConnectorCredentialPair, db_session: Session
|
||||
) -> bool:
|
||||
if cc_pair.access_type != AccessType.SYNC:
|
||||
return False
|
||||
|
||||
task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair.id)
|
||||
|
||||
latest_task = get_latest_task(task_name, db_session)
|
||||
if not latest_task:
|
||||
return True
|
||||
|
||||
if check_task_is_live_and_not_timed_out(latest_task, db_session):
|
||||
logger.debug(f"{task_name} is already being performed. Skipping.")
|
||||
return False
|
||||
|
||||
if not _is_time_to_run_sync(cc_pair):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def should_perform_external_group_permissions_check(
|
||||
cc_pair: ConnectorCredentialPair, db_session: Session
|
||||
) -> bool:
|
||||
if cc_pair.access_type != AccessType.SYNC:
|
||||
return False
|
||||
|
||||
task_name = name_sync_external_group_permissions_task(cc_pair_id=cc_pair.id)
|
||||
|
||||
latest_task = get_latest_task(task_name, db_session)
|
||||
if not latest_task:
|
||||
return True
|
||||
|
||||
if check_task_is_live_and_not_timed_out(latest_task, db_session):
|
||||
logger.debug(f"{task_name} is already being performed. Skipping.")
|
||||
return False
|
||||
|
||||
if not _is_time_to_run_sync(cc_pair):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@@ -1,2 +1,8 @@
|
||||
def name_chat_ttl_task(retention_limit_days: int, tenant_id: str | None = None) -> str:
|
||||
return f"chat_ttl_{retention_limit_days}_days"
|
||||
|
||||
|
||||
def name_sync_external_group_permissions_task(
|
||||
cc_pair_id: int, tenant_id: str | None = None
|
||||
) -> str:
|
||||
return f"sync_external_group_permissions_task__{cc_pair_id}"
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -48,53 +45,3 @@ def upsert_document_external_perms__no_commit(
|
||||
document.external_user_emails = list(external_access.external_user_emails)
|
||||
document.external_user_group_ids = prefixed_external_groups
|
||||
document.is_public = external_access.is_public
|
||||
|
||||
|
||||
def upsert_document_external_perms(
|
||||
db_session: Session,
|
||||
doc_id: str,
|
||||
external_access: ExternalAccess,
|
||||
source_type: DocumentSource,
|
||||
) -> None:
|
||||
"""
|
||||
This sets the permissions for a document in postgres.
|
||||
NOTE: this will replace any existing external access, it will not do a union
|
||||
"""
|
||||
document = db_session.scalars(
|
||||
select(DbDocument).where(DbDocument.id == doc_id)
|
||||
).first()
|
||||
|
||||
prefixed_external_groups: set[str] = {
|
||||
prefix_group_w_source(
|
||||
ext_group_name=group_id,
|
||||
source=source_type,
|
||||
)
|
||||
for group_id in external_access.external_user_group_ids
|
||||
}
|
||||
|
||||
if not document:
|
||||
# If the document does not exist, still store the external access
|
||||
# So that if the document is added later, the external access is already stored
|
||||
# The upsert function in the indexing pipeline does not overwrite the permissions fields
|
||||
document = DbDocument(
|
||||
id=doc_id,
|
||||
semantic_id="",
|
||||
external_user_emails=external_access.external_user_emails,
|
||||
external_user_group_ids=prefixed_external_groups,
|
||||
is_public=external_access.is_public,
|
||||
)
|
||||
db_session.add(document)
|
||||
db_session.commit()
|
||||
return
|
||||
|
||||
# If the document exists, we need to check if the external access has changed
|
||||
if (
|
||||
external_access.external_user_emails != set(document.external_user_emails or [])
|
||||
or prefixed_external_groups != set(document.external_user_group_ids or [])
|
||||
or external_access.is_public != document.is_public
|
||||
):
|
||||
document.external_user_emails = list(external_access.external_user_emails)
|
||||
document.external_user_group_ids = list(prefixed_external_groups)
|
||||
document.is_public = external_access.is_public
|
||||
document.last_modified = datetime.now(timezone.utc)
|
||||
db_session.commit()
|
||||
|
||||
@@ -9,12 +9,11 @@ from sqlalchemy.orm import Session
|
||||
from danswer.access.utils import prefix_group_w_source
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.models import User__ExternalUserGroupId
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
|
||||
|
||||
class ExternalUserGroup(BaseModel):
|
||||
id: str
|
||||
user_emails: list[str]
|
||||
user_ids: list[UUID]
|
||||
|
||||
|
||||
def delete_user__ext_group_for_user__no_commit(
|
||||
@@ -39,7 +38,7 @@ def delete_user__ext_group_for_cc_pair__no_commit(
|
||||
)
|
||||
|
||||
|
||||
def replace_user__ext_group_for_cc_pair(
|
||||
def replace_user__ext_group_for_cc_pair__no_commit(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
group_defs: list[ExternalUserGroup],
|
||||
@@ -47,44 +46,24 @@ def replace_user__ext_group_for_cc_pair(
|
||||
) -> None:
|
||||
"""
|
||||
This function clears all existing external user group relations for a given cc_pair_id
|
||||
and replaces them with the new group definitions and commits the changes.
|
||||
and replaces them with the new group definitions.
|
||||
"""
|
||||
delete_user__ext_group_for_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
# collect all emails from all groups to batch add all users at once for efficiency
|
||||
all_group_member_emails = set()
|
||||
for external_group in group_defs:
|
||||
for user_email in external_group.user_emails:
|
||||
all_group_member_emails.add(user_email)
|
||||
|
||||
# batch add users if they don't exist and get their ids
|
||||
all_group_members = batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session=db_session, emails=list(all_group_member_emails)
|
||||
)
|
||||
|
||||
# map emails to ids
|
||||
email_id_map = {user.email: user.id for user in all_group_members}
|
||||
|
||||
# use these ids to create new external user group relations relating group_id to user_ids
|
||||
new_external_permissions = []
|
||||
for external_group in group_defs:
|
||||
for user_email in external_group.user_emails:
|
||||
user_id = email_id_map[user_email]
|
||||
new_external_permissions.append(
|
||||
User__ExternalUserGroupId(
|
||||
user_id=user_id,
|
||||
external_user_group_id=prefix_group_w_source(
|
||||
external_group.id, source
|
||||
),
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
)
|
||||
new_external_permissions = [
|
||||
User__ExternalUserGroupId(
|
||||
user_id=user_id,
|
||||
external_user_group_id=prefix_group_w_source(external_group.id, source),
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
for external_group in group_defs
|
||||
for user_id in external_group.user_ids
|
||||
]
|
||||
|
||||
db_session.add_all(new_external_permissions)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def fetch_external_groups_for_user(
|
||||
|
||||
@@ -124,21 +124,16 @@ def _cleanup_document_set__user_group_relationships__no_commit(
|
||||
def validate_user_creation_permissions(
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
target_group_ids: list[int] | None = None,
|
||||
object_is_public: bool | None = None,
|
||||
object_is_perm_sync: bool | None = None,
|
||||
target_group_ids: list[int] | None,
|
||||
object_is_public: bool | None,
|
||||
) -> None:
|
||||
"""
|
||||
All users can create/edit permission synced objects if they don't specify a group
|
||||
All admin actions are allowed.
|
||||
Prevents non-admins from creating/editing:
|
||||
- public objects
|
||||
- objects with no groups
|
||||
- objects that belong to a group they don't curate
|
||||
"""
|
||||
if object_is_perm_sync and not target_group_ids:
|
||||
return
|
||||
|
||||
if not user or user.role == UserRole.ADMIN:
|
||||
return
|
||||
|
||||
|
||||
@@ -4,14 +4,17 @@ https://confluence.atlassian.com/conf85/check-who-can-view-a-page-1283360557.htm
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
from danswer.access.models import DocExternalAccess
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import ExternalAccess
|
||||
from danswer.connectors.confluence.connector import ConfluenceConnector
|
||||
from danswer.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from danswer.connectors.confluence.utils import get_user_email_from_username__server
|
||||
from danswer.connectors.models import SlimDocument
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.document import upsert_document_external_perms__no_commit
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -160,13 +163,7 @@ def _extract_read_access_restrictions(
|
||||
f"Email for user {user['username']} not found in Confluence"
|
||||
)
|
||||
else:
|
||||
if user.get("email") is not None:
|
||||
logger.warning(f"Cant find email for user {user.get('displayName')}")
|
||||
logger.warning(
|
||||
"This user needs to make their email accessible in Confluence Settings"
|
||||
)
|
||||
|
||||
logger.warning(f"no user email or username for {user}")
|
||||
logger.warning(f"User {user} does not have an email or username")
|
||||
|
||||
# Extract the groups with read access
|
||||
read_access_group = read_access_restrictions.get("group", {})
|
||||
@@ -193,12 +190,12 @@ def _fetch_all_page_restrictions_for_space(
|
||||
confluence_client: OnyxConfluence,
|
||||
slim_docs: list[SlimDocument],
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess],
|
||||
) -> list[DocExternalAccess]:
|
||||
) -> dict[str, ExternalAccess]:
|
||||
"""
|
||||
For all pages, if a page has restrictions, then use those restrictions.
|
||||
Otherwise, use the space's restrictions.
|
||||
"""
|
||||
document_restrictions: list[DocExternalAccess] = []
|
||||
document_restrictions: dict[str, ExternalAccess] = {}
|
||||
|
||||
for slim_doc in slim_docs:
|
||||
if slim_doc.perm_sync_data is None:
|
||||
@@ -210,34 +207,21 @@ def _fetch_all_page_restrictions_for_space(
|
||||
restrictions=slim_doc.perm_sync_data.get("restrictions", {}),
|
||||
)
|
||||
if restrictions:
|
||||
document_restrictions.append(
|
||||
DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=restrictions,
|
||||
)
|
||||
)
|
||||
# If there are restrictions, then we don't need to use the space's restrictions
|
||||
continue
|
||||
|
||||
space_key = slim_doc.perm_sync_data.get("space_key")
|
||||
if space_permissions := space_permissions_by_space_key.get(space_key):
|
||||
# If there are no restrictions, then use the space's restrictions
|
||||
document_restrictions.append(
|
||||
DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=space_permissions,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||
document_restrictions[slim_doc.id] = restrictions
|
||||
else:
|
||||
space_key = slim_doc.perm_sync_data.get("space_key")
|
||||
if space_permissions := space_permissions_by_space_key.get(space_key):
|
||||
document_restrictions[slim_doc.id] = space_permissions
|
||||
else:
|
||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||
|
||||
return document_restrictions
|
||||
|
||||
|
||||
def confluence_doc_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[DocExternalAccess]:
|
||||
) -> None:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
if the document doesn't already exists in postgres, we create
|
||||
@@ -263,8 +247,20 @@ def confluence_doc_sync(
|
||||
for doc_batch in confluence_connector.retrieve_all_slim_documents():
|
||||
slim_docs.extend(doc_batch)
|
||||
|
||||
return _fetch_all_page_restrictions_for_space(
|
||||
permissions_by_doc_id = _fetch_all_page_restrictions_for_space(
|
||||
confluence_client=confluence_client,
|
||||
slim_docs=slim_docs,
|
||||
space_permissions_by_space_key=space_permissions_by_space_key,
|
||||
)
|
||||
|
||||
all_emails = set()
|
||||
for doc_id, page_specific_access in permissions_by_doc_id.items():
|
||||
upsert_document_external_perms__no_commit(
|
||||
db_session=db_session,
|
||||
doc_id=doc_id,
|
||||
external_access=page_specific_access,
|
||||
source_type=cc_pair.connector.source,
|
||||
)
|
||||
all_emails.update(page_specific_access.external_user_emails)
|
||||
|
||||
batch_add_non_web_user_if_not_exists__no_commit(db_session, list(all_emails))
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from danswer.connectors.confluence.utils import build_confluence_client
|
||||
from danswer.connectors.confluence.utils import get_user_email_from_username__server
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.external_perm import ExternalUserGroup
|
||||
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -36,8 +40,9 @@ def _get_group_members_email_paginated(
|
||||
|
||||
|
||||
def confluence_group_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[ExternalUserGroup]:
|
||||
) -> None:
|
||||
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
|
||||
confluence_client = build_confluence_client(
|
||||
credentials_json=cc_pair.credential.credential_json,
|
||||
@@ -58,13 +63,20 @@ def confluence_group_sync(
|
||||
group_member_emails = _get_group_members_email_paginated(
|
||||
confluence_client, group_name
|
||||
)
|
||||
if not group_member_emails:
|
||||
continue
|
||||
danswer_groups.append(
|
||||
ExternalUserGroup(
|
||||
id=group_name,
|
||||
user_emails=list(group_member_emails),
|
||||
)
|
||||
group_members = batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session=db_session, emails=list(group_member_emails)
|
||||
)
|
||||
if group_members:
|
||||
danswer_groups.append(
|
||||
ExternalUserGroup(
|
||||
id=group_name,
|
||||
user_ids=[user.id for user in group_members],
|
||||
)
|
||||
)
|
||||
|
||||
return danswer_groups
|
||||
replace_user__ext_group_for_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
group_defs=danswer_groups,
|
||||
source=cc_pair.connector.source,
|
||||
)
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from danswer.access.models import DocExternalAccess
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import ExternalAccess
|
||||
from danswer.connectors.gmail.connector import GmailConnector
|
||||
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.document import upsert_document_external_perms__no_commit
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -28,8 +31,9 @@ def _get_slim_doc_generator(
|
||||
|
||||
|
||||
def gmail_doc_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[DocExternalAccess]:
|
||||
) -> None:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
if the document doesn't already exists in postgres, we create
|
||||
@@ -41,7 +45,6 @@ def gmail_doc_sync(
|
||||
|
||||
slim_doc_generator = _get_slim_doc_generator(cc_pair, gmail_connector)
|
||||
|
||||
document_external_access: list[DocExternalAccess] = []
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
for slim_doc in slim_doc_batch:
|
||||
if slim_doc.perm_sync_data is None:
|
||||
@@ -53,11 +56,13 @@ def gmail_doc_sync(
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
)
|
||||
document_external_access.append(
|
||||
DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=ext_access,
|
||||
)
|
||||
batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session=db_session,
|
||||
emails=list(ext_access.external_user_emails),
|
||||
)
|
||||
upsert_document_external_perms__no_commit(
|
||||
db_session=db_session,
|
||||
doc_id=slim_doc.id,
|
||||
external_access=ext_access,
|
||||
source_type=cc_pair.connector.source,
|
||||
)
|
||||
|
||||
return document_external_access
|
||||
|
||||
@@ -2,7 +2,8 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from danswer.access.models import DocExternalAccess
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import ExternalAccess
|
||||
from danswer.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from danswer.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
@@ -10,7 +11,9 @@ from danswer.connectors.google_utils.resources import get_drive_service
|
||||
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from danswer.connectors.models import SlimDocument
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.document import upsert_document_external_perms__no_commit
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -110,13 +113,8 @@ def _get_permissions_from_slim_doc(
|
||||
elif permission_type == "group":
|
||||
group_emails.add(permission["emailAddress"])
|
||||
elif permission_type == "domain" and company_domain:
|
||||
if permission.get("domain") == company_domain:
|
||||
if permission["domain"] == company_domain:
|
||||
public = True
|
||||
else:
|
||||
logger.warning(
|
||||
"Permission is type domain but does not match company domain:"
|
||||
f"\n {permission}"
|
||||
)
|
||||
elif permission_type == "anyone":
|
||||
public = True
|
||||
|
||||
@@ -128,8 +126,9 @@ def _get_permissions_from_slim_doc(
|
||||
|
||||
|
||||
def gdrive_doc_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[DocExternalAccess]:
|
||||
) -> None:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
if the document doesn't already exists in postgres, we create
|
||||
@@ -143,17 +142,19 @@ def gdrive_doc_sync(
|
||||
|
||||
slim_doc_generator = _get_slim_doc_generator(cc_pair, google_drive_connector)
|
||||
|
||||
document_external_accesses = []
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
for slim_doc in slim_doc_batch:
|
||||
ext_access = _get_permissions_from_slim_doc(
|
||||
google_drive_connector=google_drive_connector,
|
||||
slim_doc=slim_doc,
|
||||
)
|
||||
document_external_accesses.append(
|
||||
DocExternalAccess(
|
||||
external_access=ext_access,
|
||||
doc_id=slim_doc.id,
|
||||
)
|
||||
batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session=db_session,
|
||||
emails=list(ext_access.external_user_emails),
|
||||
)
|
||||
upsert_document_external_perms__no_commit(
|
||||
db_session=db_session,
|
||||
doc_id=slim_doc.id,
|
||||
external_access=ext_access,
|
||||
source_type=cc_pair.connector.source,
|
||||
)
|
||||
return document_external_accesses
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from danswer.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
from danswer.connectors.google_utils.resources import get_admin_service
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.external_perm import ExternalUserGroup
|
||||
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def gdrive_group_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[ExternalUserGroup]:
|
||||
) -> None:
|
||||
google_drive_connector = GoogleDriveConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
@@ -39,14 +44,20 @@ def gdrive_group_sync(
|
||||
):
|
||||
group_member_emails.append(member["email"])
|
||||
|
||||
if not group_member_emails:
|
||||
continue
|
||||
|
||||
danswer_groups.append(
|
||||
ExternalUserGroup(
|
||||
id=group_email,
|
||||
user_emails=list(group_member_emails),
|
||||
)
|
||||
# Add group members to DB and get their IDs
|
||||
group_members = batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session=db_session, emails=group_member_emails
|
||||
)
|
||||
if group_members:
|
||||
danswer_groups.append(
|
||||
ExternalUserGroup(
|
||||
id=group_email, user_ids=[user.id for user in group_members]
|
||||
)
|
||||
)
|
||||
|
||||
return danswer_groups
|
||||
replace_user__ext_group_for_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
group_defs=danswer_groups,
|
||||
source=cc_pair.connector.source,
|
||||
)
|
||||
|
||||
115
backend/ee/danswer/external_permissions/permission_sync.py
Normal file
115
backend/ee/danswer/external_permissions/permission_sync.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_documents
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import get_document_ids_for_connector_credential_pair
|
||||
from danswer.document_index.factory import get_current_primary_default_document_index
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
|
||||
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def run_external_group_permission_sync(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
) -> None:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"No connector credential pair found for id: {cc_pair_id}")
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
|
||||
if group_sync_func is None:
|
||||
# Not all sync connectors support group permissions so this is fine
|
||||
return
|
||||
|
||||
try:
|
||||
# This function updates:
|
||||
# - the user_email <-> external_user_group_id mapping
|
||||
# in postgres without committing
|
||||
logger.debug(f"Syncing groups for {source_type}")
|
||||
if group_sync_func is not None:
|
||||
group_sync_func(
|
||||
db_session,
|
||||
cc_pair,
|
||||
)
|
||||
|
||||
# update postgres
|
||||
db_session.commit()
|
||||
except Exception:
|
||||
logger.exception("Error Syncing Group Permissions")
|
||||
db_session.rollback()
|
||||
|
||||
|
||||
def run_external_doc_permission_sync(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
) -> None:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"No connector credential pair found for id: {cc_pair_id}")
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
|
||||
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
last_time_perm_sync = cc_pair.last_time_perm_sync
|
||||
|
||||
if doc_sync_func is None:
|
||||
raise ValueError(
|
||||
f"No permission sync function found for source type: {source_type}"
|
||||
)
|
||||
|
||||
try:
|
||||
# This function updates:
|
||||
# - the user_email <-> document mapping
|
||||
# - the external_user_group_id <-> document mapping
|
||||
# in postgres without committing
|
||||
logger.info(f"Syncing docs for {source_type}")
|
||||
doc_sync_func(
|
||||
db_session,
|
||||
cc_pair,
|
||||
)
|
||||
|
||||
# Get the document ids for the cc pair
|
||||
document_ids_for_cc_pair = get_document_ids_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
|
||||
# This function fetches the updated access for the documents
|
||||
# and returns a dictionary of document_ids and access
|
||||
# This is the access we want to update vespa with
|
||||
docs_access = get_access_for_documents(
|
||||
document_ids=document_ids_for_cc_pair,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Then we build the update requests to update vespa
|
||||
update_reqs = [
|
||||
UpdateRequest(document_ids=[doc_id], access=doc_access)
|
||||
for doc_id, doc_access in docs_access.items()
|
||||
]
|
||||
|
||||
# Don't bother sync-ing secondary, it will be sync-ed after switch anyway
|
||||
document_index = get_current_primary_default_document_index(db_session)
|
||||
|
||||
# update vespa
|
||||
document_index.update(update_reqs)
|
||||
|
||||
cc_pair.last_time_perm_sync = datetime.now(timezone.utc)
|
||||
|
||||
# update postgres
|
||||
db_session.commit()
|
||||
logger.info(f"Successfully synced docs for {source_type}")
|
||||
except Exception:
|
||||
logger.exception("Error Syncing Document Permissions")
|
||||
cc_pair.last_time_perm_sync = last_time_perm_sync
|
||||
db_session.rollback()
|
||||
@@ -1,12 +1,16 @@
|
||||
from slack_sdk import WebClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import DocExternalAccess
|
||||
from danswer.access.models import ExternalAccess
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.interfaces import SlimConnector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.connectors.slack.connector import get_channels
|
||||
from danswer.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
from danswer.connectors.slack.connector import SlackPollConnector
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.document import upsert_document_external_perms__no_commit
|
||||
from ee.danswer.external_permissions.slack.utils import fetch_user_id_to_email_map
|
||||
|
||||
|
||||
@@ -14,15 +18,22 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def _get_slack_document_ids_and_channels(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> dict[str, list[str]]:
|
||||
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
# Get all document ids that need their permissions updated
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session=db_session,
|
||||
source=cc_pair.connector.source,
|
||||
input_type=InputType.SLIM_RETRIEVAL,
|
||||
connector_specific_config=cc_pair.connector.connector_specific_config,
|
||||
credential=cc_pair.credential,
|
||||
)
|
||||
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_documents()
|
||||
assert isinstance(runnable_connector, SlimConnector)
|
||||
|
||||
channel_doc_map: dict[str, list[str]] = {}
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
for doc_metadata_batch in runnable_connector.retrieve_all_slim_documents():
|
||||
for doc_metadata in doc_metadata_batch:
|
||||
if doc_metadata.perm_sync_data is None:
|
||||
continue
|
||||
@@ -35,11 +46,13 @@ def _get_slack_document_ids_and_channels(
|
||||
|
||||
|
||||
def _fetch_workspace_permissions(
|
||||
db_session: Session,
|
||||
user_id_to_email_map: dict[str, str],
|
||||
) -> ExternalAccess:
|
||||
user_emails = set()
|
||||
for email in user_id_to_email_map.values():
|
||||
user_emails.add(email)
|
||||
batch_add_non_web_user_if_not_exists__no_commit(db_session, list(user_emails))
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
# No group<->document mapping for slack
|
||||
@@ -50,6 +63,7 @@ def _fetch_workspace_permissions(
|
||||
|
||||
|
||||
def _fetch_channel_permissions(
|
||||
db_session: Session,
|
||||
slack_client: WebClient,
|
||||
workspace_permissions: ExternalAccess,
|
||||
user_id_to_email_map: dict[str, str],
|
||||
@@ -99,6 +113,9 @@ def _fetch_channel_permissions(
|
||||
# If no email is found, we skip the user
|
||||
continue
|
||||
user_id_to_email_map[member_id] = member_email
|
||||
batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session, [member_email]
|
||||
)
|
||||
|
||||
member_emails.add(member_email)
|
||||
|
||||
@@ -114,8 +131,9 @@ def _fetch_channel_permissions(
|
||||
|
||||
|
||||
def slack_doc_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[DocExternalAccess]:
|
||||
) -> None:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
if the document doesn't already exists in postgres, we create
|
||||
@@ -127,18 +145,19 @@ def slack_doc_sync(
|
||||
)
|
||||
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
|
||||
channel_doc_map = _get_slack_document_ids_and_channels(
|
||||
db_session=db_session,
|
||||
cc_pair=cc_pair,
|
||||
)
|
||||
workspace_permissions = _fetch_workspace_permissions(
|
||||
db_session=db_session,
|
||||
user_id_to_email_map=user_id_to_email_map,
|
||||
)
|
||||
channel_permissions = _fetch_channel_permissions(
|
||||
db_session=db_session,
|
||||
slack_client=slack_client,
|
||||
workspace_permissions=workspace_permissions,
|
||||
user_id_to_email_map=user_id_to_email_map,
|
||||
)
|
||||
|
||||
document_external_accesses = []
|
||||
for channel_id, ext_access in channel_permissions.items():
|
||||
doc_ids = channel_doc_map.get(channel_id)
|
||||
if not doc_ids:
|
||||
@@ -146,10 +165,9 @@ def slack_doc_sync(
|
||||
continue
|
||||
|
||||
for doc_id in doc_ids:
|
||||
document_external_accesses.append(
|
||||
DocExternalAccess(
|
||||
external_access=ext_access,
|
||||
doc_id=doc_id,
|
||||
)
|
||||
upsert_document_external_perms__no_commit(
|
||||
db_session=db_session,
|
||||
doc_id=doc_id,
|
||||
external_access=ext_access,
|
||||
source_type=cc_pair.connector.source,
|
||||
)
|
||||
return document_external_accesses
|
||||
|
||||
@@ -5,11 +5,14 @@ SO WHEN CHECKING IF A USER CAN ACCESS A DOCUMENT, WE ONLY NEED TO CHECK THEIR EM
|
||||
THERE IS NO USERGROUP <-> DOCUMENT PERMISSION MAPPING
|
||||
"""
|
||||
from slack_sdk import WebClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.external_perm import ExternalUserGroup
|
||||
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit
|
||||
from ee.danswer.external_permissions.slack.utils import fetch_user_id_to_email_map
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -26,6 +29,7 @@ def _get_slack_group_ids(
|
||||
|
||||
|
||||
def _get_slack_group_members_email(
|
||||
db_session: Session,
|
||||
slack_client: WebClient,
|
||||
group_name: str,
|
||||
user_id_to_email_map: dict[str, str],
|
||||
@@ -45,14 +49,18 @@ def _get_slack_group_members_email(
|
||||
# If no email is found, we skip the user
|
||||
continue
|
||||
user_id_to_email_map[member_id] = member_email
|
||||
batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session, [member_email]
|
||||
)
|
||||
group_member_emails.append(member_email)
|
||||
|
||||
return group_member_emails
|
||||
|
||||
|
||||
def slack_group_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[ExternalUserGroup]:
|
||||
) -> None:
|
||||
slack_client = WebClient(
|
||||
token=cc_pair.credential.credential_json["slack_bot_token"]
|
||||
)
|
||||
@@ -61,13 +69,24 @@ def slack_group_sync(
|
||||
danswer_groups: list[ExternalUserGroup] = []
|
||||
for group_name in _get_slack_group_ids(slack_client):
|
||||
group_member_emails = _get_slack_group_members_email(
|
||||
db_session=db_session,
|
||||
slack_client=slack_client,
|
||||
group_name=group_name,
|
||||
user_id_to_email_map=user_id_to_email_map,
|
||||
)
|
||||
if not group_member_emails:
|
||||
continue
|
||||
danswer_groups.append(
|
||||
ExternalUserGroup(id=group_name, user_emails=group_member_emails)
|
||||
group_members = batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session=db_session, emails=group_member_emails
|
||||
)
|
||||
return danswer_groups
|
||||
if group_members:
|
||||
danswer_groups.append(
|
||||
ExternalUserGroup(
|
||||
id=group_name, user_ids=[user.id for user in group_members]
|
||||
)
|
||||
)
|
||||
|
||||
replace_user__ext_group_for_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
group_defs=danswer_groups,
|
||||
source=cc_pair.connector.source,
|
||||
)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from danswer.access.models import DocExternalAccess
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from ee.danswer.db.external_perm import ExternalUserGroup
|
||||
from ee.danswer.external_permissions.confluence.doc_sync import confluence_doc_sync
|
||||
from ee.danswer.external_permissions.confluence.group_sync import confluence_group_sync
|
||||
from ee.danswer.external_permissions.gmail.doc_sync import gmail_doc_sync
|
||||
@@ -12,18 +12,12 @@ from ee.danswer.external_permissions.google_drive.group_sync import gdrive_group
|
||||
from ee.danswer.external_permissions.slack.doc_sync import slack_doc_sync
|
||||
|
||||
# Defining the input/output types for the sync functions
|
||||
DocSyncFuncType = Callable[
|
||||
SyncFuncType = Callable[
|
||||
[
|
||||
Session,
|
||||
ConnectorCredentialPair,
|
||||
],
|
||||
list[DocExternalAccess],
|
||||
]
|
||||
|
||||
GroupSyncFuncType = Callable[
|
||||
[
|
||||
ConnectorCredentialPair,
|
||||
],
|
||||
list[ExternalUserGroup],
|
||||
None,
|
||||
]
|
||||
|
||||
# These functions update:
|
||||
@@ -31,7 +25,7 @@ GroupSyncFuncType = Callable[
|
||||
# - the external_user_group_id <-> document mapping
|
||||
# in postgres without committing
|
||||
# THIS ONE IS NECESSARY FOR AUTO SYNC TO WORK
|
||||
DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, DocSyncFuncType] = {
|
||||
DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, SyncFuncType] = {
|
||||
DocumentSource.GOOGLE_DRIVE: gdrive_doc_sync,
|
||||
DocumentSource.CONFLUENCE: confluence_doc_sync,
|
||||
DocumentSource.SLACK: slack_doc_sync,
|
||||
@@ -42,21 +36,19 @@ DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, DocSyncFuncType] = {
|
||||
# - the user_email <-> external_user_group_id mapping
|
||||
# in postgres without committing
|
||||
# THIS ONE IS OPTIONAL ON AN APP BY APP BASIS
|
||||
GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, GroupSyncFuncType] = {
|
||||
GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, SyncFuncType] = {
|
||||
DocumentSource.GOOGLE_DRIVE: gdrive_group_sync,
|
||||
DocumentSource.CONFLUENCE: confluence_group_sync,
|
||||
}
|
||||
|
||||
|
||||
# If nothing is specified here, we run the doc_sync every time the celery beat runs
|
||||
DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
# Polling is not supported so we fetch all doc permissions every 5 minutes
|
||||
DocumentSource.CONFLUENCE: 5 * 60,
|
||||
DocumentSource.SLACK: 5 * 60,
|
||||
}
|
||||
|
||||
EXTERNAL_GROUP_SYNC_PERIOD: int = 30 # 30 seconds
|
||||
|
||||
|
||||
def check_if_valid_sync_source(source_type: DocumentSource) -> bool:
|
||||
return source_type in DOC_PERMISSIONS_FUNC_MAP
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from fastapi import FastAPI
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
from httpx_oauth.clients.openid import OpenID
|
||||
|
||||
from danswer.auth.users import auth_backend
|
||||
@@ -60,31 +59,6 @@ def get_application() -> FastAPI:
|
||||
if MULTI_TENANT:
|
||||
add_tenant_id_middleware(application, logger)
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
create_danswer_oauth_router(
|
||||
oauth_client,
|
||||
auth_backend,
|
||||
USER_AUTH_SECRET,
|
||||
associate_by_email=True,
|
||||
is_verified_by_default=True,
|
||||
# Points the user back to the login page
|
||||
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
|
||||
),
|
||||
prefix="/auth/oauth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
# Need basic auth router for `logout` endpoint
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
fastapi_users.get_logout_router(auth_backend),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.OIDC:
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
@@ -99,7 +73,6 @@ def get_application() -> FastAPI:
|
||||
prefix="/auth/oidc",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
# need basic auth router for `logout` endpoint
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
from typing import cast
|
||||
from typing import List
|
||||
|
||||
from cohere import Client
|
||||
|
||||
from ee.danswer.configs.app_configs import COHERE_DEFAULT_API_KEY
|
||||
|
||||
Embedding = List[float]
|
||||
|
||||
|
||||
def load_processed_docs(cohere_enabled: bool) -> list[dict]:
|
||||
base_path = os.path.join(os.getcwd(), "danswer", "seeding")
|
||||
|
||||
if cohere_enabled and COHERE_DEFAULT_API_KEY:
|
||||
initial_docs_path = os.path.join(base_path, "initial_docs_cohere.json")
|
||||
processed_docs = json.load(open(initial_docs_path))
|
||||
|
||||
cohere_client = Client(api_key=COHERE_DEFAULT_API_KEY)
|
||||
embed_model = "embed-english-v3.0"
|
||||
|
||||
for doc in processed_docs:
|
||||
title_embed_response = cohere_client.embed(
|
||||
texts=[doc["title"]],
|
||||
model=embed_model,
|
||||
input_type="search_document",
|
||||
)
|
||||
content_embed_response = cohere_client.embed(
|
||||
texts=[doc["content"]],
|
||||
model=embed_model,
|
||||
input_type="search_document",
|
||||
)
|
||||
|
||||
doc["title_embedding"] = cast(
|
||||
List[Embedding], title_embed_response.embeddings
|
||||
)[0]
|
||||
doc["content_embedding"] = cast(
|
||||
List[Embedding], content_embed_response.embeddings
|
||||
)[0]
|
||||
else:
|
||||
initial_docs_path = os.path.join(base_path, "initial_docs.json")
|
||||
processed_docs = json.load(open(initial_docs_path))
|
||||
|
||||
return processed_docs
|
||||
@@ -38,4 +38,3 @@ class ImpersonateRequest(BaseModel):
|
||||
class TenantCreationPayload(BaseModel):
|
||||
tenant_id: str
|
||||
email: str
|
||||
referral_source: str | None = None
|
||||
|
||||
@@ -44,9 +44,7 @@ from shared_configs.enums import EmbeddingProvider
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_or_create_tenant_id(
|
||||
email: str, referral_source: str | None = None
|
||||
) -> str:
|
||||
async def get_or_create_tenant_id(email: str) -> str:
|
||||
"""Get existing tenant ID for an email or create a new tenant if none exists."""
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
@@ -56,7 +54,7 @@ async def get_or_create_tenant_id(
|
||||
except exceptions.UserNotExists:
|
||||
# If tenant does not exist and in Multi tenant mode, provision a new tenant
|
||||
try:
|
||||
tenant_id = await create_tenant(email, referral_source)
|
||||
tenant_id = await create_tenant(email)
|
||||
except Exception as e:
|
||||
logger.error(f"Tenant provisioning failed: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
|
||||
@@ -69,13 +67,13 @@ async def get_or_create_tenant_id(
|
||||
return tenant_id
|
||||
|
||||
|
||||
async def create_tenant(email: str, referral_source: str | None = None) -> str:
|
||||
async def create_tenant(email: str) -> str:
|
||||
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
|
||||
try:
|
||||
# Provision tenant on data plane
|
||||
await provision_tenant(tenant_id, email)
|
||||
# Notify control plane
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
await notify_control_plane(tenant_id, email)
|
||||
except Exception as e:
|
||||
logger.error(f"Tenant provisioning failed: {e}")
|
||||
await rollback_tenant_provisioning(tenant_id)
|
||||
@@ -132,18 +130,14 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
async def notify_control_plane(
|
||||
tenant_id: str, email: str, referral_source: str | None = None
|
||||
) -> None:
|
||||
async def notify_control_plane(tenant_id: str, email: str) -> None:
|
||||
logger.info("Fetching billing information")
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = TenantCreationPayload(
|
||||
tenant_id=tenant_id, email=email, referral_source=referral_source
|
||||
)
|
||||
payload = TenantCreationPayload(tenant_id=tenant_id, email=email)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
|
||||
@@ -26,5 +26,4 @@ lxml==5.3.0
|
||||
lxml_html_clean==0.2.2
|
||||
boto3-stubs[s3]==1.34.133
|
||||
pandas==2.2.3
|
||||
pandas-stubs==2.2.3.241009
|
||||
cohere==5.6.1
|
||||
pandas-stubs==2.2.3.241009
|
||||
@@ -1,2 +1 @@
|
||||
python3-saml==1.15.0
|
||||
cohere==5.6.1
|
||||
python3-saml==1.15.0
|
||||
@@ -42,7 +42,7 @@ def run_jobs() -> None:
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
|
||||
"vespa_metadata_sync,connector_deletion",
|
||||
]
|
||||
|
||||
cmd_worker_heavy = [
|
||||
@@ -56,7 +56,7 @@ def run_jobs() -> None:
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync",
|
||||
"connector_pruning",
|
||||
]
|
||||
|
||||
cmd_worker_indexing = [
|
||||
|
||||
@@ -142,14 +142,14 @@ async def async_return_default_schema(*args: Any, **kwargs: Any) -> str:
|
||||
# Prefix used for all tenant ids
|
||||
TENANT_ID_PREFIX = "tenant_"
|
||||
|
||||
DISALLOWED_SLACK_BOT_TENANT_IDS = os.environ.get("DISALLOWED_SLACK_BOT_TENANT_IDS")
|
||||
ALLOWED_SLACK_BOT_TENANT_IDS = os.environ.get("ALLOWED_SLACK_BOT_TENANT_IDS")
|
||||
DISALLOWED_SLACK_BOT_TENANT_LIST = (
|
||||
[tenant.strip() for tenant in DISALLOWED_SLACK_BOT_TENANT_IDS.split(",")]
|
||||
if DISALLOWED_SLACK_BOT_TENANT_IDS
|
||||
[tenant.strip() for tenant in ALLOWED_SLACK_BOT_TENANT_IDS.split(",")]
|
||||
if ALLOWED_SLACK_BOT_TENANT_IDS
|
||||
else None
|
||||
)
|
||||
|
||||
IGNORED_SYNCING_TENANT_IDS = os.environ.get("IGNORED_SYNCING_TENANT_IDS")
|
||||
IGNORED_SYNCING_TENANT_IDS = os.environ.get("IGNORED_SYNCING_TENANT_ID")
|
||||
IGNORED_SYNCING_TENANT_LIST = (
|
||||
[tenant.strip() for tenant in IGNORED_SYNCING_TENANT_IDS.split(",")]
|
||||
if IGNORED_SYNCING_TENANT_IDS
|
||||
|
||||
@@ -33,7 +33,7 @@ stopasgroup=true
|
||||
command=celery -A danswer.background.celery.versioned_apps.light worker
|
||||
--loglevel=INFO
|
||||
--hostname=light@%%n
|
||||
-Q vespa_metadata_sync,connector_deletion,doc_permissions_upsert
|
||||
-Q vespa_metadata_sync,connector_deletion
|
||||
stdout_logfile=/var/log/celery_worker_light.log
|
||||
stdout_logfile_maxbytes=16MB
|
||||
redirect_stderr=true
|
||||
@@ -45,7 +45,7 @@ stopasgroup=true
|
||||
command=celery -A danswer.background.celery.versioned_apps.heavy worker
|
||||
--loglevel=INFO
|
||||
--hostname=heavy@%%n
|
||||
-Q connector_pruning,connector_doc_permissions_sync,connector_external_group_sync
|
||||
-Q connector_pruning
|
||||
stdout_logfile=/var/log/celery_worker_heavy.log
|
||||
stdout_logfile_maxbytes=16MB
|
||||
redirect_stderr=true
|
||||
|
||||
@@ -39,39 +39,24 @@ def test_confluence_connector_basic(
|
||||
with pytest.raises(StopIteration):
|
||||
next(doc_batch_generator)
|
||||
|
||||
assert len(doc_batch) == 3
|
||||
assert len(doc_batch) == 2
|
||||
|
||||
for doc in doc_batch:
|
||||
if doc.semantic_identifier == "DailyConnectorTestSpace Home":
|
||||
page_doc = doc
|
||||
elif ".txt" in doc.semantic_identifier:
|
||||
txt_doc = doc
|
||||
elif doc.semantic_identifier == "Page Within A Page":
|
||||
page_within_a_page_doc = doc
|
||||
|
||||
assert page_within_a_page_doc.semantic_identifier == "Page Within A Page"
|
||||
assert page_within_a_page_doc.primary_owners
|
||||
assert page_within_a_page_doc.primary_owners[0].email == "hagen@danswer.ai"
|
||||
assert len(page_within_a_page_doc.sections) == 1
|
||||
|
||||
page_within_a_page_section = page_within_a_page_doc.sections[0]
|
||||
page_within_a_page_text = "@Chris Weaver loves cherry pie"
|
||||
assert page_within_a_page_section.text == page_within_a_page_text
|
||||
assert (
|
||||
page_within_a_page_section.link
|
||||
== "https://danswerai.atlassian.net/wiki/spaces/DailyConne/pages/200769540/Page+Within+A+Page"
|
||||
)
|
||||
|
||||
assert page_doc.semantic_identifier == "DailyConnectorTestSpace Home"
|
||||
assert page_doc.metadata["labels"] == ["testlabel"]
|
||||
assert page_doc.primary_owners
|
||||
assert page_doc.primary_owners[0].email == "hagen@danswer.ai"
|
||||
assert page_doc.primary_owners[0].email == "chris@danswer.ai"
|
||||
assert len(page_doc.sections) == 1
|
||||
|
||||
page_section = page_doc.sections[0]
|
||||
assert page_section.text == "test123 " + page_within_a_page_text
|
||||
section = page_doc.sections[0]
|
||||
assert section.text == "test123"
|
||||
assert (
|
||||
page_section.link
|
||||
section.link
|
||||
== "https://danswerai.atlassian.net/wiki/spaces/DailyConne/overview"
|
||||
)
|
||||
|
||||
|
||||
@@ -8,11 +8,11 @@ import requests
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.server.documents.models import CCPairFullInfo
|
||||
from danswer.db.enums import TaskStatus
|
||||
from danswer.server.documents.models import CeleryTaskStatus
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.server.documents.models import ConnectorIndexingStatus
|
||||
from danswer.server.documents.models import DocumentSource
|
||||
from danswer.server.documents.models import DocumentSyncStatus
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import MAX_DELAY
|
||||
@@ -73,7 +73,7 @@ class CCPairManager:
|
||||
source=source,
|
||||
input_type=input_type,
|
||||
connector_specific_config=connector_specific_config,
|
||||
access_type=access_type,
|
||||
is_public=(access_type == AccessType.PUBLIC),
|
||||
groups=groups,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
@@ -147,22 +147,7 @@ class CCPairManager:
|
||||
result.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_single(
|
||||
cc_pair_id: int,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> CCPairFullInfo | None:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
cc_pair_json = response.json()
|
||||
return CCPairFullInfo(**cc_pair_json)
|
||||
|
||||
@staticmethod
|
||||
def get_indexing_status_by_id(
|
||||
def get_one(
|
||||
cc_pair_id: int,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> ConnectorIndexingStatus | None:
|
||||
@@ -181,7 +166,7 @@ class CCPairManager:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_indexing_statuses(
|
||||
def get_all(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> list[ConnectorIndexingStatus]:
|
||||
response = requests.get(
|
||||
@@ -199,7 +184,7 @@ class CCPairManager:
|
||||
verify_deleted: bool = False,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
all_cc_pairs = CCPairManager.get_indexing_statuses(user_performing_action)
|
||||
all_cc_pairs = CCPairManager.get_all(user_performing_action)
|
||||
for retrieved_cc_pair in all_cc_pairs:
|
||||
if retrieved_cc_pair.cc_pair_id == cc_pair.id:
|
||||
if verify_deleted:
|
||||
@@ -249,9 +234,7 @@ class CCPairManager:
|
||||
"""after: Wait for an indexing success time after this time"""
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
fetched_cc_pairs = CCPairManager.get_indexing_statuses(
|
||||
user_performing_action
|
||||
)
|
||||
fetched_cc_pairs = CCPairManager.get_all(user_performing_action)
|
||||
for fetched_cc_pair in fetched_cc_pairs:
|
||||
if fetched_cc_pair.cc_pair_id != cc_pair.id:
|
||||
continue
|
||||
@@ -344,133 +327,57 @@ class CCPairManager:
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
"""This function triggers a permission sync.
|
||||
Naming / intent of this function probably could use improvement, but currently it's letting
|
||||
409 Conflict pass through since if it's running that's what we were trying to do anyway.
|
||||
"""
|
||||
result = requests.post(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync-permissions",
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
#
|
||||
if result.status_code != 409:
|
||||
result.raise_for_status()
|
||||
result.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_sync_task(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> datetime | None:
|
||||
) -> CeleryTaskStatus:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync-permissions",
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_str = response.json()
|
||||
|
||||
# If the response itself is a datetime string, parse it
|
||||
if not isinstance(response_str, str):
|
||||
return None
|
||||
|
||||
try:
|
||||
return datetime.fromisoformat(response_str)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_doc_sync_statuses(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> list[DocumentSyncStatus]:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/get-docs-sync-status",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
doc_sync_statuses: list[DocumentSyncStatus] = []
|
||||
for doc_sync_status in response.json():
|
||||
doc_sync_statuses.append(
|
||||
DocumentSyncStatus(
|
||||
doc_id=doc_sync_status["doc_id"],
|
||||
last_synced=datetime.fromisoformat(doc_sync_status["last_synced"]),
|
||||
last_modified=datetime.fromisoformat(
|
||||
doc_sync_status["last_modified"]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return doc_sync_statuses
|
||||
return CeleryTaskStatus(**response.json())
|
||||
|
||||
@staticmethod
|
||||
def wait_for_sync(
|
||||
cc_pair: DATestCCPair,
|
||||
after: datetime,
|
||||
timeout: float = MAX_DELAY,
|
||||
number_of_updated_docs: int = 0,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
"""after: The task register time must be after this time."""
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
last_synced = CCPairManager.get_sync_task(cc_pair, user_performing_action)
|
||||
if last_synced and last_synced > after:
|
||||
print(f"last_synced: {last_synced}")
|
||||
print(f"sync command start time: {after}")
|
||||
print(f"permission sync complete: cc_pair={cc_pair.id}")
|
||||
break
|
||||
task = CCPairManager.get_sync_task(cc_pair, user_performing_action)
|
||||
if not task:
|
||||
raise ValueError("Sync task not found.")
|
||||
|
||||
if not task.register_time or task.register_time < after:
|
||||
raise ValueError("Sync task register time is too early.")
|
||||
|
||||
if task.status == TaskStatus.SUCCESS:
|
||||
# Sync succeeded
|
||||
return
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
if elapsed > timeout:
|
||||
raise TimeoutError(
|
||||
f"Permission sync was not completed within {timeout} seconds"
|
||||
f"CC pair syncing was not completed within {timeout} seconds"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Waiting for CC sync to complete. elapsed={elapsed:.2f} timeout={timeout}"
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
# TODO: remove this sleep,
|
||||
# this shouldnt be necessary but something is off with the timing for the sync jobs
|
||||
time.sleep(5)
|
||||
|
||||
print("waiting for vespa sync")
|
||||
# wait for the vespa sync to complete once the permission sync is complete
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
doc_sync_statuses = CCPairManager.get_doc_sync_statuses(
|
||||
cc_pair=cc_pair,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
synced_docs = 0
|
||||
for doc_sync_status in doc_sync_statuses:
|
||||
if (
|
||||
doc_sync_status.last_synced is not None
|
||||
and doc_sync_status.last_modified is not None
|
||||
and doc_sync_status.last_synced >= doc_sync_status.last_modified
|
||||
and doc_sync_status.last_synced >= after
|
||||
and doc_sync_status.last_modified >= after
|
||||
):
|
||||
synced_docs += 1
|
||||
|
||||
if synced_docs >= number_of_updated_docs:
|
||||
print(f"all docs synced: cc_pair={cc_pair.id}")
|
||||
break
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
if elapsed > timeout:
|
||||
raise TimeoutError(
|
||||
f"Vespa sync was not completed within {timeout} seconds"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Waiting for vespa sync to complete. elapsed={elapsed:.2f} timeout={timeout}"
|
||||
f"Waiting for CC syncing to complete. elapsed={elapsed:.2f} timeout={timeout}"
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
@@ -485,7 +392,7 @@ class CCPairManager:
|
||||
cc_pair_id is good to do."""
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
cc_pairs = CCPairManager.get_indexing_statuses(user_performing_action)
|
||||
cc_pairs = CCPairManager.get_all(user_performing_action)
|
||||
if cc_pair_id:
|
||||
found = False
|
||||
for cc_pair in cc_pairs:
|
||||
|
||||
@@ -4,7 +4,6 @@ from uuid import uuid4
|
||||
import requests
|
||||
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.server.documents.models import ConnectorUpdateRequest
|
||||
from danswer.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
@@ -20,7 +19,7 @@ class ConnectorManager:
|
||||
source: DocumentSource = DocumentSource.FILE,
|
||||
input_type: InputType = InputType.LOAD_STATE,
|
||||
connector_specific_config: dict[str, Any] | None = None,
|
||||
access_type: AccessType = AccessType.PUBLIC,
|
||||
is_public: bool = True,
|
||||
groups: list[int] | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DATestConnector:
|
||||
@@ -31,7 +30,7 @@ class ConnectorManager:
|
||||
source=source,
|
||||
input_type=input_type,
|
||||
connector_specific_config=connector_specific_config or {},
|
||||
access_type=access_type,
|
||||
is_public=is_public,
|
||||
groups=groups or [],
|
||||
)
|
||||
|
||||
@@ -52,7 +51,7 @@ class ConnectorManager:
|
||||
input_type=input_type,
|
||||
connector_specific_config=connector_specific_config or {},
|
||||
groups=groups,
|
||||
access_type=access_type,
|
||||
is_public=is_public,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -28,12 +28,10 @@ class TenantManager:
|
||||
def create(
|
||||
tenant_id: str | None = None,
|
||||
initial_admin_email: str | None = None,
|
||||
referral_source: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
body = {
|
||||
"tenant_id": tenant_id,
|
||||
"initial_admin_email": initial_admin_email,
|
||||
"referral_source": referral_source,
|
||||
}
|
||||
|
||||
token = generate_auth_token()
|
||||
|
||||
@@ -55,7 +55,7 @@ class DATestConnector(BaseModel):
|
||||
input_type: InputType
|
||||
connector_specific_config: dict[str, Any]
|
||||
groups: list[int] | None = None
|
||||
access_type: AccessType | None = None
|
||||
is_public: bool | None = None
|
||||
|
||||
|
||||
class SimpleTestDocument(BaseModel):
|
||||
|
||||
@@ -6,10 +6,6 @@ import pytest
|
||||
|
||||
from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager
|
||||
|
||||
# from tests.load_env_vars import load_env_vars
|
||||
|
||||
# load_env_vars()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def slack_test_setup() -> Generator[tuple[dict[str, Any], dict[str, Any]], None, None]:
|
||||
|
||||
@@ -67,7 +67,7 @@ def test_slack_permission_sync(
|
||||
"workspace": "onyx-test-workspace",
|
||||
"channels": [public_channel["name"], private_channel["name"]],
|
||||
},
|
||||
access_type=AccessType.SYNC,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
@@ -96,13 +96,11 @@ def test_slack_permission_sync(
|
||||
private_message = "Sara's favorite number is 346794"
|
||||
|
||||
# Add messages to channels
|
||||
print(f"\n Adding public message to channel: {public_message}")
|
||||
SlackManager.add_message_to_channel(
|
||||
slack_client=slack_client,
|
||||
channel=public_channel,
|
||||
message=public_message,
|
||||
)
|
||||
print(f"\n Adding private message to channel: {private_message}")
|
||||
SlackManager.add_message_to_channel(
|
||||
slack_client=slack_client,
|
||||
channel=private_channel,
|
||||
@@ -119,6 +117,7 @@ def test_slack_permission_sync(
|
||||
)
|
||||
|
||||
# Run permission sync
|
||||
before = datetime.now(timezone.utc)
|
||||
CCPairManager.sync(
|
||||
cc_pair=cc_pair,
|
||||
user_performing_action=admin_user,
|
||||
@@ -126,33 +125,26 @@ def test_slack_permission_sync(
|
||||
CCPairManager.wait_for_sync(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
number_of_updated_docs=2,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Search as admin with access to both channels
|
||||
print("\nSearching as admin user")
|
||||
danswer_doc_message_strings = DocumentSearchManager.search_documents(
|
||||
query="favorite number",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
print(
|
||||
"\n documents retrieved by admin user: ",
|
||||
danswer_doc_message_strings,
|
||||
)
|
||||
|
||||
# Ensure admin user can see messages from both channels
|
||||
assert public_message in danswer_doc_message_strings
|
||||
assert private_message in danswer_doc_message_strings
|
||||
|
||||
# Search as test_user_2 with access to only the public channel
|
||||
print("\n Searching as test_user_2")
|
||||
danswer_doc_message_strings = DocumentSearchManager.search_documents(
|
||||
query="favorite number",
|
||||
user_performing_action=test_user_2,
|
||||
)
|
||||
print(
|
||||
"\n documents retrieved by test_user_2: ",
|
||||
"\ntop_documents content before removing from private channel for test_user_2: ",
|
||||
danswer_doc_message_strings,
|
||||
)
|
||||
|
||||
@@ -161,13 +153,12 @@ def test_slack_permission_sync(
|
||||
assert private_message not in danswer_doc_message_strings
|
||||
|
||||
# Search as test_user_1 with access to both channels
|
||||
print("\n Searching as test_user_1")
|
||||
danswer_doc_message_strings = DocumentSearchManager.search_documents(
|
||||
query="favorite number",
|
||||
user_performing_action=test_user_1,
|
||||
)
|
||||
print(
|
||||
"\n documents retrieved by test_user_1 before being removed from private channel: ",
|
||||
"\ntop_documents content before removing from private channel for test_user_1: ",
|
||||
danswer_doc_message_strings,
|
||||
)
|
||||
|
||||
@@ -176,8 +167,7 @@ def test_slack_permission_sync(
|
||||
assert private_message in danswer_doc_message_strings
|
||||
|
||||
# ----------------------MAKE THE CHANGES--------------------------
|
||||
print("\n Removing test_user_1 from the private channel")
|
||||
before = datetime.now(timezone.utc)
|
||||
print("\nRemoving test_user_1 from the private channel")
|
||||
# Remove test_user_1 from the private channel
|
||||
desired_channel_members = [admin_user]
|
||||
SlackManager.set_channel_members(
|
||||
@@ -195,20 +185,18 @@ def test_slack_permission_sync(
|
||||
CCPairManager.wait_for_sync(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
number_of_updated_docs=1,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# ----------------------------VERIFY THE CHANGES---------------------------
|
||||
# Ensure test_user_1 can no longer see messages from the private channel
|
||||
# Search as test_user_1 with access to only the public channel
|
||||
|
||||
danswer_doc_message_strings = DocumentSearchManager.search_documents(
|
||||
query="favorite number",
|
||||
user_performing_action=test_user_1,
|
||||
)
|
||||
print(
|
||||
"\n documents retrieved by test_user_1 after being removed from private channel: ",
|
||||
"\ntop_documents content after removing from private channel for test_user_1: ",
|
||||
danswer_doc_message_strings,
|
||||
)
|
||||
|
||||
|
||||
@@ -3,8 +3,6 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.server.documents.models import DocumentSource
|
||||
@@ -24,7 +22,7 @@ from tests.integration.common_utils.vespa import vespa_fixture
|
||||
from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="flaky - see DAN-986 for details", strict=False)
|
||||
# @pytest.mark.xfail(reason="flaky - see DAN-835 for example", strict=False)
|
||||
def test_slack_prune(
|
||||
reset: None,
|
||||
vespa_client: vespa_fixture,
|
||||
@@ -64,7 +62,7 @@ def test_slack_prune(
|
||||
"workspace": "onyx-test-workspace",
|
||||
"channels": [public_channel["name"], private_channel["name"]],
|
||||
},
|
||||
access_type=AccessType.PUBLIC,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
@@ -14,12 +14,12 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
def test_multi_tenant_access_control(reset_multitenant: None) -> None:
|
||||
# Create Tenant 1 and its Admin User
|
||||
TenantManager.create("tenant_dev1", "test1@test.com", "Data Plane Registration")
|
||||
TenantManager.create("tenant_dev1", "test1@test.com")
|
||||
test_user1: DATestUser = UserManager.create(name="test1", email="test1@test.com")
|
||||
assert UserManager.verify_role(test_user1, UserRole.ADMIN)
|
||||
|
||||
# Create Tenant 2 and its Admin User
|
||||
TenantManager.create("tenant_dev2", "test2@test.com", "Data Plane Registration")
|
||||
TenantManager.create("tenant_dev2", "test2@test.com")
|
||||
test_user2: DATestUser = UserManager.create(name="test2", email="test2@test.com")
|
||||
assert UserManager.verify_role(test_user2, UserRole.ADMIN)
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
# Test flow from creating tenant to registering as a user
|
||||
def test_tenant_creation(reset_multitenant: None) -> None:
|
||||
TenantManager.create("tenant_dev", "test@test.com", "Data Plane Registration")
|
||||
TenantManager.create("tenant_dev", "test@test.com")
|
||||
test_user: DATestUser = UserManager.create(name="test", email="test@test.com")
|
||||
|
||||
assert UserManager.verify_role(test_user, UserRole.ADMIN)
|
||||
@@ -26,7 +26,7 @@ def test_tenant_creation(reset_multitenant: None) -> None:
|
||||
test_connector = ConnectorManager.create(
|
||||
name="admin_test_connector",
|
||||
source=DocumentSource.FILE,
|
||||
access_type=AccessType.PRIVATE,
|
||||
is_public=False,
|
||||
user_performing_action=test_user,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
import requests
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestAPIKey
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_limited(reset: None) -> None:
|
||||
"""Verify that with a limited role key, limited endpoints are accessible and
|
||||
others are not."""
|
||||
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
api_key_role=UserRole.LIMITED,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# test limited endpoint
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/persona/0",
|
||||
headers=api_key.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# test basic endpoints
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/input_prompt",
|
||||
headers=api_key.headers,
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
# test admin endpoints
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/api-key",
|
||||
headers=api_key.headers,
|
||||
)
|
||||
assert response.status_code == 403
|
||||
@@ -29,25 +29,6 @@ from tests.integration.common_utils.test_models import DATestUserGroup
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
|
||||
|
||||
def test_connector_creation(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# create connectors
|
||||
cc_pair_1 = CCPairManager.create_from_scratch(
|
||||
source=DocumentSource.INGESTION_API,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
cc_pair_info = CCPairManager.get_single(
|
||||
cc_pair_1.id, user_performing_action=admin_user
|
||||
)
|
||||
assert cc_pair_info
|
||||
assert cc_pair_info.creator
|
||||
assert str(cc_pair_info.creator) == admin_user.id
|
||||
assert cc_pair_info.creator_email == admin_user.email
|
||||
|
||||
|
||||
def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -50,11 +50,12 @@ def test_cc_pair_permissions(reset: None) -> None:
|
||||
user_groups_to_check=[user_group_1], user_performing_action=admin_user
|
||||
)
|
||||
|
||||
# Create a credentials that the curator is and is not curator of
|
||||
connector_1 = ConnectorManager.create(
|
||||
name="admin_owned_connector",
|
||||
name="curator_owned_connector",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
groups=[user_group_1.id],
|
||||
access_type=AccessType.PRIVATE,
|
||||
is_public=False,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
# currently we dont enforce permissions at the connector level
|
||||
@@ -66,7 +67,6 @@ def test_cc_pair_permissions(reset: None) -> None:
|
||||
# is_public=False,
|
||||
# user_performing_action=admin_user,
|
||||
# )
|
||||
# Create a credentials that the curator is and is not curator of
|
||||
credential_1 = CredentialManager.create(
|
||||
name="curator_owned_credential",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
|
||||
@@ -5,7 +5,6 @@ the permissions of the curator manipulating connectors.
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||
from tests.integration.common_utils.managers.user import DATestUser
|
||||
@@ -58,7 +57,7 @@ def test_connector_permissions(reset: None) -> None:
|
||||
name="invalid_connector_1",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
groups=[user_group_1.id],
|
||||
access_type=AccessType.PUBLIC,
|
||||
is_public=True,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
@@ -69,7 +68,7 @@ def test_connector_permissions(reset: None) -> None:
|
||||
name="invalid_connector_2",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
groups=[user_group_1.id, user_group_2.id],
|
||||
access_type=AccessType.PRIVATE,
|
||||
is_public=False,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
@@ -81,7 +80,7 @@ def test_connector_permissions(reset: None) -> None:
|
||||
name="valid_connector",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
groups=[user_group_1.id],
|
||||
access_type=AccessType.PRIVATE,
|
||||
is_public=False,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
assert valid_connector.id is not None
|
||||
@@ -122,7 +121,7 @@ def test_connector_permissions(reset: None) -> None:
|
||||
name="invalid_connector_3",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
groups=[user_group_2.id],
|
||||
access_type=AccessType.PRIVATE,
|
||||
is_public=False,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
@@ -132,6 +131,6 @@ def test_connector_permissions(reset: None) -> None:
|
||||
name="invalid_connector_4",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
groups=[user_group_1.id],
|
||||
access_type=AccessType.PUBLIC,
|
||||
is_public=True,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
@@ -51,7 +51,7 @@ def test_whole_curator_flow(reset: None) -> None:
|
||||
test_connector = ConnectorManager.create(
|
||||
name="curator_test_connector",
|
||||
source=DocumentSource.FILE,
|
||||
access_type=AccessType.PRIVATE,
|
||||
is_public=False,
|
||||
groups=[user_group_1.id],
|
||||
user_performing_action=curator,
|
||||
)
|
||||
@@ -130,7 +130,7 @@ def test_global_curator_flow(reset: None) -> None:
|
||||
test_connector = ConnectorManager.create(
|
||||
name="curator_test_connector",
|
||||
source=DocumentSource.FILE,
|
||||
access_type=AccessType.PRIVATE,
|
||||
is_public=False,
|
||||
groups=[user_group_1.id],
|
||||
user_performing_action=global_curator,
|
||||
)
|
||||
|
||||
@@ -139,7 +139,7 @@ def test_web_pruning(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
cc_pair_1, now, timeout=60, user_performing_action=admin_user
|
||||
)
|
||||
|
||||
selected_cc_pair = CCPairManager.get_indexing_status_by_id(
|
||||
selected_cc_pair = CCPairManager.get_one(
|
||||
cc_pair_1.id, user_performing_action=admin_user
|
||||
)
|
||||
assert selected_cc_pair is not None, "cc_pair not found after indexing!"
|
||||
@@ -156,7 +156,7 @@ def test_web_pruning(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
cc_pair_1, now, timeout=60, user_performing_action=admin_user
|
||||
)
|
||||
|
||||
selected_cc_pair = CCPairManager.get_indexing_status_by_id(
|
||||
selected_cc_pair = CCPairManager.get_one(
|
||||
cc_pair_1.id, user_performing_action=admin_user
|
||||
)
|
||||
assert selected_cc_pair is not None, "cc_pair not found after pruning!"
|
||||
|
||||
@@ -14,7 +14,7 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-heavy
|
||||
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.12
|
||||
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.10
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
@@ -25,7 +25,7 @@ spec:
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync",
|
||||
"connector_pruning",
|
||||
]
|
||||
env:
|
||||
- name: REDIS_PASSWORD
|
||||
|
||||
@@ -14,7 +14,7 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-indexing
|
||||
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.12
|
||||
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.10
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user