mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 00:05:47 +00:00
Compare commits
28 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c29beaf403 | ||
|
|
46f84d15f8 | ||
|
|
e8c93199f2 | ||
|
|
637b6d9e75 | ||
|
|
54dc1ac917 | ||
|
|
21d5cc43f8 | ||
|
|
7c841051ed | ||
|
|
6e91964924 | ||
|
|
facf1d55a0 | ||
|
|
d68f8d6fbc | ||
|
|
65a205d488 | ||
|
|
485f3f72fa | ||
|
|
dcbea883ae | ||
|
|
a50a3944b3 | ||
|
|
60471b6a73 | ||
|
|
d703e694ce | ||
|
|
6066042fef | ||
|
|
eb0e20b9e4 | ||
|
|
490a68773b | ||
|
|
227aff1e47 | ||
|
|
6e29d1944c | ||
|
|
22189f02c6 | ||
|
|
fdc4811fce | ||
|
|
021d0cf314 | ||
|
|
942e47db29 | ||
|
|
f4a020b599 | ||
|
|
5166649eae | ||
|
|
ba805f766f |
3
.github/workflows/pr-Integration-tests.yml
vendored
3
.github/workflows/pr-Integration-tests.yml
vendored
@@ -197,7 +197,8 @@ jobs:
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
danswer/danswer-integration:test \
|
||||
/app/tests/integration/tests
|
||||
/app/tests/integration/tests \
|
||||
/app/tests/integration/connector_job_tests
|
||||
continue-on-error: true
|
||||
id: run_tests
|
||||
|
||||
|
||||
31
.github/workflows/pr-helm-chart-testing.yml
vendored
31
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -23,21 +23,6 @@ jobs:
|
||||
with:
|
||||
version: v3.14.4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
- run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@v2.6.1
|
||||
|
||||
@@ -52,6 +37,22 @@ jobs:
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
# rkuo: I don't think we need python?
|
||||
# - name: Set up Python
|
||||
# uses: actions/setup-python@v5
|
||||
# with:
|
||||
# python-version: '3.11'
|
||||
# cache: 'pip'
|
||||
# cache-dependency-path: |
|
||||
# backend/requirements/default.txt
|
||||
# backend/requirements/dev.txt
|
||||
# backend/requirements/model_server.txt
|
||||
# - run: |
|
||||
# python -m pip install --upgrade pip
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
# lint all charts if any changes were detected
|
||||
- name: Run chart-testing (lint)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
|
||||
4
.vscode/launch.template.jsonc
vendored
4
.vscode/launch.template.jsonc
vendored
@@ -203,7 +203,7 @@
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
@@ -232,7 +232,7 @@
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync",
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2",
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
"""default chosen assistants to none
|
||||
|
||||
Revision ID: 26b931506ecb
|
||||
Revises: 2daa494a0851
|
||||
Create Date: 2024-11-12 13:23:29.858995
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "26b931506ecb"
|
||||
down_revision = "2daa494a0851"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user", sa.Column("chosen_assistants_new", postgresql.JSONB(), nullable=True)
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET chosen_assistants_new =
|
||||
CASE
|
||||
WHEN chosen_assistants = '[-2, -1, 0]' THEN NULL
|
||||
ELSE chosen_assistants
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
op.drop_column("user", "chosen_assistants")
|
||||
|
||||
op.alter_column(
|
||||
"user", "chosen_assistants_new", new_column_name="chosen_assistants"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"chosen_assistants_old",
|
||||
postgresql.JSONB(),
|
||||
nullable=False,
|
||||
server_default="[-2, -1, 0]",
|
||||
),
|
||||
)
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET chosen_assistants_old =
|
||||
CASE
|
||||
WHEN chosen_assistants IS NULL THEN '[-2, -1, 0]'::jsonb
|
||||
ELSE chosen_assistants
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
op.drop_column("user", "chosen_assistants")
|
||||
|
||||
op.alter_column(
|
||||
"user", "chosen_assistants_old", new_column_name="chosen_assistants"
|
||||
)
|
||||
30
backend/alembic/versions/2daa494a0851_add_group_sync_time.py
Normal file
30
backend/alembic/versions/2daa494a0851_add_group_sync_time.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""add-group-sync-time
|
||||
|
||||
Revision ID: 2daa494a0851
|
||||
Revises: c0fd6e4da83a
|
||||
Create Date: 2024-11-11 10:57:22.991157
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2daa494a0851"
|
||||
down_revision = "c0fd6e4da83a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"last_time_external_group_sync",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector_credential_pair", "last_time_external_group_sync")
|
||||
@@ -0,0 +1,30 @@
|
||||
"""add creator to cc pair
|
||||
|
||||
Revision ID: 9cf5c00f72fe
|
||||
Revises: c0fd6e4da83a
|
||||
Create Date: 2024-11-12 15:16:42.682902
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9cf5c00f72fe"
|
||||
down_revision = "26b931506ecb"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"creator_id",
|
||||
sa.UUID(as_uuid=True),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector_credential_pair", "creator_id")
|
||||
@@ -288,6 +288,15 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# NOTE: you will lose all chat history. This is to satisfy the non-nullable constraints
|
||||
# below
|
||||
op.execute("DELETE FROM chat_feedback")
|
||||
op.execute("DELETE FROM chat_message__search_doc")
|
||||
op.execute("DELETE FROM document_retrieval_feedback")
|
||||
op.execute("DELETE FROM document_retrieval_feedback")
|
||||
op.execute("DELETE FROM chat_message")
|
||||
op.execute("DELETE FROM chat_session")
|
||||
|
||||
op.drop_constraint(
|
||||
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
|
||||
)
|
||||
|
||||
@@ -23,6 +23,56 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Delete chat messages and feedback first since they reference chat sessions
|
||||
# Get chat messages from sessions with null persona_id
|
||||
chat_messages_query = """
|
||||
SELECT id
|
||||
FROM chat_message
|
||||
WHERE chat_session_id IN (
|
||||
SELECT id
|
||||
FROM chat_session
|
||||
WHERE persona_id IS NULL
|
||||
)
|
||||
"""
|
||||
|
||||
# Delete dependent records first
|
||||
op.execute(
|
||||
f"""
|
||||
DELETE FROM document_retrieval_feedback
|
||||
WHERE chat_message_id IN (
|
||||
{chat_messages_query}
|
||||
)
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
f"""
|
||||
DELETE FROM chat_message__search_doc
|
||||
WHERE chat_message_id IN (
|
||||
{chat_messages_query}
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Delete chat messages
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM chat_message
|
||||
WHERE chat_session_id IN (
|
||||
SELECT id
|
||||
FROM chat_session
|
||||
WHERE persona_id IS NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Now we can safely delete the chat sessions
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM chat_session
|
||||
WHERE persona_id IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
op.alter_column(
|
||||
"chat_session",
|
||||
"persona_id",
|
||||
|
||||
@@ -16,6 +16,41 @@ class ExternalAccess:
|
||||
is_public: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocExternalAccess:
|
||||
external_access: ExternalAccess
|
||||
# The document ID
|
||||
doc_id: str
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"external_access": {
|
||||
"external_user_emails": list(self.external_access.external_user_emails),
|
||||
"external_user_group_ids": list(
|
||||
self.external_access.external_user_group_ids
|
||||
),
|
||||
"is_public": self.external_access.is_public,
|
||||
},
|
||||
"doc_id": self.doc_id,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "DocExternalAccess":
|
||||
external_access = ExternalAccess(
|
||||
external_user_emails=set(
|
||||
data["external_access"].get("external_user_emails", [])
|
||||
),
|
||||
external_user_group_ids=set(
|
||||
data["external_access"].get("external_user_group_ids", [])
|
||||
),
|
||||
is_public=data["external_access"]["is_public"],
|
||||
)
|
||||
return cls(
|
||||
external_access=external_access,
|
||||
doc_id=data["doc_id"],
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocumentAccess(ExternalAccess):
|
||||
# User emails for Danswer users, None indicates admin
|
||||
|
||||
@@ -15,6 +15,7 @@ class UserRole(str, Enum):
|
||||
for all groups they are a member of
|
||||
"""
|
||||
|
||||
LIMITED = "limited"
|
||||
BASIC = "basic"
|
||||
ADMIN = "admin"
|
||||
CURATOR = "curator"
|
||||
|
||||
@@ -228,12 +228,17 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
safe: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
) -> User:
|
||||
referral_source = None
|
||||
if request is not None:
|
||||
referral_source = request.cookies.get("referral_source", None)
|
||||
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=user_create.email,
|
||||
referral_source=referral_source,
|
||||
)
|
||||
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -294,12 +299,17 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
) -> models.UOAP:
|
||||
referral_source = None
|
||||
if request:
|
||||
referral_source = getattr(request.state, "referral_source", None)
|
||||
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=account_email,
|
||||
referral_source=referral_source,
|
||||
)
|
||||
|
||||
if not tenant_id:
|
||||
@@ -652,12 +662,26 @@ async def current_user_with_expired_token(
|
||||
return await double_check_user(user, include_expired=True)
|
||||
|
||||
|
||||
async def current_user(
|
||||
async def current_limited_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
return await double_check_user(user)
|
||||
|
||||
|
||||
async def current_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
user = await double_check_user(user)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
if user.role == UserRole.LIMITED:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User role is LIMITED. BASIC or higher permissions are required.",
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
async def current_curator_or_admin_user(
|
||||
user: User | None = Depends(current_user),
|
||||
) -> User | None:
|
||||
@@ -711,8 +735,6 @@ def generate_state_token(
|
||||
|
||||
|
||||
# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91
|
||||
|
||||
|
||||
def create_danswer_oauth_router(
|
||||
oauth_client: BaseOAuth2,
|
||||
backend: AuthenticationBackend,
|
||||
@@ -762,15 +784,22 @@ def get_oauth_router(
|
||||
response_model=OAuth2AuthorizeResponse,
|
||||
)
|
||||
async def authorize(
|
||||
request: Request, scopes: List[str] = Query(None)
|
||||
request: Request,
|
||||
scopes: List[str] = Query(None),
|
||||
) -> OAuth2AuthorizeResponse:
|
||||
referral_source = request.cookies.get("referral_source", None)
|
||||
|
||||
if redirect_url is not None:
|
||||
authorize_redirect_url = redirect_url
|
||||
else:
|
||||
authorize_redirect_url = str(request.url_for(callback_route_name))
|
||||
|
||||
next_url = request.query_params.get("next", "/")
|
||||
state_data: Dict[str, str] = {"next_url": next_url}
|
||||
|
||||
state_data: Dict[str, str] = {
|
||||
"next_url": next_url,
|
||||
"referral_source": referral_source or "default_referral",
|
||||
}
|
||||
state = generate_state_token(state_data, state_secret)
|
||||
authorization_url = await oauth_client.get_authorization_url(
|
||||
authorize_redirect_url,
|
||||
@@ -829,8 +858,11 @@ def get_oauth_router(
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
next_url = state_data.get("next_url", "/")
|
||||
referral_source = state_data.get("referral_source", None)
|
||||
|
||||
# Authenticate user
|
||||
request.state.referral_source = referral_source
|
||||
|
||||
# Proceed to authenticate or create the user
|
||||
try:
|
||||
user = await user_manager.oauth_callback(
|
||||
oauth_client.name,
|
||||
@@ -872,7 +904,6 @@ def get_oauth_router(
|
||||
redirect_response.status_code = response.status_code
|
||||
if hasattr(response, "media_type"):
|
||||
redirect_response.media_type = response.media_type
|
||||
|
||||
return redirect_response
|
||||
|
||||
return router
|
||||
|
||||
@@ -24,6 +24,8 @@ from danswer.document_index.vespa_constants import VESPA_CONFIG_SERVER_URL
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
|
||||
from danswer.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from danswer.redis.redis_document_set import RedisDocumentSet
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
@@ -136,6 +138,22 @@ def on_task_postrun(
|
||||
RedisConnectorPrune.remove_from_taskset(int(cc_pair_id), task_id, r)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorPermissionSync.SUBTASK_PREFIX):
|
||||
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
RedisConnectorPermissionSync.remove_from_taskset(
|
||||
int(cc_pair_id), task_id, r
|
||||
)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorExternalGroupSync.SUBTASK_PREFIX):
|
||||
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
RedisConnectorExternalGroupSync.remove_from_taskset(
|
||||
int(cc_pair_id), task_id, r
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||
"""The first signal sent on celery worker startup"""
|
||||
|
||||
@@ -12,6 +12,7 @@ from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
@@ -72,6 +73,15 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
logger.info(f"Found {len(existing_tenants)} existing tenants in schedule")
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if (
|
||||
IGNORED_SYNCING_TENANT_LIST
|
||||
and tenant_id in IGNORED_SYNCING_TENANT_LIST
|
||||
):
|
||||
logger.info(
|
||||
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
|
||||
)
|
||||
continue
|
||||
|
||||
if tenant_id not in existing_tenants:
|
||||
logger.info(f"Processing new tenant: {tenant_id}")
|
||||
|
||||
|
||||
@@ -91,5 +91,7 @@ def on_setup_logging(
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"danswer.background.celery.tasks.pruning",
|
||||
"danswer.background.celery.tasks.doc_permission_syncing",
|
||||
"danswer.background.celery.tasks.external_group_syncing",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ from celery import signals
|
||||
from celery import Task
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_process_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
@@ -59,7 +60,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=sender.concurrency)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
@@ -81,6 +82,11 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_process_init.connect
|
||||
def init_worker(**kwargs: Any) -> None:
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
|
||||
@@ -92,5 +92,6 @@ celery_app.autodiscover_tasks(
|
||||
"danswer.background.celery.tasks.shared",
|
||||
"danswer.background.celery.tasks.vespa",
|
||||
"danswer.background.celery.tasks.connector_deletion",
|
||||
"danswer.background.celery.tasks.doc_permission_syncing",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -20,6 +20,8 @@ from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndex
|
||||
from danswer.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from danswer.redis.redis_connector_stop import RedisConnectorStop
|
||||
@@ -134,6 +136,10 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
|
||||
RedisConnectorStop.reset_all(r)
|
||||
|
||||
RedisConnectorPermissionSync.reset_all(r)
|
||||
|
||||
RedisConnectorExternalGroupSync.reset_all(r)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
@@ -233,6 +239,8 @@ celery_app.autodiscover_tasks(
|
||||
"danswer.background.celery.tasks.connector_deletion",
|
||||
"danswer.background.celery.tasks.indexing",
|
||||
"danswer.background.celery.tasks.periodic",
|
||||
"danswer.background.celery.tasks.doc_permission_syncing",
|
||||
"danswer.background.celery.tasks.external_group_syncing",
|
||||
"danswer.background.celery.tasks.pruning",
|
||||
"danswer.background.celery.tasks.shared",
|
||||
"danswer.background.celery.tasks.vespa",
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from celery.beat import PersistentScheduler # type: ignore
|
||||
from celery.utils.log import get_task_logger
|
||||
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
logger = get_task_logger(__name__)
|
||||
|
||||
|
||||
class DynamicTenantScheduler(PersistentScheduler):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._reload_interval = timedelta(minutes=1)
|
||||
self._last_reload = self.app.now() - self._reload_interval
|
||||
|
||||
def setup_schedule(self) -> None:
|
||||
super().setup_schedule()
|
||||
|
||||
def tick(self) -> float:
|
||||
retval = super().tick()
|
||||
now = self.app.now()
|
||||
if (
|
||||
self._last_reload is None
|
||||
or (now - self._last_reload) > self._reload_interval
|
||||
):
|
||||
logger.info("Reloading schedule to check for new tenants...")
|
||||
self._update_tenant_tasks()
|
||||
self._last_reload = now
|
||||
return retval
|
||||
|
||||
def _update_tenant_tasks(self) -> None:
|
||||
logger.info("Checking for tenant task updates...")
|
||||
try:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
tasks_to_schedule = fetch_versioned_implementation(
|
||||
"danswer.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
|
||||
)
|
||||
|
||||
new_beat_schedule: dict[str, dict[str, Any]] = {}
|
||||
|
||||
current_schedule = getattr(self, "_store", {"entries": {}}).get(
|
||||
"entries", {}
|
||||
)
|
||||
|
||||
existing_tenants = set()
|
||||
for task_name in current_schedule.keys():
|
||||
if "-" in task_name:
|
||||
existing_tenants.add(task_name.split("-")[-1])
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if tenant_id not in existing_tenants:
|
||||
logger.info(f"Found new tenant: {tenant_id}")
|
||||
|
||||
for task in tasks_to_schedule():
|
||||
task_name = f"{task['name']}-{tenant_id}"
|
||||
new_task = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"kwargs": {"tenant_id": tenant_id},
|
||||
}
|
||||
if options := task.get("options"):
|
||||
new_task["options"] = options
|
||||
new_beat_schedule[task_name] = new_task
|
||||
|
||||
if self._should_update_schedule(current_schedule, new_beat_schedule):
|
||||
logger.info(
|
||||
"Updating schedule",
|
||||
extra={
|
||||
"new_tasks": len(new_beat_schedule),
|
||||
"current_tasks": len(current_schedule),
|
||||
},
|
||||
)
|
||||
if not hasattr(self, "_store"):
|
||||
self._store: dict[str, dict] = {"entries": {}}
|
||||
self.update_from_dict(new_beat_schedule)
|
||||
logger.info(f"New schedule: {new_beat_schedule}")
|
||||
|
||||
logger.info("Tenant tasks updated successfully")
|
||||
else:
|
||||
logger.debug("No schedule updates needed")
|
||||
|
||||
except (AttributeError, KeyError):
|
||||
logger.exception("Failed to process task configuration")
|
||||
except Exception:
|
||||
logger.exception("Unexpected error updating tenant tasks")
|
||||
|
||||
def _should_update_schedule(
|
||||
self, current_schedule: dict, new_schedule: dict
|
||||
) -> bool:
|
||||
"""Compare schedules to determine if an update is needed."""
|
||||
current_tasks = set(current_schedule.keys())
|
||||
new_tasks = set(new_schedule.keys())
|
||||
return current_tasks != new_tasks
|
||||
@@ -81,7 +81,7 @@ def extract_ids_from_runnable_connector(
|
||||
callback: RunIndexingCallbackInterface | None = None,
|
||||
) -> set[str]:
|
||||
"""
|
||||
If the PruneConnector hasnt been implemented for the given connector, just pull
|
||||
If the SlimConnector hasnt been implemented for the given connector, just pull
|
||||
all docs using the load_from_state and grab out the IDs.
|
||||
|
||||
Optionally, a callback can be passed to handle the length of each document batch.
|
||||
|
||||
@@ -8,7 +8,7 @@ tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": "check_for_vespa_sync_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
@@ -20,13 +20,13 @@ tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": "check_for_indexing",
|
||||
"schedule": timedelta(seconds=10),
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-prune",
|
||||
"task": "check_for_pruning",
|
||||
"schedule": timedelta(seconds=10),
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
@@ -41,6 +41,18 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-doc-permissions-sync",
|
||||
"task": "check_for_doc_permissions_sync",
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-external-group-sync",
|
||||
"task": "check_for_external_group_sync",
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -143,6 +143,12 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
f"cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
if redis_connector.permissions.fenced:
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (permissions in progress): "
|
||||
f"cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
redis_connector.delete.taskset_clear()
|
||||
|
||||
|
||||
@@ -0,0 +1,321 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
|
||||
from danswer.access.models import DocExternalAccess
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_doc_perm_sync import (
|
||||
RedisConnectorPermissionSyncData,
|
||||
)
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import doc_permission_sync_ctx
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.danswer.db.document import upsert_document_external_perms
|
||||
from ee.danswer.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS
|
||||
from ee.danswer.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES = 3
|
||||
|
||||
|
||||
# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT
|
||||
LIGHT_SOFT_TIME_LIMIT = 105
|
||||
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
|
||||
|
||||
|
||||
def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
"""Returns boolean indicating if external doc permissions sync is due."""
|
||||
|
||||
if cc_pair.access_type != AccessType.SYNC:
|
||||
return False
|
||||
|
||||
# skip doc permissions sync if not active
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
||||
return False
|
||||
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return False
|
||||
|
||||
# If the last sync is None, it has never been run so we run the sync
|
||||
last_perm_sync = cc_pair.last_time_perm_sync
|
||||
if last_perm_sync is None:
|
||||
return True
|
||||
|
||||
source_sync_period = DOC_PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
|
||||
|
||||
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
|
||||
if not source_sync_period:
|
||||
return True
|
||||
|
||||
# If the last sync is greater than the full fetch period, we run the sync
|
||||
next_sync = last_perm_sync + timedelta(seconds=source_sync_period)
|
||||
if datetime.now(timezone.utc) >= next_sync:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_doc_permissions_sync",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
# get all cc pairs that need to be synced
|
||||
cc_pair_ids_to_sync: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
|
||||
for cc_pair in cc_pairs:
|
||||
if _is_external_doc_permissions_sync_due(cc_pair):
|
||||
cc_pair_ids_to_sync.append(cc_pair.id)
|
||||
|
||||
for cc_pair_id in cc_pair_ids_to_sync:
|
||||
tasks_created = try_creating_permissions_sync_task(
|
||||
self.app, cc_pair_id, r, tenant_id
|
||||
)
|
||||
if not tasks_created:
|
||||
continue
|
||||
|
||||
task_logger.info(f"Doc permissions sync queued: cc_pair={cc_pair_id}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def try_creating_permissions_sync_task(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Returns None if no syncing is required."""
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
lock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
|
||||
if not acquired:
|
||||
return None
|
||||
|
||||
try:
|
||||
if redis_connector.permissions.fenced:
|
||||
return None
|
||||
|
||||
if redis_connector.delete.fenced:
|
||||
return None
|
||||
|
||||
if redis_connector.prune.fenced:
|
||||
return None
|
||||
|
||||
redis_connector.permissions.generator_clear()
|
||||
redis_connector.permissions.taskset_clear()
|
||||
|
||||
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
|
||||
|
||||
app.send_task(
|
||||
"connector_permission_sync_generator_task",
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair_id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.HIGH,
|
||||
)
|
||||
|
||||
# set a basic fence to start
|
||||
payload = RedisConnectorPermissionSyncData(
|
||||
started=None,
|
||||
)
|
||||
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return 1
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="connector_permission_sync_generator_task",
|
||||
acks_late=False,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def connector_permission_sync_generator_task(
|
||||
self: Task,
|
||||
cc_pair_id: int,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
Permission sync task that handles document permission syncing for a given connector credential pair
|
||||
This task assumes that the task has already been properly fenced
|
||||
"""
|
||||
|
||||
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
|
||||
doc_permission_sync_ctx_dict["cc_pair_id"] = cc_pair_id
|
||||
doc_permission_sync_ctx_dict["request_id"] = self.request.id
|
||||
doc_permission_sync_ctx.set(doc_permission_sync_ctx_dict)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
f"Permission sync task already running, exiting...: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if cc_pair is None:
|
||||
raise ValueError(
|
||||
f"No connector credential pair found for id: {cc_pair_id}"
|
||||
)
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
|
||||
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
if doc_sync_func is None:
|
||||
raise ValueError(f"No doc sync func found for {source_type}")
|
||||
|
||||
logger.info(f"Syncing docs for {source_type}")
|
||||
|
||||
payload = RedisConnectorPermissionSyncData(
|
||||
started=datetime.now(timezone.utc),
|
||||
)
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
|
||||
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
)
|
||||
tasks_generated = redis_connector.permissions.generate_tasks(
|
||||
self.app, lock, document_external_accesses, source_type
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.permissions.generate_tasks finished. "
|
||||
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
redis_connector.permissions.generator_complete = tasks_generated
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(f"Failed to run permission sync: cc_pair={cc_pair_id}")
|
||||
|
||||
redis_connector.permissions.generator_clear()
|
||||
redis_connector.permissions.taskset_clear()
|
||||
redis_connector.permissions.set_fence(None)
|
||||
raise e
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="update_external_document_permissions_task",
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES,
|
||||
bind=True,
|
||||
)
|
||||
def update_external_document_permissions_task(
|
||||
self: Task,
|
||||
tenant_id: str | None,
|
||||
serialized_doc_external_access: dict,
|
||||
source_string: str,
|
||||
) -> bool:
|
||||
document_external_access = DocExternalAccess.from_dict(
|
||||
serialized_doc_external_access
|
||||
)
|
||||
doc_id = document_external_access.doc_id
|
||||
external_access = document_external_access.external_access
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Then we build the update requests to update vespa
|
||||
batch_add_non_web_user_if_not_exists(
|
||||
db_session=db_session,
|
||||
emails=list(external_access.external_user_emails),
|
||||
)
|
||||
upsert_document_external_perms(
|
||||
db_session=db_session,
|
||||
doc_id=doc_id,
|
||||
external_access=external_access,
|
||||
source_type=DocumentSource(source_string),
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Successfully synced postgres document permissions for {doc_id}"
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Error Syncing Document Permissions")
|
||||
return False
|
||||
@@ -0,0 +1,265 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector import mark_cc_pair_as_external_group_synced
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.danswer.db.external_perm import ExternalUserGroup
|
||||
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair
|
||||
from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIOD
|
||||
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
EXTERNAL_GROUPS_UPDATE_MAX_RETRIES = 3
|
||||
|
||||
|
||||
# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT
|
||||
LIGHT_SOFT_TIME_LIMIT = 105
|
||||
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
|
||||
|
||||
|
||||
def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
"""Returns boolean indicating if external group sync is due."""
|
||||
|
||||
if cc_pair.access_type != AccessType.SYNC:
|
||||
return False
|
||||
|
||||
# skip pruning if not active
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
||||
return False
|
||||
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return False
|
||||
|
||||
# If there is not group sync function for the connector, we don't run the sync
|
||||
# This is fine because all sources dont necessarily have a concept of groups
|
||||
if not GROUP_PERMISSIONS_FUNC_MAP.get(cc_pair.connector.source):
|
||||
return False
|
||||
|
||||
# If the last sync is None, it has never been run so we run the sync
|
||||
last_ext_group_sync = cc_pair.last_time_external_group_sync
|
||||
if last_ext_group_sync is None:
|
||||
return True
|
||||
|
||||
source_sync_period = EXTERNAL_GROUP_SYNC_PERIOD
|
||||
|
||||
# If EXTERNAL_GROUP_SYNC_PERIOD is None, we always run the sync.
|
||||
if not source_sync_period:
|
||||
return True
|
||||
|
||||
# If the last sync is greater than the full fetch period, we run the sync
|
||||
next_sync = last_ext_group_sync + timedelta(seconds=source_sync_period)
|
||||
if datetime.now(timezone.utc) >= next_sync:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_external_group_sync",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
cc_pair_ids_to_sync: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
|
||||
for cc_pair in cc_pairs:
|
||||
if _is_external_group_sync_due(cc_pair):
|
||||
cc_pair_ids_to_sync.append(cc_pair.id)
|
||||
|
||||
for cc_pair_id in cc_pair_ids_to_sync:
|
||||
tasks_created = try_creating_permissions_sync_task(
|
||||
self.app, cc_pair_id, r, tenant_id
|
||||
)
|
||||
if not tasks_created:
|
||||
continue
|
||||
|
||||
task_logger.info(f"External group sync queued: cc_pair={cc_pair_id}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def try_creating_permissions_sync_task(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Returns None if no syncing is required."""
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
lock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_external_group_sync_tasks",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
|
||||
if not acquired:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Dont kick off a new sync if the previous one is still running
|
||||
if redis_connector.external_group_sync.fenced:
|
||||
return None
|
||||
|
||||
redis_connector.external_group_sync.generator_clear()
|
||||
redis_connector.external_group_sync.taskset_clear()
|
||||
|
||||
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
|
||||
|
||||
_ = app.send_task(
|
||||
"connector_external_group_sync_generator_task",
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair_id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.HIGH,
|
||||
)
|
||||
# set a basic fence to start
|
||||
redis_connector.external_group_sync.set_fence(True)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return 1
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="connector_external_group_sync_generator_task",
|
||||
acks_late=False,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def connector_external_group_sync_generator_task(
|
||||
self: Task,
|
||||
cc_pair_id: int,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
Permission sync task that handles document permission syncing for a given connector credential pair
|
||||
This task assumes that the task has already been properly fenced
|
||||
"""
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if cc_pair is None:
|
||||
raise ValueError(
|
||||
f"No connector credential pair found for id: {cc_pair_id}"
|
||||
)
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
|
||||
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
if ext_group_sync_func is None:
|
||||
raise ValueError(f"No external group sync func found for {source_type}")
|
||||
|
||||
logger.info(f"Syncing docs for {source_type}")
|
||||
|
||||
external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)
|
||||
|
||||
logger.info(
|
||||
f"Syncing {len(external_user_groups)} external user groups for {source_type}"
|
||||
)
|
||||
|
||||
replace_user__ext_group_for_cc_pair(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
group_defs=external_user_groups,
|
||||
source=cc_pair.connector.source,
|
||||
)
|
||||
logger.info(
|
||||
f"Synced {len(external_user_groups)} external user groups for {source_type}"
|
||||
)
|
||||
|
||||
mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Failed to run external group sync: cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
redis_connector.external_group_sync.generator_clear()
|
||||
redis_connector.external_group_sync.taskset_clear()
|
||||
raise e
|
||||
finally:
|
||||
# we always want to clear the fence after the task is done or failed so it doesn't get stuck
|
||||
redis_connector.external_group_sync.set_fence(False)
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
@@ -38,6 +38,35 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
"""Returns boolean indicating if pruning is due."""
|
||||
|
||||
# skip pruning if no prune frequency is set
|
||||
# pruning can still be forced via the API which will run a pruning task directly
|
||||
if not cc_pair.connector.prune_freq:
|
||||
return False
|
||||
|
||||
# skip pruning if not active
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
||||
return False
|
||||
|
||||
# skip pruning if the next scheduled prune time hasn't been reached yet
|
||||
last_pruned = cc_pair.last_pruned
|
||||
if not last_pruned:
|
||||
if not cc_pair.last_successful_index_time:
|
||||
# if we've never indexed, we can't prune
|
||||
return False
|
||||
|
||||
# if never pruned, use the last time the connector indexed successfully
|
||||
last_pruned = cc_pair.last_successful_index_time
|
||||
|
||||
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
|
||||
if datetime.now(timezone.utc) < next_prune:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_pruning",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
@@ -69,7 +98,7 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
|
||||
if not cc_pair:
|
||||
continue
|
||||
|
||||
if not is_pruning_due(cc_pair, db_session, r):
|
||||
if not _is_pruning_due(cc_pair):
|
||||
continue
|
||||
|
||||
tasks_created = try_creating_prune_generator_task(
|
||||
@@ -90,47 +119,6 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def is_pruning_due(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
) -> bool:
|
||||
"""Returns an int if pruning is triggered.
|
||||
The int represents the number of prune tasks generated (in this case, only one
|
||||
because the task is a long running generator task.)
|
||||
Returns None if no pruning is triggered (due to not being needed or
|
||||
other reasons such as simultaneous pruning restrictions.
|
||||
|
||||
Checks for scheduling related conditions, then delegates the rest of the checks to
|
||||
try_creating_prune_generator_task.
|
||||
"""
|
||||
|
||||
# skip pruning if no prune frequency is set
|
||||
# pruning can still be forced via the API which will run a pruning task directly
|
||||
if not cc_pair.connector.prune_freq:
|
||||
return False
|
||||
|
||||
# skip pruning if not active
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
||||
return False
|
||||
|
||||
# skip pruning if the next scheduled prune time hasn't been reached yet
|
||||
last_pruned = cc_pair.last_pruned
|
||||
if not last_pruned:
|
||||
if not cc_pair.last_successful_index_time:
|
||||
# if we've never indexed, we can't prune
|
||||
return False
|
||||
|
||||
# if never pruned, use the last time the connector indexed successfully
|
||||
last_pruned = cc_pair.last_successful_index_time
|
||||
|
||||
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
|
||||
if datetime.now(timezone.utc) < next_prune:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_prune_generator_task(
|
||||
celery_app: Celery,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
@@ -166,10 +154,16 @@ def try_creating_prune_generator_task(
|
||||
return None
|
||||
|
||||
try:
|
||||
if redis_connector.prune.fenced: # skip pruning if already pruning
|
||||
# skip pruning if already pruning
|
||||
if redis_connector.prune.fenced:
|
||||
return None
|
||||
|
||||
if redis_connector.delete.fenced: # skip pruning if the cc_pair is deleting
|
||||
# skip pruning if the cc_pair is deleting
|
||||
if redis_connector.delete.fenced:
|
||||
return None
|
||||
|
||||
# skip pruning if doc permissions sync is running
|
||||
if redis_connector.permissions.fenced:
|
||||
return None
|
||||
|
||||
db_session.refresh(cc_pair)
|
||||
|
||||
@@ -59,7 +59,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
connector / credential pair from the access list
|
||||
(6) delete all relevant entries from postgres
|
||||
"""
|
||||
task_logger.info(f"tenant={tenant_id} doc={document_id}")
|
||||
task_logger.debug(f"Task start: tenant={tenant_id} doc={document_id}")
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -141,7 +141,9 @@ def document_by_cc_pair_cleanup_task(
|
||||
return False
|
||||
except Exception as ex:
|
||||
if isinstance(ex, RetryError):
|
||||
task_logger.info(f"Retry failed: {ex.last_attempt.attempt_number}")
|
||||
task_logger.warning(
|
||||
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
|
||||
)
|
||||
|
||||
# only set the inner exception if it is of type Exception
|
||||
e_temp = ex.last_attempt.exception()
|
||||
@@ -171,8 +173,8 @@ def document_by_cc_pair_cleanup_task(
|
||||
else:
|
||||
# This is the last attempt! mark the document as dirty in the db so that it
|
||||
# eventually gets fixed out of band via stale document reconciliation
|
||||
task_logger.info(
|
||||
f"Max retries reached. Marking doc as dirty for reconciliation: "
|
||||
task_logger.warning(
|
||||
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
|
||||
f"tenant={tenant_id} doc={document_id}"
|
||||
)
|
||||
with get_session_with_tenant(tenant_id):
|
||||
|
||||
@@ -27,6 +27,7 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector import mark_cc_pair_as_permissions_synced
|
||||
from danswer.db.connector import mark_ccpair_as_pruned
|
||||
from danswer.db.connector_credential_pair import add_deletion_failure_message
|
||||
from danswer.db.connector_credential_pair import (
|
||||
@@ -58,6 +59,10 @@ from danswer.document_index.interfaces import VespaDocumentFields
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from danswer.redis.redis_connector_doc_perm_sync import (
|
||||
RedisConnectorPermissionSyncData,
|
||||
)
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndex
|
||||
from danswer.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from danswer.redis.redis_document_set import RedisDocumentSet
|
||||
@@ -546,6 +551,47 @@ def monitor_ccpair_pruning_taskset(
|
||||
redis_connector.prune.set_fence(False)
|
||||
|
||||
|
||||
def monitor_ccpair_permissions_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"monitor_ccpair_permissions_taskset: could not parse cc_pair_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if not redis_connector.permissions.fenced:
|
||||
return
|
||||
|
||||
initial = redis_connector.permissions.generator_complete
|
||||
if initial is None:
|
||||
return
|
||||
|
||||
remaining = redis_connector.permissions.get_remaining()
|
||||
task_logger.info(
|
||||
f"Permissions sync progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
|
||||
)
|
||||
if remaining > 0:
|
||||
return
|
||||
|
||||
payload: RedisConnectorPermissionSyncData | None = (
|
||||
redis_connector.permissions.payload
|
||||
)
|
||||
start_time: datetime | None = payload.started if payload else None
|
||||
|
||||
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
|
||||
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
|
||||
|
||||
redis_connector.permissions.taskset_clear()
|
||||
redis_connector.permissions.generator_clear()
|
||||
redis_connector.permissions.set_fence(None)
|
||||
|
||||
|
||||
def monitor_ccpair_indexing_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
@@ -668,13 +714,17 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
n_pruning = celery_get_queue_length(
|
||||
DanswerCeleryQueues.CONNECTOR_PRUNING, r_celery
|
||||
)
|
||||
n_permissions_sync = celery_get_queue_length(
|
||||
DanswerCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Queue lengths: celery={n_celery} "
|
||||
f"indexing={n_indexing} "
|
||||
f"sync={n_sync} "
|
||||
f"deletion={n_deletion} "
|
||||
f"pruning={n_pruning}"
|
||||
f"pruning={n_pruning} "
|
||||
f"permissions_sync={n_permissions_sync} "
|
||||
)
|
||||
|
||||
# do some cleanup before clearing fences
|
||||
@@ -688,20 +738,22 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
|
||||
)
|
||||
|
||||
for a in attempts:
|
||||
for attempt in attempts:
|
||||
# if attempts exist in the db but we don't detect them in redis, mark them as failed
|
||||
fence_key = RedisConnectorIndex.fence_key_with_ids(
|
||||
a.connector_credential_pair_id, a.search_settings_id
|
||||
attempt.connector_credential_pair_id, attempt.search_settings_id
|
||||
)
|
||||
if not r.exists(fence_key):
|
||||
failure_reason = (
|
||||
f"Unknown index attempt. Might be left over from a process restart: "
|
||||
f"index_attempt={a.id} "
|
||||
f"cc_pair={a.connector_credential_pair_id} "
|
||||
f"search_settings={a.search_settings_id}"
|
||||
f"index_attempt={attempt.id} "
|
||||
f"cc_pair={attempt.connector_credential_pair_id} "
|
||||
f"search_settings={attempt.search_settings_id}"
|
||||
)
|
||||
task_logger.warning(failure_reason)
|
||||
mark_attempt_failed(a.id, db_session, failure_reason=failure_reason)
|
||||
mark_attempt_failed(
|
||||
attempt.id, db_session, failure_reason=failure_reason
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
@@ -741,6 +793,12 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
# uncomment for debugging if needed
|
||||
# r_celery = celery_app.broker_connection().channel().client
|
||||
# length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery)
|
||||
@@ -811,7 +869,9 @@ def vespa_metadata_sync_task(
|
||||
)
|
||||
except Exception as ex:
|
||||
if isinstance(ex, RetryError):
|
||||
task_logger.warning(f"Retry failed: {ex.last_attempt.attempt_number}")
|
||||
task_logger.warning(
|
||||
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
|
||||
)
|
||||
|
||||
# only set the inner exception if it is of type Exception
|
||||
e_temp = ex.last_attempt.exception()
|
||||
|
||||
@@ -29,18 +29,26 @@ JobStatusType = (
|
||||
def _initializer(
|
||||
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""Ensure the parent proc's database connections are not touched
|
||||
in the new connection pool
|
||||
"""Initialize the child process with a fresh SQLAlchemy Engine.
|
||||
|
||||
Based on the recommended approach in the SQLAlchemy docs found:
|
||||
Based on SQLAlchemy's recommendations to handle multiprocessing:
|
||||
https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
logger.info("Initializing spawned worker child process.")
|
||||
|
||||
# Reset the engine in the child process
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
# Optionally set a custom app name for database logging purposes
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
|
||||
|
||||
# Initialize a new engine with desired parameters
|
||||
SqlEngine.init_engine(pool_size=4, max_overflow=12, pool_recycle=60)
|
||||
|
||||
# Proceed with executing the target function
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -33,8 +33,8 @@ from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.indexing.embedder import DefaultIndexingEmbedder
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
|
||||
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from danswer.utils.logger import IndexAttemptSingleton
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.logger import TaskAttemptSingleton
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -427,7 +427,7 @@ def run_indexing_entrypoint(
|
||||
|
||||
# set the indexing attempt ID so that all log messages from this process
|
||||
# will have it added as a prefix
|
||||
IndexAttemptSingleton.set_cc_and_index_id(
|
||||
TaskAttemptSingleton.set_cc_and_index_id(
|
||||
index_attempt_id, connector_credential_pair_id
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
def name_sync_external_doc_permissions_task(
|
||||
cc_pair_id: int, tenant_id: str | None = None
|
||||
) -> str:
|
||||
return f"sync_external_doc_permissions_task__{cc_pair_id}"
|
||||
@@ -14,15 +14,6 @@ from danswer.db.tasks import mark_task_start
|
||||
from danswer.db.tasks import register_task
|
||||
|
||||
|
||||
def name_cc_prune_task(
|
||||
connector_id: int | None = None, credential_id: int | None = None
|
||||
) -> str:
|
||||
task_name = f"prune_connector_credential_pair_{connector_id}_{credential_id}"
|
||||
if not connector_id or not credential_id:
|
||||
task_name = "prune_connector_credential_pair"
|
||||
return task_name
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Callable)
|
||||
|
||||
|
||||
|
||||
@@ -19,16 +19,10 @@ from danswer.chat.models import MessageSpecificCitations
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
|
||||
from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
|
||||
from danswer.configs.chat_configs import BING_API_KEY
|
||||
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.db.chat import attach_files_to_chat_message
|
||||
from danswer.db.chat import create_db_search_doc
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
@@ -41,7 +35,6 @@ from danswer.db.chat import reserve_message_id
|
||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.models import SearchDoc as DbSearchDoc
|
||||
from danswer.db.models import ToolCall
|
||||
from danswer.db.models import User
|
||||
@@ -61,14 +54,13 @@ from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import litellm_exception_to_error_msg
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.search.retrieval.search_runner import inference_sections_from_ids
|
||||
from danswer.search.utils import chunks_or_sections_to_search_docs
|
||||
from danswer.search.utils import dedupe_documents
|
||||
@@ -77,14 +69,14 @@ from danswer.search.utils import relevant_sections_to_indices
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.tools.built_in_tools import get_built_in_tool_by_id
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import DynamicSchemaInfo
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from danswer.tools.tool_constructor import construct_tools
|
||||
from danswer.tools.tool_constructor import CustomToolConfig
|
||||
from danswer.tools.tool_constructor import ImageGenerationToolConfig
|
||||
from danswer.tools.tool_constructor import InternetSearchToolConfig
|
||||
from danswer.tools.tool_constructor import SearchToolConfig
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
CUSTOM_TOOL_RESPONSE_ID,
|
||||
)
|
||||
@@ -95,9 +87,6 @@ from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationResponse,
|
||||
)
|
||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
INTERNET_SEARCH_RESPONSE_ID,
|
||||
)
|
||||
@@ -122,9 +111,6 @@ from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SECTION_RELEVANCE_LIST_ID,
|
||||
)
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.tools.utils import compute_all_tool_tokens
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
from danswer.utils.headers import header_dict_to_header_list
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
|
||||
@@ -295,7 +281,6 @@ def stream_chat_message_objects(
|
||||
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
|
||||
# if specified, uses the last user message and does not create a new user message based
|
||||
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
@@ -307,6 +292,9 @@ def stream_chat_message_objects(
|
||||
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
|
||||
4. [always] Details on the final AI response message that is created
|
||||
"""
|
||||
use_existing_user_message = new_msg_req.use_existing_user_message
|
||||
existing_assistant_message_id = new_msg_req.existing_assistant_message_id
|
||||
|
||||
# Currently surrounding context is not supported for chat
|
||||
# Chat is already token heavy and harder for the model to process plus it would roll history over much faster
|
||||
new_msg_req.chunks_above = 0
|
||||
@@ -428,12 +416,20 @@ def stream_chat_message_objects(
|
||||
final_msg, history_msgs = create_chat_chain(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)
|
||||
if final_msg.message_type != MessageType.USER:
|
||||
raise RuntimeError(
|
||||
"The last message was not a user message. Cannot call "
|
||||
"`stream_chat_message_objects` with `is_regenerate=True` "
|
||||
"when the last message is not a user message."
|
||||
)
|
||||
if existing_assistant_message_id is None:
|
||||
if final_msg.message_type != MessageType.USER:
|
||||
raise RuntimeError(
|
||||
"The last message was not a user message. Cannot call "
|
||||
"`stream_chat_message_objects` with `is_regenerate=True` "
|
||||
"when the last message is not a user message."
|
||||
)
|
||||
else:
|
||||
if final_msg.id != existing_assistant_message_id:
|
||||
raise RuntimeError(
|
||||
"The last message was not the existing assistant message. "
|
||||
f"Final message id: {final_msg.id}, "
|
||||
f"existing assistant message id: {existing_assistant_message_id}"
|
||||
)
|
||||
|
||||
# Disable Query Rephrasing for the first message
|
||||
# This leads to a better first response since the LLM rephrasing the question
|
||||
@@ -504,13 +500,19 @@ def stream_chat_message_objects(
|
||||
),
|
||||
max_window_percentage=max_document_percentage,
|
||||
)
|
||||
reserved_message_id = reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=user_message.id
|
||||
if user_message is not None
|
||||
else parent_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
|
||||
# we don't need to reserve a message id if we're using an existing assistant message
|
||||
reserved_message_id = (
|
||||
final_msg.id
|
||||
if existing_assistant_message_id is not None
|
||||
else reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=user_message.id
|
||||
if user_message is not None
|
||||
else parent_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
)
|
||||
yield MessageResponseIDInfo(
|
||||
user_message_id=user_message.id if user_message else None,
|
||||
@@ -525,7 +527,13 @@ def stream_chat_message_objects(
|
||||
partial_response = partial(
|
||||
create_new_chat_message,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=final_msg,
|
||||
# if we're using an existing assistant message, then this will just be an
|
||||
# update operation, in which case the parent should be the parent of
|
||||
# the latest. If we're creating a new assistant message, then the parent
|
||||
# should be the latest message (latest user message)
|
||||
parent_message=(
|
||||
final_msg if existing_assistant_message_id is None else parent_message
|
||||
),
|
||||
prompt_id=prompt_id,
|
||||
overridden_model=overridden_model,
|
||||
# message=,
|
||||
@@ -537,6 +545,7 @@ def stream_chat_message_objects(
|
||||
# reference_docs=,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
reserved_message_id=reserved_message_id,
|
||||
)
|
||||
|
||||
if not final_msg.prompt:
|
||||
@@ -560,142 +569,39 @@ def stream_chat_message_objects(
|
||||
structured_response_format=new_msg_req.structured_response_format,
|
||||
)
|
||||
|
||||
# find out what tools to use
|
||||
search_tool: SearchTool | None = None
|
||||
tool_dict: dict[int, list[Tool]] = {} # tool_id to tool
|
||||
for db_tool_model in persona.tools:
|
||||
# handle in-code tools specially
|
||||
if db_tool_model.in_code_tool_id:
|
||||
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
|
||||
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
persona=persona,
|
||||
retrieval_options=retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=document_pruning_config,
|
||||
answer_style_config=answer_style_config,
|
||||
selected_sections=selected_sections,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
),
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [search_tool]
|
||||
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
||||
img_generation_llm_config: LLMConfig | None = None
|
||||
if (
|
||||
llm
|
||||
and llm.config.api_key
|
||||
and llm.config.model_provider == "openai"
|
||||
):
|
||||
img_generation_llm_config = LLMConfig(
|
||||
model_provider=llm.config.model_provider,
|
||||
model_name="dall-e-3",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=llm.config.api_key,
|
||||
api_base=llm.config.api_base,
|
||||
api_version=llm.config.api_version,
|
||||
)
|
||||
elif (
|
||||
llm.config.model_provider == "azure"
|
||||
and AZURE_DALLE_API_KEY is not None
|
||||
):
|
||||
img_generation_llm_config = LLMConfig(
|
||||
model_provider="azure",
|
||||
model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=AZURE_DALLE_API_KEY,
|
||||
api_base=AZURE_DALLE_API_BASE,
|
||||
api_version=AZURE_DALLE_API_VERSION,
|
||||
)
|
||||
else:
|
||||
llm_providers = fetch_existing_llm_providers(db_session)
|
||||
openai_provider = next(
|
||||
iter(
|
||||
[
|
||||
llm_provider
|
||||
for llm_provider in llm_providers
|
||||
if llm_provider.provider == "openai"
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not openai_provider or not openai_provider.api_key:
|
||||
raise ValueError(
|
||||
"Image generation tool requires an OpenAI API key"
|
||||
)
|
||||
img_generation_llm_config = LLMConfig(
|
||||
model_provider=openai_provider.provider,
|
||||
model_name="dall-e-3",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=openai_provider.api_key,
|
||||
api_base=openai_provider.api_base,
|
||||
api_version=openai_provider.api_version,
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
ImageGenerationTool(
|
||||
api_key=cast(str, img_generation_llm_config.api_key),
|
||||
api_base=img_generation_llm_config.api_base,
|
||||
api_version=img_generation_llm_config.api_version,
|
||||
additional_headers=litellm_additional_headers,
|
||||
model=img_generation_llm_config.model_name,
|
||||
)
|
||||
]
|
||||
elif tool_cls.__name__ == InternetSearchTool.__name__:
|
||||
bing_api_key = BING_API_KEY
|
||||
if not bing_api_key:
|
||||
raise ValueError(
|
||||
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
InternetSearchTool(
|
||||
api_key=bing_api_key,
|
||||
answer_style_config=answer_style_config,
|
||||
prompt_config=prompt_config,
|
||||
)
|
||||
]
|
||||
|
||||
continue
|
||||
|
||||
# handle all custom tools
|
||||
if db_tool_model.openapi_schema:
|
||||
tool_dict[db_tool_model.id] = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema_and_headers(
|
||||
db_tool_model.openapi_schema,
|
||||
dynamic_schema_info=DynamicSchemaInfo(
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=user_message.id if user_message else None,
|
||||
),
|
||||
custom_headers=(db_tool_model.custom_headers or [])
|
||||
+ (
|
||||
header_dict_to_header_list(
|
||||
custom_tool_additional_headers or {}
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
prompt_config=prompt_config,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
document_pruning_config=document_pruning_config,
|
||||
retrieval_options=retrieval_options or RetrievalDetails(),
|
||||
selected_sections=selected_sections,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
latest_query_files=latest_query_files,
|
||||
),
|
||||
internet_search_tool_config=InternetSearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
),
|
||||
image_generation_tool_config=ImageGenerationToolConfig(
|
||||
additional_headers=litellm_additional_headers,
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=user_message.id if user_message else None,
|
||||
additional_headers=custom_tool_additional_headers,
|
||||
),
|
||||
)
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
|
||||
# factor in tool definition size when pruning
|
||||
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
|
||||
tools, llm_tokenizer
|
||||
)
|
||||
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
|
||||
llm_provider, llm_model_name
|
||||
)
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
answer = Answer(
|
||||
is_connected=is_connected,
|
||||
@@ -871,7 +777,6 @@ def stream_chat_message_objects(
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
reserved_message_id=reserved_message_id,
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=(
|
||||
qa_docs_response.rephrased_query if qa_docs_response else None
|
||||
@@ -879,9 +784,11 @@ def stream_chat_message_objects(
|
||||
reference_docs=reference_db_search_docs,
|
||||
files=ai_message_files,
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
citations=message_specific_citations.citation_map
|
||||
if message_specific_citations
|
||||
else None,
|
||||
citations=(
|
||||
message_specific_citations.citation_map
|
||||
if message_specific_citations
|
||||
else None
|
||||
),
|
||||
error=None,
|
||||
tool_call=(
|
||||
ToolCall(
|
||||
@@ -915,7 +822,6 @@ def stream_chat_message_objects(
|
||||
def stream_chat_message(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
@@ -925,7 +831,6 @@ def stream_chat_message(
|
||||
new_msg_req=new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
use_existing_user_message=use_existing_user_message,
|
||||
litellm_additional_headers=litellm_additional_headers,
|
||||
custom_tool_additional_headers=custom_tool_additional_headers,
|
||||
is_connected=is_connected,
|
||||
|
||||
@@ -503,3 +503,7 @@ _API_KEY_HASH_ROUNDS_RAW = os.environ.get("API_KEY_HASH_ROUNDS")
|
||||
API_KEY_HASH_ROUNDS = (
|
||||
int(_API_KEY_HASH_ROUNDS_RAW) if _API_KEY_HASH_ROUNDS_RAW else None
|
||||
)
|
||||
|
||||
|
||||
POD_NAME = os.environ.get("POD_NAME")
|
||||
POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
|
||||
|
||||
@@ -80,6 +80,10 @@ CELERY_INDEXING_LOCK_TIMEOUT = 60 * 60 # 60 min
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
|
||||
|
||||
|
||||
@@ -209,9 +213,17 @@ class PostgresAdvisoryLocks(Enum):
|
||||
|
||||
|
||||
class DanswerCeleryQueues:
|
||||
# Light queue
|
||||
VESPA_METADATA_SYNC = "vespa_metadata_sync"
|
||||
DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert"
|
||||
CONNECTOR_DELETION = "connector_deletion"
|
||||
|
||||
# Heavy queue
|
||||
CONNECTOR_PRUNING = "connector_pruning"
|
||||
CONNECTOR_DOC_PERMISSIONS_SYNC = "connector_doc_permissions_sync"
|
||||
CONNECTOR_EXTERNAL_GROUP_SYNC = "connector_external_group_sync"
|
||||
|
||||
# Indexing queue
|
||||
CONNECTOR_INDEXING = "connector_indexing"
|
||||
|
||||
|
||||
@@ -221,8 +233,18 @@ class DanswerRedisLocks:
|
||||
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
|
||||
CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat"
|
||||
CHECK_INDEXING_BEAT_LOCK = "da_lock:check_indexing_beat"
|
||||
CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK = (
|
||||
"da_lock:check_connector_doc_permissions_sync_beat"
|
||||
)
|
||||
CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK = (
|
||||
"da_lock:check_connector_external_group_sync_beat"
|
||||
)
|
||||
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
|
||||
|
||||
CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX = (
|
||||
"da_lock:connector_doc_permissions_sync"
|
||||
)
|
||||
CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX = "da_lock:connector_external_group_sync"
|
||||
PRUNING_LOCK_PREFIX = "da_lock:pruning"
|
||||
INDEXING_METADATA_PREFIX = "da_metadata:indexing"
|
||||
|
||||
|
||||
@@ -119,3 +119,14 @@ if _LITELLM_PASS_THROUGH_HEADERS_RAW:
|
||||
logger.error(
|
||||
"Failed to parse LITELLM_PASS_THROUGH_HEADERS, must be a valid JSON object"
|
||||
)
|
||||
|
||||
|
||||
# if specified, will merge the specified JSON with the existing body of the
|
||||
# request before sending it to the LLM
|
||||
LITELLM_EXTRA_BODY: dict | None = None
|
||||
_LITELLM_EXTRA_BODY_RAW = os.environ.get("LITELLM_EXTRA_BODY")
|
||||
if _LITELLM_EXTRA_BODY_RAW:
|
||||
try:
|
||||
LITELLM_EXTRA_BODY = json.loads(_LITELLM_EXTRA_BODY_RAW)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -146,7 +146,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
# The url and the id are the same
|
||||
object_url = build_confluence_document_id(
|
||||
self.wiki_base, confluence_object["_links"]["webui"]
|
||||
self.wiki_base, confluence_object["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
|
||||
object_text = None
|
||||
@@ -278,7 +278,9 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
doc_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=build_confluence_document_id(
|
||||
self.wiki_base, page["_links"]["webui"]
|
||||
self.wiki_base,
|
||||
page["_links"]["webui"],
|
||||
self.is_cloud,
|
||||
),
|
||||
perm_sync_data=perm_sync_data,
|
||||
)
|
||||
@@ -293,7 +295,9 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
doc_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=build_confluence_document_id(
|
||||
self.wiki_base, attachment["_links"]["webui"]
|
||||
self.wiki_base,
|
||||
attachment["_links"]["webui"],
|
||||
self.is_cloud,
|
||||
),
|
||||
perm_sync_data=perm_sync_data,
|
||||
)
|
||||
|
||||
@@ -100,6 +100,39 @@ def extract_text_from_confluence_html(
|
||||
continue
|
||||
# Include @ sign for tagging, more clear for LLM
|
||||
user.replaceWith("@" + _get_user(confluence_client, user_id))
|
||||
|
||||
for html_page_reference in soup.findAll("ri:page"):
|
||||
# Wrap this in a try-except because there are some pages that might not exist
|
||||
try:
|
||||
page_title = html_page_reference.attrs["ri:content-title"]
|
||||
if not page_title:
|
||||
continue
|
||||
|
||||
page_query = f"type=page and title='{page_title}'"
|
||||
|
||||
page_contents: dict[str, Any] | None = None
|
||||
# Confluence enforces title uniqueness, so we should only get one result here
|
||||
for page_batch in confluence_client.paginated_cql_page_retrieval(
|
||||
cql=page_query,
|
||||
expand="body.storage.value",
|
||||
limit=1,
|
||||
):
|
||||
page_contents = page_batch[0]
|
||||
break
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Error getting page contents for object {confluence_object}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not page_contents:
|
||||
continue
|
||||
text_from_page = extract_text_from_confluence_html(
|
||||
confluence_client, page_contents
|
||||
)
|
||||
|
||||
html_page_reference.replaceWith(text_from_page)
|
||||
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
@@ -153,7 +186,9 @@ def attachment_to_content(
|
||||
return extracted_text
|
||||
|
||||
|
||||
def build_confluence_document_id(base_url: str, content_url: str) -> str:
|
||||
def build_confluence_document_id(
|
||||
base_url: str, content_url: str, is_cloud: bool
|
||||
) -> str:
|
||||
"""For confluence, the document id is the page url for a page based document
|
||||
or the attachment download url for an attachment based document
|
||||
|
||||
@@ -164,6 +199,8 @@ def build_confluence_document_id(base_url: str, content_url: str) -> str:
|
||||
Returns:
|
||||
str: The document id
|
||||
"""
|
||||
if is_cloud and not base_url.endswith("/wiki"):
|
||||
base_url += "/wiki"
|
||||
return f"{base_url}{content_url}"
|
||||
|
||||
|
||||
|
||||
@@ -305,6 +305,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
query = _build_time_range_query(time_range_start, time_range_end)
|
||||
doc_batch = []
|
||||
for user_email in self._get_all_user_emails():
|
||||
logger.info(f"Fetching slim threads for user: {user_email}")
|
||||
gmail_service = get_gmail_service(self.creds, user_email)
|
||||
for thread in execute_paginated_retrieval(
|
||||
retrieval_function=gmail_service.users().threads().list,
|
||||
|
||||
@@ -192,23 +192,33 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def _update_traversed_parent_ids(self, folder_id: str) -> None:
|
||||
self._retrieved_ids.add(folder_id)
|
||||
|
||||
def _get_all_user_emails(self, admins_only: bool) -> list[str]:
|
||||
def _get_all_user_emails(self) -> list[str]:
|
||||
# Start with primary admin email
|
||||
user_emails = [self.primary_admin_email]
|
||||
|
||||
# Only fetch additional users if using service account
|
||||
if isinstance(self.creds, OAuthCredentials):
|
||||
return user_emails
|
||||
|
||||
admin_service = get_admin_service(
|
||||
creds=self.creds,
|
||||
user_email=self.primary_admin_email,
|
||||
)
|
||||
query = "isAdmin=true" if admins_only else "isAdmin=false"
|
||||
emails = []
|
||||
for user in execute_paginated_retrieval(
|
||||
retrieval_function=admin_service.users().list,
|
||||
list_key="users",
|
||||
fields=USER_FIELDS,
|
||||
domain=self.google_domain,
|
||||
query=query,
|
||||
):
|
||||
if email := user.get("primaryEmail"):
|
||||
emails.append(email)
|
||||
return emails
|
||||
|
||||
# Get admins first since they're more likely to have access to most files
|
||||
for is_admin in [True, False]:
|
||||
query = "isAdmin=true" if is_admin else "isAdmin=false"
|
||||
for user in execute_paginated_retrieval(
|
||||
retrieval_function=admin_service.users().list,
|
||||
list_key="users",
|
||||
fields=USER_FIELDS,
|
||||
domain=self.google_domain,
|
||||
query=query,
|
||||
):
|
||||
if email := user.get("primaryEmail"):
|
||||
if email not in user_emails:
|
||||
user_emails.append(email)
|
||||
return user_emails
|
||||
|
||||
def _get_all_drive_ids(self) -> set[str]:
|
||||
primary_drive_service = get_drive_service(
|
||||
@@ -216,55 +226,48 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
user_email=self.primary_admin_email,
|
||||
)
|
||||
all_drive_ids = set()
|
||||
# We don't want to fail if we're using OAuth because you can
|
||||
# access your my drive as a non admin user in an org still
|
||||
ignore_fetch_failure = isinstance(self.creds, OAuthCredentials)
|
||||
for drive in execute_paginated_retrieval(
|
||||
retrieval_function=primary_drive_service.drives().list,
|
||||
list_key="drives",
|
||||
continue_on_404_or_403=ignore_fetch_failure,
|
||||
useDomainAdminAccess=True,
|
||||
fields="drives(id)",
|
||||
):
|
||||
all_drive_ids.add(drive["id"])
|
||||
return all_drive_ids
|
||||
|
||||
def _initialize_all_class_variables(self) -> None:
|
||||
# Get all user emails
|
||||
# Get admins first becuase they are more likely to have access to the most files
|
||||
user_emails = [self.primary_admin_email]
|
||||
for admins_only in [True, False]:
|
||||
for email in self._get_all_user_emails(admins_only=admins_only):
|
||||
if email not in user_emails:
|
||||
user_emails.append(email)
|
||||
self._all_org_emails = user_emails
|
||||
|
||||
self._all_drive_ids: set[str] = self._get_all_drive_ids()
|
||||
|
||||
# remove drive ids from the folder ids because they are queried differently
|
||||
self._requested_folder_ids -= self._all_drive_ids
|
||||
|
||||
# Remove drive_ids that are not in the all_drive_ids and check them as folders instead
|
||||
invalid_drive_ids = self._requested_shared_drive_ids - self._all_drive_ids
|
||||
if invalid_drive_ids:
|
||||
if not all_drive_ids:
|
||||
logger.warning(
|
||||
f"Some shared drive IDs were not found. IDs: {invalid_drive_ids}"
|
||||
"No drives found. This is likely because oauth user "
|
||||
"is not an admin and cannot view all drive IDs. "
|
||||
"Continuing with only the shared drive IDs specified in the config."
|
||||
)
|
||||
logger.warning("Checking for folder access instead...")
|
||||
self._requested_folder_ids.update(invalid_drive_ids)
|
||||
all_drive_ids = set(self._requested_shared_drive_ids)
|
||||
|
||||
if not self.include_shared_drives:
|
||||
self._requested_shared_drive_ids = set()
|
||||
elif not self._requested_shared_drive_ids:
|
||||
self._requested_shared_drive_ids = self._all_drive_ids
|
||||
return all_drive_ids
|
||||
|
||||
def _impersonate_user_for_retrieval(
|
||||
self,
|
||||
user_email: str,
|
||||
is_slim: bool,
|
||||
filtered_drive_ids: set[str],
|
||||
filtered_folder_ids: set[str],
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
drive_service = get_drive_service(self.creds, user_email)
|
||||
|
||||
# if we are including my drives, try to get the current user's my
|
||||
# drive if any of the following are true:
|
||||
# - no specific emails were requested
|
||||
# - the current user's email is in the requested emails
|
||||
# - we are using OAuth (in which case we assume that is the only email we will try)
|
||||
if self.include_my_drives and (
|
||||
not self._requested_my_drive_emails
|
||||
or user_email in self._requested_my_drive_emails
|
||||
or isinstance(self.creds, OAuthCredentials)
|
||||
):
|
||||
yield from get_all_files_in_my_drive(
|
||||
service=drive_service,
|
||||
@@ -274,7 +277,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
end=end,
|
||||
)
|
||||
|
||||
remaining_drive_ids = self._requested_shared_drive_ids - self._retrieved_ids
|
||||
remaining_drive_ids = filtered_drive_ids - self._retrieved_ids
|
||||
for drive_id in remaining_drive_ids:
|
||||
yield from get_files_in_shared_drive(
|
||||
service=drive_service,
|
||||
@@ -285,7 +288,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
end=end,
|
||||
)
|
||||
|
||||
remaining_folders = self._requested_folder_ids - self._retrieved_ids
|
||||
remaining_folders = filtered_folder_ids - self._retrieved_ids
|
||||
for folder_id in remaining_folders:
|
||||
yield from crawl_folders_for_files(
|
||||
service=drive_service,
|
||||
@@ -302,22 +305,56 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
self._initialize_all_class_variables()
|
||||
all_org_emails: list[str] = self._get_all_user_emails()
|
||||
|
||||
all_drive_ids: set[str] = self._get_all_drive_ids()
|
||||
|
||||
# remove drive ids from the folder ids because they are queried differently
|
||||
filtered_folder_ids = self._requested_folder_ids - all_drive_ids
|
||||
|
||||
# Remove drive_ids that are not in the all_drive_ids and check them as folders instead
|
||||
invalid_drive_ids = self._requested_shared_drive_ids - all_drive_ids
|
||||
if invalid_drive_ids:
|
||||
logger.warning(
|
||||
f"Some shared drive IDs were not found. IDs: {invalid_drive_ids}"
|
||||
)
|
||||
logger.warning("Checking for folder access instead...")
|
||||
filtered_folder_ids.update(invalid_drive_ids)
|
||||
|
||||
# If including shared drives, use the requested IDs if provided,
|
||||
# otherwise use all drive IDs
|
||||
filtered_drive_ids = set()
|
||||
if self.include_shared_drives:
|
||||
if self._requested_shared_drive_ids:
|
||||
# Remove invalid drive IDs from requested IDs
|
||||
filtered_drive_ids = (
|
||||
self._requested_shared_drive_ids - invalid_drive_ids
|
||||
)
|
||||
else:
|
||||
filtered_drive_ids = all_drive_ids
|
||||
|
||||
# Process users in parallel using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
future_to_email = {
|
||||
executor.submit(
|
||||
self._impersonate_user_for_retrieval, email, is_slim, start, end
|
||||
self._impersonate_user_for_retrieval,
|
||||
email,
|
||||
is_slim,
|
||||
filtered_drive_ids,
|
||||
filtered_folder_ids,
|
||||
start,
|
||||
end,
|
||||
): email
|
||||
for email in self._all_org_emails
|
||||
for email in all_org_emails
|
||||
}
|
||||
|
||||
# Yield results as they complete
|
||||
for future in as_completed(future_to_email):
|
||||
yield from future.result()
|
||||
|
||||
remaining_folders = self._requested_folder_ids - self._retrieved_ids
|
||||
remaining_folders = (
|
||||
filtered_drive_ids | filtered_folder_ids
|
||||
) - self._retrieved_ids
|
||||
if remaining_folders:
|
||||
logger.warning(
|
||||
f"Some folders/drives were not retrieved. IDs: {remaining_folders}"
|
||||
|
||||
@@ -105,7 +105,7 @@ def execute_paginated_retrieval(
|
||||
)()
|
||||
elif e.resp.status == 404 or e.resp.status == 403:
|
||||
if continue_on_404_or_403:
|
||||
logger.warning(f"Error executing request: {e}")
|
||||
logger.debug(f"Error executing request: {e}")
|
||||
results = {}
|
||||
else:
|
||||
raise e
|
||||
|
||||
@@ -55,11 +55,11 @@ def validate_channel_names(
|
||||
# Scaling configurations for multi-tenant Slack bot handling
|
||||
TENANT_LOCK_EXPIRATION = 1800 # How long a pod can hold exclusive access to a tenant before other pods can acquire it
|
||||
TENANT_HEARTBEAT_INTERVAL = (
|
||||
60 # How often pods send heartbeats to indicate they are still processing a tenant
|
||||
15 # How often pods send heartbeats to indicate they are still processing a tenant
|
||||
)
|
||||
TENANT_HEARTBEAT_EXPIRATION = 180 # How long before a tenant's heartbeat expires, allowing other pods to take over
|
||||
TENANT_ACQUISITION_INTERVAL = (
|
||||
60 # How often pods attempt to acquire unprocessed tenants
|
||||
TENANT_HEARTBEAT_EXPIRATION = (
|
||||
30 # How long before a tenant's heartbeat expires, allowing other pods to take over
|
||||
)
|
||||
TENANT_ACQUISITION_INTERVAL = 60 # How often pods attempt to acquire unprocessed tenants and checks for new tokens
|
||||
|
||||
MAX_TENANTS_PER_POD = int(os.getenv("MAX_TENANTS_PER_POD", 50))
|
||||
|
||||
@@ -17,6 +17,8 @@ from slack_sdk import WebClient
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
|
||||
from danswer.configs.app_configs import POD_NAME
|
||||
from danswer.configs.app_configs import POD_NAMESPACE
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
|
||||
@@ -75,6 +77,7 @@ from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import DISALLOWED_SLACK_BOT_TENANT_LIST
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
@@ -84,7 +87,9 @@ logger = setup_logger()
|
||||
|
||||
# Prometheus metric for HPA
|
||||
active_tenants_gauge = Gauge(
|
||||
"active_tenants", "Number of active tenants handled by this pod"
|
||||
"active_tenants",
|
||||
"Number of active tenants handled by this pod",
|
||||
["namespace", "pod"],
|
||||
)
|
||||
|
||||
# In rare cases, some users have been experiencing a massive amount of trivial messages coming through
|
||||
@@ -147,7 +152,9 @@ class SlackbotHandler:
|
||||
while not self._shutdown_event.is_set():
|
||||
try:
|
||||
self.acquire_tenants()
|
||||
active_tenants_gauge.set(len(self.tenant_ids))
|
||||
active_tenants_gauge.labels(namespace=POD_NAMESPACE, pod=POD_NAME).set(
|
||||
len(self.tenant_ids)
|
||||
)
|
||||
logger.debug(f"Current active tenants: {len(self.tenant_ids)}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in Slack acquisition: {e}")
|
||||
@@ -164,9 +171,15 @@ class SlackbotHandler:
|
||||
|
||||
def acquire_tenants(self) -> None:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
logger.debug(f"Found {len(tenant_ids)} total tenants in Postgres")
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if (
|
||||
DISALLOWED_SLACK_BOT_TENANT_LIST is not None
|
||||
and tenant_id in DISALLOWED_SLACK_BOT_TENANT_LIST
|
||||
):
|
||||
logger.debug(f"Tenant {tenant_id} is in the disallowed list, skipping")
|
||||
continue
|
||||
|
||||
if tenant_id in self.tenant_ids:
|
||||
logger.debug(f"Tenant {tenant_id} already in self.tenant_ids")
|
||||
continue
|
||||
@@ -190,6 +203,9 @@ class SlackbotHandler:
|
||||
continue
|
||||
|
||||
logger.debug(f"Acquired lock for tenant {tenant_id}")
|
||||
self.tenant_ids.add(tenant_id)
|
||||
|
||||
for tenant_id in self.tenant_ids:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(
|
||||
tenant_id or POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
@@ -236,14 +252,14 @@ class SlackbotHandler:
|
||||
|
||||
self.slack_bot_tokens[tenant_id] = slack_bot_tokens
|
||||
|
||||
if tenant_id in self.socket_clients:
|
||||
if self.socket_clients.get(tenant_id):
|
||||
asyncio.run(self.socket_clients[tenant_id].close())
|
||||
|
||||
self.start_socket_client(tenant_id, slack_bot_tokens)
|
||||
|
||||
except KvKeyNotFoundError:
|
||||
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
|
||||
if tenant_id in self.socket_clients:
|
||||
if self.socket_clients.get(tenant_id):
|
||||
asyncio.run(self.socket_clients[tenant_id].close())
|
||||
del self.socket_clients[tenant_id]
|
||||
del self.slack_bot_tokens[tenant_id]
|
||||
@@ -277,14 +293,14 @@ class SlackbotHandler:
|
||||
logger.info(f"Connecting socket client for tenant {tenant_id}")
|
||||
socket_client.connect()
|
||||
self.socket_clients[tenant_id] = socket_client
|
||||
self.tenant_ids.add(tenant_id)
|
||||
logger.info(f"Started SocketModeClient for tenant {tenant_id}")
|
||||
|
||||
def stop_socket_clients(self) -> None:
|
||||
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
|
||||
for tenant_id, client in self.socket_clients.items():
|
||||
asyncio.run(client.close())
|
||||
logger.info(f"Stopped SocketModeClient for tenant {tenant_id}")
|
||||
if client:
|
||||
asyncio.run(client.close())
|
||||
logger.info(f"Stopped SocketModeClient for tenant {tenant_id}")
|
||||
|
||||
def shutdown(self, signum: int | None, frame: FrameType | None) -> None:
|
||||
if not self.running:
|
||||
@@ -298,6 +314,16 @@ class SlackbotHandler:
|
||||
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
|
||||
self.stop_socket_clients()
|
||||
|
||||
# Release locks for all tenants
|
||||
logger.info(f"Releasing locks for {len(self.tenant_ids)} tenants")
|
||||
for tenant_id in self.tenant_ids:
|
||||
try:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(DanswerRedisLocks.SLACK_BOT_LOCK)
|
||||
logger.info(f"Released lock for tenant {tenant_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error releasing lock for tenant {tenant_id}: {e}")
|
||||
|
||||
# Wait for background threads to finish (with timeout)
|
||||
logger.info("Waiting for background threads to finish...")
|
||||
self.acquire_thread.join(timeout=5)
|
||||
|
||||
@@ -282,3 +282,32 @@ def mark_ccpair_as_pruned(cc_pair_id: int, db_session: Session) -> None:
|
||||
|
||||
cc_pair.last_pruned = datetime.now(timezone.utc)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_cc_pair_as_permissions_synced(
|
||||
db_session: Session, cc_pair_id: int, start_time: datetime | None
|
||||
) -> None:
|
||||
stmt = select(ConnectorCredentialPair).where(
|
||||
ConnectorCredentialPair.id == cc_pair_id
|
||||
)
|
||||
cc_pair = db_session.scalar(stmt)
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")
|
||||
|
||||
cc_pair.last_time_perm_sync = start_time
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_cc_pair_as_external_group_synced(db_session: Session, cc_pair_id: int) -> None:
|
||||
stmt = select(ConnectorCredentialPair).where(
|
||||
ConnectorCredentialPair.id == cc_pair_id
|
||||
)
|
||||
cc_pair = db_session.scalar(stmt)
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")
|
||||
|
||||
# The sync time can be marked after it ran because all group syncs
|
||||
# are run in full, not polling for changes.
|
||||
# If this changes, we need to update this function.
|
||||
cc_pair.last_time_external_group_sync = datetime.now(timezone.utc)
|
||||
db_session.commit()
|
||||
|
||||
@@ -76,8 +76,10 @@ def _add_user_filters(
|
||||
.where(~UG__CCpair.user_group_id.in_(user_groups))
|
||||
.correlate(ConnectorCredentialPair)
|
||||
)
|
||||
where_clause |= ConnectorCredentialPair.creator_id == user.id
|
||||
else:
|
||||
where_clause |= ConnectorCredentialPair.access_type == AccessType.PUBLIC
|
||||
where_clause |= ConnectorCredentialPair.access_type == AccessType.SYNC
|
||||
|
||||
return stmt.where(where_clause)
|
||||
|
||||
@@ -387,6 +389,7 @@ def add_credential_to_connector(
|
||||
)
|
||||
|
||||
association = ConnectorCredentialPair(
|
||||
creator_id=user.id if user else None,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
name=cc_pair_name,
|
||||
|
||||
@@ -19,6 +19,7 @@ from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import null
|
||||
|
||||
from danswer.configs.constants import DEFAULT_BOOST
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.feedback import delete_document_feedback_for_documents__no_commit
|
||||
@@ -46,13 +47,21 @@ def count_documents_by_needs_sync(session: Session) -> int:
|
||||
"""Get the count of all documents where:
|
||||
1. last_modified is newer than last_synced
|
||||
2. last_synced is null (meaning we've never synced)
|
||||
AND the document has a relationship with a connector/credential pair
|
||||
|
||||
TODO: The documents without a relationship with a connector/credential pair
|
||||
should be cleaned up somehow eventually.
|
||||
|
||||
This function executes the query and returns the count of
|
||||
documents matching the criteria."""
|
||||
|
||||
count = (
|
||||
session.query(func.count())
|
||||
session.query(func.count(DbDocument.id.distinct()))
|
||||
.select_from(DbDocument)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
DbDocument.id == DocumentByConnectorCredentialPair.id,
|
||||
)
|
||||
.filter(
|
||||
or_(
|
||||
DbDocument.last_modified > DbDocument.last_synced,
|
||||
@@ -91,6 +100,22 @@ def construct_document_select_for_connector_credential_pair_by_needs_sync(
|
||||
return stmt
|
||||
|
||||
|
||||
def get_all_documents_needing_vespa_sync_for_cc_pair(
|
||||
db_session: Session, cc_pair_id: int
|
||||
) -> list[DbDocument]:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id, db_session=db_session
|
||||
)
|
||||
if not cc_pair:
|
||||
raise ValueError(f"No CC pair found with ID: {cc_pair_id}")
|
||||
|
||||
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
|
||||
cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def construct_document_select_for_connector_credential_pair(
|
||||
connector_id: int, credential_id: int | None = None
|
||||
) -> Select:
|
||||
@@ -104,6 +129,21 @@ def construct_document_select_for_connector_credential_pair(
|
||||
return stmt
|
||||
|
||||
|
||||
def get_documents_for_cc_pair(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
) -> list[DbDocument]:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id, db_session=db_session
|
||||
)
|
||||
if not cc_pair:
|
||||
raise ValueError(f"No CC pair found with ID: {cc_pair_id}")
|
||||
stmt = construct_document_select_for_connector_credential_pair(
|
||||
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
|
||||
)
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def get_document_ids_for_connector_credential_pair(
|
||||
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
|
||||
) -> list[str]:
|
||||
@@ -268,7 +308,7 @@ def get_access_info_for_documents(
|
||||
return db_session.execute(stmt).all() # type: ignore
|
||||
|
||||
|
||||
def upsert_documents(
|
||||
def _upsert_documents(
|
||||
db_session: Session,
|
||||
document_metadata_batch: list[DocumentMetadata],
|
||||
initial_boost: int = DEFAULT_BOOST,
|
||||
@@ -306,6 +346,8 @@ def upsert_documents(
|
||||
]
|
||||
)
|
||||
|
||||
# This does not update the permissions of the document if
|
||||
# the document already exists.
|
||||
on_conflict_stmt = insert_stmt.on_conflict_do_update(
|
||||
index_elements=["id"], # Conflict target
|
||||
set_={
|
||||
@@ -322,7 +364,7 @@ def upsert_documents(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def upsert_document_by_connector_credential_pair(
|
||||
def _upsert_document_by_connector_credential_pair(
|
||||
db_session: Session, document_metadata_batch: list[DocumentMetadata]
|
||||
) -> None:
|
||||
"""NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause."""
|
||||
@@ -404,8 +446,8 @@ def upsert_documents_complete(
|
||||
db_session: Session,
|
||||
document_metadata_batch: list[DocumentMetadata],
|
||||
) -> None:
|
||||
upsert_documents(db_session, document_metadata_batch)
|
||||
upsert_document_by_connector_credential_pair(db_session, document_metadata_batch)
|
||||
_upsert_documents(db_session, document_metadata_batch)
|
||||
_upsert_document_by_connector_credential_pair(db_session, document_metadata_batch)
|
||||
logger.info(
|
||||
f"Upserted {len(document_metadata_batch)} document store entries into DB"
|
||||
)
|
||||
@@ -463,7 +505,6 @@ def delete_documents_complete__no_commit(
|
||||
db_session: Session, document_ids: list[str]
|
||||
) -> None:
|
||||
"""This completely deletes the documents from the db, including all foreign key relationships"""
|
||||
logger.info(f"Deleting {len(document_ids)} documents from the DB")
|
||||
delete_documents_by_connector_credential_pair__no_commit(db_session, document_ids)
|
||||
delete_document_feedback_for_documents__no_commit(
|
||||
document_ids=document_ids, db_session=db_session
|
||||
|
||||
@@ -189,6 +189,13 @@ class SqlEngine:
|
||||
return ""
|
||||
return cls._app_name
|
||||
|
||||
@classmethod
|
||||
def reset_engine(cls) -> None:
|
||||
with cls._lock:
|
||||
if cls._engine:
|
||||
cls._engine.dispose()
|
||||
cls._engine = None
|
||||
|
||||
|
||||
def get_all_tenant_ids() -> list[str] | list[None]:
|
||||
if not MULTI_TENANT:
|
||||
@@ -312,7 +319,9 @@ async def get_async_session_with_tenant(
|
||||
await session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||
await session.execute(
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
text(
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error setting search_path.")
|
||||
@@ -373,7 +382,9 @@ def get_session_with_tenant(
|
||||
cursor.execute(f'SET search_path = "{tenant_id}"')
|
||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||
cursor.execute(
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
text(
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
)
|
||||
)
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
@@ -126,8 +126,8 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
|
||||
# if specified, controls the assistants that are shown to the user + their order
|
||||
# if not specified, all assistants are shown
|
||||
chosen_assistants: Mapped[list[int]] = mapped_column(
|
||||
postgresql.JSONB(), nullable=False, default=[-2, -1, 0]
|
||||
chosen_assistants: Mapped[list[int] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True, default=None
|
||||
)
|
||||
visible_assistants: Mapped[list[int]] = mapped_column(
|
||||
postgresql.JSONB(), nullable=False, default=[]
|
||||
@@ -173,6 +173,11 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
)
|
||||
# Whether the user has logged in via web. False if user has only used Danswer through Slack bot
|
||||
has_web_login: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
cc_pairs: Mapped[list["ConnectorCredentialPair"]] = relationship(
|
||||
"ConnectorCredentialPair",
|
||||
back_populates="creator",
|
||||
primaryjoin="User.id == foreign(ConnectorCredentialPair.creator_id)",
|
||||
)
|
||||
|
||||
|
||||
class InputPrompt(Base):
|
||||
@@ -420,6 +425,9 @@ class ConnectorCredentialPair(Base):
|
||||
last_time_perm_sync: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
last_time_external_group_sync: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
# Time finished, not used for calculating backend jobs which uses time started (created)
|
||||
last_successful_index_time: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), default=None
|
||||
@@ -452,6 +460,14 @@ class ConnectorCredentialPair(Base):
|
||||
"IndexAttempt", back_populates="connector_credential_pair"
|
||||
)
|
||||
|
||||
# the user id of the user that created this cc pair
|
||||
creator_id: Mapped[UUID | None] = mapped_column(nullable=True)
|
||||
creator: Mapped["User"] = relationship(
|
||||
"User",
|
||||
back_populates="cc_pairs",
|
||||
primaryjoin="foreign(ConnectorCredentialPair.creator_id) == remote(User.id)",
|
||||
)
|
||||
|
||||
|
||||
class Document(Base):
|
||||
__tablename__ = "document"
|
||||
|
||||
@@ -743,5 +743,4 @@ def delete_persona_by_name(
|
||||
)
|
||||
|
||||
db_session.execute(stmt)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
@@ -24,6 +24,13 @@ def get_tool_by_id(tool_id: int, db_session: Session) -> Tool:
|
||||
return tool
|
||||
|
||||
|
||||
def get_tool_by_name(tool_name: str, db_session: Session) -> Tool:
|
||||
tool = db_session.scalar(select(Tool).where(Tool.name == tool_name))
|
||||
if not tool:
|
||||
raise ValueError("Tool by specified name does not exist")
|
||||
return tool
|
||||
|
||||
|
||||
def create_tool(
|
||||
name: str,
|
||||
description: str | None,
|
||||
@@ -37,7 +44,7 @@ def create_tool(
|
||||
description=description,
|
||||
in_code_tool_id=None,
|
||||
openapi_schema=openapi_schema,
|
||||
custom_headers=[header.dict() for header in custom_headers]
|
||||
custom_headers=[header.model_dump() for header in custom_headers]
|
||||
if custom_headers
|
||||
else [],
|
||||
user_id=user_id,
|
||||
|
||||
@@ -97,3 +97,18 @@ def batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session.flush() # generate ids
|
||||
|
||||
return found_users + new_users
|
||||
|
||||
|
||||
def batch_add_non_web_user_if_not_exists(
|
||||
db_session: Session, emails: list[str]
|
||||
) -> list[User]:
|
||||
found_users, missing_user_emails = get_users_by_emails(db_session, emails)
|
||||
|
||||
new_users: list[User] = []
|
||||
for email in missing_user_emails:
|
||||
new_users.append(_generate_non_web_user(email=email))
|
||||
|
||||
db_session.add_all(new_users)
|
||||
db_session.commit()
|
||||
|
||||
return found_users + new_users
|
||||
|
||||
@@ -56,7 +56,7 @@ class IndexingPipelineProtocol(Protocol):
|
||||
...
|
||||
|
||||
|
||||
def upsert_documents_in_db(
|
||||
def _upsert_documents_in_db(
|
||||
documents: list[Document],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
db_session: Session,
|
||||
@@ -243,7 +243,7 @@ def index_doc_batch_prepare(
|
||||
|
||||
# Create records in the source of truth about these documents,
|
||||
# does not include doc_updated_at which is also used to indicate a successful update
|
||||
upsert_documents_in_db(
|
||||
_upsert_documents_in_db(
|
||||
documents=documents,
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
db_session=db_session,
|
||||
@@ -255,7 +255,7 @@ def index_doc_batch_prepare(
|
||||
)
|
||||
|
||||
|
||||
@log_function_time()
|
||||
@log_function_time(debug_only=True)
|
||||
def index_doc_batch(
|
||||
*,
|
||||
chunker: Chunker,
|
||||
|
||||
@@ -26,6 +26,7 @@ from danswer.configs.app_configs import LOG_ALL_MODEL_INTERACTIONS
|
||||
from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
|
||||
from danswer.configs.model_configs import DISABLE_LITELLM_STREAMING
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.configs.model_configs import LITELLM_EXTRA_BODY
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.interfaces import ToolChoiceOptions
|
||||
@@ -213,6 +214,7 @@ class DefaultMultiLLM(LLM):
|
||||
temperature: float = GEN_AI_TEMPERATURE,
|
||||
custom_config: dict[str, str] | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
extra_body: dict | None = LITELLM_EXTRA_BODY,
|
||||
):
|
||||
self._timeout = timeout
|
||||
self._model_provider = model_provider
|
||||
@@ -246,6 +248,8 @@ class DefaultMultiLLM(LLM):
|
||||
model_kwargs: dict[str, Any] = {}
|
||||
if extra_headers:
|
||||
model_kwargs.update({"extra_headers": extra_headers})
|
||||
if extra_body:
|
||||
model_kwargs.update({"extra_body": extra_body})
|
||||
|
||||
self._model_kwargs = model_kwargs
|
||||
|
||||
|
||||
@@ -74,6 +74,9 @@ from danswer.server.manage.search_settings import router as search_settings_rout
|
||||
from danswer.server.manage.slack_bot import router as slack_bot_management_router
|
||||
from danswer.server.manage.users import router as user_router
|
||||
from danswer.server.middleware.latency_logging import add_latency_logging_middleware
|
||||
from danswer.server.openai_assistants_api.full_openai_assistants_api import (
|
||||
get_full_openai_assistants_api_router,
|
||||
)
|
||||
from danswer.server.query_and_chat.chat_backend import router as chat_router
|
||||
from danswer.server.query_and_chat.query_backend import (
|
||||
admin_router as admin_query_router,
|
||||
@@ -270,6 +273,9 @@ def get_application() -> FastAPI:
|
||||
application, token_rate_limit_settings_router
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, indexing_router)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, get_full_openai_assistants_api_router()
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
# Server logs this during auth setup verification step
|
||||
@@ -309,7 +315,7 @@ def get_application() -> FastAPI:
|
||||
tags=["users"],
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.GOOGLE_OAUTH or AUTH_TYPE == AuthType.CLOUD:
|
||||
if AUTH_TYPE == AuthType.GOOGLE_OAUTH:
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import redis
|
||||
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndex
|
||||
from danswer.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from danswer.redis.redis_connector_stop import RedisConnectorStop
|
||||
@@ -19,6 +21,10 @@ class RedisConnector:
|
||||
self.stop = RedisConnectorStop(tenant_id, id, self.redis)
|
||||
self.prune = RedisConnectorPrune(tenant_id, id, self.redis)
|
||||
self.delete = RedisConnectorDelete(tenant_id, id, self.redis)
|
||||
self.permissions = RedisConnectorPermissionSync(tenant_id, id, self.redis)
|
||||
self.external_group_sync = RedisConnectorExternalGroupSync(
|
||||
tenant_id, id, self.redis
|
||||
)
|
||||
|
||||
def new_index(self, search_settings_id: int) -> RedisConnectorIndex:
|
||||
return RedisConnectorIndex(
|
||||
|
||||
@@ -63,6 +63,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
|
||||
cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
|
||||
187
backend/danswer/redis/redis_connector_doc_perm_sync.py
Normal file
187
backend/danswer/redis/redis_connector_doc_perm_sync.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.access.models import DocExternalAccess
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
|
||||
|
||||
class RedisConnectorPermissionSyncData(BaseModel):
|
||||
started: datetime | None
|
||||
|
||||
|
||||
class RedisConnectorPermissionSync:
|
||||
"""Manages interactions with redis for doc permission sync tasks. Should only be accessed
|
||||
through RedisConnector."""
|
||||
|
||||
PREFIX = "connectordocpermissionsync"
|
||||
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
|
||||
# phase 1 - geneartor task and progress signals
|
||||
GENERATORTASK_PREFIX = f"{PREFIX}+generator" # connectorpermissions+generator
|
||||
GENERATOR_PROGRESS_PREFIX = (
|
||||
PREFIX + "_generator_progress"
|
||||
) # connectorpermissions_generator_progress
|
||||
GENERATOR_COMPLETE_PREFIX = (
|
||||
PREFIX + "_generator_complete"
|
||||
) # connectorpermissions_generator_complete
|
||||
|
||||
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpermissions_taskset
|
||||
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpermissions+sub
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id = id
|
||||
self.redis = redis
|
||||
|
||||
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}"
|
||||
self.generator_task_key = f"{self.GENERATORTASK_PREFIX}_{id}"
|
||||
self.generator_progress_key = f"{self.GENERATOR_PROGRESS_PREFIX}_{id}"
|
||||
self.generator_complete_key = f"{self.GENERATOR_COMPLETE_PREFIX}_{id}"
|
||||
|
||||
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
|
||||
|
||||
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
|
||||
|
||||
def taskset_clear(self) -> None:
|
||||
self.redis.delete(self.taskset_key)
|
||||
|
||||
def generator_clear(self) -> None:
|
||||
self.redis.delete(self.generator_progress_key)
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
|
||||
def get_remaining(self) -> int:
|
||||
remaining = cast(int, self.redis.scard(self.taskset_key))
|
||||
return remaining
|
||||
|
||||
def get_active_task_count(self) -> int:
|
||||
"""Count of active permission sync tasks"""
|
||||
count = 0
|
||||
for _ in self.redis.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@property
|
||||
def fenced(self) -> bool:
|
||||
if self.redis.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def payload(self) -> RedisConnectorPermissionSyncData | None:
|
||||
# read related data and evaluate/print task progress
|
||||
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
|
||||
if fence_bytes is None:
|
||||
return None
|
||||
|
||||
fence_str = fence_bytes.decode("utf-8")
|
||||
payload = RedisConnectorPermissionSyncData.model_validate_json(
|
||||
cast(str, fence_str)
|
||||
)
|
||||
|
||||
return payload
|
||||
|
||||
def set_fence(
|
||||
self,
|
||||
payload: RedisConnectorPermissionSyncData | None,
|
||||
) -> None:
|
||||
if not payload:
|
||||
self.redis.delete(self.fence_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.fence_key, payload.model_dump_json())
|
||||
|
||||
@property
|
||||
def generator_complete(self) -> int | None:
|
||||
"""the fence payload is an int representing the starting number of
|
||||
permission sync tasks to be processed ... just after the generator completes."""
|
||||
fence_bytes = self.redis.get(self.generator_complete_key)
|
||||
if fence_bytes is None:
|
||||
return None
|
||||
|
||||
if fence_bytes == b"None":
|
||||
return None
|
||||
|
||||
fence_int = int(cast(bytes, fence_bytes).decode())
|
||||
return fence_int
|
||||
|
||||
@generator_complete.setter
|
||||
def generator_complete(self, payload: int | None) -> None:
|
||||
"""Set the payload to an int to set the fence, otherwise if None it will
|
||||
be deleted"""
|
||||
if payload is None:
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.generator_complete_key, payload)
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
lock: redis.lock.Lock | None,
|
||||
new_permissions: list[DocExternalAccess],
|
||||
source_string: str,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
async_results = []
|
||||
|
||||
# Create a task for each document permission sync
|
||||
for doc_perm in new_permissions:
|
||||
current_time = time.monotonic()
|
||||
if lock and current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
# Add task for document permissions sync
|
||||
custom_task_id = f"{self.subtask_prefix}_{uuid4()}"
|
||||
self.redis.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
"update_external_document_permissions_task",
|
||||
kwargs=dict(
|
||||
tenant_id=self.tenant_id,
|
||||
serialized_doc_external_access=doc_perm.to_dict(),
|
||||
source_string=source_string,
|
||||
),
|
||||
queue=DanswerCeleryQueues.DOC_PERMISSIONS_UPSERT,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
@staticmethod
|
||||
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
|
||||
taskset_key = f"{RedisConnectorPermissionSync.TASKSET_PREFIX}_{id}"
|
||||
r.srem(taskset_key, task_id)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def reset_all(r: redis.Redis) -> None:
|
||||
"""Deletes all redis values for all connectors"""
|
||||
for key in r.scan_iter(RedisConnectorPermissionSync.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(
|
||||
RedisConnectorPermissionSync.GENERATOR_COMPLETE_PREFIX + "*"
|
||||
):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(
|
||||
RedisConnectorPermissionSync.GENERATOR_PROGRESS_PREFIX + "*"
|
||||
):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
133
backend/danswer/redis/redis_connector_ext_group_sync.py
Normal file
133
backend/danswer/redis/redis_connector_ext_group_sync.py
Normal file
@@ -0,0 +1,133 @@
|
||||
from typing import cast
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
class RedisConnectorExternalGroupSync:
|
||||
"""Manages interactions with redis for external group syncing tasks. Should only be accessed
|
||||
through RedisConnector."""
|
||||
|
||||
PREFIX = "connectorexternalgroupsync"
|
||||
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
|
||||
# phase 1 - geneartor task and progress signals
|
||||
GENERATORTASK_PREFIX = f"{PREFIX}+generator" # connectorexternalgroupsync+generator
|
||||
GENERATOR_PROGRESS_PREFIX = (
|
||||
PREFIX + "_generator_progress"
|
||||
) # connectorexternalgroupsync_generator_progress
|
||||
GENERATOR_COMPLETE_PREFIX = (
|
||||
PREFIX + "_generator_complete"
|
||||
) # connectorexternalgroupsync_generator_complete
|
||||
|
||||
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorexternalgroupsync_taskset
|
||||
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorexternalgroupsync+sub
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id = id
|
||||
self.redis = redis
|
||||
|
||||
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}"
|
||||
self.generator_task_key = f"{self.GENERATORTASK_PREFIX}_{id}"
|
||||
self.generator_progress_key = f"{self.GENERATOR_PROGRESS_PREFIX}_{id}"
|
||||
self.generator_complete_key = f"{self.GENERATOR_COMPLETE_PREFIX}_{id}"
|
||||
|
||||
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
|
||||
|
||||
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
|
||||
|
||||
def taskset_clear(self) -> None:
|
||||
self.redis.delete(self.taskset_key)
|
||||
|
||||
def generator_clear(self) -> None:
|
||||
self.redis.delete(self.generator_progress_key)
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
|
||||
def get_remaining(self) -> int:
|
||||
# todo: move into fence
|
||||
remaining = cast(int, self.redis.scard(self.taskset_key))
|
||||
return remaining
|
||||
|
||||
def get_active_task_count(self) -> int:
|
||||
"""Count of active external group syncing tasks"""
|
||||
count = 0
|
||||
for _ in self.redis.scan_iter(
|
||||
RedisConnectorExternalGroupSync.FENCE_PREFIX + "*"
|
||||
):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@property
|
||||
def fenced(self) -> bool:
|
||||
if self.redis.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def set_fence(self, value: bool) -> None:
|
||||
if not value:
|
||||
self.redis.delete(self.fence_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.fence_key, 0)
|
||||
|
||||
@property
|
||||
def generator_complete(self) -> int | None:
|
||||
"""the fence payload is an int representing the starting number of
|
||||
external group syncing tasks to be processed ... just after the generator completes.
|
||||
"""
|
||||
fence_bytes = self.redis.get(self.generator_complete_key)
|
||||
if fence_bytes is None:
|
||||
return None
|
||||
|
||||
if fence_bytes == b"None":
|
||||
return None
|
||||
|
||||
fence_int = int(cast(bytes, fence_bytes).decode())
|
||||
return fence_int
|
||||
|
||||
@generator_complete.setter
|
||||
def generator_complete(self, payload: int | None) -> None:
|
||||
"""Set the payload to an int to set the fence, otherwise if None it will
|
||||
be deleted"""
|
||||
if payload is None:
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.generator_complete_key, payload)
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
lock: redis.lock.Lock | None,
|
||||
) -> int | None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
|
||||
taskset_key = f"{RedisConnectorExternalGroupSync.TASKSET_PREFIX}_{id}"
|
||||
r.srem(taskset_key, task_id)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def reset_all(r: redis.Redis) -> None:
|
||||
"""Deletes all redis values for all connectors"""
|
||||
for key in r.scan_iter(RedisConnectorExternalGroupSync.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(
|
||||
RedisConnectorExternalGroupSync.GENERATOR_COMPLETE_PREFIX + "*"
|
||||
):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(
|
||||
RedisConnectorExternalGroupSync.GENERATOR_PROGRESS_PREFIX + "*"
|
||||
):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorExternalGroupSync.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
File diff suppressed because it is too large
Load Diff
44
backend/danswer/seeding/initial_docs_cohere.json
Normal file
44
backend/danswer/seeding/initial_docs_cohere.json
Normal file
@@ -0,0 +1,44 @@
|
||||
[
|
||||
{
|
||||
"url": "https://docs.danswer.dev/more/use_cases/overview",
|
||||
"title": "Use Cases Overview",
|
||||
"content": "How to leverage Danswer in your organization\n\nDanswer Overview\nDanswer is the AI Assistant connected to your organization's docs, apps, and people. Danswer makes Generative AI more versatile for work by enabling new types of questions like \"What is the most common feature request we've heard from customers this month\". Whereas other AI systems have no context of your team and are generally unhelpful with work related questions, Danswer makes it possible to ask these questions in natural language and get back answers in seconds.\n\nDanswer can connect to +30 different tools and the use cases are not limited to the ones in the following pages. The highlighted use cases are for inspiration and come from feedback gathered from our users and customers.\n\n\nCommon Getting Started Questions:\n\nWhy are these docs connected in my Danswer deployment?\nAnswer: This is just an example of how connectors work in Danswer. You can connect up your own team's knowledge and you will be able to ask questions unique to your organization. Danswer will keep all of the knowledge up to date and in sync with your connected applications.\n\nIs my data being sent anywhere when I connect it up to Danswer?\nAnswer: No! Danswer is built with data security as our highest priority. We open sourced it so our users can know exactly what is going on with their data. By default all of the document processing happens within Danswer. The only time it is sent outward is for the GenAI call to generate answers.\n\nWhere is the feature for auto sync-ing document level access permissions from all connected sources?\nAnswer: This falls under the Enterprise Edition set of Danswer features built on top of the MIT/community edition. If you are on Danswer Cloud, you have access to them by default. If you're running it yourself, reach out to the Danswer team to receive access.",
|
||||
"chunk_ind": 0
|
||||
},
|
||||
{
|
||||
"url": "https://docs.danswer.dev/more/use_cases/enterprise_search",
|
||||
"title": "Enterprise Search",
|
||||
"content": "Value of Enterprise Search with Danswer\n\nWhat is Enterprise Search and why is it Important?\nAn Enterprise Search system gives team members a single place to access all of the disparate knowledge of an organization. Critical information is saved across a host of channels like call transcripts with prospects, engineering design docs, IT runbooks, customer support email exchanges, project management tickets, and more. As fast moving teams scale up, information gets spread out and more disorganized.\n\nSince it quickly becomes infeasible to check across every source, decisions get made on incomplete information, employee satisfaction decreases, and the most valuable members of your team are tied up with constant distractions as junior teammates are unable to unblock themselves. Danswer solves this problem by letting anyone on the team access all of the knowledge across your organization in a permissioned and secure way. Users can ask questions in natural language and get back answers and documents across all of the connected sources instantly.\n\nWhat's the real cost?\nA typical knowledge worker spends over 2 hours a week on search, but more than that, the cost of incomplete or incorrect information can be extremely high. Customer support/success that isn't able to find the reference to similar cases could cause hours or even days of delay leading to lower customer satisfaction or in the worst case - churn. An account exec not realizing that a prospect had previously mentioned a specific need could lead to lost deals. An engineer not realizing a similar feature had previously been built could result in weeks of wasted development time and tech debt with duplicate implementation. With a lack of knowledge, your whole organization is navigating in the dark - inefficient and mistake prone.",
|
||||
"chunk_ind": 0
|
||||
},
|
||||
{
|
||||
"url": "https://docs.danswer.dev/more/use_cases/enterprise_search",
|
||||
"title": "Enterprise Search",
|
||||
"content": "More than Search\nWhen analyzing the entire corpus of knowledge within your company is as easy as asking a question in a search bar, your entire team can stay informed and up to date. Danswer also makes it trivial to identify where knowledge is well documented and where it is lacking. Team members who are centers of knowledge can begin to effectively document their expertise since it is no longer being thrown into a black hole. All of this allows the organization to achieve higher efficiency and drive business outcomes.\n\nWith Generative AI, the entire user experience has evolved as well. For example, instead of just finding similar cases for your customer support team to reference, Danswer breaks down the issue and explains it so that even the most junior members can understand it. This in turn lets them give the most holistic and technically accurate response possible to your customers. On the other end, even the super stars of your sales team will not be able to review 10 hours of transcripts before hopping on that critical call, but Danswer can easily parse through it in mere seconds and give crucial context to help your team close.",
|
||||
"chunk_ind": 0
|
||||
},
|
||||
{
|
||||
"url": "https://docs.danswer.dev/more/use_cases/ai_platform",
|
||||
"title": "AI Platform",
|
||||
"content": "Build AI Agents powered by the knowledge and workflows specific to your organization.\n\nBeyond Answers\nAgents enabled by generative AI and reasoning capable models are helping teams to automate their work. Danswer is helping teams make it happen. Danswer provides out of the box user chat sessions, attaching custom tools, handling LLM reasoning, code execution, data analysis, referencing internal knowledge, and much more.\n\nDanswer as a platform is not a no-code agent builder. We are made by developers for developers and this gives your team the full flexibility and power to create agents not constrained by blocks and simple logic paths.\n\nFlexibility and Extensibility\nDanswer is open source and completely whitebox. This not only gives transparency to what happens within the system but also means that your team can directly modify the source code to suit your unique needs.",
|
||||
"chunk_ind": 0
|
||||
},
|
||||
{
|
||||
"url": "https://docs.danswer.dev/more/use_cases/customer_support",
|
||||
"title": "Customer Support",
|
||||
"content": "Help your customer support team instantly answer any question across your entire product.\n\nAI Enabled Support\nCustomer support agents have one of the highest breadth jobs. They field requests that cover the entire surface area of the product and need to help your users find success on extremely short timelines. Because they're not the same people who designed or built the system, they often lack the depth of understanding needed - resulting in delays and escalations to other teams. Modern teams are leveraging AI to help their CS team optimize the speed and quality of these critical customer-facing interactions.\n\nThe Importance of Context\nThere are two critical components of AI copilots for customer support. The first is that the AI system needs to be connected with as much information as possible (not just support tools like Zendesk or Intercom) and that the knowledge needs to be as fresh as possible. Sometimes a fix might even be in places rarely checked by CS such as pull requests in a code repository. The second critical component is the ability of the AI system to break down difficult concepts and convoluted processes into more digestible descriptions and for your team members to be able to chat back and forth with the system to build a better understanding.\n\nDanswer takes care of both of these. The system connects up to over 30+ different applications and the knowledge is pulled in constantly so that the information access is always up to date.",
|
||||
"chunk_ind": 0
|
||||
},
|
||||
{
|
||||
"url": "https://docs.danswer.dev/more/use_cases/sales",
|
||||
"title": "Sales",
|
||||
"content": "Keep your team up to date on every conversation and update so they can close.\n\nRecall Every Detail\nBeing able to instantly revisit every detail of any call without reading transcripts is helping Sales teams provide more tailored pitches, build stronger relationships, and close more deals. Instead of searching and reading through hours of transcripts in preparation for a call, your team can now ask Danswer \"What specific features was ACME interested in seeing for the demo\". Since your team doesn't have time to read every transcript prior to a call, Danswer provides a more thorough summary because it can instantly parse hundreds of pages and distill out the relevant information. Even for fast lookups it becomes much more convenient - for example to brush up on connection building topics by asking \"What rapport building topic did we chat about in the last call with ACME\".\n\nKnow Every Product Update\nIt is impossible for Sales teams to keep up with every product update. Because of this, when a prospect has a question that the Sales team does not know, they have no choice but to rely on the Product and Engineering orgs to get an authoritative answer. Not only is this distracting to the other teams, it also slows down the time to respond to the prospect (and as we know, time is the biggest killer of deals). With Danswer, it is even possible to get answers live on call because of how fast accessing information becomes. A question like \"Have we shipped the Microsoft AD integration yet?\" can now be answered in seconds meaning that prospects can get answers while on the call instead of asynchronously and sales cycles are reduced as a result.",
|
||||
"chunk_ind": 0
|
||||
},
|
||||
{
|
||||
"url": "https://docs.danswer.dev/more/use_cases/operations",
|
||||
"title": "Operations",
|
||||
"content": "Double the productivity of your Ops teams like IT, HR, etc.\n\nAutomatically Resolve Tickets\nModern teams are leveraging AI to auto-resolve up to 50% of tickets. Whether it is an employee asking about benefits details or how to set up the VPN for remote work, Danswer can help your team help themselves. This frees up your team to do the real impactful work of landing star candidates or improving your internal processes.\n\nAI Aided Onboarding\nOne of the periods where your team needs the most help is when they're just ramping up. Instead of feeling lost in dozens of new tools, Danswer gives them a single place where they can ask about anything in natural language. Whether it's how to set up their work environment or what their onboarding goals are, Danswer can walk them through every step with the help of Generative AI. This lets your team feel more empowered and gives time back to the more seasoned members of your team to focus on moving the needle.",
|
||||
"chunk_ind": 0
|
||||
}
|
||||
]
|
||||
@@ -32,7 +32,7 @@ from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.server.documents.models import ConnectorBase
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.retry_wrapper import retry_builder
|
||||
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -91,7 +91,21 @@ def _create_indexable_chunks(
|
||||
return list(ids_to_documents.values()), chunks
|
||||
|
||||
|
||||
def seed_initial_documents(db_session: Session, tenant_id: str | None) -> None:
|
||||
# Cohere is used in EE version
|
||||
def load_processed_docs(cohere_enabled: bool) -> list[dict]:
|
||||
initial_docs_path = os.path.join(
|
||||
os.getcwd(),
|
||||
"danswer",
|
||||
"seeding",
|
||||
"initial_docs.json",
|
||||
)
|
||||
processed_docs = json.load(open(initial_docs_path))
|
||||
return processed_docs
|
||||
|
||||
|
||||
def seed_initial_documents(
|
||||
db_session: Session, tenant_id: str | None, cohere_enabled: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Seed initial documents so users don't have an empty index to start
|
||||
|
||||
@@ -132,7 +146,9 @@ def seed_initial_documents(db_session: Session, tenant_id: str | None) -> None:
|
||||
return
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
if search_settings.model_name != DEFAULT_DOCUMENT_ENCODER_MODEL:
|
||||
if search_settings.model_name != DEFAULT_DOCUMENT_ENCODER_MODEL and not (
|
||||
search_settings.model_name == "embed-english-v3.0" and cohere_enabled
|
||||
):
|
||||
logger.info("Embedding model has been updated, skipping")
|
||||
return
|
||||
|
||||
@@ -172,11 +188,10 @@ def seed_initial_documents(db_session: Session, tenant_id: str | None) -> None:
|
||||
last_successful_index_time=last_index_time,
|
||||
)
|
||||
cc_pair_id = cast(int, result.data)
|
||||
|
||||
initial_docs_path = os.path.join(
|
||||
os.getcwd(), "danswer", "seeding", "initial_docs.json"
|
||||
)
|
||||
processed_docs = json.load(open(initial_docs_path))
|
||||
processed_docs = fetch_versioned_implementation(
|
||||
"danswer.seeding.load_docs",
|
||||
"load_processed_docs",
|
||||
)(cohere_enabled)
|
||||
|
||||
docs, chunks = _create_indexable_chunks(processed_docs, tenant_id)
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from starlette.routing import BaseRoute
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_limited_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.auth.users import current_user_with_expired_token
|
||||
from danswer.configs.app_configs import APP_API_PREFIX
|
||||
@@ -102,7 +103,8 @@ def check_router_auth(
|
||||
for dependency in route_dependant_obj.dependencies:
|
||||
depends_fn = dependency.cache_key[0]
|
||||
if (
|
||||
depends_fn == current_user
|
||||
depends_fn == current_limited_user
|
||||
or depends_fn == current_user
|
||||
or depends_fn == current_admin_user
|
||||
or depends_fn == current_curator_or_admin_user
|
||||
or depends_fn == api_key_dep
|
||||
@@ -118,5 +120,5 @@ def check_router_auth(
|
||||
# print(f"(\"{route.path}\", {set(route.methods)}),")
|
||||
|
||||
raise RuntimeError(
|
||||
f"Did not find current_user or current_admin_user dependency in route - {route}"
|
||||
f"Did not find user dependency in private route - {route}"
|
||||
)
|
||||
|
||||
@@ -12,13 +12,13 @@ from sqlalchemy.orm import Session
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot
|
||||
from danswer.background.celery.tasks.doc_permission_syncing.tasks import (
|
||||
try_creating_permissions_sync_task,
|
||||
)
|
||||
from danswer.background.celery.tasks.pruning.tasks import (
|
||||
try_creating_prune_generator_task,
|
||||
)
|
||||
from danswer.background.celery.versioned_apps.primary import app as primary_app
|
||||
from danswer.background.task_name_builders import (
|
||||
name_sync_external_doc_permissions_task,
|
||||
)
|
||||
from danswer.db.connector_credential_pair import add_credential_to_connector
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.connector_credential_pair import remove_credential_from_connector
|
||||
@@ -26,6 +26,7 @@ from danswer.db.connector_credential_pair import (
|
||||
update_connector_credential_pair_from_id,
|
||||
)
|
||||
from danswer.db.document import get_document_counts_for_cc_pairs
|
||||
from danswer.db.document import get_documents_for_cc_pair
|
||||
from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from danswer.db.engine import get_current_tenant_id
|
||||
from danswer.db.engine import get_session
|
||||
@@ -38,15 +39,13 @@ from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
|
||||
from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
|
||||
from danswer.db.models import User
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.server.documents.models import CCPairFullInfo
|
||||
from danswer.server.documents.models import CCStatusUpdateRequest
|
||||
from danswer.server.documents.models import CeleryTaskStatus
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.server.documents.models import ConnectorCredentialPairMetadata
|
||||
from danswer.server.documents.models import DocumentSyncStatus
|
||||
from danswer.server.documents.models import PaginatedIndexAttempts
|
||||
from danswer.server.models import StatusResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -288,12 +287,12 @@ def prune_cc_pair(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/admin/cc-pair/{cc_pair_id}/sync")
|
||||
@router.get("/admin/cc-pair/{cc_pair_id}/sync-permissions")
|
||||
def get_cc_pair_latest_sync(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CeleryTaskStatus:
|
||||
) -> datetime | None:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
@@ -303,34 +302,20 @@ def get_cc_pair_latest_sync(
|
||||
if not cc_pair:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Connection not found for current user's permissions",
|
||||
detail="cc_pair not found for current user's permissions",
|
||||
)
|
||||
|
||||
# look up the last sync task for this connector (if it exists)
|
||||
sync_task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair_id)
|
||||
last_sync_task = get_latest_task(sync_task_name, db_session)
|
||||
if not last_sync_task:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
detail="No sync task found.",
|
||||
)
|
||||
|
||||
return CeleryTaskStatus(
|
||||
id=last_sync_task.task_id,
|
||||
name=last_sync_task.task_name,
|
||||
status=last_sync_task.status,
|
||||
start_time=last_sync_task.start_time,
|
||||
register_time=last_sync_task.register_time,
|
||||
)
|
||||
return cc_pair.last_time_perm_sync
|
||||
|
||||
|
||||
@router.post("/admin/cc-pair/{cc_pair_id}/sync")
|
||||
@router.post("/admin/cc-pair/{cc_pair_id}/sync-permissions")
|
||||
def sync_cc_pair(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> StatusResponse[list[int]]:
|
||||
# avoiding circular refs
|
||||
"""Triggers permissions sync on a particular cc_pair immediately"""
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id,
|
||||
@@ -344,37 +329,49 @@ def sync_cc_pair(
|
||||
detail="Connection not found for current user's permissions",
|
||||
)
|
||||
|
||||
sync_task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair_id)
|
||||
last_sync_task = get_latest_task(sync_task_name, db_session)
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if last_sync_task and check_task_is_live_and_not_timed_out(
|
||||
last_sync_task, db_session
|
||||
):
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if redis_connector.permissions.fenced:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.CONFLICT,
|
||||
detail="Sync task already in progress.",
|
||||
detail="Doc permissions sync task already in progress.",
|
||||
)
|
||||
|
||||
logger.info(f"Syncing the {cc_pair.connector.name} connector.")
|
||||
sync_external_doc_permissions_task = fetch_ee_implementation_or_noop(
|
||||
"danswer.background.celery.apps.primary",
|
||||
"sync_external_doc_permissions_task",
|
||||
None,
|
||||
logger.info(
|
||||
f"Doc permissions sync cc_pair={cc_pair_id} "
|
||||
f"connector_id={cc_pair.connector_id} "
|
||||
f"credential_id={cc_pair.credential_id} "
|
||||
f"{cc_pair.connector.name} connector."
|
||||
)
|
||||
|
||||
if sync_external_doc_permissions_task:
|
||||
sync_external_doc_permissions_task.apply_async(
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair_id, tenant_id=CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
),
|
||||
tasks_created = try_creating_permissions_sync_task(
|
||||
primary_app, cc_pair_id, r, CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
)
|
||||
if not tasks_created:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
detail="Doc permissions sync task creation failed.",
|
||||
)
|
||||
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
message="Successfully created the sync task.",
|
||||
message="Successfully created the doc permissions sync task.",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/admin/cc-pair/{cc_pair_id}/get-docs-sync-status")
|
||||
def get_docs_sync_status(
|
||||
cc_pair_id: int,
|
||||
_: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[DocumentSyncStatus]:
|
||||
all_docs_for_cc_pair = get_documents_for_cc_pair(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
return [DocumentSyncStatus.from_model(doc) for doc in all_docs_for_cc_pair]
|
||||
|
||||
|
||||
@router.put("/connector/{connector_id}/credential/{credential_id}")
|
||||
def associate_credential_to_connector(
|
||||
connector_id: int,
|
||||
@@ -390,6 +387,7 @@ def associate_credential_to_connector(
|
||||
user=user,
|
||||
target_group_ids=metadata.groups,
|
||||
object_is_public=metadata.access_type == AccessType.PUBLIC,
|
||||
object_is_perm_sync=metadata.access_type == AccessType.SYNC,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -81,7 +81,6 @@ from danswer.db.index_attempt import get_latest_index_attempts_by_status
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import UserRole
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
@@ -665,7 +664,8 @@ def create_connector_from_model(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
target_group_ids=connector_data.groups,
|
||||
object_is_public=connector_data.is_public,
|
||||
object_is_public=connector_data.access_type == AccessType.PUBLIC,
|
||||
object_is_perm_sync=connector_data.access_type == AccessType.SYNC,
|
||||
)
|
||||
connector_base = connector_data.to_connector_base()
|
||||
return create_connector(
|
||||
@@ -683,32 +683,31 @@ def create_connector_with_mock_credential(
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse:
|
||||
if user and user.role != UserRole.ADMIN:
|
||||
if connector_data.is_public:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="User does not have permission to create public credentials",
|
||||
)
|
||||
if not connector_data.groups:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Curators must specify 1+ groups",
|
||||
)
|
||||
fetch_ee_implementation_or_noop(
|
||||
"danswer.db.user_group", "validate_user_creation_permissions", None
|
||||
)(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
target_group_ids=connector_data.groups,
|
||||
object_is_public=connector_data.access_type == AccessType.PUBLIC,
|
||||
object_is_perm_sync=connector_data.access_type == AccessType.SYNC,
|
||||
)
|
||||
try:
|
||||
_validate_connector_allowed(connector_data.source)
|
||||
connector_response = create_connector(
|
||||
db_session=db_session, connector_data=connector_data
|
||||
db_session=db_session,
|
||||
connector_data=connector_data,
|
||||
)
|
||||
|
||||
mock_credential = CredentialBase(
|
||||
credential_json={}, admin_public=True, source=connector_data.source
|
||||
credential_json={},
|
||||
admin_public=True,
|
||||
source=connector_data.source,
|
||||
)
|
||||
credential = create_credential(
|
||||
mock_credential, user=user, db_session=db_session
|
||||
)
|
||||
|
||||
access_type = (
|
||||
AccessType.PUBLIC if connector_data.is_public else AccessType.PRIVATE
|
||||
credential_data=mock_credential,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
response = add_credential_to_connector(
|
||||
@@ -716,7 +715,7 @@ def create_connector_with_mock_credential(
|
||||
user=user,
|
||||
connector_id=cast(int, connector_response.id), # will aways be an int
|
||||
credential_id=credential.id,
|
||||
access_type=access_type,
|
||||
access_type=connector_data.access_type,
|
||||
cc_pair_name=connector_data.name,
|
||||
groups=connector_data.groups,
|
||||
)
|
||||
@@ -741,7 +740,8 @@ def update_connector_from_model(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
target_group_ids=connector_data.groups,
|
||||
object_is_public=connector_data.is_public,
|
||||
object_is_public=connector_data.access_type == AccessType.PUBLIC,
|
||||
object_is_perm_sync=connector_data.access_type == AccessType.SYNC,
|
||||
)
|
||||
connector_base = connector_data.to_connector_base()
|
||||
except ValueError as e:
|
||||
|
||||
@@ -14,6 +14,7 @@ from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import Document as DbDocument
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexAttemptError as DbIndexAttemptError
|
||||
from danswer.db.models import IndexingStatus
|
||||
@@ -21,6 +22,20 @@ from danswer.db.models import TaskStatus
|
||||
from danswer.server.utils import mask_credential_dict
|
||||
|
||||
|
||||
class DocumentSyncStatus(BaseModel):
|
||||
doc_id: str
|
||||
last_synced: datetime | None
|
||||
last_modified: datetime | None
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, doc: DbDocument) -> "DocumentSyncStatus":
|
||||
return DocumentSyncStatus(
|
||||
doc_id=doc.id,
|
||||
last_synced=doc.last_synced,
|
||||
last_modified=doc.last_modified,
|
||||
)
|
||||
|
||||
|
||||
class DocumentInfo(BaseModel):
|
||||
num_chunks: int
|
||||
num_tokens: int
|
||||
@@ -49,11 +64,11 @@ class ConnectorBase(BaseModel):
|
||||
|
||||
|
||||
class ConnectorUpdateRequest(ConnectorBase):
|
||||
is_public: bool = True
|
||||
access_type: AccessType
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
|
||||
def to_connector_base(self) -> ConnectorBase:
|
||||
return ConnectorBase(**self.model_dump(exclude={"is_public", "groups"}))
|
||||
return ConnectorBase(**self.model_dump(exclude={"access_type", "groups"}))
|
||||
|
||||
|
||||
class ConnectorSnapshot(ConnectorBase):
|
||||
@@ -222,6 +237,8 @@ class CCPairFullInfo(BaseModel):
|
||||
is_editable_for_current_user: bool
|
||||
deletion_failure_message: str | None
|
||||
indexing: bool
|
||||
creator: UUID | None
|
||||
creator_email: str | None
|
||||
|
||||
@classmethod
|
||||
def from_models(
|
||||
@@ -267,6 +284,10 @@ class CCPairFullInfo(BaseModel):
|
||||
is_editable_for_current_user=is_editable_for_current_user,
|
||||
deletion_failure_message=cc_pair_model.deletion_failure_message,
|
||||
indexing=indexing,
|
||||
creator=cc_pair_model.creator_id,
|
||||
creator_email=cc_pair_model.creator.email
|
||||
if cc_pair_model.creator
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_limited_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.configs.constants import NotificationType
|
||||
@@ -272,7 +273,7 @@ def list_personas(
|
||||
@basic_router.get("/{persona_id}")
|
||||
def get_persona(
|
||||
persona_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
user: User | None = Depends(current_limited_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PersonaSnapshot:
|
||||
return PersonaSnapshot.from_model(
|
||||
|
||||
@@ -630,31 +630,25 @@ def update_user_assistant_list(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_assistant_list(
|
||||
def update_assistant_visibility(
|
||||
preferences: UserPreferences, assistant_id: int, show: bool
|
||||
) -> UserPreferences:
|
||||
visible_assistants = preferences.visible_assistants or []
|
||||
hidden_assistants = preferences.hidden_assistants or []
|
||||
chosen_assistants = preferences.chosen_assistants or []
|
||||
|
||||
if show:
|
||||
if assistant_id not in visible_assistants:
|
||||
visible_assistants.append(assistant_id)
|
||||
if assistant_id in hidden_assistants:
|
||||
hidden_assistants.remove(assistant_id)
|
||||
if assistant_id not in chosen_assistants:
|
||||
chosen_assistants.append(assistant_id)
|
||||
else:
|
||||
if assistant_id in visible_assistants:
|
||||
visible_assistants.remove(assistant_id)
|
||||
if assistant_id not in hidden_assistants:
|
||||
hidden_assistants.append(assistant_id)
|
||||
if assistant_id in chosen_assistants:
|
||||
chosen_assistants.remove(assistant_id)
|
||||
|
||||
preferences.visible_assistants = visible_assistants
|
||||
preferences.hidden_assistants = hidden_assistants
|
||||
preferences.chosen_assistants = chosen_assistants
|
||||
return preferences
|
||||
|
||||
|
||||
@@ -670,15 +664,23 @@ def update_user_assistant_visibility(
|
||||
store = get_kv_store()
|
||||
no_auth_user = fetch_no_auth_user(store)
|
||||
preferences = no_auth_user.preferences
|
||||
updated_preferences = update_assistant_list(preferences, assistant_id, show)
|
||||
updated_preferences = update_assistant_visibility(
|
||||
preferences, assistant_id, show
|
||||
)
|
||||
if updated_preferences.chosen_assistants is not None:
|
||||
updated_preferences.chosen_assistants.append(assistant_id)
|
||||
|
||||
set_no_auth_user_preferences(store, updated_preferences)
|
||||
return
|
||||
else:
|
||||
raise RuntimeError("This should never happen")
|
||||
|
||||
user_preferences = UserInfo.from_model(user).preferences
|
||||
updated_preferences = update_assistant_list(user_preferences, assistant_id, show)
|
||||
|
||||
updated_preferences = update_assistant_visibility(
|
||||
user_preferences, assistant_id, show
|
||||
)
|
||||
if updated_preferences.chosen_assistants is not None:
|
||||
updated_preferences.chosen_assistants.append(assistant_id)
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == user.id) # type: ignore
|
||||
|
||||
273
backend/danswer/server/openai_assistants_api/asssistants_api.py
Normal file
273
backend/danswer/server/openai_assistants_api/asssistants_api.py
Normal file
@@ -0,0 +1,273 @@
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_persona_by_id
|
||||
from danswer.db.persona import get_personas
|
||||
from danswer.db.persona import mark_persona_as_deleted
|
||||
from danswer.db.persona import upsert_persona
|
||||
from danswer.db.persona import upsert_prompt
|
||||
from danswer.db.tools import get_tool_by_name
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
router = APIRouter(prefix="/assistants")
|
||||
|
||||
|
||||
# Base models
|
||||
class AssistantObject(BaseModel):
|
||||
id: int
|
||||
object: str = "assistant"
|
||||
created_at: int
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
model: str
|
||||
instructions: Optional[str] = None
|
||||
tools: list[dict[str, Any]]
|
||||
file_ids: list[str]
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class CreateAssistantRequest(BaseModel):
|
||||
model: str
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
instructions: Optional[str] = None
|
||||
tools: Optional[list[dict[str, Any]]] = None
|
||||
file_ids: Optional[list[str]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class ModifyAssistantRequest(BaseModel):
|
||||
model: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
instructions: Optional[str] = None
|
||||
tools: Optional[list[dict[str, Any]]] = None
|
||||
file_ids: Optional[list[str]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class DeleteAssistantResponse(BaseModel):
|
||||
id: int
|
||||
object: str = "assistant.deleted"
|
||||
deleted: bool
|
||||
|
||||
|
||||
class ListAssistantsResponse(BaseModel):
|
||||
object: str = "list"
|
||||
data: list[AssistantObject]
|
||||
first_id: Optional[int] = None
|
||||
last_id: Optional[int] = None
|
||||
has_more: bool
|
||||
|
||||
|
||||
def persona_to_assistant(persona: Persona) -> AssistantObject:
|
||||
return AssistantObject(
|
||||
id=persona.id,
|
||||
created_at=0,
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
model=persona.llm_model_version_override or "gpt-3.5-turbo",
|
||||
instructions=persona.prompts[0].system_prompt if persona.prompts else None,
|
||||
tools=[
|
||||
{
|
||||
"type": tool.display_name,
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"schema": tool.openapi_schema,
|
||||
},
|
||||
}
|
||||
for tool in persona.tools
|
||||
],
|
||||
file_ids=[], # Assuming no file support for now
|
||||
metadata={}, # Assuming no metadata for now
|
||||
)
|
||||
|
||||
|
||||
# API endpoints
|
||||
@router.post("")
|
||||
def create_assistant(
|
||||
request: CreateAssistantRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> AssistantObject:
|
||||
prompt = None
|
||||
if request.instructions:
|
||||
prompt = upsert_prompt(
|
||||
user=user,
|
||||
name=f"Prompt for {request.name or 'New Assistant'}",
|
||||
description="Auto-generated prompt",
|
||||
system_prompt=request.instructions,
|
||||
task_prompt="",
|
||||
include_citations=True,
|
||||
datetime_aware=True,
|
||||
personas=[],
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
tool_ids = []
|
||||
for tool in request.tools or []:
|
||||
tool_type = tool.get("type")
|
||||
if not tool_type:
|
||||
continue
|
||||
|
||||
try:
|
||||
tool_db = get_tool_by_name(tool_type, db_session)
|
||||
tool_ids.append(tool_db.id)
|
||||
except ValueError:
|
||||
# Skip tools that don't exist in the database
|
||||
logger.error(f"Tool {tool_type} not found in database")
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Tool {tool_type} not found in database"
|
||||
)
|
||||
|
||||
persona = upsert_persona(
|
||||
user=user,
|
||||
name=request.name or f"Assistant-{uuid4()}",
|
||||
description=request.description or "",
|
||||
num_chunks=25,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=request.model,
|
||||
starter_messages=None,
|
||||
is_public=False,
|
||||
db_session=db_session,
|
||||
prompt_ids=[prompt.id] if prompt else [0],
|
||||
document_set_ids=[],
|
||||
tool_ids=tool_ids,
|
||||
icon_color=None,
|
||||
icon_shape=None,
|
||||
is_visible=True,
|
||||
)
|
||||
|
||||
if prompt:
|
||||
prompt.personas = [persona]
|
||||
db_session.commit()
|
||||
|
||||
return persona_to_assistant(persona)
|
||||
|
||||
|
||||
""
|
||||
|
||||
|
||||
@router.get("/{assistant_id}")
|
||||
def retrieve_assistant(
|
||||
assistant_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> AssistantObject:
|
||||
try:
|
||||
persona = get_persona_by_id(
|
||||
persona_id=assistant_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
except ValueError:
|
||||
persona = None
|
||||
|
||||
if not persona:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
return persona_to_assistant(persona)
|
||||
|
||||
|
||||
@router.post("/{assistant_id}")
|
||||
def modify_assistant(
|
||||
assistant_id: int,
|
||||
request: ModifyAssistantRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> AssistantObject:
|
||||
persona = get_persona_by_id(
|
||||
persona_id=assistant_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=True,
|
||||
)
|
||||
if not persona:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(persona, key, value)
|
||||
|
||||
if "instructions" in update_data and persona.prompts:
|
||||
persona.prompts[0].system_prompt = update_data["instructions"]
|
||||
|
||||
db_session.commit()
|
||||
return persona_to_assistant(persona)
|
||||
|
||||
|
||||
@router.delete("/{assistant_id}")
|
||||
def delete_assistant(
|
||||
assistant_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> DeleteAssistantResponse:
|
||||
try:
|
||||
mark_persona_as_deleted(
|
||||
persona_id=int(assistant_id),
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
return DeleteAssistantResponse(id=assistant_id, deleted=True)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
|
||||
|
||||
@router.get("")
|
||||
def list_assistants(
|
||||
limit: int = Query(20, le=100),
|
||||
order: str = Query("desc", regex="^(asc|desc)$"),
|
||||
after: Optional[int] = None,
|
||||
before: Optional[int] = None,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ListAssistantsResponse:
|
||||
personas = list(
|
||||
get_personas(
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
get_editable=False,
|
||||
joinedload_all=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Apply filtering based on after and before
|
||||
if after:
|
||||
personas = [p for p in personas if p.id > int(after)]
|
||||
if before:
|
||||
personas = [p for p in personas if p.id < int(before)]
|
||||
|
||||
# Apply ordering
|
||||
personas.sort(key=lambda p: p.id, reverse=(order == "desc"))
|
||||
|
||||
# Apply limit
|
||||
personas = personas[:limit]
|
||||
|
||||
assistants = [persona_to_assistant(p) for p in personas]
|
||||
|
||||
return ListAssistantsResponse(
|
||||
data=assistants,
|
||||
first_id=assistants[0].id if assistants else None,
|
||||
last_id=assistants[-1].id if assistants else None,
|
||||
has_more=len(personas) == limit,
|
||||
)
|
||||
@@ -0,0 +1,19 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from danswer.server.openai_assistants_api.asssistants_api import (
|
||||
router as assistants_router,
|
||||
)
|
||||
from danswer.server.openai_assistants_api.messages_api import router as messages_router
|
||||
from danswer.server.openai_assistants_api.runs_api import router as runs_router
|
||||
from danswer.server.openai_assistants_api.threads_api import router as threads_router
|
||||
|
||||
|
||||
def get_full_openai_assistants_api_router() -> APIRouter:
|
||||
router = APIRouter(prefix="/openai-assistants")
|
||||
|
||||
router.include_router(assistants_router)
|
||||
router.include_router(runs_router)
|
||||
router.include_router(threads_router)
|
||||
router.include_router(messages_router)
|
||||
|
||||
return router
|
||||
235
backend/danswer/server/openai_assistants_api/messages_api.py
Normal file
235
backend/danswer/server/openai_assistants_api/messages_api.py
Normal file
@@ -0,0 +1,235 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import get_chat_message
|
||||
from danswer.db.chat import get_chat_messages_by_session
|
||||
from danswer.db.chat import get_chat_session_by_id
|
||||
from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
|
||||
router = APIRouter(prefix="")
|
||||
|
||||
|
||||
Role = Literal["user", "assistant"]
|
||||
|
||||
|
||||
class MessageContent(BaseModel):
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"msg_{uuid.uuid4()}")
|
||||
object: Literal["thread.message"] = "thread.message"
|
||||
created_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
|
||||
thread_id: str
|
||||
role: Role
|
||||
content: list[MessageContent]
|
||||
file_ids: list[str] = []
|
||||
assistant_id: Optional[str] = None
|
||||
run_id: Optional[str] = None
|
||||
metadata: Optional[dict[str, Any]] = None # Change this line to use dict[str, Any]
|
||||
|
||||
|
||||
class CreateMessageRequest(BaseModel):
|
||||
role: Role
|
||||
content: str
|
||||
file_ids: list[str] = []
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
class ListMessagesResponse(BaseModel):
|
||||
object: Literal["list"] = "list"
|
||||
data: list[Message]
|
||||
first_id: str
|
||||
last_id: str
|
||||
has_more: bool
|
||||
|
||||
|
||||
@router.post("/threads/{thread_id}/messages")
|
||||
def create_message(
|
||||
thread_id: str,
|
||||
message: CreateMessageRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Message:
|
||||
user_id = user.id if user else None
|
||||
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=uuid.UUID(thread_id),
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Chat session not found")
|
||||
|
||||
chat_messages = get_chat_messages_by_session(
|
||||
chat_session_id=chat_session.id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
latest_message = (
|
||||
chat_messages[-1]
|
||||
if chat_messages
|
||||
else get_or_create_root_message(chat_session.id, db_session)
|
||||
)
|
||||
|
||||
new_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=latest_message,
|
||||
message=message.content,
|
||||
prompt_id=chat_session.persona.prompts[0].id,
|
||||
token_count=check_number_of_tokens(message.content),
|
||||
message_type=(
|
||||
MessageType.USER if message.role == "user" else MessageType.ASSISTANT
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return Message(
|
||||
id=str(new_message.id),
|
||||
thread_id=thread_id,
|
||||
role="user",
|
||||
content=[MessageContent(type="text", text=message.content)],
|
||||
file_ids=message.file_ids,
|
||||
metadata=message.metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/messages")
|
||||
def list_messages(
|
||||
thread_id: str,
|
||||
limit: int = 20,
|
||||
order: Literal["asc", "desc"] = "desc",
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ListMessagesResponse:
|
||||
user_id = user.id if user else None
|
||||
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=uuid.UUID(thread_id),
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Chat session not found")
|
||||
|
||||
messages = get_chat_messages_by_session(
|
||||
chat_session_id=chat_session.id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Apply filtering based on after and before
|
||||
if after:
|
||||
messages = [m for m in messages if str(m.id) >= after]
|
||||
if before:
|
||||
messages = [m for m in messages if str(m.id) <= before]
|
||||
|
||||
# Apply ordering
|
||||
messages = sorted(messages, key=lambda m: m.id, reverse=(order == "desc"))
|
||||
|
||||
# Apply limit
|
||||
messages = messages[:limit]
|
||||
|
||||
data = [
|
||||
Message(
|
||||
id=str(m.id),
|
||||
thread_id=thread_id,
|
||||
role="user" if m.message_type == "user" else "assistant",
|
||||
content=[MessageContent(type="text", text=m.message)],
|
||||
created_at=int(m.time_sent.timestamp()),
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
|
||||
return ListMessagesResponse(
|
||||
data=data,
|
||||
first_id=str(data[0].id) if data else "",
|
||||
last_id=str(data[-1].id) if data else "",
|
||||
has_more=len(messages) == limit,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/messages/{message_id}")
|
||||
def retrieve_message(
|
||||
thread_id: str,
|
||||
message_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Message:
|
||||
user_id = user.id if user else None
|
||||
|
||||
try:
|
||||
chat_message = get_chat_message(
|
||||
chat_message_id=message_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Message not found")
|
||||
|
||||
return Message(
|
||||
id=str(chat_message.id),
|
||||
thread_id=thread_id,
|
||||
role="user" if chat_message.message_type == "user" else "assistant",
|
||||
content=[MessageContent(type="text", text=chat_message.message)],
|
||||
created_at=int(chat_message.time_sent.timestamp()),
|
||||
)
|
||||
|
||||
|
||||
class ModifyMessageRequest(BaseModel):
|
||||
metadata: dict
|
||||
|
||||
|
||||
@router.post("/threads/{thread_id}/messages/{message_id}")
|
||||
def modify_message(
|
||||
thread_id: str,
|
||||
message_id: int,
|
||||
request: ModifyMessageRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Message:
|
||||
user_id = user.id if user else None
|
||||
|
||||
try:
|
||||
chat_message = get_chat_message(
|
||||
chat_message_id=message_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Message not found")
|
||||
|
||||
# Update metadata
|
||||
# TODO: Uncomment this once we have metadata in the chat message
|
||||
# chat_message.metadata = request.metadata
|
||||
# db_session.commit()
|
||||
|
||||
return Message(
|
||||
id=str(chat_message.id),
|
||||
thread_id=thread_id,
|
||||
role="user" if chat_message.message_type == "user" else "assistant",
|
||||
content=[MessageContent(type="text", text=chat_message.message)],
|
||||
created_at=int(chat_message.time_sent.timestamp()),
|
||||
metadata=request.metadata,
|
||||
)
|
||||
344
backend/danswer/server/openai_assistants_api/runs_api.py
Normal file
344
backend/danswer/server/openai_assistants_api/runs_api.py
Normal file
@@ -0,0 +1,344 @@
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import BackgroundTasks
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.chat.process_message import stream_chat_message_objects
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import get_chat_message
|
||||
from danswer.db.chat import get_chat_messages_by_session
|
||||
from danswer.db.chat import get_chat_session_by_id
|
||||
from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import User
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class RunRequest(BaseModel):
|
||||
assistant_id: int
|
||||
model: Optional[str] = None
|
||||
instructions: Optional[str] = None
|
||||
additional_instructions: Optional[str] = None
|
||||
tools: Optional[list[dict]] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
RunStatus = Literal[
|
||||
"queued",
|
||||
"in_progress",
|
||||
"requires_action",
|
||||
"cancelling",
|
||||
"cancelled",
|
||||
"failed",
|
||||
"completed",
|
||||
"expired",
|
||||
]
|
||||
|
||||
|
||||
class RunResponse(BaseModel):
|
||||
id: str
|
||||
object: Literal["thread.run"]
|
||||
created_at: int
|
||||
assistant_id: int
|
||||
thread_id: UUID
|
||||
status: RunStatus
|
||||
started_at: Optional[int] = None
|
||||
expires_at: Optional[int] = None
|
||||
cancelled_at: Optional[int] = None
|
||||
failed_at: Optional[int] = None
|
||||
completed_at: Optional[int] = None
|
||||
last_error: Optional[dict] = None
|
||||
model: str
|
||||
instructions: str
|
||||
tools: list[dict]
|
||||
file_ids: list[str]
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
def process_run_in_background(
|
||||
message_id: int,
|
||||
parent_message_id: int,
|
||||
chat_session_id: UUID,
|
||||
assistant_id: int,
|
||||
instructions: str,
|
||||
tools: list[dict],
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
# Get the latest message in the chat session
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=chat_session_id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
search_tool_retrieval_details = RetrievalDetails()
|
||||
for tool in tools:
|
||||
if tool["type"] == SearchTool.__name__ and (
|
||||
retrieval_details := tool.get("retrieval_details")
|
||||
):
|
||||
search_tool_retrieval_details = RetrievalDetails.model_validate(
|
||||
retrieval_details
|
||||
)
|
||||
break
|
||||
|
||||
new_msg_req = CreateChatMessageRequest(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=int(parent_message_id) if parent_message_id else None,
|
||||
message=instructions,
|
||||
file_descriptors=[],
|
||||
prompt_id=chat_session.persona.prompts[0].id,
|
||||
search_doc_ids=None,
|
||||
retrieval_options=search_tool_retrieval_details, # Adjust as needed
|
||||
query_override=None,
|
||||
regenerate=None,
|
||||
llm_override=None,
|
||||
prompt_override=None,
|
||||
alternate_assistant_id=assistant_id,
|
||||
use_existing_user_message=True,
|
||||
existing_assistant_message_id=message_id,
|
||||
)
|
||||
|
||||
run_message = get_chat_message(message_id, user.id if user else None, db_session)
|
||||
try:
|
||||
for packet in stream_chat_message_objects(
|
||||
new_msg_req=new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
):
|
||||
if isinstance(packet, ChatMessageDetail):
|
||||
# Update the run status and message content
|
||||
run_message = get_chat_message(
|
||||
message_id, user.id if user else None, db_session
|
||||
)
|
||||
if run_message:
|
||||
# this handles cancelling
|
||||
if run_message.error:
|
||||
return
|
||||
|
||||
run_message.message = packet.message
|
||||
run_message.message_type = MessageType.ASSISTANT
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.exception("Error processing run in background")
|
||||
run_message.error = str(e)
|
||||
db_session.commit()
|
||||
return
|
||||
|
||||
db_session.refresh(run_message)
|
||||
if run_message.token_count == 0:
|
||||
run_message.error = "No tokens generated"
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@router.post("/threads/{thread_id}/runs")
|
||||
def create_run(
|
||||
thread_id: UUID,
|
||||
run_request: RunRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> RunResponse:
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=thread_id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
chat_messages = get_chat_messages_by_session(
|
||||
chat_session_id=chat_session.id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
latest_message = (
|
||||
chat_messages[-1]
|
||||
if chat_messages
|
||||
else get_or_create_root_message(chat_session.id, db_session)
|
||||
)
|
||||
|
||||
# Create a new "run" (chat message) in the session
|
||||
new_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=latest_message,
|
||||
message="",
|
||||
prompt_id=chat_session.persona.prompts[0].id,
|
||||
token_count=0,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
db_session.flush()
|
||||
latest_message.latest_child_message = new_message.id
|
||||
db_session.commit()
|
||||
|
||||
# Schedule the background task
|
||||
background_tasks.add_task(
|
||||
process_run_in_background,
|
||||
new_message.id,
|
||||
latest_message.id,
|
||||
chat_session.id,
|
||||
run_request.assistant_id,
|
||||
run_request.instructions or "",
|
||||
run_request.tools or [],
|
||||
user,
|
||||
db_session,
|
||||
)
|
||||
|
||||
return RunResponse(
|
||||
id=str(new_message.id),
|
||||
object="thread.run",
|
||||
created_at=int(new_message.time_sent.timestamp()),
|
||||
assistant_id=run_request.assistant_id,
|
||||
thread_id=chat_session.id,
|
||||
status="queued",
|
||||
model=run_request.model or "default_model",
|
||||
instructions=run_request.instructions or "",
|
||||
tools=run_request.tools or [],
|
||||
file_ids=[],
|
||||
metadata=run_request.metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/runs/{run_id}")
|
||||
def retrieve_run(
|
||||
thread_id: UUID,
|
||||
run_id: str,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> RunResponse:
|
||||
# Retrieve the chat message (which represents a "run" in DAnswer)
|
||||
chat_message = get_chat_message(
|
||||
chat_message_id=int(run_id), # Convert string run_id to int
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
if not chat_message:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
|
||||
chat_session = chat_message.chat_session
|
||||
|
||||
# Map DAnswer status to OpenAI status
|
||||
run_status: RunStatus = "queued"
|
||||
if chat_message.message:
|
||||
run_status = "in_progress"
|
||||
if chat_message.token_count != 0:
|
||||
run_status = "completed"
|
||||
if chat_message.error:
|
||||
run_status = "cancelled"
|
||||
|
||||
return RunResponse(
|
||||
id=run_id,
|
||||
object="thread.run",
|
||||
created_at=int(chat_message.time_sent.timestamp()),
|
||||
assistant_id=chat_session.persona_id or 0,
|
||||
thread_id=chat_session.id,
|
||||
status=run_status,
|
||||
started_at=int(chat_message.time_sent.timestamp()),
|
||||
completed_at=(
|
||||
int(chat_message.time_sent.timestamp()) if chat_message.message else None
|
||||
),
|
||||
model=chat_session.current_alternate_model or "default_model",
|
||||
instructions="", # DAnswer doesn't store per-message instructions
|
||||
tools=[], # DAnswer doesn't have a direct equivalent for tools
|
||||
file_ids=(
|
||||
[file["id"] for file in chat_message.files] if chat_message.files else []
|
||||
),
|
||||
metadata=None, # DAnswer doesn't store metadata for individual messages
|
||||
)
|
||||
|
||||
|
||||
@router.post("/threads/{thread_id}/runs/{run_id}/cancel")
|
||||
def cancel_run(
|
||||
thread_id: UUID,
|
||||
run_id: str,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> RunResponse:
|
||||
# In DAnswer, we don't have a direct equivalent to cancelling a run
|
||||
# We'll simulate it by marking the message as "cancelled"
|
||||
chat_message = (
|
||||
db_session.query(ChatMessage).filter(ChatMessage.id == run_id).first()
|
||||
)
|
||||
if not chat_message:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
|
||||
chat_message.error = "Cancelled"
|
||||
db_session.commit()
|
||||
|
||||
return retrieve_run(thread_id, run_id, user, db_session)
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/runs")
|
||||
def list_runs(
|
||||
thread_id: UUID,
|
||||
limit: int = 20,
|
||||
order: Literal["asc", "desc"] = "desc",
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[RunResponse]:
|
||||
# In DAnswer, we'll treat each message in a chat session as a "run"
|
||||
chat_messages = get_chat_messages_by_session(
|
||||
chat_session_id=thread_id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Apply pagination
|
||||
if after:
|
||||
chat_messages = [msg for msg in chat_messages if str(msg.id) > after]
|
||||
if before:
|
||||
chat_messages = [msg for msg in chat_messages if str(msg.id) < before]
|
||||
|
||||
# Apply ordering
|
||||
chat_messages = sorted(
|
||||
chat_messages, key=lambda msg: msg.time_sent, reverse=(order == "desc")
|
||||
)
|
||||
|
||||
# Apply limit
|
||||
chat_messages = chat_messages[:limit]
|
||||
|
||||
return [
|
||||
retrieve_run(thread_id, str(msg.id), user, db_session) for msg in chat_messages
|
||||
]
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/runs/{run_id}/steps")
|
||||
def list_run_steps(
|
||||
run_id: str,
|
||||
limit: int = 20,
|
||||
order: Literal["asc", "desc"] = "desc",
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[dict]: # You may want to create a specific model for run steps
|
||||
# DAnswer doesn't have an equivalent to run steps
|
||||
# We'll return an empty list to maintain API compatibility
|
||||
return []
|
||||
|
||||
|
||||
# Additional helper functions can be added here if needed
|
||||
156
backend/danswer/server/openai_assistants_api/threads_api.py
Normal file
156
backend/danswer/server/openai_assistants_api/threads_api.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.db.chat import create_chat_session
|
||||
from danswer.db.chat import delete_chat_session
|
||||
from danswer.db.chat import get_chat_session_by_id
|
||||
from danswer.db.chat import get_chat_sessions_by_user
|
||||
from danswer.db.chat import update_chat_session
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.server.query_and_chat.models import ChatSessionDetails
|
||||
from danswer.server.query_and_chat.models import ChatSessionsResponse
|
||||
|
||||
router = APIRouter(prefix="/threads")
|
||||
|
||||
|
||||
# Models
|
||||
class Thread(BaseModel):
|
||||
id: UUID
|
||||
object: str = "thread"
|
||||
created_at: int
|
||||
metadata: Optional[dict[str, str]] = None
|
||||
|
||||
|
||||
class CreateThreadRequest(BaseModel):
|
||||
messages: Optional[list[dict]] = None
|
||||
metadata: Optional[dict[str, str]] = None
|
||||
|
||||
|
||||
class ModifyThreadRequest(BaseModel):
|
||||
metadata: Optional[dict[str, str]] = None
|
||||
|
||||
|
||||
# API Endpoints
|
||||
@router.post("")
|
||||
def create_thread(
|
||||
request: CreateThreadRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Thread:
|
||||
user_id = user.id if user else None
|
||||
new_chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description="", # Leave the naming till later to prevent delay
|
||||
user_id=user_id,
|
||||
persona_id=0,
|
||||
)
|
||||
|
||||
return Thread(
|
||||
id=new_chat_session.id,
|
||||
created_at=int(new_chat_session.time_created.timestamp()),
|
||||
metadata=request.metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{thread_id}")
|
||||
def retrieve_thread(
|
||||
thread_id: UUID,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Thread:
|
||||
user_id = user.id if user else None
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=thread_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
return Thread(
|
||||
id=chat_session.id,
|
||||
created_at=int(chat_session.time_created.timestamp()),
|
||||
metadata=None, # Assuming we don't store metadata in our current implementation
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{thread_id}")
|
||||
def modify_thread(
|
||||
thread_id: UUID,
|
||||
request: ModifyThreadRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Thread:
|
||||
user_id = user.id if user else None
|
||||
try:
|
||||
chat_session = update_chat_session(
|
||||
db_session=db_session,
|
||||
user_id=user_id,
|
||||
chat_session_id=thread_id,
|
||||
description=None, # Not updating description
|
||||
sharing_status=None, # Not updating sharing status
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
return Thread(
|
||||
id=chat_session.id,
|
||||
created_at=int(chat_session.time_created.timestamp()),
|
||||
metadata=request.metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{thread_id}")
|
||||
def delete_thread(
|
||||
thread_id: UUID,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict:
|
||||
user_id = user.id if user else None
|
||||
try:
|
||||
delete_chat_session(
|
||||
user_id=user_id,
|
||||
chat_session_id=thread_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
return {"id": str(thread_id), "object": "thread.deleted", "deleted": True}
|
||||
|
||||
|
||||
@router.get("")
|
||||
def list_threads(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionsResponse:
|
||||
user_id = user.id if user else None
|
||||
chat_sessions = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return ChatSessionsResponse(
|
||||
sessions=[
|
||||
ChatSessionDetails(
|
||||
id=chat.id,
|
||||
name=chat.description,
|
||||
persona_id=chat.persona_id,
|
||||
time_created=chat.time_created.isoformat(),
|
||||
shared_status=chat.shared_status,
|
||||
folder_id=chat.folder_id,
|
||||
current_alternate_model=chat.current_alternate_model,
|
||||
)
|
||||
for chat in chat_sessions
|
||||
]
|
||||
)
|
||||
@@ -18,6 +18,7 @@ from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_limited_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.chat.chat_utils import create_chat_chain
|
||||
from danswer.chat.chat_utils import extract_headers
|
||||
@@ -309,7 +310,7 @@ async def is_connected(request: Request) -> Callable[[], bool]:
|
||||
def handle_new_chat_message(
|
||||
chat_message_req: CreateChatMessageRequest,
|
||||
request: Request,
|
||||
user: User | None = Depends(current_user),
|
||||
user: User | None = Depends(current_limited_user),
|
||||
_: None = Depends(check_token_rate_limits),
|
||||
is_connected_func: Callable[[], bool] = Depends(is_connected),
|
||||
) -> StreamingResponse:
|
||||
@@ -347,7 +348,6 @@ def handle_new_chat_message(
|
||||
for packet in stream_chat_message(
|
||||
new_msg_req=chat_message_req,
|
||||
user=user,
|
||||
use_existing_user_message=chat_message_req.use_existing_user_message,
|
||||
litellm_additional_headers=extract_headers(
|
||||
request.headers, LITELLM_PASS_THROUGH_HEADERS
|
||||
),
|
||||
@@ -392,7 +392,7 @@ def set_message_as_latest(
|
||||
@router.post("/create-chat-message-feedback")
|
||||
def create_chat_feedback(
|
||||
feedback: ChatFeedbackRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
user: User | None = Depends(current_limited_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
user_id = user.id if user else None
|
||||
|
||||
@@ -108,6 +108,9 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
# used for seeded chats to kick off the generation of an AI answer
|
||||
use_existing_user_message: bool = False
|
||||
|
||||
# used for "OpenAI Assistants API"
|
||||
existing_assistant_message_id: int | None = None
|
||||
|
||||
# forces the LLM to return a structured response, see
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
@@ -9,6 +9,7 @@ from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_limited_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MessageType
|
||||
@@ -262,7 +263,7 @@ def stream_query_validation(
|
||||
@basic_router.post("/stream-answer-with-quote")
|
||||
def get_answer_with_quote(
|
||||
query_request: DirectQARequest,
|
||||
user: User = Depends(current_user),
|
||||
user: User = Depends(current_limited_user),
|
||||
_: None = Depends(check_token_rate_limits),
|
||||
) -> StreamingResponse:
|
||||
query = query_request.messages[0].message
|
||||
|
||||
@@ -59,7 +59,9 @@ from shared_configs.model_server_models import SupportedEmbeddingModel
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def setup_danswer(db_session: Session, tenant_id: str | None) -> None:
|
||||
def setup_danswer(
|
||||
db_session: Session, tenant_id: str | None, cohere_enabled: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Setup Danswer for a particular tenant. In the Single Tenant case, it will set it up for the default schema
|
||||
on server startup. In the MT case, it will be called when the tenant is created.
|
||||
@@ -148,7 +150,7 @@ def setup_danswer(db_session: Session, tenant_id: str | None) -> None:
|
||||
# update multipass indexing setting based on GPU availability
|
||||
update_default_multipass_indexing(db_session)
|
||||
|
||||
seed_initial_documents(db_session, tenant_id)
|
||||
seed_initial_documents(db_session, tenant_id, cohere_enabled)
|
||||
|
||||
|
||||
def translate_saved_search_settings(db_session: Session) -> None:
|
||||
|
||||
255
backend/danswer/tools/tool_constructor.py
Normal file
255
backend/danswer/tools/tool_constructor.py
Normal file
@@ -0,0 +1,255 @@
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
|
||||
from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
|
||||
from danswer.configs.chat_configs import BING_API_KEY
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import CitationConfig
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.tools.built_in_tools import get_built_in_tool_by_id
|
||||
from danswer.tools.models import DynamicSchemaInfo
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchTool,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from danswer.tools.utils import compute_all_tool_tokens
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
from danswer.utils.headers import header_dict_to_header_list
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig:
|
||||
"""Helper function to get image generation LLM config based on available providers"""
|
||||
if llm and llm.config.api_key and llm.config.model_provider == "openai":
|
||||
return LLMConfig(
|
||||
model_provider=llm.config.model_provider,
|
||||
model_name="dall-e-3",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=llm.config.api_key,
|
||||
api_base=llm.config.api_base,
|
||||
api_version=llm.config.api_version,
|
||||
)
|
||||
|
||||
if llm.config.model_provider == "azure" and AZURE_DALLE_API_KEY is not None:
|
||||
return LLMConfig(
|
||||
model_provider="azure",
|
||||
model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=AZURE_DALLE_API_KEY,
|
||||
api_base=AZURE_DALLE_API_BASE,
|
||||
api_version=AZURE_DALLE_API_VERSION,
|
||||
)
|
||||
|
||||
# Fallback to checking for OpenAI provider in database
|
||||
llm_providers = fetch_existing_llm_providers(db_session)
|
||||
openai_provider = next(
|
||||
iter(
|
||||
[
|
||||
llm_provider
|
||||
for llm_provider in llm_providers
|
||||
if llm_provider.provider == "openai"
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not openai_provider or not openai_provider.api_key:
|
||||
raise ValueError("Image generation tool requires an OpenAI API key")
|
||||
|
||||
return LLMConfig(
|
||||
model_provider=openai_provider.provider,
|
||||
model_name="dall-e-3",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=openai_provider.api_key,
|
||||
api_base=openai_provider.api_base,
|
||||
api_version=openai_provider.api_version,
|
||||
)
|
||||
|
||||
|
||||
class SearchToolConfig(BaseModel):
|
||||
answer_style_config: AnswerStyleConfig = Field(
|
||||
default_factory=lambda: AnswerStyleConfig(citation_config=CitationConfig())
|
||||
)
|
||||
document_pruning_config: DocumentPruningConfig = Field(
|
||||
default_factory=DocumentPruningConfig
|
||||
)
|
||||
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
|
||||
selected_sections: list[InferenceSection] | None = None
|
||||
chunks_above: int = 0
|
||||
chunks_below: int = 0
|
||||
full_doc: bool = False
|
||||
latest_query_files: list[InMemoryChatFile] | None = None
|
||||
|
||||
|
||||
class InternetSearchToolConfig(BaseModel):
|
||||
answer_style_config: AnswerStyleConfig = Field(
|
||||
default_factory=lambda: AnswerStyleConfig(
|
||||
citation_config=CitationConfig(all_docs_useful=True)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ImageGenerationToolConfig(BaseModel):
|
||||
additional_headers: dict[str, str] | None = None
|
||||
|
||||
|
||||
class CustomToolConfig(BaseModel):
|
||||
chat_session_id: UUID | None = None
|
||||
message_id: int | None = None
|
||||
additional_headers: dict[str, str] | None = None
|
||||
|
||||
|
||||
def construct_tools(
|
||||
persona: Persona,
|
||||
prompt_config: PromptConfig,
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
llm: LLM,
|
||||
fast_llm: LLM,
|
||||
search_tool_config: SearchToolConfig | None = None,
|
||||
internet_search_tool_config: InternetSearchToolConfig | None = None,
|
||||
image_generation_tool_config: ImageGenerationToolConfig | None = None,
|
||||
custom_tool_config: CustomToolConfig | None = None,
|
||||
) -> dict[int, list[Tool]]:
|
||||
"""Constructs tools based on persona configuration and available APIs"""
|
||||
tool_dict: dict[int, list[Tool]] = {}
|
||||
|
||||
for db_tool_model in persona.tools:
|
||||
if db_tool_model.in_code_tool_id:
|
||||
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
|
||||
|
||||
# Handle Search Tool
|
||||
if tool_cls.__name__ == SearchTool.__name__:
|
||||
if not search_tool_config:
|
||||
search_tool_config = SearchToolConfig()
|
||||
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
persona=persona,
|
||||
retrieval_options=search_tool_config.retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=search_tool_config.document_pruning_config,
|
||||
answer_style_config=search_tool_config.answer_style_config,
|
||||
selected_sections=search_tool_config.selected_sections,
|
||||
chunks_above=search_tool_config.chunks_above,
|
||||
chunks_below=search_tool_config.chunks_below,
|
||||
full_doc=search_tool_config.full_doc,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
),
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [search_tool]
|
||||
|
||||
# Handle Image Generation Tool
|
||||
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
||||
if not image_generation_tool_config:
|
||||
image_generation_tool_config = ImageGenerationToolConfig()
|
||||
|
||||
img_generation_llm_config = _get_image_generation_config(
|
||||
llm, db_session
|
||||
)
|
||||
|
||||
tool_dict[db_tool_model.id] = [
|
||||
ImageGenerationTool(
|
||||
api_key=cast(str, img_generation_llm_config.api_key),
|
||||
api_base=img_generation_llm_config.api_base,
|
||||
api_version=img_generation_llm_config.api_version,
|
||||
additional_headers=image_generation_tool_config.additional_headers,
|
||||
model=img_generation_llm_config.model_name,
|
||||
)
|
||||
]
|
||||
|
||||
# Handle Internet Search Tool
|
||||
elif tool_cls.__name__ == InternetSearchTool.__name__:
|
||||
if not internet_search_tool_config:
|
||||
internet_search_tool_config = InternetSearchToolConfig()
|
||||
|
||||
if not BING_API_KEY:
|
||||
raise ValueError(
|
||||
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
InternetSearchTool(
|
||||
api_key=BING_API_KEY,
|
||||
answer_style_config=internet_search_tool_config.answer_style_config,
|
||||
prompt_config=prompt_config,
|
||||
)
|
||||
]
|
||||
|
||||
# Handle custom tools
|
||||
elif db_tool_model.openapi_schema:
|
||||
if not custom_tool_config:
|
||||
custom_tool_config = CustomToolConfig()
|
||||
|
||||
tool_dict[db_tool_model.id] = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema_and_headers(
|
||||
db_tool_model.openapi_schema,
|
||||
dynamic_schema_info=DynamicSchemaInfo(
|
||||
chat_session_id=custom_tool_config.chat_session_id,
|
||||
message_id=custom_tool_config.message_id,
|
||||
),
|
||||
custom_headers=(db_tool_model.custom_headers or [])
|
||||
+ (
|
||||
header_dict_to_header_list(
|
||||
custom_tool_config.additional_headers or {}
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
|
||||
# factor in tool definition size when pruning
|
||||
if search_tool_config:
|
||||
search_tool_config.document_pruning_config.tool_num_tokens = (
|
||||
compute_all_tool_tokens(
|
||||
tools,
|
||||
get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
provider_type=llm.config.model_provider,
|
||||
),
|
||||
)
|
||||
)
|
||||
search_tool_config.document_pruning_config.using_tool_message = (
|
||||
explicit_tool_calling_supported(
|
||||
llm.config.model_provider, llm.config.model_name
|
||||
)
|
||||
)
|
||||
|
||||
return tool_dict
|
||||
@@ -21,8 +21,12 @@ pruning_ctx: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar(
|
||||
"pruning_ctx", default=dict()
|
||||
)
|
||||
|
||||
doc_permission_sync_ctx: contextvars.ContextVar[
|
||||
dict[str, Any]
|
||||
] = contextvars.ContextVar("doc_permission_sync_ctx", default=dict())
|
||||
|
||||
class IndexAttemptSingleton:
|
||||
|
||||
class TaskAttemptSingleton:
|
||||
"""Used to tell if this process is an indexing job, and if so what is the
|
||||
unique identifier for this indexing attempt. For things like the API server,
|
||||
main background job (scheduler), etc. this will not be used."""
|
||||
@@ -66,9 +70,10 @@ class DanswerLoggingAdapter(logging.LoggerAdapter):
|
||||
) -> tuple[str, MutableMapping[str, Any]]:
|
||||
# If this is an indexing job, add the attempt ID to the log message
|
||||
# This helps filter the logs for this specific indexing
|
||||
index_attempt_id = IndexAttemptSingleton.get_index_attempt_id()
|
||||
cc_pair_id = IndexAttemptSingleton.get_connector_credential_pair_id()
|
||||
index_attempt_id = TaskAttemptSingleton.get_index_attempt_id()
|
||||
cc_pair_id = TaskAttemptSingleton.get_connector_credential_pair_id()
|
||||
|
||||
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
|
||||
pruning_ctx_dict = pruning_ctx.get()
|
||||
if len(pruning_ctx_dict) > 0:
|
||||
if "request_id" in pruning_ctx_dict:
|
||||
@@ -76,6 +81,9 @@ class DanswerLoggingAdapter(logging.LoggerAdapter):
|
||||
|
||||
if "cc_pair_id" in pruning_ctx_dict:
|
||||
msg = f"[CC Pair: {pruning_ctx_dict['cc_pair_id']}] {msg}"
|
||||
elif len(doc_permission_sync_ctx_dict) > 0:
|
||||
if "request_id" in doc_permission_sync_ctx_dict:
|
||||
msg = f"[Doc Permissions Sync: {doc_permission_sync_ctx_dict['request_id']}] {msg}"
|
||||
else:
|
||||
if index_attempt_id is not None:
|
||||
msg = f"[Index Attempt: {index_attempt_id}] {msg}"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
|
||||
@@ -139,8 +140,19 @@ def fetch_ee_implementation_or_noop(
|
||||
Exception: If EE is enabled but the fetch fails.
|
||||
"""
|
||||
if not global_version.is_ee_version():
|
||||
return lambda *args, **kwargs: noop_return_value
|
||||
if inspect.iscoroutinefunction(noop_return_value):
|
||||
|
||||
async def async_noop(*args: Any, **kwargs: Any) -> Any:
|
||||
return await noop_return_value(*args, **kwargs)
|
||||
|
||||
return async_noop
|
||||
|
||||
else:
|
||||
|
||||
def sync_noop(*args: Any, **kwargs: Any) -> Any:
|
||||
return noop_return_value
|
||||
|
||||
return sync_noop
|
||||
try:
|
||||
return fetch_versioned_implementation(module, attribute)
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,32 +1,12 @@
|
||||
from danswer.background.celery.apps.primary import celery_app
|
||||
from danswer.background.task_name_builders import (
|
||||
name_sync_external_doc_permissions_task,
|
||||
)
|
||||
from danswer.background.task_utils import build_celery_task_wrapper
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.db.chat import delete_chat_sessions_older_than
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.server.settings.store import load_settings
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from ee.danswer.background.celery_utils import should_perform_chat_ttl_check
|
||||
from ee.danswer.background.celery_utils import (
|
||||
should_perform_external_doc_permissions_check,
|
||||
)
|
||||
from ee.danswer.background.celery_utils import (
|
||||
should_perform_external_group_permissions_check,
|
||||
)
|
||||
from ee.danswer.background.task_name_builders import name_chat_ttl_task
|
||||
from ee.danswer.background.task_name_builders import (
|
||||
name_sync_external_group_permissions_task,
|
||||
)
|
||||
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.danswer.external_permissions.permission_sync import (
|
||||
run_external_doc_permission_sync,
|
||||
)
|
||||
from ee.danswer.external_permissions.permission_sync import (
|
||||
run_external_group_permission_sync,
|
||||
)
|
||||
from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
@@ -34,25 +14,6 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
logger = setup_logger()
|
||||
|
||||
# mark as EE for all tasks in this file
|
||||
global_version.set_ee()
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_sync_external_doc_permissions_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def sync_external_doc_permissions_task(
|
||||
cc_pair_id: int, *, tenant_id: str | None
|
||||
) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
run_external_doc_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_sync_external_group_permissions_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def sync_external_group_permissions_task(
|
||||
cc_pair_id: int, *, tenant_id: str | None
|
||||
) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
run_external_group_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_chat_ttl_task)
|
||||
@@ -67,38 +28,6 @@ def perform_ttl_management_task(
|
||||
#####
|
||||
# Periodic Tasks
|
||||
#####
|
||||
@celery_app.task(
|
||||
name="check_sync_external_doc_permissions_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_sync_external_doc_permissions_task(*, tenant_id: str | None) -> None:
|
||||
"""Runs periodically to sync external permissions"""
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
if should_perform_external_doc_permissions_check(
|
||||
cc_pair=cc_pair, db_session=db_session
|
||||
):
|
||||
sync_external_doc_permissions_task.apply_async(
|
||||
kwargs=dict(cc_pair_id=cc_pair.id, tenant_id=tenant_id),
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="check_sync_external_group_permissions_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_sync_external_group_permissions_task(*, tenant_id: str | None) -> None:
|
||||
"""Runs periodically to sync external group permissions"""
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
if should_perform_external_group_permissions_check(
|
||||
cc_pair=cc_pair, db_session=db_session
|
||||
):
|
||||
sync_external_group_permissions_task.apply_async(
|
||||
kwargs=dict(cc_pair_id=cc_pair.id, tenant_id=tenant_id),
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
|
||||
@@ -6,16 +6,6 @@ from danswer.background.celery.tasks.beat_schedule import (
|
||||
)
|
||||
|
||||
ee_tasks_to_schedule = [
|
||||
{
|
||||
"name": "sync-external-doc-permissions",
|
||||
"task": "check_sync_external_doc_permissions_task",
|
||||
"schedule": timedelta(seconds=30), # TODO: optimize this
|
||||
},
|
||||
{
|
||||
"name": "sync-external-group-permissions",
|
||||
"task": "check_sync_external_group_permissions_task",
|
||||
"schedule": timedelta(seconds=60), # TODO: optimize this
|
||||
},
|
||||
{
|
||||
"name": "autogenerate_usage_report",
|
||||
"task": "autogenerate_usage_report_task",
|
||||
|
||||
@@ -1,46 +1,13 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.task_name_builders import (
|
||||
name_sync_external_doc_permissions_task,
|
||||
)
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.background.task_name_builders import name_chat_ttl_task
|
||||
from ee.danswer.background.task_name_builders import (
|
||||
name_sync_external_group_permissions_task,
|
||||
)
|
||||
from ee.danswer.external_permissions.sync_params import PERMISSION_SYNC_PERIODS
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _is_time_to_run_sync(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
source_sync_period = PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
|
||||
|
||||
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
|
||||
if not source_sync_period:
|
||||
return True
|
||||
|
||||
# If the last sync is None, it has never been run so we run the sync
|
||||
if cc_pair.last_time_perm_sync is None:
|
||||
return True
|
||||
|
||||
last_sync = cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc)
|
||||
current_time = datetime.now(timezone.utc)
|
||||
|
||||
# If the last sync is greater than the full fetch period, we run the sync
|
||||
if (current_time - last_sync).total_seconds() > source_sync_period:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def should_perform_chat_ttl_check(
|
||||
retention_limit_days: int | None, db_session: Session
|
||||
) -> bool:
|
||||
@@ -57,47 +24,3 @@ def should_perform_chat_ttl_check(
|
||||
logger.debug(f"{task_name} is already being performed. Skipping.")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def should_perform_external_doc_permissions_check(
|
||||
cc_pair: ConnectorCredentialPair, db_session: Session
|
||||
) -> bool:
|
||||
if cc_pair.access_type != AccessType.SYNC:
|
||||
return False
|
||||
|
||||
task_name = name_sync_external_doc_permissions_task(cc_pair_id=cc_pair.id)
|
||||
|
||||
latest_task = get_latest_task(task_name, db_session)
|
||||
if not latest_task:
|
||||
return True
|
||||
|
||||
if check_task_is_live_and_not_timed_out(latest_task, db_session):
|
||||
logger.debug(f"{task_name} is already being performed. Skipping.")
|
||||
return False
|
||||
|
||||
if not _is_time_to_run_sync(cc_pair):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def should_perform_external_group_permissions_check(
|
||||
cc_pair: ConnectorCredentialPair, db_session: Session
|
||||
) -> bool:
|
||||
if cc_pair.access_type != AccessType.SYNC:
|
||||
return False
|
||||
|
||||
task_name = name_sync_external_group_permissions_task(cc_pair_id=cc_pair.id)
|
||||
|
||||
latest_task = get_latest_task(task_name, db_session)
|
||||
if not latest_task:
|
||||
return True
|
||||
|
||||
if check_task_is_live_and_not_timed_out(latest_task, db_session):
|
||||
logger.debug(f"{task_name} is already being performed. Skipping.")
|
||||
return False
|
||||
|
||||
if not _is_time_to_run_sync(cc_pair):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@@ -1,8 +1,2 @@
|
||||
def name_chat_ttl_task(retention_limit_days: int, tenant_id: str | None = None) -> str:
|
||||
return f"chat_ttl_{retention_limit_days}_days"
|
||||
|
||||
|
||||
def name_sync_external_group_permissions_task(
|
||||
cc_pair_id: int, tenant_id: str | None = None
|
||||
) -> str:
|
||||
return f"sync_external_group_permissions_task__{cc_pair_id}"
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -45,3 +48,53 @@ def upsert_document_external_perms__no_commit(
|
||||
document.external_user_emails = list(external_access.external_user_emails)
|
||||
document.external_user_group_ids = prefixed_external_groups
|
||||
document.is_public = external_access.is_public
|
||||
|
||||
|
||||
def upsert_document_external_perms(
|
||||
db_session: Session,
|
||||
doc_id: str,
|
||||
external_access: ExternalAccess,
|
||||
source_type: DocumentSource,
|
||||
) -> None:
|
||||
"""
|
||||
This sets the permissions for a document in postgres.
|
||||
NOTE: this will replace any existing external access, it will not do a union
|
||||
"""
|
||||
document = db_session.scalars(
|
||||
select(DbDocument).where(DbDocument.id == doc_id)
|
||||
).first()
|
||||
|
||||
prefixed_external_groups: set[str] = {
|
||||
prefix_group_w_source(
|
||||
ext_group_name=group_id,
|
||||
source=source_type,
|
||||
)
|
||||
for group_id in external_access.external_user_group_ids
|
||||
}
|
||||
|
||||
if not document:
|
||||
# If the document does not exist, still store the external access
|
||||
# So that if the document is added later, the external access is already stored
|
||||
# The upsert function in the indexing pipeline does not overwrite the permissions fields
|
||||
document = DbDocument(
|
||||
id=doc_id,
|
||||
semantic_id="",
|
||||
external_user_emails=external_access.external_user_emails,
|
||||
external_user_group_ids=prefixed_external_groups,
|
||||
is_public=external_access.is_public,
|
||||
)
|
||||
db_session.add(document)
|
||||
db_session.commit()
|
||||
return
|
||||
|
||||
# If the document exists, we need to check if the external access has changed
|
||||
if (
|
||||
external_access.external_user_emails != set(document.external_user_emails or [])
|
||||
or prefixed_external_groups != set(document.external_user_group_ids or [])
|
||||
or external_access.is_public != document.is_public
|
||||
):
|
||||
document.external_user_emails = list(external_access.external_user_emails)
|
||||
document.external_user_group_ids = list(prefixed_external_groups)
|
||||
document.is_public = external_access.is_public
|
||||
document.last_modified = datetime.now(timezone.utc)
|
||||
db_session.commit()
|
||||
|
||||
@@ -9,11 +9,12 @@ from sqlalchemy.orm import Session
|
||||
from danswer.access.utils import prefix_group_w_source
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.models import User__ExternalUserGroupId
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
|
||||
|
||||
class ExternalUserGroup(BaseModel):
|
||||
id: str
|
||||
user_ids: list[UUID]
|
||||
user_emails: list[str]
|
||||
|
||||
|
||||
def delete_user__ext_group_for_user__no_commit(
|
||||
@@ -38,7 +39,7 @@ def delete_user__ext_group_for_cc_pair__no_commit(
|
||||
)
|
||||
|
||||
|
||||
def replace_user__ext_group_for_cc_pair__no_commit(
|
||||
def replace_user__ext_group_for_cc_pair(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
group_defs: list[ExternalUserGroup],
|
||||
@@ -46,24 +47,44 @@ def replace_user__ext_group_for_cc_pair__no_commit(
|
||||
) -> None:
|
||||
"""
|
||||
This function clears all existing external user group relations for a given cc_pair_id
|
||||
and replaces them with the new group definitions.
|
||||
and replaces them with the new group definitions and commits the changes.
|
||||
"""
|
||||
delete_user__ext_group_for_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
new_external_permissions = [
|
||||
User__ExternalUserGroupId(
|
||||
user_id=user_id,
|
||||
external_user_group_id=prefix_group_w_source(external_group.id, source),
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
for external_group in group_defs
|
||||
for user_id in external_group.user_ids
|
||||
]
|
||||
# collect all emails from all groups to batch add all users at once for efficiency
|
||||
all_group_member_emails = set()
|
||||
for external_group in group_defs:
|
||||
for user_email in external_group.user_emails:
|
||||
all_group_member_emails.add(user_email)
|
||||
|
||||
# batch add users if they don't exist and get their ids
|
||||
all_group_members = batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session=db_session, emails=list(all_group_member_emails)
|
||||
)
|
||||
|
||||
# map emails to ids
|
||||
email_id_map = {user.email: user.id for user in all_group_members}
|
||||
|
||||
# use these ids to create new external user group relations relating group_id to user_ids
|
||||
new_external_permissions = []
|
||||
for external_group in group_defs:
|
||||
for user_email in external_group.user_emails:
|
||||
user_id = email_id_map[user_email]
|
||||
new_external_permissions.append(
|
||||
User__ExternalUserGroupId(
|
||||
user_id=user_id,
|
||||
external_user_group_id=prefix_group_w_source(
|
||||
external_group.id, source
|
||||
),
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
)
|
||||
|
||||
db_session.add_all(new_external_permissions)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def fetch_external_groups_for_user(
|
||||
|
||||
@@ -124,16 +124,21 @@ def _cleanup_document_set__user_group_relationships__no_commit(
|
||||
def validate_user_creation_permissions(
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
target_group_ids: list[int] | None,
|
||||
object_is_public: bool | None,
|
||||
target_group_ids: list[int] | None = None,
|
||||
object_is_public: bool | None = None,
|
||||
object_is_perm_sync: bool | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
All users can create/edit permission synced objects if they don't specify a group
|
||||
All admin actions are allowed.
|
||||
Prevents non-admins from creating/editing:
|
||||
- public objects
|
||||
- objects with no groups
|
||||
- objects that belong to a group they don't curate
|
||||
"""
|
||||
if object_is_perm_sync and not target_group_ids:
|
||||
return
|
||||
|
||||
if not user or user.role == UserRole.ADMIN:
|
||||
return
|
||||
|
||||
|
||||
@@ -4,17 +4,14 @@ https://confluence.atlassian.com/conf85/check-who-can-view-a-page-1283360557.htm
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import DocExternalAccess
|
||||
from danswer.access.models import ExternalAccess
|
||||
from danswer.connectors.confluence.connector import ConfluenceConnector
|
||||
from danswer.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from danswer.connectors.confluence.utils import get_user_email_from_username__server
|
||||
from danswer.connectors.models import SlimDocument
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.document import upsert_document_external_perms__no_commit
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -163,7 +160,13 @@ def _extract_read_access_restrictions(
|
||||
f"Email for user {user['username']} not found in Confluence"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"User {user} does not have an email or username")
|
||||
if user.get("email") is not None:
|
||||
logger.warning(f"Cant find email for user {user.get('displayName')}")
|
||||
logger.warning(
|
||||
"This user needs to make their email accessible in Confluence Settings"
|
||||
)
|
||||
|
||||
logger.warning(f"no user email or username for {user}")
|
||||
|
||||
# Extract the groups with read access
|
||||
read_access_group = read_access_restrictions.get("group", {})
|
||||
@@ -190,12 +193,12 @@ def _fetch_all_page_restrictions_for_space(
|
||||
confluence_client: OnyxConfluence,
|
||||
slim_docs: list[SlimDocument],
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess],
|
||||
) -> dict[str, ExternalAccess]:
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
For all pages, if a page has restrictions, then use those restrictions.
|
||||
Otherwise, use the space's restrictions.
|
||||
"""
|
||||
document_restrictions: dict[str, ExternalAccess] = {}
|
||||
document_restrictions: list[DocExternalAccess] = []
|
||||
|
||||
for slim_doc in slim_docs:
|
||||
if slim_doc.perm_sync_data is None:
|
||||
@@ -207,21 +210,34 @@ def _fetch_all_page_restrictions_for_space(
|
||||
restrictions=slim_doc.perm_sync_data.get("restrictions", {}),
|
||||
)
|
||||
if restrictions:
|
||||
document_restrictions[slim_doc.id] = restrictions
|
||||
else:
|
||||
space_key = slim_doc.perm_sync_data.get("space_key")
|
||||
if space_permissions := space_permissions_by_space_key.get(space_key):
|
||||
document_restrictions[slim_doc.id] = space_permissions
|
||||
else:
|
||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||
document_restrictions.append(
|
||||
DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=restrictions,
|
||||
)
|
||||
)
|
||||
# If there are restrictions, then we don't need to use the space's restrictions
|
||||
continue
|
||||
|
||||
space_key = slim_doc.perm_sync_data.get("space_key")
|
||||
if space_permissions := space_permissions_by_space_key.get(space_key):
|
||||
# If there are no restrictions, then use the space's restrictions
|
||||
document_restrictions.append(
|
||||
DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=space_permissions,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||
|
||||
return document_restrictions
|
||||
|
||||
|
||||
def confluence_doc_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> None:
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
if the document doesn't already exists in postgres, we create
|
||||
@@ -247,20 +263,8 @@ def confluence_doc_sync(
|
||||
for doc_batch in confluence_connector.retrieve_all_slim_documents():
|
||||
slim_docs.extend(doc_batch)
|
||||
|
||||
permissions_by_doc_id = _fetch_all_page_restrictions_for_space(
|
||||
return _fetch_all_page_restrictions_for_space(
|
||||
confluence_client=confluence_client,
|
||||
slim_docs=slim_docs,
|
||||
space_permissions_by_space_key=space_permissions_by_space_key,
|
||||
)
|
||||
|
||||
all_emails = set()
|
||||
for doc_id, page_specific_access in permissions_by_doc_id.items():
|
||||
upsert_document_external_perms__no_commit(
|
||||
db_session=db_session,
|
||||
doc_id=doc_id,
|
||||
external_access=page_specific_access,
|
||||
source_type=cc_pair.connector.source,
|
||||
)
|
||||
all_emails.update(page_specific_access.external_user_emails)
|
||||
|
||||
batch_add_non_web_user_if_not_exists__no_commit(db_session, list(all_emails))
|
||||
|
||||
@@ -1,15 +1,11 @@
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from danswer.connectors.confluence.utils import build_confluence_client
|
||||
from danswer.connectors.confluence.utils import get_user_email_from_username__server
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.external_perm import ExternalUserGroup
|
||||
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -40,9 +36,8 @@ def _get_group_members_email_paginated(
|
||||
|
||||
|
||||
def confluence_group_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> None:
|
||||
) -> list[ExternalUserGroup]:
|
||||
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
|
||||
confluence_client = build_confluence_client(
|
||||
credentials_json=cc_pair.credential.credential_json,
|
||||
@@ -63,20 +58,13 @@ def confluence_group_sync(
|
||||
group_member_emails = _get_group_members_email_paginated(
|
||||
confluence_client, group_name
|
||||
)
|
||||
group_members = batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session=db_session, emails=list(group_member_emails)
|
||||
)
|
||||
if group_members:
|
||||
danswer_groups.append(
|
||||
ExternalUserGroup(
|
||||
id=group_name,
|
||||
user_ids=[user.id for user in group_members],
|
||||
)
|
||||
if not group_member_emails:
|
||||
continue
|
||||
danswer_groups.append(
|
||||
ExternalUserGroup(
|
||||
id=group_name,
|
||||
user_emails=list(group_member_emails),
|
||||
)
|
||||
)
|
||||
|
||||
replace_user__ext_group_for_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
group_defs=danswer_groups,
|
||||
source=cc_pair.connector.source,
|
||||
)
|
||||
return danswer_groups
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import DocExternalAccess
|
||||
from danswer.access.models import ExternalAccess
|
||||
from danswer.connectors.gmail.connector import GmailConnector
|
||||
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.document import upsert_document_external_perms__no_commit
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -31,9 +28,8 @@ def _get_slim_doc_generator(
|
||||
|
||||
|
||||
def gmail_doc_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> None:
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
if the document doesn't already exists in postgres, we create
|
||||
@@ -45,6 +41,7 @@ def gmail_doc_sync(
|
||||
|
||||
slim_doc_generator = _get_slim_doc_generator(cc_pair, gmail_connector)
|
||||
|
||||
document_external_access: list[DocExternalAccess] = []
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
for slim_doc in slim_doc_batch:
|
||||
if slim_doc.perm_sync_data is None:
|
||||
@@ -56,13 +53,11 @@ def gmail_doc_sync(
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
)
|
||||
batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session=db_session,
|
||||
emails=list(ext_access.external_user_emails),
|
||||
)
|
||||
upsert_document_external_perms__no_commit(
|
||||
db_session=db_session,
|
||||
doc_id=slim_doc.id,
|
||||
external_access=ext_access,
|
||||
source_type=cc_pair.connector.source,
|
||||
document_external_access.append(
|
||||
DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=ext_access,
|
||||
)
|
||||
)
|
||||
|
||||
return document_external_access
|
||||
|
||||
@@ -2,8 +2,7 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import DocExternalAccess
|
||||
from danswer.access.models import ExternalAccess
|
||||
from danswer.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from danswer.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
@@ -11,9 +10,7 @@ from danswer.connectors.google_utils.resources import get_drive_service
|
||||
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from danswer.connectors.models import SlimDocument
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.document import upsert_document_external_perms__no_commit
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -113,8 +110,13 @@ def _get_permissions_from_slim_doc(
|
||||
elif permission_type == "group":
|
||||
group_emails.add(permission["emailAddress"])
|
||||
elif permission_type == "domain" and company_domain:
|
||||
if permission["domain"] == company_domain:
|
||||
if permission.get("domain") == company_domain:
|
||||
public = True
|
||||
else:
|
||||
logger.warning(
|
||||
"Permission is type domain but does not match company domain:"
|
||||
f"\n {permission}"
|
||||
)
|
||||
elif permission_type == "anyone":
|
||||
public = True
|
||||
|
||||
@@ -126,9 +128,8 @@ def _get_permissions_from_slim_doc(
|
||||
|
||||
|
||||
def gdrive_doc_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> None:
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
if the document doesn't already exists in postgres, we create
|
||||
@@ -142,19 +143,17 @@ def gdrive_doc_sync(
|
||||
|
||||
slim_doc_generator = _get_slim_doc_generator(cc_pair, google_drive_connector)
|
||||
|
||||
document_external_accesses = []
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
for slim_doc in slim_doc_batch:
|
||||
ext_access = _get_permissions_from_slim_doc(
|
||||
google_drive_connector=google_drive_connector,
|
||||
slim_doc=slim_doc,
|
||||
)
|
||||
batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session=db_session,
|
||||
emails=list(ext_access.external_user_emails),
|
||||
)
|
||||
upsert_document_external_perms__no_commit(
|
||||
db_session=db_session,
|
||||
doc_id=slim_doc.id,
|
||||
external_access=ext_access,
|
||||
source_type=cc_pair.connector.source,
|
||||
document_external_accesses.append(
|
||||
DocExternalAccess(
|
||||
external_access=ext_access,
|
||||
doc_id=slim_doc.id,
|
||||
)
|
||||
)
|
||||
return document_external_accesses
|
||||
|
||||
@@ -1,21 +1,16 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from danswer.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
from danswer.connectors.google_utils.resources import get_admin_service
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.external_perm import ExternalUserGroup
|
||||
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def gdrive_group_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> None:
|
||||
) -> list[ExternalUserGroup]:
|
||||
google_drive_connector = GoogleDriveConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
@@ -44,20 +39,14 @@ def gdrive_group_sync(
|
||||
):
|
||||
group_member_emails.append(member["email"])
|
||||
|
||||
# Add group members to DB and get their IDs
|
||||
group_members = batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session=db_session, emails=group_member_emails
|
||||
)
|
||||
if group_members:
|
||||
danswer_groups.append(
|
||||
ExternalUserGroup(
|
||||
id=group_email, user_ids=[user.id for user in group_members]
|
||||
)
|
||||
)
|
||||
if not group_member_emails:
|
||||
continue
|
||||
|
||||
replace_user__ext_group_for_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
group_defs=danswer_groups,
|
||||
source=cc_pair.connector.source,
|
||||
)
|
||||
danswer_groups.append(
|
||||
ExternalUserGroup(
|
||||
id=group_email,
|
||||
user_emails=list(group_member_emails),
|
||||
)
|
||||
)
|
||||
|
||||
return danswer_groups
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_documents
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import get_document_ids_for_connector_credential_pair
|
||||
from danswer.document_index.factory import get_current_primary_default_document_index
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
|
||||
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def run_external_group_permission_sync(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
) -> None:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"No connector credential pair found for id: {cc_pair_id}")
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
|
||||
if group_sync_func is None:
|
||||
# Not all sync connectors support group permissions so this is fine
|
||||
return
|
||||
|
||||
try:
|
||||
# This function updates:
|
||||
# - the user_email <-> external_user_group_id mapping
|
||||
# in postgres without committing
|
||||
logger.debug(f"Syncing groups for {source_type}")
|
||||
if group_sync_func is not None:
|
||||
group_sync_func(
|
||||
db_session,
|
||||
cc_pair,
|
||||
)
|
||||
|
||||
# update postgres
|
||||
db_session.commit()
|
||||
except Exception:
|
||||
logger.exception("Error Syncing Group Permissions")
|
||||
db_session.rollback()
|
||||
|
||||
|
||||
def run_external_doc_permission_sync(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
) -> None:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"No connector credential pair found for id: {cc_pair_id}")
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
|
||||
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
last_time_perm_sync = cc_pair.last_time_perm_sync
|
||||
|
||||
if doc_sync_func is None:
|
||||
raise ValueError(
|
||||
f"No permission sync function found for source type: {source_type}"
|
||||
)
|
||||
|
||||
try:
|
||||
# This function updates:
|
||||
# - the user_email <-> document mapping
|
||||
# - the external_user_group_id <-> document mapping
|
||||
# in postgres without committing
|
||||
logger.info(f"Syncing docs for {source_type}")
|
||||
doc_sync_func(
|
||||
db_session,
|
||||
cc_pair,
|
||||
)
|
||||
|
||||
# Get the document ids for the cc pair
|
||||
document_ids_for_cc_pair = get_document_ids_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
|
||||
# This function fetches the updated access for the documents
|
||||
# and returns a dictionary of document_ids and access
|
||||
# This is the access we want to update vespa with
|
||||
docs_access = get_access_for_documents(
|
||||
document_ids=document_ids_for_cc_pair,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Then we build the update requests to update vespa
|
||||
update_reqs = [
|
||||
UpdateRequest(document_ids=[doc_id], access=doc_access)
|
||||
for doc_id, doc_access in docs_access.items()
|
||||
]
|
||||
|
||||
# Don't bother sync-ing secondary, it will be sync-ed after switch anyway
|
||||
document_index = get_current_primary_default_document_index(db_session)
|
||||
|
||||
# update vespa
|
||||
document_index.update(update_reqs)
|
||||
|
||||
cc_pair.last_time_perm_sync = datetime.now(timezone.utc)
|
||||
|
||||
# update postgres
|
||||
db_session.commit()
|
||||
logger.info(f"Successfully synced docs for {source_type}")
|
||||
except Exception:
|
||||
logger.exception("Error Syncing Document Permissions")
|
||||
cc_pair.last_time_perm_sync = last_time_perm_sync
|
||||
db_session.rollback()
|
||||
@@ -1,16 +1,12 @@
|
||||
from slack_sdk import WebClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import DocExternalAccess
|
||||
from danswer.access.models import ExternalAccess
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.interfaces import SlimConnector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.connectors.slack.connector import get_channels
|
||||
from danswer.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
from danswer.connectors.slack.connector import SlackPollConnector
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.document import upsert_document_external_perms__no_commit
|
||||
from ee.danswer.external_permissions.slack.utils import fetch_user_id_to_email_map
|
||||
|
||||
|
||||
@@ -18,22 +14,15 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def _get_slack_document_ids_and_channels(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> dict[str, list[str]]:
|
||||
# Get all document ids that need their permissions updated
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session=db_session,
|
||||
source=cc_pair.connector.source,
|
||||
input_type=InputType.SLIM_RETRIEVAL,
|
||||
connector_specific_config=cc_pair.connector.connector_specific_config,
|
||||
credential=cc_pair.credential,
|
||||
)
|
||||
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
assert isinstance(runnable_connector, SlimConnector)
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_documents()
|
||||
|
||||
channel_doc_map: dict[str, list[str]] = {}
|
||||
for doc_metadata_batch in runnable_connector.retrieve_all_slim_documents():
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
for doc_metadata in doc_metadata_batch:
|
||||
if doc_metadata.perm_sync_data is None:
|
||||
continue
|
||||
@@ -46,13 +35,11 @@ def _get_slack_document_ids_and_channels(
|
||||
|
||||
|
||||
def _fetch_workspace_permissions(
|
||||
db_session: Session,
|
||||
user_id_to_email_map: dict[str, str],
|
||||
) -> ExternalAccess:
|
||||
user_emails = set()
|
||||
for email in user_id_to_email_map.values():
|
||||
user_emails.add(email)
|
||||
batch_add_non_web_user_if_not_exists__no_commit(db_session, list(user_emails))
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
# No group<->document mapping for slack
|
||||
@@ -63,7 +50,6 @@ def _fetch_workspace_permissions(
|
||||
|
||||
|
||||
def _fetch_channel_permissions(
|
||||
db_session: Session,
|
||||
slack_client: WebClient,
|
||||
workspace_permissions: ExternalAccess,
|
||||
user_id_to_email_map: dict[str, str],
|
||||
@@ -113,9 +99,6 @@ def _fetch_channel_permissions(
|
||||
# If no email is found, we skip the user
|
||||
continue
|
||||
user_id_to_email_map[member_id] = member_email
|
||||
batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session, [member_email]
|
||||
)
|
||||
|
||||
member_emails.add(member_email)
|
||||
|
||||
@@ -131,9 +114,8 @@ def _fetch_channel_permissions(
|
||||
|
||||
|
||||
def slack_doc_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> None:
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
if the document doesn't already exists in postgres, we create
|
||||
@@ -145,19 +127,18 @@ def slack_doc_sync(
|
||||
)
|
||||
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
|
||||
channel_doc_map = _get_slack_document_ids_and_channels(
|
||||
db_session=db_session,
|
||||
cc_pair=cc_pair,
|
||||
)
|
||||
workspace_permissions = _fetch_workspace_permissions(
|
||||
db_session=db_session,
|
||||
user_id_to_email_map=user_id_to_email_map,
|
||||
)
|
||||
channel_permissions = _fetch_channel_permissions(
|
||||
db_session=db_session,
|
||||
slack_client=slack_client,
|
||||
workspace_permissions=workspace_permissions,
|
||||
user_id_to_email_map=user_id_to_email_map,
|
||||
)
|
||||
|
||||
document_external_accesses = []
|
||||
for channel_id, ext_access in channel_permissions.items():
|
||||
doc_ids = channel_doc_map.get(channel_id)
|
||||
if not doc_ids:
|
||||
@@ -165,9 +146,10 @@ def slack_doc_sync(
|
||||
continue
|
||||
|
||||
for doc_id in doc_ids:
|
||||
upsert_document_external_perms__no_commit(
|
||||
db_session=db_session,
|
||||
doc_id=doc_id,
|
||||
external_access=ext_access,
|
||||
source_type=cc_pair.connector.source,
|
||||
document_external_accesses.append(
|
||||
DocExternalAccess(
|
||||
external_access=ext_access,
|
||||
doc_id=doc_id,
|
||||
)
|
||||
)
|
||||
return document_external_accesses
|
||||
|
||||
@@ -5,14 +5,11 @@ SO WHEN CHECKING IF A USER CAN ACCESS A DOCUMENT, WE ONLY NEED TO CHECK THEIR EM
|
||||
THERE IS NO USERGROUP <-> DOCUMENT PERMISSION MAPPING
|
||||
"""
|
||||
from slack_sdk import WebClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_non_web_user_if_not_exists__no_commit
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.external_perm import ExternalUserGroup
|
||||
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair__no_commit
|
||||
from ee.danswer.external_permissions.slack.utils import fetch_user_id_to_email_map
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -29,7 +26,6 @@ def _get_slack_group_ids(
|
||||
|
||||
|
||||
def _get_slack_group_members_email(
|
||||
db_session: Session,
|
||||
slack_client: WebClient,
|
||||
group_name: str,
|
||||
user_id_to_email_map: dict[str, str],
|
||||
@@ -49,18 +45,14 @@ def _get_slack_group_members_email(
|
||||
# If no email is found, we skip the user
|
||||
continue
|
||||
user_id_to_email_map[member_id] = member_email
|
||||
batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session, [member_email]
|
||||
)
|
||||
group_member_emails.append(member_email)
|
||||
|
||||
return group_member_emails
|
||||
|
||||
|
||||
def slack_group_sync(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> None:
|
||||
) -> list[ExternalUserGroup]:
|
||||
slack_client = WebClient(
|
||||
token=cc_pair.credential.credential_json["slack_bot_token"]
|
||||
)
|
||||
@@ -69,24 +61,13 @@ def slack_group_sync(
|
||||
danswer_groups: list[ExternalUserGroup] = []
|
||||
for group_name in _get_slack_group_ids(slack_client):
|
||||
group_member_emails = _get_slack_group_members_email(
|
||||
db_session=db_session,
|
||||
slack_client=slack_client,
|
||||
group_name=group_name,
|
||||
user_id_to_email_map=user_id_to_email_map,
|
||||
)
|
||||
group_members = batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session=db_session, emails=group_member_emails
|
||||
if not group_member_emails:
|
||||
continue
|
||||
danswer_groups.append(
|
||||
ExternalUserGroup(id=group_name, user_emails=group_member_emails)
|
||||
)
|
||||
if group_members:
|
||||
danswer_groups.append(
|
||||
ExternalUserGroup(
|
||||
id=group_name, user_ids=[user.id for user in group_members]
|
||||
)
|
||||
)
|
||||
|
||||
replace_user__ext_group_for_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
group_defs=danswer_groups,
|
||||
source=cc_pair.connector.source,
|
||||
)
|
||||
return danswer_groups
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import DocExternalAccess
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from ee.danswer.db.external_perm import ExternalUserGroup
|
||||
from ee.danswer.external_permissions.confluence.doc_sync import confluence_doc_sync
|
||||
from ee.danswer.external_permissions.confluence.group_sync import confluence_group_sync
|
||||
from ee.danswer.external_permissions.gmail.doc_sync import gmail_doc_sync
|
||||
@@ -12,12 +12,18 @@ from ee.danswer.external_permissions.google_drive.group_sync import gdrive_group
|
||||
from ee.danswer.external_permissions.slack.doc_sync import slack_doc_sync
|
||||
|
||||
# Defining the input/output types for the sync functions
|
||||
SyncFuncType = Callable[
|
||||
DocSyncFuncType = Callable[
|
||||
[
|
||||
Session,
|
||||
ConnectorCredentialPair,
|
||||
],
|
||||
None,
|
||||
list[DocExternalAccess],
|
||||
]
|
||||
|
||||
GroupSyncFuncType = Callable[
|
||||
[
|
||||
ConnectorCredentialPair,
|
||||
],
|
||||
list[ExternalUserGroup],
|
||||
]
|
||||
|
||||
# These functions update:
|
||||
@@ -25,7 +31,7 @@ SyncFuncType = Callable[
|
||||
# - the external_user_group_id <-> document mapping
|
||||
# in postgres without committing
|
||||
# THIS ONE IS NECESSARY FOR AUTO SYNC TO WORK
|
||||
DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, SyncFuncType] = {
|
||||
DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, DocSyncFuncType] = {
|
||||
DocumentSource.GOOGLE_DRIVE: gdrive_doc_sync,
|
||||
DocumentSource.CONFLUENCE: confluence_doc_sync,
|
||||
DocumentSource.SLACK: slack_doc_sync,
|
||||
@@ -36,19 +42,21 @@ DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, SyncFuncType] = {
|
||||
# - the user_email <-> external_user_group_id mapping
|
||||
# in postgres without committing
|
||||
# THIS ONE IS OPTIONAL ON AN APP BY APP BASIS
|
||||
GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, SyncFuncType] = {
|
||||
GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, GroupSyncFuncType] = {
|
||||
DocumentSource.GOOGLE_DRIVE: gdrive_group_sync,
|
||||
DocumentSource.CONFLUENCE: confluence_group_sync,
|
||||
}
|
||||
|
||||
|
||||
# If nothing is specified here, we run the doc_sync every time the celery beat runs
|
||||
PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
# Polling is not supported so we fetch all doc permissions every 5 minutes
|
||||
DocumentSource.CONFLUENCE: 5 * 60,
|
||||
DocumentSource.SLACK: 5 * 60,
|
||||
}
|
||||
|
||||
EXTERNAL_GROUP_SYNC_PERIOD: int = 30 # 30 seconds
|
||||
|
||||
|
||||
def check_if_valid_sync_source(source_type: DocumentSource) -> bool:
|
||||
return source_type in DOC_PERMISSIONS_FUNC_MAP
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from fastapi import FastAPI
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
from httpx_oauth.clients.openid import OpenID
|
||||
|
||||
from danswer.auth.users import auth_backend
|
||||
@@ -59,6 +60,31 @@ def get_application() -> FastAPI:
|
||||
if MULTI_TENANT:
|
||||
add_tenant_id_middleware(application, logger)
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
create_danswer_oauth_router(
|
||||
oauth_client,
|
||||
auth_backend,
|
||||
USER_AUTH_SECRET,
|
||||
associate_by_email=True,
|
||||
is_verified_by_default=True,
|
||||
# Points the user back to the login page
|
||||
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
|
||||
),
|
||||
prefix="/auth/oauth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
# Need basic auth router for `logout` endpoint
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
fastapi_users.get_logout_router(auth_backend),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.OIDC:
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
@@ -73,6 +99,7 @@ def get_application() -> FastAPI:
|
||||
prefix="/auth/oidc",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
# need basic auth router for `logout` endpoint
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
|
||||
45
backend/ee/danswer/seeding/load_docs.py
Normal file
45
backend/ee/danswer/seeding/load_docs.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import json
|
||||
import os
|
||||
from typing import cast
|
||||
from typing import List
|
||||
|
||||
from cohere import Client
|
||||
|
||||
from ee.danswer.configs.app_configs import COHERE_DEFAULT_API_KEY
|
||||
|
||||
Embedding = List[float]
|
||||
|
||||
|
||||
def load_processed_docs(cohere_enabled: bool) -> list[dict]:
|
||||
base_path = os.path.join(os.getcwd(), "danswer", "seeding")
|
||||
|
||||
if cohere_enabled and COHERE_DEFAULT_API_KEY:
|
||||
initial_docs_path = os.path.join(base_path, "initial_docs_cohere.json")
|
||||
processed_docs = json.load(open(initial_docs_path))
|
||||
|
||||
cohere_client = Client(api_key=COHERE_DEFAULT_API_KEY)
|
||||
embed_model = "embed-english-v3.0"
|
||||
|
||||
for doc in processed_docs:
|
||||
title_embed_response = cohere_client.embed(
|
||||
texts=[doc["title"]],
|
||||
model=embed_model,
|
||||
input_type="search_document",
|
||||
)
|
||||
content_embed_response = cohere_client.embed(
|
||||
texts=[doc["content"]],
|
||||
model=embed_model,
|
||||
input_type="search_document",
|
||||
)
|
||||
|
||||
doc["title_embedding"] = cast(
|
||||
List[Embedding], title_embed_response.embeddings
|
||||
)[0]
|
||||
doc["content_embedding"] = cast(
|
||||
List[Embedding], content_embed_response.embeddings
|
||||
)[0]
|
||||
else:
|
||||
initial_docs_path = os.path.join(base_path, "initial_docs.json")
|
||||
processed_docs = json.load(open(initial_docs_path))
|
||||
|
||||
return processed_docs
|
||||
@@ -38,3 +38,4 @@ class ImpersonateRequest(BaseModel):
|
||||
class TenantCreationPayload(BaseModel):
|
||||
tenant_id: str
|
||||
email: str
|
||||
referral_source: str | None = None
|
||||
|
||||
@@ -4,6 +4,7 @@ import uuid
|
||||
|
||||
import aiohttp # Async HTTP client
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import exceptions
|
||||
@@ -13,6 +14,8 @@ from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.llm import update_default_provider
|
||||
from danswer.db.llm import upsert_cloud_embedding_provider
|
||||
from danswer.db.llm import upsert_llm_provider
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.models import UserTenantMapping
|
||||
from danswer.llm.llm_provider_options import ANTHROPIC_MODEL_NAMES
|
||||
from danswer.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
|
||||
@@ -41,7 +44,9 @@ from shared_configs.enums import EmbeddingProvider
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_or_create_tenant_id(email: str) -> str:
|
||||
async def get_or_create_tenant_id(
|
||||
email: str, referral_source: str | None = None
|
||||
) -> str:
|
||||
"""Get existing tenant ID for an email or create a new tenant if none exists."""
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
@@ -51,7 +56,7 @@ async def get_or_create_tenant_id(email: str) -> str:
|
||||
except exceptions.UserNotExists:
|
||||
# If tenant does not exist and in Multi tenant mode, provision a new tenant
|
||||
try:
|
||||
tenant_id = await create_tenant(email)
|
||||
tenant_id = await create_tenant(email, referral_source)
|
||||
except Exception as e:
|
||||
logger.error(f"Tenant provisioning failed: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
|
||||
@@ -64,13 +69,13 @@ async def get_or_create_tenant_id(email: str) -> str:
|
||||
return tenant_id
|
||||
|
||||
|
||||
async def create_tenant(email: str) -> str:
|
||||
async def create_tenant(email: str, referral_source: str | None = None) -> str:
|
||||
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
|
||||
try:
|
||||
# Provision tenant on data plane
|
||||
await provision_tenant(tenant_id, email)
|
||||
# Notify control plane
|
||||
await notify_control_plane(tenant_id, email)
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
except Exception as e:
|
||||
logger.error(f"Tenant provisioning failed: {e}")
|
||||
await rollback_tenant_provisioning(tenant_id)
|
||||
@@ -102,9 +107,19 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
|
||||
await asyncio.to_thread(run_alembic_migrations, tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
setup_danswer(db_session, tenant_id)
|
||||
configure_default_api_keys(db_session)
|
||||
|
||||
current_search_settings = (
|
||||
db_session.query(SearchSettings)
|
||||
.filter_by(status=IndexModelStatus.FUTURE)
|
||||
.first()
|
||||
)
|
||||
cohere_enabled = (
|
||||
current_search_settings is not None
|
||||
and current_search_settings.provider_type == EmbeddingProvider.COHERE
|
||||
)
|
||||
setup_danswer(db_session, tenant_id, cohere_enabled=cohere_enabled)
|
||||
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
except Exception as e:
|
||||
@@ -117,14 +132,18 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
async def notify_control_plane(tenant_id: str, email: str) -> None:
|
||||
async def notify_control_plane(
|
||||
tenant_id: str, email: str, referral_source: str | None = None
|
||||
) -> None:
|
||||
logger.info("Fetching billing information")
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = TenantCreationPayload(tenant_id=tenant_id, email=email)
|
||||
payload = TenantCreationPayload(
|
||||
tenant_id=tenant_id, email=email, referral_source=referral_source
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
@@ -200,11 +219,51 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
provider_type=EmbeddingProvider.COHERE,
|
||||
api_key=COHERE_DEFAULT_API_KEY,
|
||||
)
|
||||
|
||||
try:
|
||||
logger.info("Attempting to upsert Cohere cloud embedding provider")
|
||||
upsert_cloud_embedding_provider(db_session, cloud_embedding_provider)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure Cohere embedding provider: {e}")
|
||||
logger.info("Successfully upserted Cohere cloud embedding provider")
|
||||
|
||||
logger.info("Updating search settings with Cohere embedding model details")
|
||||
query = (
|
||||
select(SearchSettings)
|
||||
.where(SearchSettings.status == IndexModelStatus.FUTURE)
|
||||
.order_by(SearchSettings.id.desc())
|
||||
)
|
||||
result = db_session.execute(query)
|
||||
current_search_settings = result.scalars().first()
|
||||
|
||||
if current_search_settings:
|
||||
current_search_settings.model_name = (
|
||||
"embed-english-v3.0" # Cohere's latest model as of now
|
||||
)
|
||||
current_search_settings.model_dim = (
|
||||
1024 # Cohere's embed-english-v3.0 dimension
|
||||
)
|
||||
current_search_settings.provider_type = EmbeddingProvider.COHERE
|
||||
current_search_settings.index_name = (
|
||||
"danswer_chunk_cohere_embed_english_v3_0"
|
||||
)
|
||||
current_search_settings.query_prefix = ""
|
||||
current_search_settings.passage_prefix = ""
|
||||
db_session.commit()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"No search settings specified, DB is not in a valid state"
|
||||
)
|
||||
logger.info("Fetching updated search settings to verify changes")
|
||||
updated_query = (
|
||||
select(SearchSettings)
|
||||
.where(SearchSettings.status == IndexModelStatus.PRESENT)
|
||||
.order_by(SearchSettings.id.desc())
|
||||
)
|
||||
updated_result = db_session.execute(updated_query)
|
||||
updated_result.scalars().first()
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to configure Cohere embedding provider")
|
||||
else:
|
||||
logger.error(
|
||||
logger.info(
|
||||
"COHERE_DEFAULT_API_KEY not set, skipping Cohere embedding provider configuration"
|
||||
)
|
||||
|
||||
@@ -26,4 +26,5 @@ lxml==5.3.0
|
||||
lxml_html_clean==0.2.2
|
||||
boto3-stubs[s3]==1.34.133
|
||||
pandas==2.2.3
|
||||
pandas-stubs==2.2.3.241009
|
||||
pandas-stubs==2.2.3.241009
|
||||
cohere==5.6.1
|
||||
@@ -1 +1,2 @@
|
||||
python3-saml==1.15.0
|
||||
python3-saml==1.15.0
|
||||
cohere==5.6.1
|
||||
@@ -42,7 +42,7 @@ def run_jobs() -> None:
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
|
||||
]
|
||||
|
||||
cmd_worker_heavy = [
|
||||
@@ -56,7 +56,7 @@ def run_jobs() -> None:
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync",
|
||||
]
|
||||
|
||||
cmd_worker_indexing = [
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user