mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
2 Commits
text_view
...
loading_or
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ca418fdcf2 | ||
|
|
4a1230f028 |
@@ -24,8 +24,6 @@ env:
|
||||
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }}
|
||||
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
# Slab
|
||||
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
|
||||
@@ -73,7 +73,6 @@ RUN apt-get update && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
|
||||
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
RUN python -c "from tokenizers import Tokenizer; \
|
||||
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
"""add web ui option to slack config
|
||||
|
||||
Revision ID: 93560ba1b118
|
||||
Revises: 6d562f86c78b
|
||||
Create Date: 2024-11-24 06:36:17.490612
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "93560ba1b118"
|
||||
down_revision = "6d562f86c78b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add show_continue_in_web_ui with default False to all existing channel_configs
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE slack_channel_config
|
||||
SET channel_config = channel_config || '{"show_continue_in_web_ui": false}'::jsonb
|
||||
WHERE NOT channel_config ? 'show_continue_in_web_ui'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove show_continue_in_web_ui from all channel_configs
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE slack_channel_config
|
||||
SET channel_config = channel_config - 'show_continue_in_web_ui'
|
||||
"""
|
||||
)
|
||||
@@ -1,27 +0,0 @@
|
||||
"""add auto scroll to user model
|
||||
|
||||
Revision ID: a8c2065484e6
|
||||
Revises: abe7378b8217
|
||||
Create Date: 2024-11-22 17:34:09.690295
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a8c2065484e6"
|
||||
down_revision = "abe7378b8217"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("auto_scroll", sa.Boolean(), nullable=True, server_default=None),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "auto_scroll")
|
||||
@@ -1,30 +0,0 @@
|
||||
"""add indexing trigger to cc_pair
|
||||
|
||||
Revision ID: abe7378b8217
|
||||
Revises: 6d562f86c78b
|
||||
Create Date: 2024-11-26 19:09:53.481171
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "abe7378b8217"
|
||||
down_revision = "93560ba1b118"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"indexing_trigger",
|
||||
sa.Enum("UPDATE", "REINDEX", name="indexingmode", native_enum=False),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector_credential_pair", "indexing_trigger")
|
||||
@@ -23,9 +23,7 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
|
||||
)
|
||||
return UserPreferences(**preferences_data)
|
||||
except KvKeyNotFoundError:
|
||||
return UserPreferences(
|
||||
chosen_assistants=None, default_model=None, auto_scroll=True
|
||||
)
|
||||
return UserPreferences(chosen_assistants=None, default_model=None)
|
||||
|
||||
|
||||
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:
|
||||
|
||||
@@ -5,6 +5,7 @@ from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -36,7 +37,7 @@ class TaskDependencyError(RuntimeError):
|
||||
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -59,7 +60,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
try:
|
||||
try_generate_document_cc_pair_cleanup_tasks(
|
||||
self.app, cc_pair_id, db_session, lock_beat, tenant_id
|
||||
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
except TaskDependencyError as e:
|
||||
# this means we wanted to start deleting but dependent tasks were running
|
||||
@@ -85,6 +86,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
|
||||
@@ -8,7 +8,6 @@ from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from danswer.access.models import DocExternalAccess
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
@@ -28,7 +27,7 @@ from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_doc_perm_sync import (
|
||||
RedisConnectorPermissionSyncPayload,
|
||||
RedisConnectorPermissionSyncData,
|
||||
)
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import doc_permission_sync_ctx
|
||||
@@ -139,7 +138,7 @@ def try_creating_permissions_sync_task(
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
lock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -163,7 +162,7 @@ def try_creating_permissions_sync_task(
|
||||
|
||||
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
|
||||
|
||||
result = app.send_task(
|
||||
app.send_task(
|
||||
"connector_permission_sync_generator_task",
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair_id,
|
||||
@@ -175,8 +174,8 @@ def try_creating_permissions_sync_task(
|
||||
)
|
||||
|
||||
# set a basic fence to start
|
||||
payload = RedisConnectorPermissionSyncPayload(
|
||||
started=None, celery_task_id=result.id
|
||||
payload = RedisConnectorPermissionSyncData(
|
||||
started=None,
|
||||
)
|
||||
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
@@ -242,17 +241,13 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
if doc_sync_func is None:
|
||||
raise ValueError(
|
||||
f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}"
|
||||
)
|
||||
raise ValueError(f"No doc sync func found for {source_type}")
|
||||
|
||||
logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}")
|
||||
logger.info(f"Syncing docs for {source_type}")
|
||||
|
||||
payload = redis_connector.permissions.payload
|
||||
if not payload:
|
||||
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
|
||||
|
||||
payload.started = datetime.now(timezone.utc)
|
||||
payload = RedisConnectorPermissionSyncData(
|
||||
started=datetime.now(timezone.utc),
|
||||
)
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
|
||||
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
|
||||
|
||||
@@ -8,7 +8,6 @@ from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
@@ -25,9 +24,6 @@ from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_ext_group_sync import (
|
||||
RedisConnectorExternalGroupSyncPayload,
|
||||
)
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
@@ -53,7 +49,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
if cc_pair.access_type != AccessType.SYNC:
|
||||
return False
|
||||
|
||||
# skip external group sync if not active
|
||||
# skip pruning if not active
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
||||
return False
|
||||
|
||||
@@ -111,7 +107,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
cc_pair_ids_to_sync.append(cc_pair.id)
|
||||
|
||||
for cc_pair_id in cc_pair_ids_to_sync:
|
||||
tasks_created = try_creating_external_group_sync_task(
|
||||
tasks_created = try_creating_permissions_sync_task(
|
||||
self.app, cc_pair_id, r, tenant_id
|
||||
)
|
||||
if not tasks_created:
|
||||
@@ -129,7 +125,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def try_creating_external_group_sync_task(
|
||||
def try_creating_permissions_sync_task(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
@@ -160,7 +156,7 @@ def try_creating_external_group_sync_task(
|
||||
|
||||
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
|
||||
|
||||
result = app.send_task(
|
||||
_ = app.send_task(
|
||||
"connector_external_group_sync_generator_task",
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair_id,
|
||||
@@ -170,13 +166,8 @@ def try_creating_external_group_sync_task(
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.HIGH,
|
||||
)
|
||||
|
||||
payload = RedisConnectorExternalGroupSyncPayload(
|
||||
started=datetime.now(timezone.utc),
|
||||
celery_task_id=result.id,
|
||||
)
|
||||
|
||||
redis_connector.external_group_sync.set_fence(payload)
|
||||
# set a basic fence to start
|
||||
redis_connector.external_group_sync.set_fence(True)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
@@ -204,7 +195,7 @@ def connector_external_group_sync_generator_task(
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
Permission sync task that handles external group syncing for a given connector credential pair
|
||||
Permission sync task that handles document permission syncing for a given connector credential pair
|
||||
This task assumes that the task has already been properly fenced
|
||||
"""
|
||||
|
||||
@@ -212,7 +203,7 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
|
||||
@@ -237,13 +228,9 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
if ext_group_sync_func is None:
|
||||
raise ValueError(
|
||||
f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
|
||||
)
|
||||
raise ValueError(f"No external group sync func found for {source_type}")
|
||||
|
||||
logger.info(
|
||||
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
|
||||
)
|
||||
logger.info(f"Syncing docs for {source_type}")
|
||||
|
||||
external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)
|
||||
|
||||
@@ -262,6 +249,7 @@ def connector_external_group_sync_generator_task(
|
||||
)
|
||||
|
||||
mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Failed to run external group sync: cc_pair={cc_pair_id}"
|
||||
@@ -272,6 +260,6 @@ def connector_external_group_sync_generator_task(
|
||||
raise e
|
||||
finally:
|
||||
# we always want to clear the fence after the task is done or failed so it doesn't get stuck
|
||||
redis_connector.external_group_sync.set_fence(None)
|
||||
redis_connector.external_group_sync.set_fence(False)
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
@@ -25,13 +25,11 @@ from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.enums import IndexingMode
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.db.enums import IndexModelStatus
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
@@ -39,13 +37,12 @@ from danswer.db.index_attempt import delete_index_attempt
|
||||
from danswer.db.index_attempt import get_all_index_attempts_by_status
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from danswer.db.index_attempt import mark_attempt_canceled
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.search_settings import get_active_search_settings
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
@@ -162,7 +159,7 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
|
||||
)
|
||||
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
tasks_created = 0
|
||||
locked = False
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
@@ -175,8 +172,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
locked = True
|
||||
|
||||
# check for search settings swap
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
old_search_settings = check_index_swap(db_session=db_session)
|
||||
@@ -210,10 +205,17 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
search_settings_list: list[SearchSettings] = get_active_search_settings(
|
||||
db_session
|
||||
)
|
||||
for search_settings_instance in search_settings_list:
|
||||
# Get the primary search settings
|
||||
primary_search_settings = get_current_search_settings(db_session)
|
||||
search_settings = [primary_search_settings]
|
||||
|
||||
# Check for secondary search settings
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
if secondary_search_settings is not None:
|
||||
# If secondary settings exist, add them to the list
|
||||
search_settings.append(secondary_search_settings)
|
||||
|
||||
for search_settings_instance in search_settings:
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
search_settings_instance.id
|
||||
)
|
||||
@@ -229,46 +231,22 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
last_attempt = get_last_attempt_for_cc_pair(
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
|
||||
search_settings_primary = False
|
||||
if search_settings_instance.id == search_settings_list[0].id:
|
||||
search_settings_primary = True
|
||||
|
||||
if not _should_index(
|
||||
cc_pair=cc_pair,
|
||||
last_index=last_attempt,
|
||||
search_settings_instance=search_settings_instance,
|
||||
search_settings_primary=search_settings_primary,
|
||||
secondary_index_building=len(search_settings_list) > 1,
|
||||
secondary_index_building=len(search_settings) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
continue
|
||||
|
||||
reindex = False
|
||||
if search_settings_instance.id == search_settings_list[0].id:
|
||||
# the indexing trigger is only checked and cleared with the primary search settings
|
||||
if cc_pair.indexing_trigger is not None:
|
||||
if cc_pair.indexing_trigger == IndexingMode.REINDEX:
|
||||
reindex = True
|
||||
|
||||
task_logger.info(
|
||||
f"Connector indexing manual trigger detected: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings_instance.id} "
|
||||
f"indexing_mode={cc_pair.indexing_trigger}"
|
||||
)
|
||||
|
||||
mark_ccpair_with_indexing_trigger(
|
||||
cc_pair.id, None, db_session
|
||||
)
|
||||
|
||||
# using a task queue and only allowing one task per cc_pair/search_setting
|
||||
# prevents us from starving out certain attempts
|
||||
attempt_id = try_creating_indexing_task(
|
||||
self.app,
|
||||
cc_pair,
|
||||
search_settings_instance,
|
||||
reindex,
|
||||
False,
|
||||
db_session,
|
||||
r,
|
||||
tenant_id,
|
||||
@@ -278,7 +256,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
f"Connector indexing queued: "
|
||||
f"index_attempt={attempt_id} "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings_instance.id}"
|
||||
f"search_settings={search_settings_instance.id} "
|
||||
)
|
||||
tasks_created += 1
|
||||
|
||||
@@ -303,6 +281,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
mark_attempt_failed(
|
||||
attempt.id, db_session, failure_reason=failure_reason
|
||||
)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -310,14 +289,13 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
finally:
|
||||
if locked:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"check_for_indexing - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"check_for_indexing - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
|
||||
return tasks_created
|
||||
|
||||
@@ -326,7 +304,6 @@ def _should_index(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
last_index: IndexAttempt | None,
|
||||
search_settings_instance: SearchSettings,
|
||||
search_settings_primary: bool,
|
||||
secondary_index_building: bool,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
@@ -391,11 +368,6 @@ def _should_index(
|
||||
):
|
||||
return False
|
||||
|
||||
if search_settings_primary:
|
||||
if cc_pair.indexing_trigger is not None:
|
||||
# if a manual indexing trigger is on the cc pair, honor it for primary search settings
|
||||
return True
|
||||
|
||||
# if no attempt has ever occurred, we should index regardless of refresh_freq
|
||||
if not last_index:
|
||||
return True
|
||||
@@ -523,11 +495,8 @@ def try_creating_indexing_task(
|
||||
return index_attempt_id
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="connector_indexing_proxy_task", bind=True, acks_late=False, track_started=True
|
||||
)
|
||||
@shared_task(name="connector_indexing_proxy_task", acks_late=False, track_started=True)
|
||||
def connector_indexing_proxy_task(
|
||||
self: Task,
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
@@ -540,10 +509,6 @@ def connector_indexing_proxy_task(
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
if not self.request.id:
|
||||
task_logger.error("self.request.id is None!")
|
||||
|
||||
client = SimpleJobClient()
|
||||
|
||||
job = client.submit(
|
||||
@@ -572,30 +537,8 @@ def connector_indexing_proxy_task(
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
while True:
|
||||
sleep(5)
|
||||
|
||||
if self.request.id and redis_connector_index.terminating(self.request.id):
|
||||
task_logger.warning(
|
||||
"Indexing proxy - termination signal detected: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
"Connector termination signal detected",
|
||||
)
|
||||
|
||||
job.cancel()
|
||||
break
|
||||
sleep(10)
|
||||
|
||||
# do nothing for ongoing jobs that haven't been stopped
|
||||
if not job.done():
|
||||
|
||||
@@ -46,7 +46,6 @@ from danswer.db.document_set import fetch_document_sets_for_document
|
||||
from danswer.db.document_set import get_document_set_by_id
|
||||
from danswer.db.document_set import mark_document_set_as_synced
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.db.index_attempt import delete_index_attempts
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
@@ -59,7 +58,7 @@ from danswer.redis.redis_connector_credential_pair import RedisConnectorCredenti
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from danswer.redis.redis_connector_doc_perm_sync import (
|
||||
RedisConnectorPermissionSyncPayload,
|
||||
RedisConnectorPermissionSyncData,
|
||||
)
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndex
|
||||
from danswer.redis.redis_connector_prune import RedisConnectorPrune
|
||||
@@ -589,7 +588,7 @@ def monitor_ccpair_permissions_taskset(
|
||||
if remaining > 0:
|
||||
return
|
||||
|
||||
payload: RedisConnectorPermissionSyncPayload | None = (
|
||||
payload: RedisConnectorPermissionSyncData | None = (
|
||||
redis_connector.permissions.payload
|
||||
)
|
||||
start_time: datetime | None = payload.started if payload else None
|
||||
@@ -597,7 +596,9 @@ def monitor_ccpair_permissions_taskset(
|
||||
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
|
||||
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
|
||||
|
||||
redis_connector.permissions.reset()
|
||||
redis_connector.permissions.taskset_clear()
|
||||
redis_connector.permissions.generator_clear()
|
||||
redis_connector.permissions.set_fence(None)
|
||||
|
||||
|
||||
def monitor_ccpair_indexing_taskset(
|
||||
@@ -677,15 +678,11 @@ def monitor_ccpair_indexing_taskset(
|
||||
|
||||
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
|
||||
if index_attempt:
|
||||
if (
|
||||
index_attempt.status != IndexingStatus.CANCELED
|
||||
and index_attempt.status != IndexingStatus.FAILED
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt_id=payload.index_attempt_id,
|
||||
db_session=db_session,
|
||||
failure_reason=msg,
|
||||
)
|
||||
mark_attempt_failed(
|
||||
index_attempt_id=payload.index_attempt_id,
|
||||
db_session=db_session,
|
||||
failure_reason=msg,
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
@@ -695,7 +692,6 @@ def monitor_ccpair_indexing_taskset(
|
||||
task_logger.info(
|
||||
f"Connector indexing finished: cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"status={status_enum.name} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
@@ -728,7 +724,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
|
||||
# print current queue lengths
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length("celery", r_celery)
|
||||
n_celery = celery_get_queue_length("celery", r)
|
||||
n_indexing = celery_get_queue_length(
|
||||
DanswerCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
"""Factory stub for running celery worker / celery beat."""
|
||||
from celery import Celery
|
||||
|
||||
from danswer.background.celery.apps.beat import celery_app
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
app: Celery = celery_app
|
||||
app = celery_app
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
"""Factory stub for running celery worker / celery beat."""
|
||||
from celery import Celery
|
||||
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
app: Celery = fetch_versioned_implementation(
|
||||
app = fetch_versioned_implementation(
|
||||
"danswer.background.celery.apps.primary", "celery_app"
|
||||
)
|
||||
|
||||
@@ -19,7 +19,6 @@ from danswer.db.connector_credential_pair import get_last_successful_attempt_tim
|
||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.index_attempt import mark_attempt_canceled
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.index_attempt import mark_attempt_partially_succeeded
|
||||
from danswer.db.index_attempt import mark_attempt_succeeded
|
||||
@@ -88,10 +87,6 @@ def _get_connector_runner(
|
||||
)
|
||||
|
||||
|
||||
class ConnectorStopSignal(Exception):
|
||||
"""A custom exception used to signal a stop in processing."""
|
||||
|
||||
|
||||
def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt: IndexAttempt,
|
||||
@@ -213,7 +208,9 @@ def _run_indexing(
|
||||
# contents still need to be initially pulled.
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise ConnectorStopSignal("Connector stop signal detected")
|
||||
raise RuntimeError(
|
||||
"_run_indexing: Connector stop signal detected"
|
||||
)
|
||||
|
||||
# TODO: should we move this into the above callback instead?
|
||||
db_session.refresh(db_cc_pair)
|
||||
@@ -307,16 +304,26 @@ def _run_indexing(
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds"
|
||||
f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds"
|
||||
)
|
||||
|
||||
if isinstance(e, ConnectorStopSignal):
|
||||
mark_attempt_canceled(
|
||||
# Only mark the attempt as a complete failure if this is the first indexing window.
|
||||
# Otherwise, some progress was made - the next run will not start from the beginning.
|
||||
# In this case, it is not accurate to mark it as a failure. When the next run begins,
|
||||
# if that fails immediately, it will be marked as a failure.
|
||||
#
|
||||
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
|
||||
# to give better clarity in the UI, as the next run will never happen.
|
||||
if (
|
||||
ind == 0
|
||||
or not db_cc_pair.status.is_active()
|
||||
or index_attempt.status != IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt.id,
|
||||
db_session,
|
||||
reason=str(e),
|
||||
failure_reason=str(e),
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
)
|
||||
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
@@ -328,37 +335,6 @@ def _run_indexing(
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
tracer.stop()
|
||||
raise e
|
||||
else:
|
||||
# Only mark the attempt as a complete failure if this is the first indexing window.
|
||||
# Otherwise, some progress was made - the next run will not start from the beginning.
|
||||
# In this case, it is not accurate to mark it as a failure. When the next run begins,
|
||||
# if that fails immediately, it will be marked as a failure.
|
||||
#
|
||||
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
|
||||
# to give better clarity in the UI, as the next run will never happen.
|
||||
if (
|
||||
ind == 0
|
||||
or not db_cc_pair.status.is_active()
|
||||
or index_attempt.status != IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt.id,
|
||||
db_session,
|
||||
failure_reason=str(e),
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
)
|
||||
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
tracer.stop()
|
||||
raise e
|
||||
|
||||
# break => similar to success case. As mentioned above, if the next run fails for the same
|
||||
# reason it will then be marked as a failure
|
||||
|
||||
@@ -605,7 +605,6 @@ def stream_chat_message_objects(
|
||||
additional_headers=custom_tool_additional_headers,
|
||||
),
|
||||
)
|
||||
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
|
||||
@@ -493,6 +493,10 @@ CONTROL_PLANE_API_BASE_URL = os.environ.get(
|
||||
# JWT configuration
|
||||
JWT_ALGORITHM = "HS256"
|
||||
|
||||
# Super Users
|
||||
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
|
||||
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
|
||||
|
||||
|
||||
#####
|
||||
# API Key Configs
|
||||
|
||||
@@ -70,9 +70,7 @@ GEN_AI_NUM_RESERVED_OUTPUT_TOKENS = int(
|
||||
)
|
||||
|
||||
# Typically, GenAI models nowadays are at least 4K tokens
|
||||
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = int(
|
||||
os.environ.get("GEN_AI_MODEL_FALLBACK_MAX_TOKENS") or 4096
|
||||
)
|
||||
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = 4096
|
||||
|
||||
# Number of tokens from chat history to include at maximum
|
||||
# 3000 should be enough context regardless of use, no need to include as much as possible
|
||||
|
||||
@@ -11,16 +11,11 @@ Connectors come in 3 different flows:
|
||||
- Load Connector:
|
||||
- Bulk indexes documents to reflect a point in time. This type of connector generally works by either pulling all
|
||||
documents via a connector's API or loads the documents from some sort of a dump file.
|
||||
- Poll Connector:
|
||||
- Poll connector:
|
||||
- Incrementally updates documents based on a provided time range. It is used by the background job to pull the latest
|
||||
changes and additions since the last round of polling. This connector helps keep the document index up to date
|
||||
without needing to fetch/embed/index every document which would be too slow to do frequently on large sets of
|
||||
documents.
|
||||
- Slim Connector:
|
||||
- This connector should be a lighter weight method of checking all documents in the source to see if they still exist.
|
||||
- This connector should be identical to the Poll or Load Connector except that it only fetches the IDs of the documents, not the documents themselves.
|
||||
- This is used by our pruning job which removes old documents from the index.
|
||||
- The optional start and end datetimes can be ignored.
|
||||
- Event Based connectors:
|
||||
- Connectors that listen to events and update documents accordingly.
|
||||
- Currently not used by the background job, this exists for future design purposes.
|
||||
@@ -31,14 +26,8 @@ Refer to [interfaces.py](https://github.com/danswer-ai/danswer/blob/main/backend
|
||||
and this first contributor created Pull Request for a new connector (Shoutout to Dan Brown):
|
||||
[Reference Pull Request](https://github.com/danswer-ai/danswer/pull/139)
|
||||
|
||||
For implementing a Slim Connector, refer to the comments in this PR:
|
||||
[Slim Connector PR](https://github.com/danswer-ai/danswer/pull/3303/files)
|
||||
|
||||
All new connectors should have tests added to the `backend/tests/daily/connectors` directory. Refer to the above PR for an example of adding tests for a new connector.
|
||||
|
||||
|
||||
#### Implementing the new Connector
|
||||
The connector must subclass one or more of LoadConnector, PollConnector, SlimConnector, or EventConnector.
|
||||
The connector must subclass one or more of LoadConnector, PollConnector, or EventConnector.
|
||||
|
||||
The `__init__` should take arguments for configuring what documents the connector will and where it finds those
|
||||
documents. For example, if you have a wiki site, it may include the configuration for the team, topic, folder, etc. of
|
||||
|
||||
@@ -51,7 +51,7 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
|
||||
"restrictions.read.restrictions.group",
|
||||
]
|
||||
|
||||
_SLIM_DOC_BATCH_SIZE = 5000
|
||||
_SLIM_DOC_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
@@ -301,8 +301,5 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
perm_sync_data=perm_sync_data,
|
||||
)
|
||||
)
|
||||
if len(doc_metadata_list) > _SLIM_DOC_BATCH_SIZE:
|
||||
yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE]
|
||||
doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:]
|
||||
|
||||
yield doc_metadata_list
|
||||
yield doc_metadata_list
|
||||
doc_metadata_list = []
|
||||
|
||||
@@ -120,7 +120,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
|
||||
_DEFAULT_PAGINATION_LIMIT = 1000
|
||||
_DEFAULT_PAGINATION_LIMIT = 100
|
||||
|
||||
|
||||
class OnyxConfluence(Confluence):
|
||||
|
||||
@@ -12,15 +12,12 @@ from dateutil import parser
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.interfaces import SlimConnector
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.connectors.models import SlimDocument
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -31,8 +28,6 @@ logger = setup_logger()
|
||||
SLAB_GRAPHQL_MAX_TRIES = 10
|
||||
SLAB_API_URL = "https://api.slab.com/v1/graphql"
|
||||
|
||||
_SLIM_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
def run_graphql_request(
|
||||
graphql_query: dict, bot_token: str, max_tries: int = SLAB_GRAPHQL_MAX_TRIES
|
||||
@@ -163,26 +158,21 @@ def get_slab_url_from_title_id(base_url: str, title: str, page_id: str) -> str:
|
||||
return urljoin(urljoin(base_url, "posts/"), url_id)
|
||||
|
||||
|
||||
class SlabConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class SlabConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
slab_bot_token: str | None = None,
|
||||
) -> None:
|
||||
self.base_url = base_url
|
||||
self.batch_size = batch_size
|
||||
self._slab_bot_token: str | None = None
|
||||
self.slab_bot_token = slab_bot_token
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self._slab_bot_token = credentials["slab_bot_token"]
|
||||
self.slab_bot_token = credentials["slab_bot_token"]
|
||||
return None
|
||||
|
||||
@property
|
||||
def slab_bot_token(self) -> str:
|
||||
if self._slab_bot_token is None:
|
||||
raise ConnectorMissingCredentialError("Slab")
|
||||
return self._slab_bot_token
|
||||
|
||||
def _iterate_posts(
|
||||
self, time_filter: Callable[[datetime], bool] | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
@@ -237,21 +227,3 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
yield from self._iterate_posts(
|
||||
time_filter=lambda t: start_time <= t <= end_time
|
||||
)
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
slim_doc_batch: list[SlimDocument] = []
|
||||
for post_id in get_all_post_ids(self.slab_bot_token):
|
||||
slim_doc_batch.append(
|
||||
SlimDocument(
|
||||
id=post_id,
|
||||
)
|
||||
)
|
||||
if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
|
||||
yield slim_doc_batch
|
||||
slim_doc_batch = []
|
||||
if slim_doc_batch:
|
||||
yield slim_doc_batch
|
||||
|
||||
@@ -18,30 +18,20 @@ from slack_sdk.models.blocks.block_elements import ImageElement
|
||||
|
||||
from danswer.chat.models import DanswerQuote
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY
|
||||
from danswer.context.search.models import SavedSearchDoc
|
||||
from danswer.danswerbot.slack.constants import CONTINUE_IN_WEB_UI_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.formatting import format_slack_message
|
||||
from danswer.danswerbot.slack.icons import source_to_github_img_link
|
||||
from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import build_continue_in_web_ui_id
|
||||
from danswer.danswerbot.slack.utils import build_feedback_id
|
||||
from danswer.danswerbot.slack.utils import remove_slack_text_interactions
|
||||
from danswer.danswerbot.slack.utils import translate_vespa_highlight_to_slack
|
||||
from danswer.db.chat import get_chat_session_by_message_id
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.models import ChannelConfig
|
||||
from danswer.db.models import Persona
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.utils.text_processing import decode_escapes
|
||||
from danswer.utils.text_processing import replace_whitespaces_w_space
|
||||
|
||||
@@ -111,12 +101,12 @@ def _split_text(text: str, limit: int = 3000) -> list[str]:
|
||||
return chunks
|
||||
|
||||
|
||||
def _clean_markdown_link_text(text: str) -> str:
|
||||
def clean_markdown_link_text(text: str) -> str:
|
||||
# Remove any newlines within the text
|
||||
return text.replace("\n", " ").strip()
|
||||
|
||||
|
||||
def _build_qa_feedback_block(
|
||||
def build_qa_feedback_block(
|
||||
message_id: int, feedback_reminder_id: str | None = None
|
||||
) -> Block:
|
||||
return ActionsBlock(
|
||||
@@ -125,6 +115,7 @@ def _build_qa_feedback_block(
|
||||
ButtonElement(
|
||||
action_id=LIKE_BLOCK_ACTION_ID,
|
||||
text="👍 Helpful",
|
||||
style="primary",
|
||||
value=feedback_reminder_id,
|
||||
),
|
||||
ButtonElement(
|
||||
@@ -164,7 +155,7 @@ def get_document_feedback_blocks() -> Block:
|
||||
)
|
||||
|
||||
|
||||
def _build_doc_feedback_block(
|
||||
def build_doc_feedback_block(
|
||||
message_id: int,
|
||||
document_id: str,
|
||||
document_rank: int,
|
||||
@@ -191,7 +182,7 @@ def get_restate_blocks(
|
||||
]
|
||||
|
||||
|
||||
def _build_documents_blocks(
|
||||
def build_documents_blocks(
|
||||
documents: list[SavedSearchDoc],
|
||||
message_id: int | None,
|
||||
num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY,
|
||||
@@ -232,7 +223,7 @@ def _build_documents_blocks(
|
||||
|
||||
feedback: ButtonElement | dict = {}
|
||||
if message_id is not None:
|
||||
feedback = _build_doc_feedback_block(
|
||||
feedback = build_doc_feedback_block(
|
||||
message_id=message_id,
|
||||
document_id=d.document_id,
|
||||
document_rank=rank,
|
||||
@@ -250,7 +241,7 @@ def _build_documents_blocks(
|
||||
return section_blocks
|
||||
|
||||
|
||||
def _build_sources_blocks(
|
||||
def build_sources_blocks(
|
||||
cited_documents: list[tuple[int, SavedSearchDoc]],
|
||||
num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY,
|
||||
) -> list[Block]:
|
||||
@@ -295,7 +286,7 @@ def _build_sources_blocks(
|
||||
+ ([days_ago_str] if days_ago_str else [])
|
||||
)
|
||||
|
||||
document_title = _clean_markdown_link_text(doc_sem_id)
|
||||
document_title = clean_markdown_link_text(doc_sem_id)
|
||||
img_link = source_to_github_img_link(d.source_type)
|
||||
|
||||
section_blocks.append(
|
||||
@@ -326,50 +317,7 @@ def _build_sources_blocks(
|
||||
return section_blocks
|
||||
|
||||
|
||||
def _priority_ordered_documents_blocks(
|
||||
answer: OneShotQAResponse,
|
||||
) -> list[Block]:
|
||||
docs_response = answer.docs if answer.docs else None
|
||||
top_docs = docs_response.top_documents if docs_response else []
|
||||
llm_doc_inds = answer.llm_selected_doc_indices or []
|
||||
llm_docs = [top_docs[i] for i in llm_doc_inds]
|
||||
remaining_docs = [
|
||||
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
|
||||
]
|
||||
priority_ordered_docs = llm_docs + remaining_docs
|
||||
if not priority_ordered_docs:
|
||||
return []
|
||||
|
||||
document_blocks = _build_documents_blocks(
|
||||
documents=priority_ordered_docs,
|
||||
message_id=answer.chat_message_id,
|
||||
)
|
||||
if document_blocks:
|
||||
document_blocks = [DividerBlock()] + document_blocks
|
||||
return document_blocks
|
||||
|
||||
|
||||
def _build_citations_blocks(
|
||||
answer: OneShotQAResponse,
|
||||
) -> list[Block]:
|
||||
docs_response = answer.docs if answer.docs else None
|
||||
top_docs = docs_response.top_documents if docs_response else []
|
||||
citations = answer.citations or []
|
||||
cited_docs = []
|
||||
for citation in citations:
|
||||
matching_doc = next(
|
||||
(d for d in top_docs if d.document_id == citation.document_id),
|
||||
None,
|
||||
)
|
||||
if matching_doc:
|
||||
cited_docs.append((citation.citation_num, matching_doc))
|
||||
|
||||
cited_docs.sort()
|
||||
citations_block = _build_sources_blocks(cited_documents=cited_docs)
|
||||
return citations_block
|
||||
|
||||
|
||||
def _build_quotes_block(
|
||||
def build_quotes_block(
|
||||
quotes: list[DanswerQuote],
|
||||
) -> list[Block]:
|
||||
quote_lines: list[str] = []
|
||||
@@ -411,70 +359,58 @@ def _build_quotes_block(
|
||||
return [SectionBlock(text="*Relevant Snippets*\n" + "\n".join(quote_lines))]
|
||||
|
||||
|
||||
def _build_qa_response_blocks(
|
||||
answer: OneShotQAResponse,
|
||||
def build_qa_response_blocks(
|
||||
message_id: int | None,
|
||||
answer: str | None,
|
||||
quotes: list[DanswerQuote] | None,
|
||||
source_filters: list[DocumentSource] | None,
|
||||
time_cutoff: datetime | None,
|
||||
favor_recent: bool,
|
||||
skip_quotes: bool = False,
|
||||
process_message_for_citations: bool = False,
|
||||
skip_ai_feedback: bool = False,
|
||||
feedback_reminder_id: str | None = None,
|
||||
) -> list[Block]:
|
||||
retrieval_info = answer.docs
|
||||
if not retrieval_info:
|
||||
# This should not happen, even with no docs retrieved, there is still info returned
|
||||
raise RuntimeError("Failed to retrieve docs, cannot answer question.")
|
||||
|
||||
formatted_answer = format_slack_message(answer.answer) if answer.answer else None
|
||||
quotes = answer.quotes.quotes if answer.quotes else None
|
||||
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
return []
|
||||
|
||||
quotes_blocks: list[Block] = []
|
||||
|
||||
filter_block: Block | None = None
|
||||
if (
|
||||
retrieval_info.applied_time_cutoff
|
||||
or retrieval_info.recency_bias_multiplier > 1
|
||||
or retrieval_info.applied_source_filters
|
||||
):
|
||||
if time_cutoff or favor_recent or source_filters:
|
||||
filter_text = "Filters: "
|
||||
if retrieval_info.applied_source_filters:
|
||||
sources_str = ", ".join(
|
||||
[s.value for s in retrieval_info.applied_source_filters]
|
||||
)
|
||||
if source_filters:
|
||||
sources_str = ", ".join([s.value for s in source_filters])
|
||||
filter_text += f"`Sources in [{sources_str}]`"
|
||||
if (
|
||||
retrieval_info.applied_time_cutoff
|
||||
or retrieval_info.recency_bias_multiplier > 1
|
||||
):
|
||||
if time_cutoff or favor_recent:
|
||||
filter_text += " and "
|
||||
if retrieval_info.applied_time_cutoff is not None:
|
||||
time_str = retrieval_info.applied_time_cutoff.strftime("%b %d, %Y")
|
||||
if time_cutoff is not None:
|
||||
time_str = time_cutoff.strftime("%b %d, %Y")
|
||||
filter_text += f"`Docs Updated >= {time_str}` "
|
||||
if retrieval_info.recency_bias_multiplier > 1:
|
||||
if retrieval_info.applied_time_cutoff is not None:
|
||||
if favor_recent:
|
||||
if time_cutoff is not None:
|
||||
filter_text += "+ "
|
||||
filter_text += "`Prioritize Recently Updated Docs`"
|
||||
|
||||
filter_block = SectionBlock(text=f"_{filter_text}_")
|
||||
|
||||
if not formatted_answer:
|
||||
if not answer:
|
||||
answer_blocks = [
|
||||
SectionBlock(
|
||||
text="Sorry, I was unable to find an answer, but I did find some potentially relevant docs 🤓"
|
||||
)
|
||||
]
|
||||
else:
|
||||
answer_processed = decode_escapes(
|
||||
remove_slack_text_interactions(formatted_answer)
|
||||
)
|
||||
answer_processed = decode_escapes(remove_slack_text_interactions(answer))
|
||||
if process_message_for_citations:
|
||||
answer_processed = _process_citations_for_slack(answer_processed)
|
||||
answer_blocks = [
|
||||
SectionBlock(text=text) for text in _split_text(answer_processed)
|
||||
]
|
||||
if quotes:
|
||||
quotes_blocks = _build_quotes_block(quotes)
|
||||
quotes_blocks = build_quotes_block(quotes)
|
||||
|
||||
# if no quotes OR `_build_quotes_block()` did not give back any blocks
|
||||
# if no quotes OR `build_quotes_block()` did not give back any blocks
|
||||
if not quotes_blocks:
|
||||
quotes_blocks = [
|
||||
SectionBlock(
|
||||
@@ -489,37 +425,20 @@ def _build_qa_response_blocks(
|
||||
|
||||
response_blocks.extend(answer_blocks)
|
||||
|
||||
if message_id is not None and not skip_ai_feedback:
|
||||
response_blocks.append(
|
||||
build_qa_feedback_block(
|
||||
message_id=message_id, feedback_reminder_id=feedback_reminder_id
|
||||
)
|
||||
)
|
||||
|
||||
if not skip_quotes:
|
||||
response_blocks.extend(quotes_blocks)
|
||||
|
||||
return response_blocks
|
||||
|
||||
|
||||
def _build_continue_in_web_ui_block(
|
||||
tenant_id: str | None,
|
||||
message_id: int | None,
|
||||
) -> Block:
|
||||
if message_id is None:
|
||||
raise ValueError("No message id provided to build continue in web ui block")
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
chat_session = get_chat_session_by_message_id(
|
||||
db_session=db_session,
|
||||
message_id=message_id,
|
||||
)
|
||||
return ActionsBlock(
|
||||
block_id=build_continue_in_web_ui_id(message_id),
|
||||
elements=[
|
||||
ButtonElement(
|
||||
action_id=CONTINUE_IN_WEB_UI_ACTION_ID,
|
||||
text="Continue Chat in Danswer!",
|
||||
style="primary",
|
||||
url=f"{WEB_DOMAIN}/chat?slackChatId={chat_session.id}",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _build_follow_up_block(message_id: int | None) -> ActionsBlock:
|
||||
def build_follow_up_block(message_id: int | None) -> ActionsBlock:
|
||||
return ActionsBlock(
|
||||
block_id=build_feedback_id(message_id) if message_id is not None else None,
|
||||
elements=[
|
||||
@@ -564,77 +483,3 @@ def build_follow_up_resolved_blocks(
|
||||
]
|
||||
)
|
||||
return [text_block, button_block]
|
||||
|
||||
|
||||
def build_slack_response_blocks(
|
||||
tenant_id: str | None,
|
||||
message_info: SlackMessageInfo,
|
||||
answer: OneShotQAResponse,
|
||||
persona: Persona | None,
|
||||
channel_conf: ChannelConfig | None,
|
||||
use_citations: bool,
|
||||
feedback_reminder_id: str | None,
|
||||
skip_ai_feedback: bool = False,
|
||||
) -> list[Block]:
|
||||
"""
|
||||
This function is a top level function that builds all the blocks for the Slack response.
|
||||
It also handles combining all the blocks together.
|
||||
"""
|
||||
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
|
||||
restate_question_block = get_restate_blocks(
|
||||
message_info.thread_messages[-1].message, message_info.is_bot_msg
|
||||
)
|
||||
|
||||
answer_blocks = _build_qa_response_blocks(
|
||||
answer=answer,
|
||||
skip_quotes=persona is not None or use_citations,
|
||||
process_message_for_citations=use_citations,
|
||||
)
|
||||
|
||||
web_follow_up_block = []
|
||||
if channel_conf and channel_conf.get("show_continue_in_web_ui"):
|
||||
web_follow_up_block.append(
|
||||
_build_continue_in_web_ui_block(
|
||||
tenant_id=tenant_id,
|
||||
message_id=answer.chat_message_id,
|
||||
)
|
||||
)
|
||||
|
||||
follow_up_block = []
|
||||
if channel_conf and channel_conf.get("follow_up_tags") is not None:
|
||||
follow_up_block.append(
|
||||
_build_follow_up_block(message_id=answer.chat_message_id)
|
||||
)
|
||||
|
||||
ai_feedback_block = []
|
||||
if answer.chat_message_id is not None and not skip_ai_feedback:
|
||||
ai_feedback_block.append(
|
||||
_build_qa_feedback_block(
|
||||
message_id=answer.chat_message_id,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
)
|
||||
)
|
||||
|
||||
citations_blocks = []
|
||||
document_blocks = []
|
||||
if use_citations:
|
||||
# if citations are enabled, only show cited documents
|
||||
citations_blocks = _build_citations_blocks(answer)
|
||||
else:
|
||||
document_blocks = _priority_ordered_documents_blocks(answer)
|
||||
|
||||
citations_divider = [DividerBlock()] if citations_blocks else []
|
||||
buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else []
|
||||
|
||||
all_blocks = (
|
||||
restate_question_block
|
||||
+ answer_blocks
|
||||
+ ai_feedback_block
|
||||
+ citations_divider
|
||||
+ citations_blocks
|
||||
+ document_blocks
|
||||
+ buttons_divider
|
||||
+ web_follow_up_block
|
||||
+ follow_up_block
|
||||
)
|
||||
return all_blocks
|
||||
|
||||
@@ -2,7 +2,6 @@ from enum import Enum
|
||||
|
||||
LIKE_BLOCK_ACTION_ID = "feedback-like"
|
||||
DISLIKE_BLOCK_ACTION_ID = "feedback-dislike"
|
||||
CONTINUE_IN_WEB_UI_ACTION_ID = "continue-in-web-ui"
|
||||
FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID = "feedback-doc-button"
|
||||
IMMEDIATE_RESOLVED_BUTTON_ACTION_ID = "immediate-resolved-button"
|
||||
FOLLOWUP_BUTTON_ACTION_ID = "followup-button"
|
||||
|
||||
@@ -28,7 +28,7 @@ from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import build_feedback_id
|
||||
from danswer.danswerbot.slack.utils import decompose_action_id
|
||||
from danswer.danswerbot.slack.utils import fetch_group_ids_from_names
|
||||
from danswer.danswerbot.slack.utils import fetch_slack_user_ids_from_emails
|
||||
from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails
|
||||
from danswer.danswerbot.slack.utils import get_channel_name_from_id
|
||||
from danswer.danswerbot.slack.utils import get_feedback_visibility
|
||||
from danswer.danswerbot.slack.utils import read_slack_thread
|
||||
@@ -267,7 +267,7 @@ def handle_followup_button(
|
||||
tag_names = slack_channel_config.channel_config.get("follow_up_tags")
|
||||
remaining = None
|
||||
if tag_names:
|
||||
tag_ids, remaining = fetch_slack_user_ids_from_emails(
|
||||
tag_ids, remaining = fetch_user_ids_from_emails(
|
||||
tag_names, client.web_client
|
||||
)
|
||||
if remaining:
|
||||
|
||||
@@ -13,7 +13,7 @@ from danswer.danswerbot.slack.handlers.handle_standard_answers import (
|
||||
handle_standard_answers,
|
||||
)
|
||||
from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import fetch_slack_user_ids_from_emails
|
||||
from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails
|
||||
from danswer.danswerbot.slack.utils import fetch_user_ids_from_groups
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.danswerbot.slack.utils import slack_usage_report
|
||||
@@ -184,7 +184,7 @@ def handle_message(
|
||||
send_to: list[str] | None = None
|
||||
missing_users: list[str] | None = None
|
||||
if respond_member_group_list:
|
||||
send_to, missing_ids = fetch_slack_user_ids_from_emails(
|
||||
send_to, missing_ids = fetch_user_ids_from_emails(
|
||||
respond_member_group_list, client
|
||||
)
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import TypeVar
|
||||
|
||||
from retry import retry
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.models.blocks import DividerBlock
|
||||
from slack_sdk.models.blocks import SectionBlock
|
||||
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
@@ -24,7 +25,12 @@ from danswer.context.search.enums import OptionalSearchSetting
|
||||
from danswer.context.search.models import BaseFilters
|
||||
from danswer.context.search.models import RerankingDetails
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
from danswer.danswerbot.slack.blocks import build_slack_response_blocks
|
||||
from danswer.danswerbot.slack.blocks import build_documents_blocks
|
||||
from danswer.danswerbot.slack.blocks import build_follow_up_block
|
||||
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
|
||||
from danswer.danswerbot.slack.blocks import build_sources_blocks
|
||||
from danswer.danswerbot.slack.blocks import get_restate_blocks
|
||||
from danswer.danswerbot.slack.formatting import format_slack_message
|
||||
from danswer.danswerbot.slack.handlers.utils import send_team_member_message
|
||||
from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
@@ -405,16 +411,62 @@ def handle_regular_answer(
|
||||
)
|
||||
return True
|
||||
|
||||
all_blocks = build_slack_response_blocks(
|
||||
tenant_id=tenant_id,
|
||||
message_info=message_info,
|
||||
answer=answer,
|
||||
persona=persona,
|
||||
channel_conf=channel_conf,
|
||||
use_citations=use_citations,
|
||||
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
|
||||
restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg)
|
||||
formatted_answer = format_slack_message(answer.answer) if answer.answer else None
|
||||
|
||||
answer_blocks = build_qa_response_blocks(
|
||||
message_id=answer.chat_message_id,
|
||||
answer=formatted_answer,
|
||||
quotes=answer.quotes.quotes if answer.quotes else None,
|
||||
source_filters=retrieval_info.applied_source_filters,
|
||||
time_cutoff=retrieval_info.applied_time_cutoff,
|
||||
favor_recent=retrieval_info.recency_bias_multiplier > 1,
|
||||
# currently Personas don't support quotes
|
||||
# if citations are enabled, also don't use quotes
|
||||
skip_quotes=persona is not None or use_citations,
|
||||
process_message_for_citations=use_citations,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
)
|
||||
|
||||
# Get the chunks fed to the LLM only, then fill with other docs
|
||||
llm_doc_inds = answer.llm_selected_doc_indices or []
|
||||
llm_docs = [top_docs[i] for i in llm_doc_inds]
|
||||
remaining_docs = [
|
||||
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
|
||||
]
|
||||
priority_ordered_docs = llm_docs + remaining_docs
|
||||
|
||||
document_blocks = []
|
||||
citations_block = []
|
||||
# if citations are enabled, only show cited documents
|
||||
if use_citations:
|
||||
citations = answer.citations or []
|
||||
cited_docs = []
|
||||
for citation in citations:
|
||||
matching_doc = next(
|
||||
(d for d in top_docs if d.document_id == citation.document_id),
|
||||
None,
|
||||
)
|
||||
if matching_doc:
|
||||
cited_docs.append((citation.citation_num, matching_doc))
|
||||
|
||||
cited_docs.sort()
|
||||
citations_block = build_sources_blocks(cited_documents=cited_docs)
|
||||
elif priority_ordered_docs:
|
||||
document_blocks = build_documents_blocks(
|
||||
documents=priority_ordered_docs,
|
||||
message_id=answer.chat_message_id,
|
||||
)
|
||||
document_blocks = [DividerBlock()] + document_blocks
|
||||
|
||||
all_blocks = (
|
||||
restate_question_block + answer_blocks + citations_block + document_blocks
|
||||
)
|
||||
|
||||
if channel_conf and channel_conf.get("follow_up_tags") is not None:
|
||||
all_blocks.append(build_follow_up_block(message_id=answer.chat_message_id))
|
||||
|
||||
try:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
|
||||
@@ -3,9 +3,9 @@ import random
|
||||
import re
|
||||
import string
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
|
||||
from retry import retry
|
||||
from slack_sdk import WebClient
|
||||
@@ -216,13 +216,6 @@ def build_feedback_id(
|
||||
return unique_prefix + ID_SEPARATOR + feedback_id
|
||||
|
||||
|
||||
def build_continue_in_web_ui_id(
|
||||
message_id: int,
|
||||
) -> str:
|
||||
unique_prefix = str(uuid.uuid4())[:10]
|
||||
return unique_prefix + ID_SEPARATOR + str(message_id)
|
||||
|
||||
|
||||
def decompose_action_id(feedback_id: str) -> tuple[int, str | None, int | None]:
|
||||
"""Decompose into query_id, document_id, document_rank, see above function"""
|
||||
try:
|
||||
@@ -320,7 +313,7 @@ def get_channel_name_from_id(
|
||||
raise e
|
||||
|
||||
|
||||
def fetch_slack_user_ids_from_emails(
|
||||
def fetch_user_ids_from_emails(
|
||||
user_emails: list[str], client: WebClient
|
||||
) -> tuple[list[str], list[str]]:
|
||||
user_ids: list[str] = []
|
||||
@@ -529,7 +522,7 @@ class SlackRateLimiter:
|
||||
self.last_reset_time = time.time()
|
||||
|
||||
def notify(
|
||||
self, client: WebClient, channel: str, position: int, thread_ts: str | None
|
||||
self, client: WebClient, channel: str, position: int, thread_ts: Optional[str]
|
||||
) -> None:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
|
||||
@@ -3,7 +3,6 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import func
|
||||
@@ -31,7 +30,6 @@ from danswer.db.models import SearchDoc
|
||||
from danswer.db.models import SearchDoc as DBSearchDoc
|
||||
from danswer.db.models import ToolCall
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_best_persona_id_for_user
|
||||
from danswer.db.pg_file_store import delete_lobj_by_name
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
from danswer.llm.override_models import LLMOverride
|
||||
@@ -252,50 +250,6 @@ def create_chat_session(
|
||||
return chat_session
|
||||
|
||||
|
||||
def duplicate_chat_session_for_user_from_slack(
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
chat_session_id: UUID,
|
||||
) -> ChatSession:
|
||||
"""
|
||||
This takes a chat session id for a session in Slack and:
|
||||
- Creates a new chat session in the DB
|
||||
- Tries to copy the persona from the original chat session
|
||||
(if it is available to the user clicking the button)
|
||||
- Sets the user to the given user (if provided)
|
||||
"""
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=chat_session_id,
|
||||
user_id=None, # Ignore user permissions for this
|
||||
db_session=db_session,
|
||||
)
|
||||
if not chat_session:
|
||||
raise HTTPException(status_code=400, detail="Invalid Chat Session ID provided")
|
||||
|
||||
# This enforces permissions and sets a default
|
||||
new_persona_id = get_best_persona_id_for_user(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
persona_id=chat_session.persona_id,
|
||||
)
|
||||
|
||||
return create_chat_session(
|
||||
db_session=db_session,
|
||||
user_id=user.id if user else None,
|
||||
persona_id=new_persona_id,
|
||||
# Set this to empty string so the frontend will force a rename
|
||||
description="",
|
||||
llm_override=chat_session.llm_override,
|
||||
prompt_override=chat_session.prompt_override,
|
||||
# Chat sessions from Slack should put people in the chat UI, not the search
|
||||
one_shot=False,
|
||||
# Chat is in UI now so this is false
|
||||
danswerbot_flow=False,
|
||||
# Maybe we want this in the future to track if it was created from Slack
|
||||
slack_thread_id=None,
|
||||
)
|
||||
|
||||
|
||||
def update_chat_session(
|
||||
db_session: Session,
|
||||
user_id: UUID | None,
|
||||
@@ -382,28 +336,6 @@ def get_chat_message(
|
||||
return chat_message
|
||||
|
||||
|
||||
def get_chat_session_by_message_id(
|
||||
db_session: Session,
|
||||
message_id: int,
|
||||
) -> ChatSession:
|
||||
"""
|
||||
Should only be used for Slack
|
||||
Get the chat session associated with a specific message ID
|
||||
Note: this ignores permission checks.
|
||||
"""
|
||||
stmt = select(ChatMessage).where(ChatMessage.id == message_id)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
chat_message = result.scalar_one_or_none()
|
||||
|
||||
if chat_message is None:
|
||||
raise ValueError(
|
||||
f"Unable to find chat session associated with message ID: {message_id}"
|
||||
)
|
||||
|
||||
return chat_message.chat_session
|
||||
|
||||
|
||||
def get_chat_messages_by_sessions(
|
||||
chat_session_ids: list[UUID],
|
||||
user_id: UUID | None,
|
||||
@@ -423,44 +355,6 @@ def get_chat_messages_by_sessions(
|
||||
return db_session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
def add_chats_to_session_from_slack_thread(
|
||||
db_session: Session,
|
||||
slack_chat_session_id: UUID,
|
||||
new_chat_session_id: UUID,
|
||||
) -> None:
|
||||
new_root_message = get_or_create_root_message(
|
||||
chat_session_id=new_chat_session_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
for chat_message in get_chat_messages_by_sessions(
|
||||
chat_session_ids=[slack_chat_session_id],
|
||||
user_id=None, # Ignore user permissions for this
|
||||
db_session=db_session,
|
||||
skip_permission_check=True,
|
||||
):
|
||||
if chat_message.message_type == MessageType.SYSTEM:
|
||||
continue
|
||||
# Duplicate the message
|
||||
new_root_message = create_new_chat_message(
|
||||
db_session=db_session,
|
||||
chat_session_id=new_chat_session_id,
|
||||
parent_message=new_root_message,
|
||||
message=chat_message.message,
|
||||
files=chat_message.files,
|
||||
rephrased_query=chat_message.rephrased_query,
|
||||
error=chat_message.error,
|
||||
citations=chat_message.citations,
|
||||
reference_docs=chat_message.search_docs,
|
||||
tool_call=chat_message.tool_call,
|
||||
prompt_id=chat_message.prompt_id,
|
||||
token_count=chat_message.token_count,
|
||||
message_type=chat_message.message_type,
|
||||
alternate_assistant_id=chat_message.alternate_assistant_id,
|
||||
overridden_model=chat_message.overridden_model,
|
||||
)
|
||||
|
||||
|
||||
def get_search_docs_for_chat_message(
|
||||
chat_message_id: int, db_session: Session
|
||||
) -> list[SearchDoc]:
|
||||
|
||||
@@ -12,7 +12,6 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.app_configs import DEFAULT_PRUNING_FREQ
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.enums import IndexingMode
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import IndexAttempt
|
||||
@@ -312,25 +311,3 @@ def mark_cc_pair_as_external_group_synced(db_session: Session, cc_pair_id: int)
|
||||
# If this changes, we need to update this function.
|
||||
cc_pair.last_time_external_group_sync = datetime.now(timezone.utc)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_ccpair_with_indexing_trigger(
|
||||
cc_pair_id: int, indexing_mode: IndexingMode | None, db_session: Session
|
||||
) -> None:
|
||||
"""indexing_mode sets a field which will be picked up by a background task
|
||||
to trigger indexing. Set to None to disable the trigger."""
|
||||
try:
|
||||
cc_pair = db_session.execute(
|
||||
select(ConnectorCredentialPair)
|
||||
.where(ConnectorCredentialPair.id == cc_pair_id)
|
||||
.with_for_update()
|
||||
).scalar_one()
|
||||
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")
|
||||
|
||||
cc_pair.indexing_trigger = indexing_mode
|
||||
db_session.commit()
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
raise
|
||||
|
||||
@@ -324,11 +324,8 @@ def associate_default_cc_pair(db_session: Session) -> None:
|
||||
def _relate_groups_to_cc_pair__no_commit(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
user_group_ids: list[int] | None = None,
|
||||
user_group_ids: list[int],
|
||||
) -> None:
|
||||
if not user_group_ids:
|
||||
return
|
||||
|
||||
for group_id in user_group_ids:
|
||||
db_session.add(
|
||||
UserGroup__ConnectorCredentialPair(
|
||||
@@ -405,11 +402,12 @@ def add_credential_to_connector(
|
||||
db_session.flush() # make sure the association has an id
|
||||
db_session.refresh(association)
|
||||
|
||||
_relate_groups_to_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=association.id,
|
||||
user_group_ids=groups,
|
||||
)
|
||||
if groups and access_type != AccessType.SYNC:
|
||||
_relate_groups_to_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=association.id,
|
||||
user_group_ids=groups,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -19,11 +19,6 @@ class IndexingStatus(str, PyEnum):
|
||||
return self in terminal_states
|
||||
|
||||
|
||||
class IndexingMode(str, PyEnum):
|
||||
UPDATE = "update"
|
||||
REINDEX = "reindex"
|
||||
|
||||
|
||||
# these may differ in the future, which is why we're okay with this duplication
|
||||
class DeletionStatus(str, PyEnum):
|
||||
NOT_STARTED = "not_started"
|
||||
|
||||
@@ -42,7 +42,7 @@ from danswer.configs.constants import DEFAULT_BOOST
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.enums import AccessType, IndexingMode
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.configs.constants import NotificationType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.configs.constants import TokenRateLimitScope
|
||||
@@ -126,7 +126,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
|
||||
# if specified, controls the assistants that are shown to the user + their order
|
||||
# if not specified, all assistants are shown
|
||||
auto_scroll: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
chosen_assistants: Mapped[list[int] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True, default=None
|
||||
)
|
||||
@@ -439,10 +438,6 @@ class ConnectorCredentialPair(Base):
|
||||
|
||||
total_docs_indexed: Mapped[int] = mapped_column(Integer, default=0)
|
||||
|
||||
indexing_trigger: Mapped[IndexingMode | None] = mapped_column(
|
||||
Enum(IndexingMode, native_enum=False), nullable=True
|
||||
)
|
||||
|
||||
connector: Mapped["Connector"] = relationship(
|
||||
"Connector", back_populates="credentials"
|
||||
)
|
||||
@@ -1485,7 +1480,6 @@ class ChannelConfig(TypedDict):
|
||||
# If None then no follow up
|
||||
# If empty list, follow up with no tags
|
||||
follow_up_tags: NotRequired[list[str]]
|
||||
show_continue_in_web_ui: NotRequired[bool] # defaults to False
|
||||
|
||||
|
||||
class SlackBotResponseType(str, PyEnum):
|
||||
|
||||
@@ -113,31 +113,6 @@ def fetch_persona_by_id(
|
||||
return persona
|
||||
|
||||
|
||||
def get_best_persona_id_for_user(
|
||||
db_session: Session, user: User | None, persona_id: int | None = None
|
||||
) -> int | None:
|
||||
if persona_id is not None:
|
||||
stmt = select(Persona).where(Persona.id == persona_id).distinct()
|
||||
stmt = _add_user_filters(
|
||||
stmt=stmt,
|
||||
user=user,
|
||||
# We don't want to filter by editable here, we just want to see if the
|
||||
# persona is usable by the user
|
||||
get_editable=False,
|
||||
)
|
||||
persona = db_session.scalars(stmt).one_or_none()
|
||||
if persona:
|
||||
return persona.id
|
||||
|
||||
# If the persona is not found, or the slack bot is using doc sets instead of personas,
|
||||
# we need to find the best persona for the user
|
||||
# This is the persona with the highest display priority that the user has access to
|
||||
stmt = select(Persona).order_by(Persona.display_priority.desc()).distinct()
|
||||
stmt = _add_user_filters(stmt=stmt, user=user, get_editable=True)
|
||||
persona = db_session.scalars(stmt).one_or_none()
|
||||
return persona.id if persona else None
|
||||
|
||||
|
||||
def _get_persona_by_name(
|
||||
persona_name: str, user: User | None, db_session: Session
|
||||
) -> Persona | None:
|
||||
@@ -185,7 +160,7 @@ def create_update_persona(
|
||||
"persona_id": persona_id,
|
||||
"user": user,
|
||||
"db_session": db_session,
|
||||
**create_persona_request.model_dump(exclude={"users", "groups"}),
|
||||
**create_persona_request.dict(exclude={"users", "groups"}),
|
||||
}
|
||||
|
||||
persona = upsert_persona(**persona_data)
|
||||
@@ -758,8 +733,6 @@ def get_prompt_by_name(
|
||||
if user and user.role != UserRole.ADMIN:
|
||||
stmt = stmt.where(Prompt.user_id == user.id)
|
||||
|
||||
# Order by ID to ensure consistent result when multiple prompts exist
|
||||
stmt = stmt.order_by(Prompt.id).limit(1)
|
||||
result = db_session.execute(stmt).scalar_one_or_none()
|
||||
return result
|
||||
|
||||
|
||||
@@ -143,25 +143,6 @@ def get_secondary_search_settings(db_session: Session) -> SearchSettings | None:
|
||||
return latest_settings
|
||||
|
||||
|
||||
def get_active_search_settings(db_session: Session) -> list[SearchSettings]:
|
||||
"""Returns active search settings. The first entry will always be the current search
|
||||
settings. If there are new search settings that are being migrated to, those will be
|
||||
the second entry."""
|
||||
search_settings_list: list[SearchSettings] = []
|
||||
|
||||
# Get the primary search settings
|
||||
primary_search_settings = get_current_search_settings(db_session)
|
||||
search_settings_list.append(primary_search_settings)
|
||||
|
||||
# Check for secondary search settings
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
if secondary_search_settings is not None:
|
||||
# If secondary settings exist, add them to the list
|
||||
search_settings_list.append(secondary_search_settings)
|
||||
|
||||
return search_settings_list
|
||||
|
||||
|
||||
def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
|
||||
query = select(SearchSettings).order_by(SearchSettings.id.desc())
|
||||
result = db_session.execute(query)
|
||||
|
||||
@@ -295,7 +295,7 @@ def pptx_to_text(file: IO[Any]) -> str:
|
||||
|
||||
|
||||
def xlsx_to_text(file: IO[Any]) -> str:
|
||||
workbook = openpyxl.load_workbook(file, read_only=True)
|
||||
workbook = openpyxl.load_workbook(file)
|
||||
text_content = []
|
||||
for sheet in workbook.worksheets:
|
||||
sheet_string = "\n".join(
|
||||
|
||||
@@ -59,12 +59,6 @@ class FileStore(ABC):
|
||||
Contents of the file and metadata dict
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def read_file_record(self, file_name: str) -> PGFileStore:
|
||||
"""
|
||||
Read the file record by the name
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def delete_file(self, file_name: str) -> None:
|
||||
"""
|
||||
|
||||
@@ -67,9 +67,9 @@ class CitationProcessor:
|
||||
if piece_that_comes_after == "\n" and in_code_block(self.llm_out):
|
||||
self.curr_segment = self.curr_segment.replace("```", "```plaintext")
|
||||
|
||||
citation_pattern = r"\[(\d+)\]|\[\[(\d+)\]\]" # [1], [[1]], etc.
|
||||
citation_pattern = r"\[(\d+)\]"
|
||||
citations_found = list(re.finditer(citation_pattern, self.curr_segment))
|
||||
possible_citation_pattern = r"(\[+\d*$)" # [1, [, [[, [[2, etc.
|
||||
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
|
||||
possible_citation_found = re.search(
|
||||
possible_citation_pattern, self.curr_segment
|
||||
)
|
||||
@@ -77,15 +77,13 @@ class CitationProcessor:
|
||||
if len(citations_found) == 0 and len(self.llm_out) - self.past_cite_count > 5:
|
||||
self.current_citations = []
|
||||
|
||||
result = ""
|
||||
result = "" # Initialize result here
|
||||
if citations_found and not in_code_block(self.llm_out):
|
||||
last_citation_end = 0
|
||||
length_to_add = 0
|
||||
while len(citations_found) > 0:
|
||||
citation = citations_found.pop(0)
|
||||
numerical_value = int(
|
||||
next(group for group in citation.groups() if group is not None)
|
||||
)
|
||||
numerical_value = int(citation.group(1))
|
||||
|
||||
if 1 <= numerical_value <= self.max_citation_num:
|
||||
context_llm_doc = self.context_docs[numerical_value - 1]
|
||||
@@ -133,6 +131,14 @@ class CitationProcessor:
|
||||
|
||||
link = context_llm_doc.link
|
||||
|
||||
# Replace the citation in the current segment
|
||||
start, end = citation.span()
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[{target_citation_num}]"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
|
||||
self.past_cite_count = len(self.llm_out)
|
||||
self.current_citations.append(target_citation_num)
|
||||
|
||||
@@ -143,7 +149,6 @@ class CitationProcessor:
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
|
||||
start, end = citation.span()
|
||||
if link:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
|
||||
@@ -26,9 +26,7 @@ from langchain_core.messages.tool import ToolMessage
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
|
||||
from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
|
||||
from danswer.configs.model_configs import (
|
||||
DISABLE_LITELLM_STREAMING,
|
||||
)
|
||||
from danswer.configs.model_configs import DISABLE_LITELLM_STREAMING
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.configs.model_configs import LITELLM_EXTRA_BODY
|
||||
from danswer.llm.interfaces import LLM
|
||||
@@ -163,9 +161,7 @@ def _convert_delta_to_message_chunk(
|
||||
|
||||
if role == "user":
|
||||
return HumanMessageChunk(content=content)
|
||||
# NOTE: if tool calls are present, then it's an assistant.
|
||||
# In Ollama, the role will be None for tool-calls
|
||||
elif role == "assistant" or tool_calls:
|
||||
elif role == "assistant":
|
||||
if tool_calls:
|
||||
tool_call = tool_calls[0]
|
||||
tool_name = tool_call.function.name or (curr_msg and curr_msg.name) or ""
|
||||
@@ -240,7 +236,6 @@ class DefaultMultiLLM(LLM):
|
||||
custom_config: dict[str, str] | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
extra_body: dict | None = LITELLM_EXTRA_BODY,
|
||||
model_kwargs: dict[str, Any] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
):
|
||||
self._timeout = timeout
|
||||
@@ -273,7 +268,7 @@ class DefaultMultiLLM(LLM):
|
||||
for k, v in custom_config.items():
|
||||
os.environ[k] = v
|
||||
|
||||
model_kwargs = model_kwargs or {}
|
||||
model_kwargs: dict[str, Any] = {}
|
||||
if extra_headers:
|
||||
model_kwargs.update({"extra_headers": extra_headers})
|
||||
if extra_body:
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.chat_configs import QA_TIMEOUT
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.llm import fetch_default_provider
|
||||
@@ -16,15 +13,6 @@ from danswer.utils.headers import build_llm_extra_headers
|
||||
from danswer.utils.long_term_log import LongTermLogger
|
||||
|
||||
|
||||
def _build_extra_model_kwargs(provider: str) -> dict[str, Any]:
|
||||
"""Ollama requires us to specify the max context window.
|
||||
|
||||
For now, just using the GEN_AI_MODEL_FALLBACK_MAX_TOKENS value.
|
||||
TODO: allow model-specific values to be configured via the UI.
|
||||
"""
|
||||
return {"num_ctx": GEN_AI_MODEL_FALLBACK_MAX_TOKENS} if provider == "ollama" else {}
|
||||
|
||||
|
||||
def get_main_llm_from_tuple(
|
||||
llms: tuple[LLM, LLM],
|
||||
) -> LLM:
|
||||
@@ -144,6 +132,5 @@ def get_llm(
|
||||
temperature=temperature,
|
||||
custom_config=custom_config,
|
||||
extra_headers=build_llm_extra_headers(additional_headers),
|
||||
model_kwargs=_build_extra_model_kwargs(provider),
|
||||
long_term_logger=long_term_logger,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import copy
|
||||
import io
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
@@ -386,62 +385,6 @@ def test_llm(llm: LLM) -> str | None:
|
||||
return error_msg
|
||||
|
||||
|
||||
def get_model_map() -> dict:
|
||||
starting_map = copy.deepcopy(cast(dict, litellm.model_cost))
|
||||
|
||||
# NOTE: we could add additional models here in the future,
|
||||
# but for now there is no point. Ollama allows the user to
|
||||
# to specify their desired max context window, and it's
|
||||
# unlikely to be standard across users even for the same model
|
||||
# (it heavily depends on their hardware). For now, we'll just
|
||||
# rely on GEN_AI_MODEL_FALLBACK_MAX_TOKENS to cover this.
|
||||
# for model_name in [
|
||||
# "llama3.2",
|
||||
# "llama3.2:1b",
|
||||
# "llama3.2:3b",
|
||||
# "llama3.2:11b",
|
||||
# "llama3.2:90b",
|
||||
# ]:
|
||||
# starting_map[f"ollama/{model_name}"] = {
|
||||
# "max_tokens": 128000,
|
||||
# "max_input_tokens": 128000,
|
||||
# "max_output_tokens": 128000,
|
||||
# }
|
||||
|
||||
return starting_map
|
||||
|
||||
|
||||
def _strip_extra_provider_from_model_name(model_name: str) -> str:
|
||||
return model_name.split("/")[1] if "/" in model_name else model_name
|
||||
|
||||
|
||||
def _strip_colon_from_model_name(model_name: str) -> str:
|
||||
return ":".join(model_name.split(":")[:-1]) if ":" in model_name else model_name
|
||||
|
||||
|
||||
def _find_model_obj(
|
||||
model_map: dict, provider: str, model_names: list[str | None]
|
||||
) -> dict | None:
|
||||
# Filter out None values and deduplicate model names
|
||||
filtered_model_names = [name for name in model_names if name]
|
||||
|
||||
# First try all model names with provider prefix
|
||||
for model_name in filtered_model_names:
|
||||
model_obj = model_map.get(f"{provider}/{model_name}")
|
||||
if model_obj:
|
||||
logger.debug(f"Using model object for {provider}/{model_name}")
|
||||
return model_obj
|
||||
|
||||
# Then try all model names without provider prefix
|
||||
for model_name in filtered_model_names:
|
||||
model_obj = model_map.get(model_name)
|
||||
if model_obj:
|
||||
logger.debug(f"Using model object for {model_name}")
|
||||
return model_obj
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_llm_max_tokens(
|
||||
model_map: dict,
|
||||
model_name: str,
|
||||
@@ -454,22 +397,22 @@ def get_llm_max_tokens(
|
||||
return GEN_AI_MAX_TOKENS
|
||||
|
||||
try:
|
||||
extra_provider_stripped_model_name = _strip_extra_provider_from_model_name(
|
||||
model_name
|
||||
)
|
||||
model_obj = _find_model_obj(
|
||||
model_map,
|
||||
model_provider,
|
||||
[
|
||||
model_name,
|
||||
# Remove leading extra provider. Usually for cases where user has a
|
||||
# customer model proxy which appends another prefix
|
||||
extra_provider_stripped_model_name,
|
||||
# remove :XXXX from the end, if present. Needed for ollama.
|
||||
_strip_colon_from_model_name(model_name),
|
||||
_strip_colon_from_model_name(extra_provider_stripped_model_name),
|
||||
],
|
||||
)
|
||||
model_obj = model_map.get(f"{model_provider}/{model_name}")
|
||||
if model_obj:
|
||||
logger.debug(f"Using model object for {model_provider}/{model_name}")
|
||||
|
||||
if not model_obj:
|
||||
model_obj = model_map.get(model_name)
|
||||
if model_obj:
|
||||
logger.debug(f"Using model object for {model_name}")
|
||||
|
||||
if not model_obj:
|
||||
model_name_split = model_name.split("/")
|
||||
if len(model_name_split) > 1:
|
||||
model_obj = model_map.get(model_name_split[1])
|
||||
if model_obj:
|
||||
logger.debug(f"Using model object for {model_name_split[1]}")
|
||||
|
||||
if not model_obj:
|
||||
raise RuntimeError(
|
||||
f"No litellm entry found for {model_provider}/{model_name}"
|
||||
@@ -545,7 +488,7 @@ def get_max_input_tokens(
|
||||
# `model_cost` dict is a named public interface:
|
||||
# https://litellm.vercel.app/docs/completion/token_usage#7-model_cost
|
||||
# model_map is litellm.model_cost
|
||||
litellm_model_map = get_model_map()
|
||||
litellm_model_map = litellm.model_cost
|
||||
|
||||
input_toks = (
|
||||
get_llm_max_tokens(
|
||||
|
||||
@@ -26,7 +26,6 @@ from danswer.auth.schemas import UserRead
|
||||
from danswer.auth.schemas import UserUpdate
|
||||
from danswer.auth.users import auth_backend
|
||||
from danswer.auth.users import BasicAuthenticationError
|
||||
from danswer.auth.users import create_danswer_oauth_router
|
||||
from danswer.auth.users import fastapi_users
|
||||
from danswer.configs.app_configs import APP_API_PREFIX
|
||||
from danswer.configs.app_configs import APP_HOST
|
||||
@@ -45,7 +44,6 @@ from danswer.configs.constants import AuthType
|
||||
from danswer.configs.constants import POSTGRES_WEB_APP_NAME
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.db.engine import warm_up_connections
|
||||
from danswer.server.api_key.api import router as api_key_router
|
||||
from danswer.server.auth_check import check_router_auth
|
||||
from danswer.server.danswer_api.ingestion import router as danswer_api_router
|
||||
from danswer.server.documents.cc_pair import router as cc_pair_router
|
||||
@@ -282,7 +280,6 @@ def get_application() -> FastAPI:
|
||||
application, get_full_openai_assistants_api_router()
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, long_term_logs_router)
|
||||
include_router_with_global_prefix_prepended(application, api_key_router)
|
||||
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
# Server logs this during auth setup verification step
|
||||
@@ -326,7 +323,7 @@ def get_application() -> FastAPI:
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
create_danswer_oauth_router(
|
||||
fastapi_users.get_oauth_router(
|
||||
oauth_client,
|
||||
auth_backend,
|
||||
USER_AUTH_SECRET,
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
class ModelServerRateLimitError(Exception):
|
||||
"""
|
||||
Exception raised for rate limiting errors from the model server.
|
||||
"""
|
||||
@@ -6,9 +6,6 @@ from typing import Any
|
||||
|
||||
import requests
|
||||
from httpx import HTTPError
|
||||
from requests import JSONDecodeError
|
||||
from requests import RequestException
|
||||
from requests import Response
|
||||
from retry import retry
|
||||
|
||||
from danswer.configs.app_configs import LARGE_CHUNK_RATIO
|
||||
@@ -19,9 +16,6 @@ from danswer.configs.model_configs import (
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from danswer.natural_language_processing.exceptions import (
|
||||
ModelServerRateLimitError,
|
||||
)
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -105,43 +99,28 @@ class EmbeddingModel:
|
||||
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
||||
|
||||
def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse:
|
||||
def _make_request() -> Response:
|
||||
def _make_request() -> EmbedResponse:
|
||||
response = requests.post(
|
||||
self.embed_server_endpoint, json=embed_request.model_dump()
|
||||
)
|
||||
# signify that this is a rate limit error
|
||||
if response.status_code == 429:
|
||||
raise ModelServerRateLimitError(response.text)
|
||||
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
final_make_request_func = _make_request
|
||||
|
||||
# if the text type is a passage, add some default
|
||||
# retries + handling for rate limiting
|
||||
if embed_request.text_type == EmbedTextType.PASSAGE:
|
||||
final_make_request_func = retry(
|
||||
tries=3,
|
||||
delay=5,
|
||||
exceptions=(RequestException, ValueError, JSONDecodeError),
|
||||
)(final_make_request_func)
|
||||
# use 10 second delay as per Azure suggestion
|
||||
final_make_request_func = retry(
|
||||
tries=10, delay=10, exceptions=ModelServerRateLimitError
|
||||
)(final_make_request_func)
|
||||
|
||||
try:
|
||||
response = final_make_request_func()
|
||||
return EmbedResponse(**response.json())
|
||||
except requests.HTTPError as e:
|
||||
try:
|
||||
error_detail = response.json().get("detail", str(e))
|
||||
except Exception:
|
||||
error_detail = response.text
|
||||
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
||||
except requests.RequestException as e:
|
||||
raise HTTPError(f"Request failed: {str(e)}") from e
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
try:
|
||||
error_detail = response.json().get("detail", str(e))
|
||||
except Exception:
|
||||
error_detail = response.text
|
||||
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
||||
except requests.RequestException as e:
|
||||
raise HTTPError(f"Request failed: {str(e)}") from e
|
||||
|
||||
return EmbedResponse(**response.json())
|
||||
|
||||
# only perform retries for the non-realtime embedding of passages (e.g. for indexing)
|
||||
if embed_request.text_type == EmbedTextType.PASSAGE:
|
||||
return retry(tries=3, delay=5)(_make_request)()
|
||||
else:
|
||||
return _make_request()
|
||||
|
||||
def _batch_encode_texts(
|
||||
self,
|
||||
|
||||
@@ -131,7 +131,7 @@ def _try_initialize_tokenizer(
|
||||
return tokenizer
|
||||
except Exception as hf_error:
|
||||
logger.warning(
|
||||
f"Failed to initialize HuggingFaceTokenizer for {model_name}: {hf_error}"
|
||||
f"Error initializing HuggingFaceTokenizer for {model_name}: {hf_error}"
|
||||
)
|
||||
|
||||
# If both initializations fail, return None
|
||||
|
||||
@@ -47,7 +47,6 @@ from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.one_shot_answer.models import QueryRephrase
|
||||
from danswer.one_shot_answer.qa_utils import combine_message_thread
|
||||
from danswer.one_shot_answer.qa_utils import slackify_message_thread
|
||||
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
|
||||
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
@@ -195,22 +194,13 @@ def stream_answer_objects(
|
||||
)
|
||||
prompt = persona.prompts[0]
|
||||
|
||||
user_message_str = query_msg.message
|
||||
# For this endpoint, we only save one user message to the chat session
|
||||
# However, for slackbot, we want to include the history of the entire thread
|
||||
if danswerbot_flow:
|
||||
# Right now, we only support bringing over citations and search docs
|
||||
# from the last message in the thread, not the entire thread
|
||||
# in the future, we may want to retrieve the entire thread
|
||||
user_message_str = slackify_message_thread(query_req.messages)
|
||||
|
||||
# Create the first User query message
|
||||
new_user_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=root_message,
|
||||
prompt_id=query_req.prompt_id,
|
||||
message=user_message_str,
|
||||
token_count=len(llm_tokenizer.encode(user_message_str)),
|
||||
message=query_msg.message,
|
||||
token_count=len(llm_tokenizer.encode(query_msg.message)),
|
||||
message_type=MessageType.USER,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
|
||||
@@ -51,31 +51,3 @@ def combine_message_thread(
|
||||
total_token_count += message_token_count
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
|
||||
|
||||
def slackify_message(message: ThreadMessage) -> str:
|
||||
if message.role != MessageType.USER:
|
||||
return message.message
|
||||
|
||||
return f"{message.sender or 'Unknown User'} said in Slack:\n{message.message}"
|
||||
|
||||
|
||||
def slackify_message_thread(messages: list[ThreadMessage]) -> str:
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
message_strs: list[str] = []
|
||||
for message in messages:
|
||||
if message.role == MessageType.USER:
|
||||
message_text = (
|
||||
f"{message.sender or 'Unknown User'} said in Slack:\n{message.message}"
|
||||
)
|
||||
elif message.role == MessageType.ASSISTANT:
|
||||
message_text = f"DanswerBot said in Slack:\n{message.message}"
|
||||
else:
|
||||
message_text = (
|
||||
f"{message.role.value.upper()} said in Slack:\n{message.message}"
|
||||
)
|
||||
message_strs.append(message_text)
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
import time
|
||||
|
||||
import redis
|
||||
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
|
||||
@@ -34,44 +31,6 @@ class RedisConnector:
|
||||
self.tenant_id, self.id, search_settings_id, self.redis
|
||||
)
|
||||
|
||||
def wait_for_indexing_termination(
|
||||
self,
|
||||
search_settings_list: list[SearchSettings],
|
||||
timeout: float = 15.0,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if all indexing for the given redis connector is finished within the given timeout.
|
||||
Returns False if the timeout is exceeded
|
||||
|
||||
This check does not guarantee that current indexings being terminated
|
||||
won't get restarted midflight
|
||||
"""
|
||||
|
||||
finished = False
|
||||
|
||||
start = time.monotonic()
|
||||
|
||||
while True:
|
||||
still_indexing = False
|
||||
for search_settings in search_settings_list:
|
||||
redis_connector_index = self.new_index(search_settings.id)
|
||||
if redis_connector_index.fenced:
|
||||
still_indexing = True
|
||||
break
|
||||
|
||||
if not still_indexing:
|
||||
finished = True
|
||||
break
|
||||
|
||||
now = time.monotonic()
|
||||
if now - start > timeout:
|
||||
break
|
||||
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
return finished
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_fence_key(key: str) -> str | None:
|
||||
"""
|
||||
|
||||
@@ -14,9 +14,8 @@ from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
|
||||
|
||||
class RedisConnectorPermissionSyncPayload(BaseModel):
|
||||
class RedisConnectorPermissionSyncData(BaseModel):
|
||||
started: datetime | None
|
||||
celery_task_id: str | None
|
||||
|
||||
|
||||
class RedisConnectorPermissionSync:
|
||||
@@ -79,14 +78,14 @@ class RedisConnectorPermissionSync:
|
||||
return False
|
||||
|
||||
@property
|
||||
def payload(self) -> RedisConnectorPermissionSyncPayload | None:
|
||||
def payload(self) -> RedisConnectorPermissionSyncData | None:
|
||||
# read related data and evaluate/print task progress
|
||||
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
|
||||
if fence_bytes is None:
|
||||
return None
|
||||
|
||||
fence_str = fence_bytes.decode("utf-8")
|
||||
payload = RedisConnectorPermissionSyncPayload.model_validate_json(
|
||||
payload = RedisConnectorPermissionSyncData.model_validate_json(
|
||||
cast(str, fence_str)
|
||||
)
|
||||
|
||||
@@ -94,7 +93,7 @@ class RedisConnectorPermissionSync:
|
||||
|
||||
def set_fence(
|
||||
self,
|
||||
payload: RedisConnectorPermissionSyncPayload | None,
|
||||
payload: RedisConnectorPermissionSyncData | None,
|
||||
) -> None:
|
||||
if not payload:
|
||||
self.redis.delete(self.fence_key)
|
||||
@@ -163,12 +162,6 @@ class RedisConnectorPermissionSync:
|
||||
|
||||
return len(async_results)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.delete(self.generator_progress_key)
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
self.redis.delete(self.taskset_key)
|
||||
self.redis.delete(self.fence_key)
|
||||
|
||||
@staticmethod
|
||||
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
|
||||
taskset_key = f"{RedisConnectorPermissionSync.TASKSET_PREFIX}_{id}"
|
||||
|
||||
@@ -1,18 +1,11 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from pydantic import BaseModel
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
class RedisConnectorExternalGroupSyncPayload(BaseModel):
|
||||
started: datetime | None
|
||||
celery_task_id: str | None
|
||||
|
||||
|
||||
class RedisConnectorExternalGroupSync:
|
||||
"""Manages interactions with redis for external group syncing tasks. Should only be accessed
|
||||
through RedisConnector."""
|
||||
@@ -75,29 +68,12 @@ class RedisConnectorExternalGroupSync:
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def payload(self) -> RedisConnectorExternalGroupSyncPayload | None:
|
||||
# read related data and evaluate/print task progress
|
||||
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
|
||||
if fence_bytes is None:
|
||||
return None
|
||||
|
||||
fence_str = fence_bytes.decode("utf-8")
|
||||
payload = RedisConnectorExternalGroupSyncPayload.model_validate_json(
|
||||
cast(str, fence_str)
|
||||
)
|
||||
|
||||
return payload
|
||||
|
||||
def set_fence(
|
||||
self,
|
||||
payload: RedisConnectorExternalGroupSyncPayload | None,
|
||||
) -> None:
|
||||
if not payload:
|
||||
def set_fence(self, value: bool) -> None:
|
||||
if not value:
|
||||
self.redis.delete(self.fence_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.fence_key, payload.model_dump_json())
|
||||
self.redis.set(self.fence_key, 0)
|
||||
|
||||
@property
|
||||
def generator_complete(self) -> int | None:
|
||||
|
||||
@@ -29,8 +29,6 @@ class RedisConnectorIndex:
|
||||
|
||||
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
|
||||
|
||||
TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str | None,
|
||||
@@ -53,7 +51,6 @@ class RedisConnectorIndex:
|
||||
self.generator_lock_key = (
|
||||
f"{self.GENERATOR_LOCK_PREFIX}_{id}/{search_settings_id}"
|
||||
)
|
||||
self.terminate_key = f"{self.TERMINATE_PREFIX}_{id}/{search_settings_id}"
|
||||
|
||||
@classmethod
|
||||
def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str:
|
||||
@@ -95,18 +92,6 @@ class RedisConnectorIndex:
|
||||
|
||||
self.redis.set(self.fence_key, payload.model_dump_json())
|
||||
|
||||
def terminating(self, celery_task_id: str) -> bool:
|
||||
if self.redis.exists(f"{self.terminate_key}_{celery_task_id}"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def set_terminate(self, celery_task_id: str) -> None:
|
||||
"""This sets a signal. It does not block!"""
|
||||
# We shouldn't need very long to terminate the spawned task.
|
||||
# 10 minute TTL is good.
|
||||
self.redis.set(f"{self.terminate_key}_{celery_task_id}", 0, ex=600)
|
||||
|
||||
def set_generator_complete(self, payload: int | None) -> None:
|
||||
if not payload:
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
|
||||
@@ -5,7 +5,7 @@ personas:
|
||||
# this is for DanswerBot to use when tagged in a non-configured channel
|
||||
# Careful setting specific IDs, this won't autoincrement the next ID value for postgres
|
||||
- id: 0
|
||||
name: "Search"
|
||||
name: "Knowledge"
|
||||
description: >
|
||||
Assistant with access to documents from your Connected Sources.
|
||||
# Default Prompt objects attached to the persona, see prompts.yaml
|
||||
|
||||
@@ -6,7 +6,6 @@ from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -38,9 +37,7 @@ from danswer.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from danswer.db.index_attempt import count_index_attempts_for_connector
|
||||
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
|
||||
from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.models import User
|
||||
from danswer.db.search_settings import get_active_search_settings
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
@@ -161,19 +158,7 @@ def update_cc_pair_status(
|
||||
status_update_request: CCStatusUpdateRequest,
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""This method may wait up to 30 seconds if pausing the connector due to the need to
|
||||
terminate tasks in progress. Tasks are not guaranteed to terminate within the
|
||||
timeout.
|
||||
|
||||
Returns HTTPStatus.OK if everything finished.
|
||||
Returns HTTPStatus.ACCEPTED if the connector is being paused, but background tasks
|
||||
did not finish within the timeout.
|
||||
"""
|
||||
WAIT_TIMEOUT = 15.0
|
||||
still_terminating = False
|
||||
|
||||
) -> None:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
@@ -188,76 +173,10 @@ def update_cc_pair_status(
|
||||
)
|
||||
|
||||
if status_update_request.status == ConnectorCredentialPairStatus.PAUSED:
|
||||
search_settings_list: list[SearchSettings] = get_active_search_settings(
|
||||
db_session
|
||||
)
|
||||
|
||||
cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session)
|
||||
|
||||
cancel_indexing_attempts_past_model(db_session)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
try:
|
||||
redis_connector.stop.set_fence(True)
|
||||
while True:
|
||||
logger.debug(
|
||||
f"Wait for indexing soft termination starting: cc_pair={cc_pair_id}"
|
||||
)
|
||||
wait_succeeded = redis_connector.wait_for_indexing_termination(
|
||||
search_settings_list, WAIT_TIMEOUT
|
||||
)
|
||||
if wait_succeeded:
|
||||
logger.debug(
|
||||
f"Wait for indexing soft termination succeeded: cc_pair={cc_pair_id}"
|
||||
)
|
||||
break
|
||||
|
||||
logger.debug(
|
||||
"Wait for indexing soft termination timed out. "
|
||||
f"Moving to hard termination: cc_pair={cc_pair_id} timeout={WAIT_TIMEOUT:.2f}"
|
||||
)
|
||||
|
||||
for search_settings in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
search_settings.id
|
||||
)
|
||||
if not redis_connector_index.fenced:
|
||||
continue
|
||||
|
||||
index_payload = redis_connector_index.payload
|
||||
if not index_payload:
|
||||
continue
|
||||
|
||||
if not index_payload.celery_task_id:
|
||||
continue
|
||||
|
||||
# Revoke the task to prevent it from running
|
||||
primary_app.control.revoke(index_payload.celery_task_id)
|
||||
|
||||
# If it is running, then signaling for termination will get the
|
||||
# watchdog thread to kill the spawned task
|
||||
redis_connector_index.set_terminate(index_payload.celery_task_id)
|
||||
|
||||
logger.debug(
|
||||
f"Wait for indexing hard termination starting: cc_pair={cc_pair_id}"
|
||||
)
|
||||
wait_succeeded = redis_connector.wait_for_indexing_termination(
|
||||
search_settings_list, WAIT_TIMEOUT
|
||||
)
|
||||
if wait_succeeded:
|
||||
logger.debug(
|
||||
f"Wait for indexing hard termination succeeded: cc_pair={cc_pair_id}"
|
||||
)
|
||||
break
|
||||
|
||||
logger.debug(
|
||||
f"Wait for indexing hard termination timed out: cc_pair={cc_pair_id}"
|
||||
)
|
||||
still_terminating = True
|
||||
break
|
||||
finally:
|
||||
redis_connector.stop.set_fence(False)
|
||||
|
||||
update_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
@@ -266,18 +185,6 @@ def update_cc_pair_status(
|
||||
|
||||
db_session.commit()
|
||||
|
||||
if still_terminating:
|
||||
return JSONResponse(
|
||||
status_code=HTTPStatus.ACCEPTED,
|
||||
content={
|
||||
"message": "Request accepted, background task termination still in progress"
|
||||
},
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=HTTPStatus.OK, content={"message": str(HTTPStatus.OK)}
|
||||
)
|
||||
|
||||
|
||||
@router.put("/admin/cc-pair/{cc_pair_id}/name")
|
||||
def update_cc_pair_name(
|
||||
@@ -360,9 +267,9 @@ def prune_cc_pair(
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Pruning cc_pair: cc_pair={cc_pair_id} "
|
||||
f"connector={cc_pair.connector_id} "
|
||||
f"credential={cc_pair.credential_id} "
|
||||
f"Pruning cc_pair: cc_pair_id={cc_pair_id} "
|
||||
f"connector_id={cc_pair.connector_id} "
|
||||
f"credential_id={cc_pair.credential_id} "
|
||||
f"{cc_pair.connector.name} connector."
|
||||
)
|
||||
tasks_created = try_creating_prune_generator_task(
|
||||
|
||||
@@ -17,9 +17,9 @@ from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot
|
||||
from danswer.background.celery.tasks.indexing.tasks import try_creating_indexing_task
|
||||
from danswer.background.celery.versioned_apps.primary import app as primary_app
|
||||
from danswer.configs.app_configs import ENABLED_CONNECTOR_TYPES
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.connectors.google_utils.google_auth import (
|
||||
@@ -59,7 +59,6 @@ from danswer.db.connector import delete_connector
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector import fetch_connectors
|
||||
from danswer.db.connector import get_connector_credential_ids
|
||||
from danswer.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from danswer.db.connector import update_connector
|
||||
from danswer.db.connector_credential_pair import add_credential_to_connector
|
||||
from danswer.db.connector_credential_pair import get_cc_pair_groups_for_ids
|
||||
@@ -75,7 +74,6 @@ from danswer.db.document import get_document_counts_for_cc_pairs
|
||||
from danswer.db.engine import get_current_tenant_id
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import IndexingMode
|
||||
from danswer.db.index_attempt import get_index_attempts_for_cc_pair
|
||||
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
|
||||
from danswer.db.index_attempt import get_latest_index_attempts
|
||||
@@ -88,6 +86,7 @@ from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.server.documents.models import AuthStatus
|
||||
from danswer.server.documents.models import AuthUrl
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
@@ -793,10 +792,12 @@ def connector_run_once(
|
||||
_: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> StatusResponse[int]:
|
||||
) -> StatusResponse[list[int]]:
|
||||
"""Used to trigger indexing on a set of cc_pairs associated with a
|
||||
single connector."""
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
connector_id = run_info.connector_id
|
||||
specified_credential_ids = run_info.credential_ids
|
||||
|
||||
@@ -842,41 +843,54 @@ def connector_run_once(
|
||||
)
|
||||
]
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
connector_credential_pairs = [
|
||||
get_connector_credential_pair(connector_id, credential_id, db_session)
|
||||
for credential_id in credential_ids
|
||||
if credential_id not in skipped_credentials
|
||||
]
|
||||
|
||||
num_triggers = 0
|
||||
index_attempt_ids = []
|
||||
for cc_pair in connector_credential_pairs:
|
||||
if cc_pair is not None:
|
||||
indexing_mode = IndexingMode.UPDATE
|
||||
if run_info.from_beginning:
|
||||
indexing_mode = IndexingMode.REINDEX
|
||||
|
||||
mark_ccpair_with_indexing_trigger(cc_pair.id, indexing_mode, db_session)
|
||||
num_triggers += 1
|
||||
|
||||
logger.info(
|
||||
f"connector_run_once - marking cc_pair with indexing trigger: "
|
||||
f"connector={run_info.connector_id} "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"indexing_trigger={indexing_mode}"
|
||||
attempt_id = try_creating_indexing_task(
|
||||
primary_app,
|
||||
cc_pair,
|
||||
search_settings,
|
||||
run_info.from_beginning,
|
||||
db_session,
|
||||
r,
|
||||
tenant_id,
|
||||
)
|
||||
if attempt_id:
|
||||
logger.info(
|
||||
f"connector_run_once - try_creating_indexing_task succeeded: "
|
||||
f"connector={run_info.connector_id} "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"attempt={attempt_id} "
|
||||
)
|
||||
index_attempt_ids.append(attempt_id)
|
||||
else:
|
||||
logger.info(
|
||||
f"connector_run_once - try_creating_indexing_task failed: "
|
||||
f"connector={run_info.connector_id} "
|
||||
f"cc_pair={cc_pair.id}"
|
||||
)
|
||||
|
||||
# run the beat task to pick up the triggers immediately
|
||||
primary_app.send_task(
|
||||
"check_for_indexing",
|
||||
priority=DanswerCeleryPriority.HIGH,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
)
|
||||
if not index_attempt_ids:
|
||||
msg = "No new indexing attempts created, indexing jobs are queued or running."
|
||||
logger.info(msg)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=msg,
|
||||
)
|
||||
|
||||
msg = f"Marked {num_triggers} index attempts with indexing triggers."
|
||||
msg = f"Successfully created {len(index_attempt_ids)} index attempts. {index_attempt_ids}"
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
message=msg,
|
||||
data=num_triggers,
|
||||
data=index_attempt_ids,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -45,7 +45,6 @@ class UserPreferences(BaseModel):
|
||||
visible_assistants: list[int] = []
|
||||
recent_assistants: list[int] | None = None
|
||||
default_model: str | None = None
|
||||
auto_scroll: bool | None = None
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
@@ -80,7 +79,6 @@ class UserInfo(BaseModel):
|
||||
role=user.role,
|
||||
preferences=(
|
||||
UserPreferences(
|
||||
auto_scroll=user.auto_scroll,
|
||||
chosen_assistants=user.chosen_assistants,
|
||||
default_model=user.default_model,
|
||||
hidden_assistants=user.hidden_assistants,
|
||||
@@ -130,10 +128,6 @@ class HiddenUpdateRequest(BaseModel):
|
||||
hidden: bool
|
||||
|
||||
|
||||
class AutoScrollRequest(BaseModel):
|
||||
auto_scroll: bool | None
|
||||
|
||||
|
||||
class SlackBotCreationRequest(BaseModel):
|
||||
name: str
|
||||
enabled: bool
|
||||
@@ -162,7 +156,6 @@ class SlackChannelConfigCreationRequest(BaseModel):
|
||||
channel_name: str
|
||||
respond_tag_only: bool = False
|
||||
respond_to_bots: bool = False
|
||||
show_continue_in_web_ui: bool = False
|
||||
enable_auto_filters: bool = False
|
||||
# If no team members, assume respond in the channel to everyone
|
||||
respond_member_group_list: list[str] = Field(default_factory=list)
|
||||
|
||||
@@ -80,10 +80,6 @@ def _form_channel_config(
|
||||
if follow_up_tags is not None:
|
||||
channel_config["follow_up_tags"] = follow_up_tags
|
||||
|
||||
channel_config[
|
||||
"show_continue_in_web_ui"
|
||||
] = slack_channel_config_creation_request.show_continue_in_web_ui
|
||||
|
||||
channel_config[
|
||||
"respond_to_bots"
|
||||
] = slack_channel_config_creation_request.respond_to_bots
|
||||
|
||||
@@ -34,6 +34,7 @@ from danswer.auth.users import optional_user
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import ENABLE_EMAIL_INVITES
|
||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from danswer.configs.app_configs import SUPER_USERS
|
||||
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.db.api_key import is_api_key_email_address
|
||||
@@ -51,7 +52,6 @@ from danswer.db.users import list_users
|
||||
from danswer.db.users import validate_user_role_update
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.server.manage.models import AllUsersResponse
|
||||
from danswer.server.manage.models import AutoScrollRequest
|
||||
from danswer.server.manage.models import UserByEmail
|
||||
from danswer.server.manage.models import UserInfo
|
||||
from danswer.server.manage.models import UserPreferences
|
||||
@@ -63,7 +63,6 @@ from danswer.server.models import MinimalUserSnapshot
|
||||
from danswer.server.utils import send_user_email_invite
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from ee.danswer.configs.app_configs import SUPER_USERS
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -498,6 +497,7 @@ def verify_user_logged_in(
|
||||
return fetch_no_auth_user(store)
|
||||
|
||||
raise BasicAuthenticationError(detail="User Not Authenticated")
|
||||
|
||||
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
@@ -581,30 +581,6 @@ def update_user_recent_assistants(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@router.patch("/auto-scroll")
|
||||
def update_user_auto_scroll(
|
||||
request: AutoScrollRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
if user is None:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
store = get_kv_store()
|
||||
no_auth_user = fetch_no_auth_user(store)
|
||||
no_auth_user.preferences.auto_scroll = request.auto_scroll
|
||||
set_no_auth_user_preferences(store, no_auth_user.preferences)
|
||||
return
|
||||
else:
|
||||
raise RuntimeError("This should never happen")
|
||||
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == user.id) # type: ignore
|
||||
.values(auto_scroll=request.auto_scroll)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@router.patch("/user/default-model")
|
||||
def update_user_default_model(
|
||||
request: ChosenDefaultModelRequest,
|
||||
|
||||
@@ -27,11 +27,9 @@ from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
|
||||
from danswer.db.chat import add_chats_to_session_from_slack_thread
|
||||
from danswer.db.chat import create_chat_session
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import delete_chat_session
|
||||
from danswer.db.chat import duplicate_chat_session_for_user_from_slack
|
||||
from danswer.db.chat import get_chat_message
|
||||
from danswer.db.chat import get_chat_messages_by_session
|
||||
from danswer.db.chat import get_chat_session_by_id
|
||||
@@ -534,38 +532,6 @@ def seed_chat(
|
||||
)
|
||||
|
||||
|
||||
class SeedChatFromSlackRequest(BaseModel):
|
||||
chat_session_id: UUID
|
||||
|
||||
|
||||
class SeedChatFromSlackResponse(BaseModel):
|
||||
redirect_url: str
|
||||
|
||||
|
||||
@router.post("/seed-chat-session-from-slack")
|
||||
def seed_chat_from_slack(
|
||||
chat_seed_request: SeedChatFromSlackRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SeedChatFromSlackResponse:
|
||||
slack_chat_session_id = chat_seed_request.chat_session_id
|
||||
new_chat_session = duplicate_chat_session_for_user_from_slack(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
chat_session_id=slack_chat_session_id,
|
||||
)
|
||||
|
||||
add_chats_to_session_from_slack_thread(
|
||||
db_session=db_session,
|
||||
slack_chat_session_id=slack_chat_session_id,
|
||||
new_chat_session_id=new_chat_session.id,
|
||||
)
|
||||
|
||||
return SeedChatFromSlackResponse(
|
||||
redirect_url=f"{WEB_DOMAIN}/chat?chatId={new_chat_session.id}"
|
||||
)
|
||||
|
||||
|
||||
"""File upload"""
|
||||
|
||||
|
||||
@@ -707,18 +673,14 @@ def upload_files_for_chat(
|
||||
}
|
||||
|
||||
|
||||
@router.get("/file/{file_id:path}")
|
||||
@router.get("/file/{file_id}")
|
||||
def fetch_chat_file(
|
||||
file_id: str,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_user),
|
||||
) -> Response:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_record = file_store.read_file_record(file_id)
|
||||
if not file_record:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
media_type = file_record.file_type
|
||||
file_io = file_store.read_file(file_id, mode="b")
|
||||
|
||||
return StreamingResponse(file_io, media_type=media_type)
|
||||
# NOTE: specifying "image/jpeg" here, but it still works for pngs
|
||||
# TODO: do this properly
|
||||
return Response(content=file_io.read(), media_type="image/jpeg")
|
||||
|
||||
@@ -79,7 +79,6 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
message: str
|
||||
# Files that we should attach to this message
|
||||
file_descriptors: list[FileDescriptor]
|
||||
|
||||
# If no prompt provided, uses the largest prompt of the chat session
|
||||
# but really this should be explicitly specified, only in the simplified APIs is this inferred
|
||||
# Use prompt_id 0 to use the system default prompt which is Answer-Question
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import cast
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -37,6 +38,10 @@ basic_router = APIRouter(prefix="/settings")
|
||||
def put_settings(
|
||||
settings: Settings, _: User | None = Depends(current_admin_user)
|
||||
) -> None:
|
||||
try:
|
||||
settings.check_validity()
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
store_settings(settings)
|
||||
|
||||
|
||||
|
||||
@@ -41,10 +41,33 @@ class Notification(BaseModel):
|
||||
class Settings(BaseModel):
|
||||
"""General settings"""
|
||||
|
||||
chat_page_enabled: bool = True
|
||||
search_page_enabled: bool = True
|
||||
default_page: PageType = PageType.SEARCH
|
||||
maximum_chat_retention_days: int | None = None
|
||||
gpu_enabled: bool | None = None
|
||||
product_gating: GatingType = GatingType.NONE
|
||||
|
||||
def check_validity(self) -> None:
|
||||
chat_page_enabled = self.chat_page_enabled
|
||||
search_page_enabled = self.search_page_enabled
|
||||
default_page = self.default_page
|
||||
|
||||
if chat_page_enabled is False and search_page_enabled is False:
|
||||
raise ValueError(
|
||||
"One of `search_page_enabled` and `chat_page_enabled` must be True."
|
||||
)
|
||||
|
||||
if default_page == PageType.CHAT and chat_page_enabled is False:
|
||||
raise ValueError(
|
||||
"The default page cannot be 'chat' if the chat page is disabled."
|
||||
)
|
||||
|
||||
if default_page == PageType.SEARCH and search_page_enabled is False:
|
||||
raise ValueError(
|
||||
"The default page cannot be 'search' if the search page is disabled."
|
||||
)
|
||||
|
||||
|
||||
class UserSettings(Settings):
|
||||
notifications: list[Notification]
|
||||
|
||||
@@ -1,72 +1,23 @@
|
||||
from functools import lru_cache
|
||||
|
||||
import requests
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import status
|
||||
from jwt import decode as jwt_decode
|
||||
from jwt import InvalidTokenError
|
||||
from jwt import PyJWTError
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import SUPER_CLOUD_API_KEY
|
||||
from danswer.configs.app_configs import SUPER_USERS
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.db.models import User
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.configs.app_configs import JWT_PUBLIC_KEY_URL
|
||||
from ee.danswer.configs.app_configs import SUPER_CLOUD_API_KEY
|
||||
from ee.danswer.configs.app_configs import SUPER_USERS
|
||||
from ee.danswer.db.saml import get_saml_account
|
||||
from ee.danswer.server.seeding import get_seed_config
|
||||
from ee.danswer.utils.secrets import extract_hashed_cookie
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_public_key() -> str | None:
|
||||
if JWT_PUBLIC_KEY_URL is None:
|
||||
logger.error("JWT_PUBLIC_KEY_URL is not set")
|
||||
return None
|
||||
|
||||
response = requests.get(JWT_PUBLIC_KEY_URL)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
|
||||
|
||||
async def verify_jwt_token(token: str, async_db_session: AsyncSession) -> User | None:
|
||||
try:
|
||||
public_key_pem = get_public_key()
|
||||
if public_key_pem is None:
|
||||
logger.error("Failed to retrieve public key")
|
||||
return None
|
||||
|
||||
payload = jwt_decode(
|
||||
token,
|
||||
public_key_pem,
|
||||
algorithms=["RS256"],
|
||||
audience=None,
|
||||
)
|
||||
email = payload.get("email")
|
||||
if email:
|
||||
result = await async_db_session.execute(
|
||||
select(User).where(func.lower(User.email) == func.lower(email))
|
||||
)
|
||||
return result.scalars().first()
|
||||
except InvalidTokenError:
|
||||
logger.error("Invalid JWT token")
|
||||
get_public_key.cache_clear()
|
||||
except PyJWTError as e:
|
||||
logger.error(f"JWT decoding error: {str(e)}")
|
||||
get_public_key.cache_clear()
|
||||
return None
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
# All the Auth flows are valid for EE version
|
||||
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
@@ -87,13 +38,6 @@ async def optional_user_(
|
||||
)
|
||||
user = saml_account.user if saml_account else None
|
||||
|
||||
# If user is still None, check for JWT in Authorization header
|
||||
if user is None and JWT_PUBLIC_KEY_URL is not None:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
token = auth_header[len("Bearer ") :].strip()
|
||||
user = await verify_jwt_token(token, async_db_session)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
# Applicable for OIDC Auth
|
||||
@@ -20,11 +19,3 @@ STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")
|
||||
OPENAI_DEFAULT_API_KEY = os.environ.get("OPENAI_DEFAULT_API_KEY")
|
||||
ANTHROPIC_DEFAULT_API_KEY = os.environ.get("ANTHROPIC_DEFAULT_API_KEY")
|
||||
COHERE_DEFAULT_API_KEY = os.environ.get("COHERE_DEFAULT_API_KEY")
|
||||
|
||||
# JWT Public Key URL
|
||||
JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
|
||||
|
||||
|
||||
# Super Users
|
||||
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
|
||||
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
|
||||
|
||||
@@ -11,7 +11,6 @@ from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Credential__UserGroup
|
||||
@@ -299,11 +298,6 @@ def fetch_user_groups_for_documents(
|
||||
db_session: Session,
|
||||
document_ids: list[str],
|
||||
) -> Sequence[tuple[str, list[str]]]:
|
||||
"""
|
||||
Fetches all user groups that have access to the given documents.
|
||||
|
||||
NOTE: this doesn't include groups if the cc_pair is access type SYNC
|
||||
"""
|
||||
stmt = (
|
||||
select(Document.id, func.array_agg(UserGroup.name))
|
||||
.join(
|
||||
@@ -312,11 +306,7 @@ def fetch_user_groups_for_documents(
|
||||
)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
and_(
|
||||
ConnectorCredentialPair.id
|
||||
== UserGroup__ConnectorCredentialPair.cc_pair_id,
|
||||
ConnectorCredentialPair.access_type != AccessType.SYNC,
|
||||
),
|
||||
ConnectorCredentialPair.id == UserGroup__ConnectorCredentialPair.cc_pair_id,
|
||||
)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
|
||||
@@ -16,7 +16,7 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
_VIEWSPACE_PERMISSION_TYPE = "VIEWSPACE"
|
||||
_REQUEST_PAGINATION_LIMIT = 5000
|
||||
_REQUEST_PAGINATION_LIMIT = 100
|
||||
|
||||
|
||||
def _get_server_space_permissions(
|
||||
@@ -97,7 +97,6 @@ def _get_space_permissions(
|
||||
confluence_client: OnyxConfluence,
|
||||
is_cloud: bool,
|
||||
) -> dict[str, ExternalAccess]:
|
||||
logger.debug("Getting space permissions")
|
||||
# Gets all the spaces in the Confluence instance
|
||||
all_space_keys = []
|
||||
start = 0
|
||||
@@ -114,7 +113,6 @@ def _get_space_permissions(
|
||||
start += len(spaces_batch.get("results", []))
|
||||
|
||||
# Gets the permissions for each space
|
||||
logger.debug(f"Got {len(all_space_keys)} spaces from confluence")
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess] = {}
|
||||
for space_key in all_space_keys:
|
||||
if is_cloud:
|
||||
@@ -244,7 +242,6 @@ def _fetch_all_page_restrictions_for_space(
|
||||
|
||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||
|
||||
logger.debug("Finished fetching all page restrictions for space")
|
||||
return document_restrictions
|
||||
|
||||
|
||||
@@ -257,28 +254,27 @@ def confluence_doc_sync(
|
||||
it in postgres so that when it gets created later, the permissions are
|
||||
already populated
|
||||
"""
|
||||
logger.debug("Starting confluence doc sync")
|
||||
confluence_connector = ConfluenceConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
confluence_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
if confluence_connector.confluence_client is None:
|
||||
raise ValueError("Failed to load credentials")
|
||||
confluence_client = confluence_connector.confluence_client
|
||||
|
||||
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
|
||||
|
||||
space_permissions_by_space_key = _get_space_permissions(
|
||||
confluence_client=confluence_connector.confluence_client,
|
||||
confluence_client=confluence_client,
|
||||
is_cloud=is_cloud,
|
||||
)
|
||||
|
||||
slim_docs = []
|
||||
logger.debug("Fetching all slim documents from confluence")
|
||||
for doc_batch in confluence_connector.retrieve_all_slim_documents():
|
||||
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
|
||||
slim_docs.extend(doc_batch)
|
||||
|
||||
logger.debug("Fetching all page restrictions for space")
|
||||
return _fetch_all_page_restrictions_for_space(
|
||||
confluence_client=confluence_connector.confluence_client,
|
||||
confluence_client=confluence_client,
|
||||
slim_docs=slim_docs,
|
||||
space_permissions_by_space_key=space_permissions_by_space_key,
|
||||
)
|
||||
|
||||
@@ -14,10 +14,7 @@ def _build_group_member_email_map(
|
||||
) -> dict[str, set[str]]:
|
||||
group_member_emails: dict[str, set[str]] = {}
|
||||
for user_result in confluence_client.paginated_cql_user_retrieval():
|
||||
user = user_result.get("user", {})
|
||||
if not user:
|
||||
logger.warning(f"user result missing user field: {user_result}")
|
||||
continue
|
||||
user = user_result["user"]
|
||||
email = user.get("email")
|
||||
if not email:
|
||||
# This field is only present in Confluence Server
|
||||
|
||||
@@ -57,9 +57,9 @@ DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
|
||||
# If nothing is specified here, we run the doc_sync every time the celery beat runs
|
||||
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
# Polling is not supported so we fetch all group permissions every 5 minutes
|
||||
DocumentSource.GOOGLE_DRIVE: 5 * 60,
|
||||
DocumentSource.CONFLUENCE: 5 * 60,
|
||||
# Polling is not supported so we fetch all group permissions every 60 seconds
|
||||
DocumentSource.GOOGLE_DRIVE: 60,
|
||||
DocumentSource.CONFLUENCE: 60,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.main import get_application as get_application_base
|
||||
from danswer.main import include_router_with_global_prefix_prepended
|
||||
from danswer.server.api_key.api import router as api_key_router
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from ee.danswer.configs.app_configs import OPENID_CONFIG_URL
|
||||
@@ -115,6 +116,8 @@ def get_application() -> FastAPI:
|
||||
# Analytics endpoints
|
||||
include_router_with_global_prefix_prepended(application, analytics_router)
|
||||
include_router_with_global_prefix_prepended(application, query_history_router)
|
||||
# Api key management
|
||||
include_router_with_global_prefix_prepended(application, api_key_router)
|
||||
# EE only backend APIs
|
||||
include_router_with_global_prefix_prepended(application, query_router)
|
||||
include_router_with_global_prefix_prepended(application, chat_router)
|
||||
|
||||
@@ -113,6 +113,10 @@ async def refresh_access_token(
|
||||
def put_settings(
|
||||
settings: EnterpriseSettings, _: User | None = Depends(current_admin_user)
|
||||
) -> None:
|
||||
try:
|
||||
settings.check_validity()
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
store_settings(settings)
|
||||
|
||||
|
||||
|
||||
@@ -157,6 +157,7 @@ def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) ->
|
||||
def _seed_settings(settings: Settings) -> None:
|
||||
logger.notice("Seeding Settings")
|
||||
try:
|
||||
settings.check_validity()
|
||||
store_base_settings(settings)
|
||||
logger.notice("Successfully seeded Settings")
|
||||
except ValueError as e:
|
||||
|
||||
@@ -11,7 +11,6 @@ from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from google.oauth2 import service_account # type: ignore
|
||||
from litellm import embedding
|
||||
from litellm.exceptions import RateLimitError
|
||||
from retry import retry
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
@@ -206,22 +205,28 @@ class CloudEmbedding:
|
||||
model_name: str | None = None,
|
||||
deployment_name: str | None = None,
|
||||
) -> list[Embedding]:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return self._embed_openai(texts, model_name)
|
||||
elif self.provider == EmbeddingProvider.AZURE:
|
||||
return self._embed_azure(texts, f"azure/{deployment_name}")
|
||||
elif self.provider == EmbeddingProvider.LITELLM:
|
||||
return self._embed_litellm_proxy(texts, model_name)
|
||||
try:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return self._embed_openai(texts, model_name)
|
||||
elif self.provider == EmbeddingProvider.AZURE:
|
||||
return self._embed_azure(texts, f"azure/{deployment_name}")
|
||||
elif self.provider == EmbeddingProvider.LITELLM:
|
||||
return self._embed_litellm_proxy(texts, model_name)
|
||||
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return self._embed_cohere(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return self._embed_voyage(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return self._embed_vertex(texts, model_name, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return self._embed_cohere(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return self._embed_voyage(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return self._embed_vertex(texts, model_name, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error embedding text with {self.provider}: {str(e)}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
@@ -425,11 +430,6 @@ async def process_embed_request(
|
||||
prefix=prefix,
|
||||
)
|
||||
return EmbedResponse(embeddings=embeddings)
|
||||
except RateLimitError as e:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
exception_detail = f"Error during embedding process:\n{str(e)}"
|
||||
logger.exception(exception_detail)
|
||||
|
||||
@@ -29,7 +29,7 @@ trafilatura==1.12.2
|
||||
langchain==0.1.17
|
||||
langchain-core==0.1.50
|
||||
langchain-text-splitters==0.0.1
|
||||
litellm==1.53.1
|
||||
litellm==1.50.2
|
||||
lxml==5.3.0
|
||||
lxml_html_clean==0.2.2
|
||||
llama-index==0.9.45
|
||||
@@ -38,7 +38,7 @@ msal==1.28.0
|
||||
nltk==3.8.1
|
||||
Office365-REST-Python-Client==2.5.9
|
||||
oauthlib==3.2.2
|
||||
openai==1.55.3
|
||||
openai==1.52.2
|
||||
openpyxl==3.1.2
|
||||
playwright==1.41.2
|
||||
psutil==5.9.5
|
||||
|
||||
@@ -163,92 +163,47 @@ SUPPORTED_EMBEDDING_MODELS = [
|
||||
dim=1024,
|
||||
index_name="danswer_chunk_cohere_embed_english_v3_0",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="cohere/embed-english-v3.0",
|
||||
dim=1024,
|
||||
index_name="danswer_chunk_embed_english_v3_0",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="cohere/embed-english-light-v3.0",
|
||||
dim=384,
|
||||
index_name="danswer_chunk_cohere_embed_english_light_v3_0",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="cohere/embed-english-light-v3.0",
|
||||
dim=384,
|
||||
index_name="danswer_chunk_embed_english_light_v3_0",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="openai/text-embedding-3-large",
|
||||
dim=3072,
|
||||
index_name="danswer_chunk_openai_text_embedding_3_large",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="openai/text-embedding-3-large",
|
||||
dim=3072,
|
||||
index_name="danswer_chunk_text_embedding_3_large",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="openai/text-embedding-3-small",
|
||||
dim=1536,
|
||||
index_name="danswer_chunk_openai_text_embedding_3_small",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="openai/text-embedding-3-small",
|
||||
dim=1536,
|
||||
index_name="danswer_chunk_text_embedding_3_small",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="google/text-embedding-004",
|
||||
dim=768,
|
||||
index_name="danswer_chunk_google_text_embedding_004",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="google/text-embedding-004",
|
||||
dim=768,
|
||||
index_name="danswer_chunk_text_embedding_004",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="google/textembedding-gecko@003",
|
||||
dim=768,
|
||||
index_name="danswer_chunk_google_textembedding_gecko_003",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="google/textembedding-gecko@003",
|
||||
dim=768,
|
||||
index_name="danswer_chunk_textembedding_gecko_003",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="voyage/voyage-large-2-instruct",
|
||||
dim=1024,
|
||||
index_name="danswer_chunk_voyage_large_2_instruct",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="voyage/voyage-large-2-instruct",
|
||||
dim=1024,
|
||||
index_name="danswer_chunk_large_2_instruct",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="voyage/voyage-light-2-instruct",
|
||||
dim=384,
|
||||
index_name="danswer_chunk_voyage_light_2_instruct",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="voyage/voyage-light-2-instruct",
|
||||
dim=384,
|
||||
index_name="danswer_chunk_light_2_instruct",
|
||||
),
|
||||
# Self-hosted models
|
||||
SupportedEmbeddingModel(
|
||||
name="nomic-ai/nomic-embed-text-v1",
|
||||
dim=768,
|
||||
index_name="danswer_chunk_nomic_ai_nomic_embed_text_v1",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="nomic-ai/nomic-embed-text-v1",
|
||||
dim=768,
|
||||
index_name="danswer_chunk_nomic_embed_text_v1",
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="intfloat/e5-base-v2",
|
||||
dim=768,
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.slab.connector import SlabConnector
|
||||
|
||||
|
||||
def load_test_data(file_name: str = "test_slab_data.json") -> dict[str, str]:
|
||||
current_dir = Path(__file__).parent
|
||||
with open(current_dir / file_name, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def slab_connector() -> SlabConnector:
|
||||
connector = SlabConnector(
|
||||
base_url="https://onyx-test.slab.com/",
|
||||
)
|
||||
connector.load_credentials(
|
||||
{
|
||||
"slab_bot_token": os.environ["SLAB_BOT_TOKEN"],
|
||||
}
|
||||
)
|
||||
return connector
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"Need a test account with a slab subscription to run this test."
|
||||
"Trial only lasts 14 days."
|
||||
)
|
||||
)
|
||||
def test_slab_connector_basic(slab_connector: SlabConnector) -> None:
|
||||
all_docs: list[Document] = []
|
||||
target_test_doc_id = "jcp6cohu"
|
||||
target_test_doc: Document | None = None
|
||||
for doc_batch in slab_connector.poll_source(0, time.time()):
|
||||
for doc in doc_batch:
|
||||
all_docs.append(doc)
|
||||
if doc.id == target_test_doc_id:
|
||||
target_test_doc = doc
|
||||
|
||||
assert len(all_docs) == 6
|
||||
assert target_test_doc is not None
|
||||
|
||||
desired_test_data = load_test_data()
|
||||
assert (
|
||||
target_test_doc.semantic_identifier == desired_test_data["semantic_identifier"]
|
||||
)
|
||||
assert target_test_doc.source == DocumentSource.SLAB
|
||||
assert target_test_doc.metadata == {}
|
||||
assert target_test_doc.primary_owners is None
|
||||
assert target_test_doc.secondary_owners is None
|
||||
assert target_test_doc.title is None
|
||||
assert target_test_doc.from_ingestion_api is False
|
||||
assert target_test_doc.additional_info is None
|
||||
|
||||
assert len(target_test_doc.sections) == 1
|
||||
section = target_test_doc.sections[0]
|
||||
# Need to replace the weird apostrophe with a normal one
|
||||
assert section.text.replace("\u2019", "'") == desired_test_data["section_text"]
|
||||
assert section.link == desired_test_data["link"]
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"Need a test account with a slab subscription to run this test."
|
||||
"Trial only lasts 14 days."
|
||||
)
|
||||
)
|
||||
def test_slab_connector_slim(slab_connector: SlabConnector) -> None:
|
||||
# Get all doc IDs from the full connector
|
||||
all_full_doc_ids = set()
|
||||
for doc_batch in slab_connector.load_from_state():
|
||||
all_full_doc_ids.update([doc.id for doc in doc_batch])
|
||||
|
||||
# Get all doc IDs from the slim connector
|
||||
all_slim_doc_ids = set()
|
||||
for slim_doc_batch in slab_connector.retrieve_all_slim_documents():
|
||||
all_slim_doc_ids.update([doc.id for doc in slim_doc_batch])
|
||||
|
||||
# The set of full doc IDs should be always be a subset of the slim doc IDs
|
||||
assert all_full_doc_ids.issubset(all_slim_doc_ids)
|
||||
@@ -1,5 +0,0 @@
|
||||
{
|
||||
"section_text": "Learn about Posts\nWelcome\nThis is a post, where you can edit, share, and collaborate in real time with your team. We'd love to show you how it works!\nReading and editing\nClick the mode button to toggle between read and edit modes. You can only make changes to a post when editing.\nOrganize your posts\nWhen in edit mode, you can add topics to a post, which will keep it organized for the right 👀 to see.\nSmart mentions\nMentions are references to users, posts, topics and third party tools that show details on hover. Paste in a link for automatic conversion.\nLook back in time\nYou are ready to begin writing. You can always bring back this tour in the help menu.\nGreat job!\nYou are ready to begin writing. You can always bring back this tour in the help menu.\n\n",
|
||||
"link": "https://onyx-test.slab.com/posts/learn-about-posts-jcp6cohu",
|
||||
"semantic_identifier": "Learn about Posts"
|
||||
}
|
||||
@@ -7,7 +7,6 @@ from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.model_server_models import EmbeddingProvider
|
||||
|
||||
VALID_SAMPLE = ["hi", "hello my name is bob", "woah there!!!. 😃"]
|
||||
VALID_LONG_SAMPLE = ["hi " * 999]
|
||||
# openai limit is 2048, cohere is supposed to be 96 but in practice that doesn't
|
||||
# seem to be true
|
||||
TOO_LONG_SAMPLE = ["a"] * 2500
|
||||
@@ -100,42 +99,3 @@ def local_nomic_embedding_model() -> EmbeddingModel:
|
||||
def test_local_nomic_embedding(local_nomic_embedding_model: EmbeddingModel) -> None:
|
||||
_run_embeddings(VALID_SAMPLE, local_nomic_embedding_model, 768)
|
||||
_run_embeddings(TOO_LONG_SAMPLE, local_nomic_embedding_model, 768)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def azure_embedding_model() -> EmbeddingModel:
|
||||
return EmbeddingModel(
|
||||
server_host="localhost",
|
||||
server_port=9000,
|
||||
model_name="text-embedding-3-large",
|
||||
normalize=True,
|
||||
query_prefix=None,
|
||||
passage_prefix=None,
|
||||
api_key=os.getenv("AZURE_API_KEY"),
|
||||
provider_type=EmbeddingProvider.AZURE,
|
||||
api_url=os.getenv("AZURE_API_URL"),
|
||||
)
|
||||
|
||||
|
||||
# NOTE (chris): this test doesn't work, and I do not know why
|
||||
# def test_azure_embedding_model_rate_limit(azure_embedding_model: EmbeddingModel):
|
||||
# """NOTE: this test relies on a very low rate limit for the Azure API +
|
||||
# this test only being run once in a 1 minute window"""
|
||||
# # VALID_LONG_SAMPLE is 999 tokens, so the second call should run into rate
|
||||
# # limits assuming the limit is 1000 tokens per minute
|
||||
# result = azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
|
||||
# assert len(result) == 1
|
||||
# assert len(result[0]) == 1536
|
||||
|
||||
# # this should fail
|
||||
# with pytest.raises(ModelServerRateLimitError):
|
||||
# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
|
||||
# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
|
||||
# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
|
||||
|
||||
# # this should succeed, since passage requests retry up to 10 times
|
||||
# start = time.time()
|
||||
# result = azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.PASSAGE)
|
||||
# assert len(result) == 1
|
||||
# assert len(result[0]) == 1536
|
||||
# assert time.time() - start > 30 # make sure we waited, even though we hit rate limits
|
||||
|
||||
@@ -240,85 +240,7 @@ class CCPairManager:
|
||||
result.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def wait_for_indexing_inactive(
|
||||
cc_pair: DATestCCPair,
|
||||
timeout: float = MAX_DELAY,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
"""wait for the number of docs to be indexed on the connector.
|
||||
This is used to test pausing a connector in the middle of indexing and
|
||||
terminating that indexing."""
|
||||
print(f"Indexing wait for inactive starting: cc_pair={cc_pair.id}")
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
fetched_cc_pairs = CCPairManager.get_indexing_statuses(
|
||||
user_performing_action
|
||||
)
|
||||
for fetched_cc_pair in fetched_cc_pairs:
|
||||
if fetched_cc_pair.cc_pair_id != cc_pair.id:
|
||||
continue
|
||||
|
||||
if fetched_cc_pair.in_progress:
|
||||
continue
|
||||
|
||||
print(f"Indexing is inactive: cc_pair={cc_pair.id}")
|
||||
return
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
if elapsed > timeout:
|
||||
raise TimeoutError(
|
||||
f"Indexing wait for inactive timed out: cc_pair={cc_pair.id} timeout={timeout}s"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Indexing wait for inactive still waiting: cc_pair={cc_pair.id} elapsed={elapsed:.2f} timeout={timeout}s"
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
@staticmethod
|
||||
def wait_for_indexing_in_progress(
|
||||
cc_pair: DATestCCPair,
|
||||
timeout: float = MAX_DELAY,
|
||||
num_docs: int = 16,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
"""wait for the number of docs to be indexed on the connector.
|
||||
This is used to test pausing a connector in the middle of indexing and
|
||||
terminating that indexing."""
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
fetched_cc_pairs = CCPairManager.get_indexing_statuses(
|
||||
user_performing_action
|
||||
)
|
||||
for fetched_cc_pair in fetched_cc_pairs:
|
||||
if fetched_cc_pair.cc_pair_id != cc_pair.id:
|
||||
continue
|
||||
|
||||
if not fetched_cc_pair.in_progress:
|
||||
continue
|
||||
|
||||
if fetched_cc_pair.docs_indexed >= num_docs:
|
||||
print(
|
||||
"Indexed at least the requested number of docs: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"docs_indexed={fetched_cc_pair.docs_indexed} "
|
||||
f"num_docs={num_docs}"
|
||||
)
|
||||
return
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
if elapsed > timeout:
|
||||
raise TimeoutError(
|
||||
f"Indexing in progress wait timed out: cc_pair={cc_pair.id} timeout={timeout}s"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Indexing in progress waiting: cc_pair={cc_pair.id} elapsed={elapsed:.2f} timeout={timeout}s"
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
@staticmethod
|
||||
def wait_for_indexing_completion(
|
||||
def wait_for_indexing(
|
||||
cc_pair: DATestCCPair,
|
||||
after: datetime,
|
||||
timeout: float = MAX_DELAY,
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
import mimetypes
|
||||
from typing import cast
|
||||
from typing import IO
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
class FileManager:
|
||||
@staticmethod
|
||||
def upload_files(
|
||||
files: List[Tuple[str, IO]],
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> Tuple[List[FileDescriptor], str]:
|
||||
headers = (
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
)
|
||||
headers.pop("Content-Type", None)
|
||||
|
||||
files_param = []
|
||||
for filename, file_obj in files:
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
if mime_type is None:
|
||||
mime_type = "application/octet-stream"
|
||||
files_param.append(("files", (filename, file_obj, mime_type)))
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/file",
|
||||
files=files_param,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
return (
|
||||
cast(List[FileDescriptor], []),
|
||||
f"Failed to upload files - {response.json().get('detail', 'Unknown error')}",
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
return response_json.get("files", cast(List[FileDescriptor], [])), ""
|
||||
|
||||
@staticmethod
|
||||
def fetch_uploaded_file(
|
||||
file_id: str,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> bytes:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/chat/file/{file_id}",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
@@ -14,7 +14,6 @@ from tests.integration.common_utils.managers.document_search import (
|
||||
)
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestCCPair
|
||||
from tests.integration.common_utils.test_models import DATestConnector
|
||||
from tests.integration.common_utils.test_models import DATestCredential
|
||||
@@ -78,7 +77,7 @@ def test_slack_permission_sync(
|
||||
access_type=AccessType.SYNC,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
CCPairManager.wait_for_indexing(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
user_performing_action=admin_user,
|
||||
@@ -113,7 +112,7 @@ def test_slack_permission_sync(
|
||||
# Run indexing
|
||||
before = datetime.now(timezone.utc)
|
||||
CCPairManager.run_once(cc_pair, admin_user)
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
CCPairManager.wait_for_indexing(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
user_performing_action=admin_user,
|
||||
@@ -216,124 +215,3 @@ def test_slack_permission_sync(
|
||||
# Ensure test_user_1 can only see messages from the public channel
|
||||
assert public_message in danswer_doc_message_strings
|
||||
assert private_message not in danswer_doc_message_strings
|
||||
|
||||
|
||||
def test_slack_group_permission_sync(
|
||||
reset: None,
|
||||
vespa_client: vespa_fixture,
|
||||
slack_test_setup: tuple[dict[str, Any], dict[str, Any]],
|
||||
) -> None:
|
||||
"""
|
||||
This test ensures that permission sync overrides danswer group access.
|
||||
"""
|
||||
public_channel, private_channel = slack_test_setup
|
||||
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(
|
||||
email="admin@onyx-test.com",
|
||||
)
|
||||
|
||||
# Creating a non-admin user
|
||||
test_user_1: DATestUser = UserManager.create(
|
||||
email="test_user_1@onyx-test.com",
|
||||
)
|
||||
|
||||
# Create a user group and adding the non-admin user to it
|
||||
user_group = UserGroupManager.create(
|
||||
name="test_group",
|
||||
user_ids=[test_user_1.id],
|
||||
cc_pair_ids=[],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[user_group],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
|
||||
email_id_map = SlackManager.build_slack_user_email_id_map(slack_client)
|
||||
admin_user_id = email_id_map[admin_user.email]
|
||||
|
||||
LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
# Add only admin to the private channel
|
||||
SlackManager.set_channel_members(
|
||||
slack_client=slack_client,
|
||||
admin_user_id=admin_user_id,
|
||||
channel=private_channel,
|
||||
user_ids=[admin_user_id],
|
||||
)
|
||||
|
||||
before = datetime.now(timezone.utc)
|
||||
credential = CredentialManager.create(
|
||||
source=DocumentSource.SLACK,
|
||||
credential_json={
|
||||
"slack_bot_token": os.environ["SLACK_BOT_TOKEN"],
|
||||
},
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Create connector with sync access and assign it to the user group
|
||||
connector = ConnectorManager.create(
|
||||
name="Slack",
|
||||
input_type=InputType.POLL,
|
||||
source=DocumentSource.SLACK,
|
||||
connector_specific_config={
|
||||
"workspace": "onyx-test-workspace",
|
||||
"channels": [private_channel["name"]],
|
||||
},
|
||||
access_type=AccessType.SYNC,
|
||||
groups=[user_group.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
cc_pair = CCPairManager.create(
|
||||
credential_id=credential.id,
|
||||
connector_id=connector.id,
|
||||
access_type=AccessType.SYNC,
|
||||
user_performing_action=admin_user,
|
||||
groups=[user_group.id],
|
||||
)
|
||||
|
||||
# Add a test message to the private channel
|
||||
private_message = "This is a secret message: 987654"
|
||||
SlackManager.add_message_to_channel(
|
||||
slack_client=slack_client,
|
||||
channel=private_channel,
|
||||
message=private_message,
|
||||
)
|
||||
|
||||
# Run indexing
|
||||
CCPairManager.run_once(cc_pair, admin_user)
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Run permission sync
|
||||
CCPairManager.sync(
|
||||
cc_pair=cc_pair,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
CCPairManager.wait_for_sync(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
number_of_updated_docs=1,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify admin can see the message
|
||||
admin_docs = DocumentSearchManager.search_documents(
|
||||
query="secret message",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert private_message in admin_docs
|
||||
|
||||
# Verify test_user_1 cannot see the message despite being in the group
|
||||
# (Slack permissions should take precedence)
|
||||
user_1_docs = DocumentSearchManager.search_documents(
|
||||
query="secret message",
|
||||
user_performing_action=test_user_1,
|
||||
)
|
||||
assert private_message not in user_1_docs
|
||||
|
||||
@@ -74,7 +74,7 @@ def test_slack_prune(
|
||||
access_type=AccessType.SYNC,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
CCPairManager.wait_for_indexing(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
user_performing_action=admin_user,
|
||||
@@ -113,7 +113,7 @@ def test_slack_prune(
|
||||
# Run indexing
|
||||
before = datetime.now(timezone.utc)
|
||||
CCPairManager.run_once(cc_pair, admin_user)
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
CCPairManager.wait_for_indexing(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
user_performing_action=admin_user,
|
||||
|
||||
@@ -58,7 +58,7 @@ def test_overlapping_connector_creation(reset: None) -> None:
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
CCPairManager.wait_for_indexing(
|
||||
cc_pair_1, now, timeout=120, user_performing_action=admin_user
|
||||
)
|
||||
|
||||
@@ -71,7 +71,7 @@ def test_overlapping_connector_creation(reset: None) -> None:
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
CCPairManager.wait_for_indexing(
|
||||
cc_pair_2, now, timeout=120, user_performing_action=admin_user
|
||||
)
|
||||
|
||||
@@ -82,48 +82,3 @@ def test_overlapping_connector_creation(reset: None) -> None:
|
||||
assert info_2
|
||||
|
||||
assert info_1.num_docs_indexed == info_2.num_docs_indexed
|
||||
|
||||
|
||||
def test_connector_pause_while_indexing(reset: None) -> None:
|
||||
"""Tests that we can pause a connector while indexing is in progress and that
|
||||
tasks end early or abort as a result.
|
||||
|
||||
TODO: This does not specifically test for soft or hard termination code paths.
|
||||
Design specific tests for those use cases.
|
||||
"""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
config = {
|
||||
"wiki_base": os.environ["CONFLUENCE_TEST_SPACE_URL"],
|
||||
"space": "",
|
||||
"is_cloud": True,
|
||||
"page_id": "",
|
||||
}
|
||||
|
||||
credential = {
|
||||
"confluence_username": os.environ["CONFLUENCE_USER_NAME"],
|
||||
"confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"],
|
||||
}
|
||||
|
||||
# store the time before we create the connector so that we know after
|
||||
# when the indexing should have started
|
||||
datetime.now(timezone.utc)
|
||||
|
||||
# create connector
|
||||
cc_pair_1 = CCPairManager.create_from_scratch(
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
connector_specific_config=config,
|
||||
credential_json=credential,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
CCPairManager.wait_for_indexing_in_progress(
|
||||
cc_pair_1, timeout=60, num_docs=16, user_performing_action=admin_user
|
||||
)
|
||||
|
||||
CCPairManager.pause_cc_pair(cc_pair_1, user_performing_action=admin_user)
|
||||
|
||||
CCPairManager.wait_for_indexing_inactive(
|
||||
cc_pair_1, timeout=60, user_performing_action=admin_user
|
||||
)
|
||||
return
|
||||
|
||||
@@ -135,7 +135,7 @@ def test_web_pruning(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
CCPairManager.wait_for_indexing(
|
||||
cc_pair_1, now, timeout=60, user_performing_action=admin_user
|
||||
)
|
||||
|
||||
|
||||
@@ -385,16 +385,6 @@ def process_text(
|
||||
"Here is some text[[1]](https://0.com). Some other text",
|
||||
["doc_0"],
|
||||
),
|
||||
# ['To', ' set', ' up', ' D', 'answer', ',', ' if', ' you', ' are', ' running', ' it', ' yourself', ' and',
|
||||
# ' need', ' access', ' to', ' certain', ' features', ' like', ' auto', '-sync', 'ing', ' document',
|
||||
# '-level', ' access', ' permissions', ',', ' you', ' should', ' reach', ' out', ' to', ' the', ' D',
|
||||
# 'answer', ' team', ' to', ' receive', ' access', ' [[', '4', ']].', '']
|
||||
(
|
||||
"Unique tokens with double brackets and a single token that ends the citation and has characters after it.",
|
||||
["... to receive access", " [[", "1", "]].", ""],
|
||||
"... to receive access [[1]](https://0.com).",
|
||||
["doc_0"],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_citation_extraction(
|
||||
|
||||
@@ -130,7 +130,6 @@ services:
|
||||
restart: always
|
||||
environment:
|
||||
- ENCRYPTION_KEY_SECRET=${ENCRYPTION_KEY_SECRET:-}
|
||||
- JWT_PUBLIC_KEY_URL=${JWT_PUBLIC_KEY_URL:-} # used for JWT authentication of users via API
|
||||
# Gen AI Settings (Needed by DanswerBot)
|
||||
- GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-}
|
||||
- QA_TIMEOUT=${QA_TIMEOUT:-}
|
||||
|
||||
6
node_modules/.package-lock.json
generated
vendored
6
node_modules/.package-lock.json
generated
vendored
@@ -1,6 +0,0 @@
|
||||
{
|
||||
"name": "danswer",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {}
|
||||
}
|
||||
9
web/@types/favicon-fetch.d.ts
vendored
9
web/@types/favicon-fetch.d.ts
vendored
@@ -1,9 +0,0 @@
|
||||
declare module "favicon-fetch" {
|
||||
interface FaviconFetchOptions {
|
||||
uri: string;
|
||||
}
|
||||
|
||||
function faviconFetch(options: FaviconFetchOptions): string | null;
|
||||
|
||||
export default faviconFetch;
|
||||
}
|
||||
1011
web/package-lock.json
generated
1011
web/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -17,13 +17,11 @@
|
||||
"@headlessui/react": "^2.2.0",
|
||||
"@headlessui/tailwindcss": "^0.2.1",
|
||||
"@phosphor-icons/react": "^2.0.8",
|
||||
"@radix-ui/react-checkbox": "^1.1.2",
|
||||
"@radix-ui/react-dialog": "^1.1.2",
|
||||
"@radix-ui/react-dialog": "^1.0.5",
|
||||
"@radix-ui/react-popover": "^1.1.2",
|
||||
"@radix-ui/react-select": "^2.1.2",
|
||||
"@radix-ui/react-separator": "^1.1.0",
|
||||
"@radix-ui/react-slot": "^1.1.0",
|
||||
"@radix-ui/react-switch": "^1.1.1",
|
||||
"@radix-ui/react-tabs": "^1.1.1",
|
||||
"@radix-ui/react-tooltip": "^1.1.3",
|
||||
"@sentry/nextjs": "^8.34.0",
|
||||
@@ -39,7 +37,6 @@
|
||||
"class-variance-authority": "^0.7.0",
|
||||
"clsx": "^2.1.1",
|
||||
"date-fns": "^3.6.0",
|
||||
"favicon-fetch": "^1.0.0",
|
||||
"formik": "^2.2.9",
|
||||
"js-cookie": "^3.0.5",
|
||||
"lodash": "^4.17.21",
|
||||
@@ -70,7 +67,6 @@
|
||||
"tailwindcss-animate": "^1.0.7",
|
||||
"typescript": "5.0.3",
|
||||
"uuid": "^9.0.1",
|
||||
"vaul": "^1.1.1",
|
||||
"yup": "^1.4.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
||||
@@ -24,6 +24,13 @@ import {
|
||||
TextFormField,
|
||||
} from "@/components/admin/connectors/Field";
|
||||
|
||||
import {
|
||||
Card,
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
CardContent,
|
||||
CardFooter,
|
||||
} from "@/components/ui/card";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { getDisplayNameForModel, useCategories } from "@/lib/hooks";
|
||||
import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable";
|
||||
@@ -398,7 +405,7 @@ export function AssistantEditor({
|
||||
message: `"${assistant.name}" has been added to your list.`,
|
||||
type: "success",
|
||||
});
|
||||
await refreshAssistants();
|
||||
router.refresh();
|
||||
} else {
|
||||
setPopup({
|
||||
message: `"${assistant.name}" could not be added to your list.`,
|
||||
|
||||
@@ -90,7 +90,7 @@ export function PersonasTable() {
|
||||
message: `Failed to update persona order - ${await response.text()}`,
|
||||
});
|
||||
setFinalPersonas(assistants);
|
||||
await refreshAssistants();
|
||||
router.refresh();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -151,7 +151,7 @@ export function PersonasTable() {
|
||||
persona.is_visible
|
||||
);
|
||||
if (response.ok) {
|
||||
await refreshAssistants();
|
||||
router.refresh();
|
||||
} else {
|
||||
setPopup({
|
||||
type: "error",
|
||||
@@ -183,7 +183,7 @@ export function PersonasTable() {
|
||||
onClick={async () => {
|
||||
const response = await deletePersona(persona.id);
|
||||
if (response.ok) {
|
||||
await refreshAssistants();
|
||||
router.refresh();
|
||||
} else {
|
||||
alert(
|
||||
`Failed to delete persona - ${await response.text()}`
|
||||
|
||||
@@ -259,8 +259,29 @@ export async function updatePersona(
|
||||
): Promise<[Response, Response | null]> {
|
||||
const { id, existingPromptId } = personaUpdateRequest;
|
||||
|
||||
let fileId = null;
|
||||
if (personaUpdateRequest.uploaded_image) {
|
||||
fileId = await uploadFile(personaUpdateRequest.uploaded_image);
|
||||
if (!fileId) {
|
||||
return [new Response(null, { status: 400 }), null];
|
||||
}
|
||||
}
|
||||
|
||||
const updatePersonaResponse = await fetch(`/api/persona/${id}`, {
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(
|
||||
buildPersonaAPIBody(personaUpdateRequest, existingPromptId ?? 0, fileId)
|
||||
),
|
||||
});
|
||||
|
||||
if (!updatePersonaResponse.ok) {
|
||||
return [updatePersonaResponse, null];
|
||||
}
|
||||
|
||||
let promptResponse;
|
||||
let promptId: number | null = null;
|
||||
if (existingPromptId !== undefined) {
|
||||
promptResponse = await updatePrompt({
|
||||
promptId: existingPromptId,
|
||||
@@ -269,7 +290,6 @@ export async function updatePersona(
|
||||
taskPrompt: personaUpdateRequest.task_prompt,
|
||||
includeCitations: personaUpdateRequest.include_citations,
|
||||
});
|
||||
promptId = existingPromptId;
|
||||
} else {
|
||||
promptResponse = await createPrompt({
|
||||
personaName: personaUpdateRequest.name,
|
||||
@@ -277,30 +297,7 @@ export async function updatePersona(
|
||||
taskPrompt: personaUpdateRequest.task_prompt,
|
||||
includeCitations: personaUpdateRequest.include_citations,
|
||||
});
|
||||
promptId = promptResponse.ok
|
||||
? ((await promptResponse.json()).id as number)
|
||||
: null;
|
||||
}
|
||||
let fileId = null;
|
||||
if (personaUpdateRequest.uploaded_image) {
|
||||
fileId = await uploadFile(personaUpdateRequest.uploaded_image);
|
||||
if (!fileId) {
|
||||
return [promptResponse, null];
|
||||
}
|
||||
}
|
||||
|
||||
const updatePersonaResponse =
|
||||
promptResponse.ok && promptId !== null
|
||||
? await fetch(`/api/persona/${id}`, {
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(
|
||||
buildPersonaAPIBody(personaUpdateRequest, promptId, fileId)
|
||||
),
|
||||
})
|
||||
: null;
|
||||
|
||||
return [promptResponse, updatePersonaResponse];
|
||||
}
|
||||
|
||||
@@ -60,24 +60,21 @@ export function SlackChannelConfigsTable({
|
||||
.slice(numToDisplay * (page - 1), numToDisplay * page)
|
||||
.map((slackChannelConfig) => {
|
||||
return (
|
||||
<TableRow
|
||||
key={slackChannelConfig.id}
|
||||
className="cursor-pointer hover:bg-gray-100 transition-colors"
|
||||
onClick={() => {
|
||||
window.location.href = `/admin/bots/${slackBotId}/channels/${slackChannelConfig.id}`;
|
||||
}}
|
||||
>
|
||||
<TableRow key={slackChannelConfig.id}>
|
||||
<TableCell>
|
||||
<div className="flex gap-x-2">
|
||||
<div className="my-auto">
|
||||
<Link
|
||||
className="cursor-pointer my-auto"
|
||||
href={`/admin/bots/${slackBotId}/channels/${slackChannelConfig.id}`}
|
||||
>
|
||||
<EditIcon />
|
||||
</div>
|
||||
</Link>
|
||||
<div className="my-auto">
|
||||
{"#" + slackChannelConfig.channel_config.channel_name}
|
||||
</div>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell onClick={(e) => e.stopPropagation()}>
|
||||
<TableCell>
|
||||
{slackChannelConfig.persona &&
|
||||
!isPersonaASlackBotPersona(slackChannelConfig.persona) ? (
|
||||
<Link
|
||||
@@ -101,11 +98,10 @@ export function SlackChannelConfigsTable({
|
||||
: "-"}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell onClick={(e) => e.stopPropagation()}>
|
||||
<TableCell>
|
||||
<div
|
||||
className="cursor-pointer hover:text-destructive"
|
||||
onClick={async (e) => {
|
||||
e.stopPropagation();
|
||||
onClick={async () => {
|
||||
const response = await deleteSlackChannelConfig(
|
||||
slackChannelConfig.id
|
||||
);
|
||||
|
||||
@@ -81,11 +81,6 @@ export const SlackChannelConfigCreationForm = ({
|
||||
respond_to_bots:
|
||||
existingSlackChannelConfig?.channel_config?.respond_to_bots ||
|
||||
false,
|
||||
show_continue_in_web_ui:
|
||||
// If we're updating, we want to keep the existing value
|
||||
// Otherwise, we want to default to true
|
||||
existingSlackChannelConfig?.channel_config
|
||||
?.show_continue_in_web_ui ?? !isUpdate,
|
||||
enable_auto_filters:
|
||||
existingSlackChannelConfig?.enable_auto_filters || false,
|
||||
respond_member_group_list:
|
||||
@@ -124,7 +119,6 @@ export const SlackChannelConfigCreationForm = ({
|
||||
questionmark_prefilter_enabled: Yup.boolean().required(),
|
||||
respond_tag_only: Yup.boolean().required(),
|
||||
respond_to_bots: Yup.boolean().required(),
|
||||
show_continue_in_web_ui: Yup.boolean().required(),
|
||||
enable_auto_filters: Yup.boolean().required(),
|
||||
respond_member_group_list: Yup.array().of(Yup.string()).required(),
|
||||
still_need_help_enabled: Yup.boolean().required(),
|
||||
@@ -276,13 +270,7 @@ export const SlackChannelConfigCreationForm = ({
|
||||
|
||||
{showAdvancedOptions && (
|
||||
<div className="mt-4">
|
||||
<BooleanFormField
|
||||
name="show_continue_in_web_ui"
|
||||
removeIndent
|
||||
label="Show Continue in Web UI button"
|
||||
tooltip="If set, will show a button at the bottom of the response that allows the user to continue the conversation in the Danswer Web UI"
|
||||
/>
|
||||
<div className="w-64 mb-4 mt-4">
|
||||
<div className="w-64 mb-4">
|
||||
<SelectorFormField
|
||||
name="response_type"
|
||||
label="Answer Type"
|
||||
|
||||
@@ -15,7 +15,6 @@ interface SlackChannelConfigCreationRequest {
|
||||
questionmark_prefilter_enabled: boolean;
|
||||
respond_tag_only: boolean;
|
||||
respond_to_bots: boolean;
|
||||
show_continue_in_web_ui: boolean;
|
||||
respond_member_group_list: string[];
|
||||
follow_up_tags?: string[];
|
||||
usePersona: boolean;
|
||||
@@ -44,7 +43,6 @@ const buildRequestBodyFromCreationRequest = (
|
||||
channel_name: creationRequest.channel_name,
|
||||
respond_tag_only: creationRequest.respond_tag_only,
|
||||
respond_to_bots: creationRequest.respond_to_bots,
|
||||
show_continue_in_web_ui: creationRequest.show_continue_in_web_ui,
|
||||
enable_auto_filters: creationRequest.enable_auto_filters,
|
||||
respond_member_group_list: creationRequest.respond_member_group_list,
|
||||
answer_filters: buildFiltersFromCreationRequest(creationRequest),
|
||||
|
||||
@@ -22,6 +22,7 @@ function SlackBotEditPage({
|
||||
const unwrappedParams = use(params);
|
||||
const { popup, setPopup } = usePopup();
|
||||
|
||||
console.log("unwrappedParams", unwrappedParams);
|
||||
const {
|
||||
data: slackBot,
|
||||
isLoading: isSlackBotLoading,
|
||||
|
||||
@@ -161,7 +161,7 @@ export default function UpgradingPage({
|
||||
reindexingProgress={sortedReindexingProgress}
|
||||
/>
|
||||
) : (
|
||||
<ErrorCallout errorTitle="Failed to fetch reindexing progress" />
|
||||
<ErrorCallout errorTitle="Failed to fetch re-indexing progress" />
|
||||
)}
|
||||
</>
|
||||
) : (
|
||||
@@ -171,7 +171,7 @@ export default function UpgradingPage({
|
||||
</h3>
|
||||
<p className="mb-4 text-text-800">
|
||||
You're currently switching embedding models, but there
|
||||
are no connectors to reindex. This means the transition will
|
||||
are no connectors to re-index. This means the transition will
|
||||
be quick and seamless!
|
||||
</p>
|
||||
<p className="text-text-600">
|
||||
|
||||
@@ -6,8 +6,6 @@ import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { mutate } from "swr";
|
||||
import { buildCCPairInfoUrl } from "./lib";
|
||||
import { setCCPairStatus } from "@/lib/ccPair";
|
||||
import { useState } from "react";
|
||||
import { LoadingAnimation } from "@/components/Loading";
|
||||
|
||||
export function ModifyStatusButtonCluster({
|
||||
ccPair,
|
||||
@@ -15,72 +13,44 @@ export function ModifyStatusButtonCluster({
|
||||
ccPair: CCPairFullInfo;
|
||||
}) {
|
||||
const { popup, setPopup } = usePopup();
|
||||
const [isUpdating, setIsUpdating] = useState(false);
|
||||
|
||||
const handleStatusChange = async (
|
||||
newStatus: ConnectorCredentialPairStatus
|
||||
) => {
|
||||
if (isUpdating) return; // Prevent double-clicks or multiple requests
|
||||
setIsUpdating(true);
|
||||
|
||||
try {
|
||||
// Call the backend to update the status
|
||||
await setCCPairStatus(ccPair.id, newStatus, setPopup);
|
||||
|
||||
// Use mutate to revalidate the status on the backend
|
||||
await mutate(buildCCPairInfoUrl(ccPair.id));
|
||||
} catch (error) {
|
||||
console.error("Failed to update status", error);
|
||||
} finally {
|
||||
// Reset local updating state and button text after mutation
|
||||
setIsUpdating(false);
|
||||
}
|
||||
};
|
||||
|
||||
// Compute the button text based on current state and backend status
|
||||
const buttonText =
|
||||
ccPair.status === ConnectorCredentialPairStatus.PAUSED
|
||||
? "Re-Enable"
|
||||
: "Pause";
|
||||
|
||||
const tooltip =
|
||||
ccPair.status === ConnectorCredentialPairStatus.PAUSED
|
||||
? "Click to start indexing again!"
|
||||
: "When paused, the connector's documents will still be visible. However, no new documents will be indexed.";
|
||||
|
||||
return (
|
||||
<>
|
||||
{popup}
|
||||
<Button
|
||||
className="flex items-center justify-center w-auto min-w-[100px] px-4 py-2"
|
||||
variant={
|
||||
ccPair.status === ConnectorCredentialPairStatus.PAUSED
|
||||
? "success-reverse"
|
||||
: "default"
|
||||
}
|
||||
disabled={isUpdating}
|
||||
onClick={() =>
|
||||
handleStatusChange(
|
||||
ccPair.status === ConnectorCredentialPairStatus.PAUSED
|
||||
? ConnectorCredentialPairStatus.ACTIVE
|
||||
: ConnectorCredentialPairStatus.PAUSED
|
||||
)
|
||||
}
|
||||
tooltip={tooltip}
|
||||
>
|
||||
{isUpdating ? (
|
||||
<LoadingAnimation
|
||||
text={
|
||||
ccPair.status === ConnectorCredentialPairStatus.PAUSED
|
||||
? "Resuming"
|
||||
: "Pausing"
|
||||
}
|
||||
size="text-md"
|
||||
/>
|
||||
) : (
|
||||
buttonText
|
||||
)}
|
||||
</Button>
|
||||
{ccPair.status === ConnectorCredentialPairStatus.PAUSED ? (
|
||||
<Button
|
||||
variant="success-reverse"
|
||||
onClick={() =>
|
||||
setCCPairStatus(
|
||||
ccPair.id,
|
||||
ConnectorCredentialPairStatus.ACTIVE,
|
||||
setPopup,
|
||||
() => mutate(buildCCPairInfoUrl(ccPair.id))
|
||||
)
|
||||
}
|
||||
tooltip="Click to start indexing again!"
|
||||
>
|
||||
Re-Enable
|
||||
</Button>
|
||||
) : (
|
||||
<Button
|
||||
variant="default"
|
||||
onClick={() =>
|
||||
setCCPairStatus(
|
||||
ccPair.id,
|
||||
ConnectorCredentialPairStatus.PAUSED,
|
||||
setPopup,
|
||||
() => mutate(buildCCPairInfoUrl(ccPair.id))
|
||||
)
|
||||
}
|
||||
tooltip={
|
||||
"When paused, the connectors documents will still" +
|
||||
" be visible. However, no new documents will be indexed."
|
||||
}
|
||||
>
|
||||
Pause
|
||||
</Button>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -121,7 +121,7 @@ export function ReIndexButton({
|
||||
{popup}
|
||||
<Button
|
||||
variant="success-reverse"
|
||||
className="ml-auto min-w-[100px]"
|
||||
className="ml-auto"
|
||||
onClick={() => {
|
||||
setReIndexPopupVisible(true);
|
||||
}}
|
||||
|
||||
@@ -25,7 +25,6 @@ import { ReIndexButton } from "./ReIndexButton";
|
||||
import { buildCCPairInfoUrl } from "./lib";
|
||||
import { CCPairFullInfo, ConnectorCredentialPairStatus } from "./types";
|
||||
import { EditableStringFieldDisplay } from "@/components/EditableStringFieldDisplay";
|
||||
import { Button } from "@/components/ui/button";
|
||||
|
||||
// since the uploaded files are cleaned up after some period of time
|
||||
// re-indexing will not work for the file connector. Also, it would not
|
||||
|
||||
@@ -83,7 +83,7 @@ const EditRow = ({
|
||||
</div>
|
||||
</TooltipTrigger>
|
||||
{!documentSet.is_up_to_date && (
|
||||
<TooltipContent width="max-w-sm">
|
||||
<TooltipContent maxWidth="max-w-sm">
|
||||
<div className="flex break-words break-keep whitespace-pre-wrap items-start">
|
||||
<InfoIcon className="mr-2 mt-0.5" />
|
||||
Cannot update while syncing! Wait for the sync to finish, then
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user