Compare commits

..

2 Commits

Author SHA1 Message Date
pablodanswer
ca418fdcf2 order seeding 2024-11-27 12:16:42 -08:00
pablodanswer
4a1230f028 proper no assistant typing + no assistant modal 2024-11-27 09:55:17 -08:00
175 changed files with 4015 additions and 5989 deletions

View File

@@ -24,8 +24,6 @@ env:
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }}
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
# Slab
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
jobs:
connectors-check:

View File

@@ -73,7 +73,6 @@ RUN apt-get update && \
rm -rf /var/lib/apt/lists/* && \
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
# Pre-downloading models for setups with limited egress
RUN python -c "from tokenizers import Tokenizer; \
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"

View File

@@ -1,35 +0,0 @@
"""add web ui option to slack config
Revision ID: 93560ba1b118
Revises: 6d562f86c78b
Create Date: 2024-11-24 06:36:17.490612
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "93560ba1b118"
down_revision = "6d562f86c78b"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add show_continue_in_web_ui with default False to all existing channel_configs
op.execute(
"""
UPDATE slack_channel_config
SET channel_config = channel_config || '{"show_continue_in_web_ui": false}'::jsonb
WHERE NOT channel_config ? 'show_continue_in_web_ui'
"""
)
def downgrade() -> None:
# Remove show_continue_in_web_ui from all channel_configs
op.execute(
"""
UPDATE slack_channel_config
SET channel_config = channel_config - 'show_continue_in_web_ui'
"""
)

View File

@@ -1,27 +0,0 @@
"""add auto scroll to user model
Revision ID: a8c2065484e6
Revises: abe7378b8217
Create Date: 2024-11-22 17:34:09.690295
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "a8c2065484e6"
down_revision = "abe7378b8217"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column("auto_scroll", sa.Boolean(), nullable=True, server_default=None),
)
def downgrade() -> None:
op.drop_column("user", "auto_scroll")

View File

@@ -1,30 +0,0 @@
"""add indexing trigger to cc_pair
Revision ID: abe7378b8217
Revises: 6d562f86c78b
Create Date: 2024-11-26 19:09:53.481171
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "abe7378b8217"
down_revision = "93560ba1b118"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column(
"indexing_trigger",
sa.Enum("UPDATE", "REINDEX", name="indexingmode", native_enum=False),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("connector_credential_pair", "indexing_trigger")

View File

@@ -23,9 +23,7 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
)
return UserPreferences(**preferences_data)
except KvKeyNotFoundError:
return UserPreferences(
chosen_assistants=None, default_model=None, auto_scroll=True
)
return UserPreferences(chosen_assistants=None, default_model=None)
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:

View File

@@ -5,6 +5,7 @@ from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
@@ -36,7 +37,7 @@ class TaskDependencyError(RuntimeError):
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
lock_beat = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -59,7 +60,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
redis_connector = RedisConnector(tenant_id, cc_pair_id)
try:
try_generate_document_cc_pair_cleanup_tasks(
self.app, cc_pair_id, db_session, lock_beat, tenant_id
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
)
except TaskDependencyError as e:
# this means we wanted to start deleting but dependent tasks were running
@@ -85,6 +86,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
app: Celery,
cc_pair_id: int,
db_session: Session,
r: Redis,
lock_beat: RedisLock,
tenant_id: str | None,
) -> int | None:

View File

@@ -8,7 +8,6 @@ from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from danswer.access.models import DocExternalAccess
from danswer.background.celery.apps.app_base import task_logger
@@ -28,7 +27,7 @@ from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_ext_perm_user_if_not_exists
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncPayload,
RedisConnectorPermissionSyncData,
)
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import doc_permission_sync_ctx
@@ -139,7 +138,7 @@ def try_creating_permissions_sync_task(
LOCK_TIMEOUT = 30
lock: RedisLock = r.lock(
lock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
timeout=LOCK_TIMEOUT,
)
@@ -163,7 +162,7 @@ def try_creating_permissions_sync_task(
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
result = app.send_task(
app.send_task(
"connector_permission_sync_generator_task",
kwargs=dict(
cc_pair_id=cc_pair_id,
@@ -175,8 +174,8 @@ def try_creating_permissions_sync_task(
)
# set a basic fence to start
payload = RedisConnectorPermissionSyncPayload(
started=None, celery_task_id=result.id
payload = RedisConnectorPermissionSyncData(
started=None,
)
redis_connector.permissions.set_fence(payload)
@@ -242,17 +241,13 @@ def connector_permission_sync_generator_task(
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
if doc_sync_func is None:
raise ValueError(
f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}"
)
raise ValueError(f"No doc sync func found for {source_type}")
logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}")
logger.info(f"Syncing docs for {source_type}")
payload = redis_connector.permissions.payload
if not payload:
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
payload.started = datetime.now(timezone.utc)
payload = RedisConnectorPermissionSyncData(
started=datetime.now(timezone.utc),
)
redis_connector.permissions.set_fence(payload)
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)

View File

@@ -8,7 +8,6 @@ from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
@@ -25,9 +24,6 @@ from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_ext_group_sync import (
RedisConnectorExternalGroupSyncPayload,
)
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
@@ -53,7 +49,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
if cc_pair.access_type != AccessType.SYNC:
return False
# skip external group sync if not active
# skip pruning if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
@@ -111,7 +107,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
cc_pair_ids_to_sync.append(cc_pair.id)
for cc_pair_id in cc_pair_ids_to_sync:
tasks_created = try_creating_external_group_sync_task(
tasks_created = try_creating_permissions_sync_task(
self.app, cc_pair_id, r, tenant_id
)
if not tasks_created:
@@ -129,7 +125,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
lock_beat.release()
def try_creating_external_group_sync_task(
def try_creating_permissions_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
@@ -160,7 +156,7 @@ def try_creating_external_group_sync_task(
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
result = app.send_task(
_ = app.send_task(
"connector_external_group_sync_generator_task",
kwargs=dict(
cc_pair_id=cc_pair_id,
@@ -170,13 +166,8 @@ def try_creating_external_group_sync_task(
task_id=custom_task_id,
priority=DanswerCeleryPriority.HIGH,
)
payload = RedisConnectorExternalGroupSyncPayload(
started=datetime.now(timezone.utc),
celery_task_id=result.id,
)
redis_connector.external_group_sync.set_fence(payload)
# set a basic fence to start
redis_connector.external_group_sync.set_fence(True)
except Exception:
task_logger.exception(
@@ -204,7 +195,7 @@ def connector_external_group_sync_generator_task(
tenant_id: str | None,
) -> None:
"""
Permission sync task that handles external group syncing for a given connector credential pair
Permission sync task that handles document permission syncing for a given connector credential pair
This task assumes that the task has already been properly fenced
"""
@@ -212,7 +203,7 @@ def connector_external_group_sync_generator_task(
r = get_redis_client(tenant_id=tenant_id)
lock: RedisLock = r.lock(
lock = r.lock(
DanswerRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
@@ -237,13 +228,9 @@ def connector_external_group_sync_generator_task(
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
if ext_group_sync_func is None:
raise ValueError(
f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
)
raise ValueError(f"No external group sync func found for {source_type}")
logger.info(
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
)
logger.info(f"Syncing docs for {source_type}")
external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)
@@ -262,6 +249,7 @@ def connector_external_group_sync_generator_task(
)
mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)
except Exception as e:
task_logger.exception(
f"Failed to run external group sync: cc_pair={cc_pair_id}"
@@ -272,6 +260,6 @@ def connector_external_group_sync_generator_task(
raise e
finally:
# we always want to clear the fence after the task is done or failed so it doesn't get stuck
redis_connector.external_group_sync.set_fence(None)
redis_connector.external_group_sync.set_fence(False)
if lock.owned():
lock.release()

View File

@@ -25,13 +25,11 @@ from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import DocumentSource
from danswer.db.connector import mark_ccpair_with_indexing_trigger
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.enums import IndexingMode
from danswer.db.enums import IndexingStatus
from danswer.db.enums import IndexModelStatus
from danswer.db.index_attempt import create_index_attempt
@@ -39,13 +37,12 @@ from danswer.db.index_attempt import delete_index_attempt
from danswer.db.index_attempt import get_all_index_attempts_by_status
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
from danswer.db.index_attempt import mark_attempt_canceled
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexAttempt
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_active_search_settings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
@@ -162,7 +159,7 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
)
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
tasks_created = 0
locked = False
r = get_redis_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
@@ -175,8 +172,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
if not lock_beat.acquire(blocking=False):
return None
locked = True
# check for search settings swap
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
old_search_settings = check_index_swap(db_session=db_session)
@@ -210,10 +205,17 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
redis_connector = RedisConnector(tenant_id, cc_pair_id)
with get_session_with_tenant(tenant_id) as db_session:
search_settings_list: list[SearchSettings] = get_active_search_settings(
db_session
)
for search_settings_instance in search_settings_list:
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings = [primary_search_settings]
# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings.append(secondary_search_settings)
for search_settings_instance in search_settings:
redis_connector_index = redis_connector.new_index(
search_settings_instance.id
)
@@ -229,46 +231,22 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
last_attempt = get_last_attempt_for_cc_pair(
cc_pair.id, search_settings_instance.id, db_session
)
search_settings_primary = False
if search_settings_instance.id == search_settings_list[0].id:
search_settings_primary = True
if not _should_index(
cc_pair=cc_pair,
last_index=last_attempt,
search_settings_instance=search_settings_instance,
search_settings_primary=search_settings_primary,
secondary_index_building=len(search_settings_list) > 1,
secondary_index_building=len(search_settings) > 1,
db_session=db_session,
):
continue
reindex = False
if search_settings_instance.id == search_settings_list[0].id:
# the indexing trigger is only checked and cleared with the primary search settings
if cc_pair.indexing_trigger is not None:
if cc_pair.indexing_trigger == IndexingMode.REINDEX:
reindex = True
task_logger.info(
f"Connector indexing manual trigger detected: "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings_instance.id} "
f"indexing_mode={cc_pair.indexing_trigger}"
)
mark_ccpair_with_indexing_trigger(
cc_pair.id, None, db_session
)
# using a task queue and only allowing one task per cc_pair/search_setting
# prevents us from starving out certain attempts
attempt_id = try_creating_indexing_task(
self.app,
cc_pair,
search_settings_instance,
reindex,
False,
db_session,
r,
tenant_id,
@@ -278,7 +256,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
f"Connector indexing queued: "
f"index_attempt={attempt_id} "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings_instance.id}"
f"search_settings={search_settings_instance.id} "
)
tasks_created += 1
@@ -303,6 +281,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
mark_attempt_failed(
attempt.id, db_session, failure_reason=failure_reason
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -310,14 +289,13 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if locked:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error(
"check_for_indexing - Lock not owned on completion: "
f"tenant={tenant_id}"
)
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error(
"check_for_indexing - Lock not owned on completion: "
f"tenant={tenant_id}"
)
return tasks_created
@@ -326,7 +304,6 @@ def _should_index(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
search_settings_instance: SearchSettings,
search_settings_primary: bool,
secondary_index_building: bool,
db_session: Session,
) -> bool:
@@ -391,11 +368,6 @@ def _should_index(
):
return False
if search_settings_primary:
if cc_pair.indexing_trigger is not None:
# if a manual indexing trigger is on the cc pair, honor it for primary search settings
return True
# if no attempt has ever occurred, we should index regardless of refresh_freq
if not last_index:
return True
@@ -523,11 +495,8 @@ def try_creating_indexing_task(
return index_attempt_id
@shared_task(
name="connector_indexing_proxy_task", bind=True, acks_late=False, track_started=True
)
@shared_task(name="connector_indexing_proxy_task", acks_late=False, track_started=True)
def connector_indexing_proxy_task(
self: Task,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
@@ -540,10 +509,6 @@ def connector_indexing_proxy_task(
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
if not self.request.id:
task_logger.error("self.request.id is None!")
client = SimpleJobClient()
job = client.submit(
@@ -572,30 +537,8 @@ def connector_indexing_proxy_task(
f"search_settings={search_settings_id}"
)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
while True:
sleep(5)
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
"Indexing proxy - termination signal detected: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_canceled(
index_attempt_id,
db_session,
"Connector termination signal detected",
)
job.cancel()
break
sleep(10)
# do nothing for ongoing jobs that haven't been stopped
if not job.done():

View File

@@ -46,7 +46,6 @@ from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import IndexingStatus
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
@@ -59,7 +58,7 @@ from danswer.redis.redis_connector_credential_pair import RedisConnectorCredenti
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncPayload,
RedisConnectorPermissionSyncData,
)
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
@@ -589,7 +588,7 @@ def monitor_ccpair_permissions_taskset(
if remaining > 0:
return
payload: RedisConnectorPermissionSyncPayload | None = (
payload: RedisConnectorPermissionSyncData | None = (
redis_connector.permissions.payload
)
start_time: datetime | None = payload.started if payload else None
@@ -597,7 +596,9 @@ def monitor_ccpair_permissions_taskset(
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
redis_connector.permissions.reset()
redis_connector.permissions.taskset_clear()
redis_connector.permissions.generator_clear()
redis_connector.permissions.set_fence(None)
def monitor_ccpair_indexing_taskset(
@@ -677,15 +678,11 @@ def monitor_ccpair_indexing_taskset(
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
if index_attempt:
if (
index_attempt.status != IndexingStatus.CANCELED
and index_attempt.status != IndexingStatus.FAILED
):
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason=msg,
)
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason=msg,
)
redis_connector_index.reset()
return
@@ -695,7 +692,6 @@ def monitor_ccpair_indexing_taskset(
task_logger.info(
f"Connector indexing finished: cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"status={status_enum.name} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
@@ -728,7 +724,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
# print current queue lengths
r_celery = self.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r_celery)
n_celery = celery_get_queue_length("celery", r)
n_indexing = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_INDEXING, r_celery
)

View File

@@ -1,8 +1,6 @@
"""Factory stub for running celery worker / celery beat."""
from celery import Celery
from danswer.background.celery.apps.beat import celery_app
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
app: Celery = celery_app
app = celery_app

View File

@@ -1,10 +1,8 @@
"""Factory stub for running celery worker / celery beat."""
from celery import Celery
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
app: Celery = fetch_versioned_implementation(
app = fetch_versioned_implementation(
"danswer.background.celery.apps.primary", "celery_app"
)

View File

@@ -19,7 +19,6 @@ from danswer.db.connector_credential_pair import get_last_successful_attempt_tim
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.index_attempt import mark_attempt_canceled
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_partially_succeeded
from danswer.db.index_attempt import mark_attempt_succeeded
@@ -88,10 +87,6 @@ def _get_connector_runner(
)
class ConnectorStopSignal(Exception):
"""A custom exception used to signal a stop in processing."""
def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
@@ -213,7 +208,9 @@ def _run_indexing(
# contents still need to be initially pulled.
if callback:
if callback.should_stop():
raise ConnectorStopSignal("Connector stop signal detected")
raise RuntimeError(
"_run_indexing: Connector stop signal detected"
)
# TODO: should we move this into the above callback instead?
db_session.refresh(db_cc_pair)
@@ -307,16 +304,26 @@ def _run_indexing(
)
except Exception as e:
logger.exception(
f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds"
f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds"
)
if isinstance(e, ConnectorStopSignal):
mark_attempt_canceled(
# Only mark the attempt as a complete failure if this is the first indexing window.
# Otherwise, some progress was made - the next run will not start from the beginning.
# In this case, it is not accurate to mark it as a failure. When the next run begins,
# if that fails immediately, it will be marked as a failure.
#
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
# to give better clarity in the UI, as the next run will never happen.
if (
ind == 0
or not db_cc_pair.status.is_active()
or index_attempt.status != IndexingStatus.IN_PROGRESS
):
mark_attempt_failed(
index_attempt.id,
db_session,
reason=str(e),
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
@@ -328,37 +335,6 @@ def _run_indexing(
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
else:
# Only mark the attempt as a complete failure if this is the first indexing window.
# Otherwise, some progress was made - the next run will not start from the beginning.
# In this case, it is not accurate to mark it as a failure. When the next run begins,
# if that fails immediately, it will be marked as a failure.
#
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
# to give better clarity in the UI, as the next run will never happen.
if (
ind == 0
or not db_cc_pair.status.is_active()
or index_attempt.status != IndexingStatus.IN_PROGRESS
):
mark_attempt_failed(
index_attempt.id,
db_session,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
net_docs=net_doc_change,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
# break => similar to success case. As mentioned above, if the next run fails for the same
# reason it will then be marked as a failure

View File

@@ -605,7 +605,6 @@ def stream_chat_message_objects(
additional_headers=custom_tool_additional_headers,
),
)
tools: list[Tool] = []
for tool_list in tool_dict.values():
tools.extend(tool_list)

View File

@@ -493,6 +493,10 @@ CONTROL_PLANE_API_BASE_URL = os.environ.get(
# JWT configuration
JWT_ALGORITHM = "HS256"
# Super Users
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
#####
# API Key Configs

View File

@@ -70,9 +70,7 @@ GEN_AI_NUM_RESERVED_OUTPUT_TOKENS = int(
)
# Typically, GenAI models nowadays are at least 4K tokens
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = int(
os.environ.get("GEN_AI_MODEL_FALLBACK_MAX_TOKENS") or 4096
)
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = 4096
# Number of tokens from chat history to include at maximum
# 3000 should be enough context regardless of use, no need to include as much as possible

View File

@@ -11,16 +11,11 @@ Connectors come in 3 different flows:
- Load Connector:
- Bulk indexes documents to reflect a point in time. This type of connector generally works by either pulling all
documents via a connector's API or loads the documents from some sort of a dump file.
- Poll Connector:
- Poll connector:
- Incrementally updates documents based on a provided time range. It is used by the background job to pull the latest
changes and additions since the last round of polling. This connector helps keep the document index up to date
without needing to fetch/embed/index every document which would be too slow to do frequently on large sets of
documents.
- Slim Connector:
- This connector should be a lighter weight method of checking all documents in the source to see if they still exist.
- This connector should be identical to the Poll or Load Connector except that it only fetches the IDs of the documents, not the documents themselves.
- This is used by our pruning job which removes old documents from the index.
- The optional start and end datetimes can be ignored.
- Event Based connectors:
- Connectors that listen to events and update documents accordingly.
- Currently not used by the background job, this exists for future design purposes.
@@ -31,14 +26,8 @@ Refer to [interfaces.py](https://github.com/danswer-ai/danswer/blob/main/backend
and this first contributor created Pull Request for a new connector (Shoutout to Dan Brown):
[Reference Pull Request](https://github.com/danswer-ai/danswer/pull/139)
For implementing a Slim Connector, refer to the comments in this PR:
[Slim Connector PR](https://github.com/danswer-ai/danswer/pull/3303/files)
All new connectors should have tests added to the `backend/tests/daily/connectors` directory. Refer to the above PR for an example of adding tests for a new connector.
#### Implementing the new Connector
The connector must subclass one or more of LoadConnector, PollConnector, SlimConnector, or EventConnector.
The connector must subclass one or more of LoadConnector, PollConnector, or EventConnector.
The `__init__` should take arguments for configuring what documents the connector will and where it finds those
documents. For example, if you have a wiki site, it may include the configuration for the team, topic, folder, etc. of

View File

@@ -51,7 +51,7 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
"restrictions.read.restrictions.group",
]
_SLIM_DOC_BATCH_SIZE = 5000
_SLIM_DOC_BATCH_SIZE = 1000
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
@@ -301,8 +301,5 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
perm_sync_data=perm_sync_data,
)
)
if len(doc_metadata_list) > _SLIM_DOC_BATCH_SIZE:
yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE]
doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:]
yield doc_metadata_list
yield doc_metadata_list
doc_metadata_list = []

View File

@@ -120,7 +120,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
return cast(F, wrapped_call)
_DEFAULT_PAGINATION_LIMIT = 1000
_DEFAULT_PAGINATION_LIMIT = 100
class OnyxConfluence(Confluence):

View File

@@ -12,15 +12,12 @@ from dateutil import parser
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.interfaces import SlimConnector
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.models import SlimDocument
from danswer.utils.logger import setup_logger
@@ -31,8 +28,6 @@ logger = setup_logger()
SLAB_GRAPHQL_MAX_TRIES = 10
SLAB_API_URL = "https://api.slab.com/v1/graphql"
_SLIM_BATCH_SIZE = 1000
def run_graphql_request(
graphql_query: dict, bot_token: str, max_tries: int = SLAB_GRAPHQL_MAX_TRIES
@@ -163,26 +158,21 @@ def get_slab_url_from_title_id(base_url: str, title: str, page_id: str) -> str:
return urljoin(urljoin(base_url, "posts/"), url_id)
class SlabConnector(LoadConnector, PollConnector, SlimConnector):
class SlabConnector(LoadConnector, PollConnector):
def __init__(
self,
base_url: str,
batch_size: int = INDEX_BATCH_SIZE,
slab_bot_token: str | None = None,
) -> None:
self.base_url = base_url
self.batch_size = batch_size
self._slab_bot_token: str | None = None
self.slab_bot_token = slab_bot_token
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self._slab_bot_token = credentials["slab_bot_token"]
self.slab_bot_token = credentials["slab_bot_token"]
return None
@property
def slab_bot_token(self) -> str:
if self._slab_bot_token is None:
raise ConnectorMissingCredentialError("Slab")
return self._slab_bot_token
def _iterate_posts(
self, time_filter: Callable[[datetime], bool] | None = None
) -> GenerateDocumentsOutput:
@@ -237,21 +227,3 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnector):
yield from self._iterate_posts(
time_filter=lambda t: start_time <= t <= end_time
)
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
slim_doc_batch: list[SlimDocument] = []
for post_id in get_all_post_ids(self.slab_bot_token):
slim_doc_batch.append(
SlimDocument(
id=post_id,
)
)
if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
yield slim_doc_batch
slim_doc_batch = []
if slim_doc_batch:
yield slim_doc_batch

View File

@@ -18,30 +18,20 @@ from slack_sdk.models.blocks.block_elements import ImageElement
from danswer.chat.models import DanswerQuote
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import SearchFeedbackType
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY
from danswer.context.search.models import SavedSearchDoc
from danswer.danswerbot.slack.constants import CONTINUE_IN_WEB_UI_ACTION_ID
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID
from danswer.danswerbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.formatting import format_slack_message
from danswer.danswerbot.slack.icons import source_to_github_img_link
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import build_continue_in_web_ui_id
from danswer.danswerbot.slack.utils import build_feedback_id
from danswer.danswerbot.slack.utils import remove_slack_text_interactions
from danswer.danswerbot.slack.utils import translate_vespa_highlight_to_slack
from danswer.db.chat import get_chat_session_by_message_id
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import ChannelConfig
from danswer.db.models import Persona
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.utils.text_processing import decode_escapes
from danswer.utils.text_processing import replace_whitespaces_w_space
@@ -111,12 +101,12 @@ def _split_text(text: str, limit: int = 3000) -> list[str]:
return chunks
def _clean_markdown_link_text(text: str) -> str:
def clean_markdown_link_text(text: str) -> str:
# Remove any newlines within the text
return text.replace("\n", " ").strip()
def _build_qa_feedback_block(
def build_qa_feedback_block(
message_id: int, feedback_reminder_id: str | None = None
) -> Block:
return ActionsBlock(
@@ -125,6 +115,7 @@ def _build_qa_feedback_block(
ButtonElement(
action_id=LIKE_BLOCK_ACTION_ID,
text="👍 Helpful",
style="primary",
value=feedback_reminder_id,
),
ButtonElement(
@@ -164,7 +155,7 @@ def get_document_feedback_blocks() -> Block:
)
def _build_doc_feedback_block(
def build_doc_feedback_block(
message_id: int,
document_id: str,
document_rank: int,
@@ -191,7 +182,7 @@ def get_restate_blocks(
]
def _build_documents_blocks(
def build_documents_blocks(
documents: list[SavedSearchDoc],
message_id: int | None,
num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY,
@@ -232,7 +223,7 @@ def _build_documents_blocks(
feedback: ButtonElement | dict = {}
if message_id is not None:
feedback = _build_doc_feedback_block(
feedback = build_doc_feedback_block(
message_id=message_id,
document_id=d.document_id,
document_rank=rank,
@@ -250,7 +241,7 @@ def _build_documents_blocks(
return section_blocks
def _build_sources_blocks(
def build_sources_blocks(
cited_documents: list[tuple[int, SavedSearchDoc]],
num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY,
) -> list[Block]:
@@ -295,7 +286,7 @@ def _build_sources_blocks(
+ ([days_ago_str] if days_ago_str else [])
)
document_title = _clean_markdown_link_text(doc_sem_id)
document_title = clean_markdown_link_text(doc_sem_id)
img_link = source_to_github_img_link(d.source_type)
section_blocks.append(
@@ -326,50 +317,7 @@ def _build_sources_blocks(
return section_blocks
def _priority_ordered_documents_blocks(
answer: OneShotQAResponse,
) -> list[Block]:
docs_response = answer.docs if answer.docs else None
top_docs = docs_response.top_documents if docs_response else []
llm_doc_inds = answer.llm_selected_doc_indices or []
llm_docs = [top_docs[i] for i in llm_doc_inds]
remaining_docs = [
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
]
priority_ordered_docs = llm_docs + remaining_docs
if not priority_ordered_docs:
return []
document_blocks = _build_documents_blocks(
documents=priority_ordered_docs,
message_id=answer.chat_message_id,
)
if document_blocks:
document_blocks = [DividerBlock()] + document_blocks
return document_blocks
def _build_citations_blocks(
answer: OneShotQAResponse,
) -> list[Block]:
docs_response = answer.docs if answer.docs else None
top_docs = docs_response.top_documents if docs_response else []
citations = answer.citations or []
cited_docs = []
for citation in citations:
matching_doc = next(
(d for d in top_docs if d.document_id == citation.document_id),
None,
)
if matching_doc:
cited_docs.append((citation.citation_num, matching_doc))
cited_docs.sort()
citations_block = _build_sources_blocks(cited_documents=cited_docs)
return citations_block
def _build_quotes_block(
def build_quotes_block(
quotes: list[DanswerQuote],
) -> list[Block]:
quote_lines: list[str] = []
@@ -411,70 +359,58 @@ def _build_quotes_block(
return [SectionBlock(text="*Relevant Snippets*\n" + "\n".join(quote_lines))]
def _build_qa_response_blocks(
answer: OneShotQAResponse,
def build_qa_response_blocks(
message_id: int | None,
answer: str | None,
quotes: list[DanswerQuote] | None,
source_filters: list[DocumentSource] | None,
time_cutoff: datetime | None,
favor_recent: bool,
skip_quotes: bool = False,
process_message_for_citations: bool = False,
skip_ai_feedback: bool = False,
feedback_reminder_id: str | None = None,
) -> list[Block]:
retrieval_info = answer.docs
if not retrieval_info:
# This should not happen, even with no docs retrieved, there is still info returned
raise RuntimeError("Failed to retrieve docs, cannot answer question.")
formatted_answer = format_slack_message(answer.answer) if answer.answer else None
quotes = answer.quotes.quotes if answer.quotes else None
if DISABLE_GENERATIVE_AI:
return []
quotes_blocks: list[Block] = []
filter_block: Block | None = None
if (
retrieval_info.applied_time_cutoff
or retrieval_info.recency_bias_multiplier > 1
or retrieval_info.applied_source_filters
):
if time_cutoff or favor_recent or source_filters:
filter_text = "Filters: "
if retrieval_info.applied_source_filters:
sources_str = ", ".join(
[s.value for s in retrieval_info.applied_source_filters]
)
if source_filters:
sources_str = ", ".join([s.value for s in source_filters])
filter_text += f"`Sources in [{sources_str}]`"
if (
retrieval_info.applied_time_cutoff
or retrieval_info.recency_bias_multiplier > 1
):
if time_cutoff or favor_recent:
filter_text += " and "
if retrieval_info.applied_time_cutoff is not None:
time_str = retrieval_info.applied_time_cutoff.strftime("%b %d, %Y")
if time_cutoff is not None:
time_str = time_cutoff.strftime("%b %d, %Y")
filter_text += f"`Docs Updated >= {time_str}` "
if retrieval_info.recency_bias_multiplier > 1:
if retrieval_info.applied_time_cutoff is not None:
if favor_recent:
if time_cutoff is not None:
filter_text += "+ "
filter_text += "`Prioritize Recently Updated Docs`"
filter_block = SectionBlock(text=f"_{filter_text}_")
if not formatted_answer:
if not answer:
answer_blocks = [
SectionBlock(
text="Sorry, I was unable to find an answer, but I did find some potentially relevant docs 🤓"
)
]
else:
answer_processed = decode_escapes(
remove_slack_text_interactions(formatted_answer)
)
answer_processed = decode_escapes(remove_slack_text_interactions(answer))
if process_message_for_citations:
answer_processed = _process_citations_for_slack(answer_processed)
answer_blocks = [
SectionBlock(text=text) for text in _split_text(answer_processed)
]
if quotes:
quotes_blocks = _build_quotes_block(quotes)
quotes_blocks = build_quotes_block(quotes)
# if no quotes OR `_build_quotes_block()` did not give back any blocks
# if no quotes OR `build_quotes_block()` did not give back any blocks
if not quotes_blocks:
quotes_blocks = [
SectionBlock(
@@ -489,37 +425,20 @@ def _build_qa_response_blocks(
response_blocks.extend(answer_blocks)
if message_id is not None and not skip_ai_feedback:
response_blocks.append(
build_qa_feedback_block(
message_id=message_id, feedback_reminder_id=feedback_reminder_id
)
)
if not skip_quotes:
response_blocks.extend(quotes_blocks)
return response_blocks
def _build_continue_in_web_ui_block(
tenant_id: str | None,
message_id: int | None,
) -> Block:
if message_id is None:
raise ValueError("No message id provided to build continue in web ui block")
with get_session_with_tenant(tenant_id) as db_session:
chat_session = get_chat_session_by_message_id(
db_session=db_session,
message_id=message_id,
)
return ActionsBlock(
block_id=build_continue_in_web_ui_id(message_id),
elements=[
ButtonElement(
action_id=CONTINUE_IN_WEB_UI_ACTION_ID,
text="Continue Chat in Danswer!",
style="primary",
url=f"{WEB_DOMAIN}/chat?slackChatId={chat_session.id}",
),
],
)
def _build_follow_up_block(message_id: int | None) -> ActionsBlock:
def build_follow_up_block(message_id: int | None) -> ActionsBlock:
return ActionsBlock(
block_id=build_feedback_id(message_id) if message_id is not None else None,
elements=[
@@ -564,77 +483,3 @@ def build_follow_up_resolved_blocks(
]
)
return [text_block, button_block]
def build_slack_response_blocks(
tenant_id: str | None,
message_info: SlackMessageInfo,
answer: OneShotQAResponse,
persona: Persona | None,
channel_conf: ChannelConfig | None,
use_citations: bool,
feedback_reminder_id: str | None,
skip_ai_feedback: bool = False,
) -> list[Block]:
"""
This function is a top level function that builds all the blocks for the Slack response.
It also handles combining all the blocks together.
"""
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
restate_question_block = get_restate_blocks(
message_info.thread_messages[-1].message, message_info.is_bot_msg
)
answer_blocks = _build_qa_response_blocks(
answer=answer,
skip_quotes=persona is not None or use_citations,
process_message_for_citations=use_citations,
)
web_follow_up_block = []
if channel_conf and channel_conf.get("show_continue_in_web_ui"):
web_follow_up_block.append(
_build_continue_in_web_ui_block(
tenant_id=tenant_id,
message_id=answer.chat_message_id,
)
)
follow_up_block = []
if channel_conf and channel_conf.get("follow_up_tags") is not None:
follow_up_block.append(
_build_follow_up_block(message_id=answer.chat_message_id)
)
ai_feedback_block = []
if answer.chat_message_id is not None and not skip_ai_feedback:
ai_feedback_block.append(
_build_qa_feedback_block(
message_id=answer.chat_message_id,
feedback_reminder_id=feedback_reminder_id,
)
)
citations_blocks = []
document_blocks = []
if use_citations:
# if citations are enabled, only show cited documents
citations_blocks = _build_citations_blocks(answer)
else:
document_blocks = _priority_ordered_documents_blocks(answer)
citations_divider = [DividerBlock()] if citations_blocks else []
buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else []
all_blocks = (
restate_question_block
+ answer_blocks
+ ai_feedback_block
+ citations_divider
+ citations_blocks
+ document_blocks
+ buttons_divider
+ web_follow_up_block
+ follow_up_block
)
return all_blocks

View File

@@ -2,7 +2,6 @@ from enum import Enum
LIKE_BLOCK_ACTION_ID = "feedback-like"
DISLIKE_BLOCK_ACTION_ID = "feedback-dislike"
CONTINUE_IN_WEB_UI_ACTION_ID = "continue-in-web-ui"
FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID = "feedback-doc-button"
IMMEDIATE_RESOLVED_BUTTON_ACTION_ID = "immediate-resolved-button"
FOLLOWUP_BUTTON_ACTION_ID = "followup-button"

View File

@@ -28,7 +28,7 @@ from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import build_feedback_id
from danswer.danswerbot.slack.utils import decompose_action_id
from danswer.danswerbot.slack.utils import fetch_group_ids_from_names
from danswer.danswerbot.slack.utils import fetch_slack_user_ids_from_emails
from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails
from danswer.danswerbot.slack.utils import get_channel_name_from_id
from danswer.danswerbot.slack.utils import get_feedback_visibility
from danswer.danswerbot.slack.utils import read_slack_thread
@@ -267,7 +267,7 @@ def handle_followup_button(
tag_names = slack_channel_config.channel_config.get("follow_up_tags")
remaining = None
if tag_names:
tag_ids, remaining = fetch_slack_user_ids_from_emails(
tag_ids, remaining = fetch_user_ids_from_emails(
tag_names, client.web_client
)
if remaining:

View File

@@ -13,7 +13,7 @@ from danswer.danswerbot.slack.handlers.handle_standard_answers import (
handle_standard_answers,
)
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import fetch_slack_user_ids_from_emails
from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails
from danswer.danswerbot.slack.utils import fetch_user_ids_from_groups
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import slack_usage_report
@@ -184,7 +184,7 @@ def handle_message(
send_to: list[str] | None = None
missing_users: list[str] | None = None
if respond_member_group_list:
send_to, missing_ids = fetch_slack_user_ids_from_emails(
send_to, missing_ids = fetch_user_ids_from_emails(
respond_member_group_list, client
)

View File

@@ -7,6 +7,7 @@ from typing import TypeVar
from retry import retry
from slack_sdk import WebClient
from slack_sdk.models.blocks import DividerBlock
from slack_sdk.models.blocks import SectionBlock
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
@@ -24,7 +25,12 @@ from danswer.context.search.enums import OptionalSearchSetting
from danswer.context.search.models import BaseFilters
from danswer.context.search.models import RerankingDetails
from danswer.context.search.models import RetrievalDetails
from danswer.danswerbot.slack.blocks import build_slack_response_blocks
from danswer.danswerbot.slack.blocks import build_documents_blocks
from danswer.danswerbot.slack.blocks import build_follow_up_block
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
from danswer.danswerbot.slack.blocks import build_sources_blocks
from danswer.danswerbot.slack.blocks import get_restate_blocks
from danswer.danswerbot.slack.formatting import format_slack_message
from danswer.danswerbot.slack.handlers.utils import send_team_member_message
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import respond_in_thread
@@ -405,16 +411,62 @@ def handle_regular_answer(
)
return True
all_blocks = build_slack_response_blocks(
tenant_id=tenant_id,
message_info=message_info,
answer=answer,
persona=persona,
channel_conf=channel_conf,
use_citations=use_citations,
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg)
formatted_answer = format_slack_message(answer.answer) if answer.answer else None
answer_blocks = build_qa_response_blocks(
message_id=answer.chat_message_id,
answer=formatted_answer,
quotes=answer.quotes.quotes if answer.quotes else None,
source_filters=retrieval_info.applied_source_filters,
time_cutoff=retrieval_info.applied_time_cutoff,
favor_recent=retrieval_info.recency_bias_multiplier > 1,
# currently Personas don't support quotes
# if citations are enabled, also don't use quotes
skip_quotes=persona is not None or use_citations,
process_message_for_citations=use_citations,
feedback_reminder_id=feedback_reminder_id,
)
# Get the chunks fed to the LLM only, then fill with other docs
llm_doc_inds = answer.llm_selected_doc_indices or []
llm_docs = [top_docs[i] for i in llm_doc_inds]
remaining_docs = [
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
]
priority_ordered_docs = llm_docs + remaining_docs
document_blocks = []
citations_block = []
# if citations are enabled, only show cited documents
if use_citations:
citations = answer.citations or []
cited_docs = []
for citation in citations:
matching_doc = next(
(d for d in top_docs if d.document_id == citation.document_id),
None,
)
if matching_doc:
cited_docs.append((citation.citation_num, matching_doc))
cited_docs.sort()
citations_block = build_sources_blocks(cited_documents=cited_docs)
elif priority_ordered_docs:
document_blocks = build_documents_blocks(
documents=priority_ordered_docs,
message_id=answer.chat_message_id,
)
document_blocks = [DividerBlock()] + document_blocks
all_blocks = (
restate_question_block + answer_blocks + citations_block + document_blocks
)
if channel_conf and channel_conf.get("follow_up_tags") is not None:
all_blocks.append(build_follow_up_block(message_id=answer.chat_message_id))
try:
respond_in_thread(
client=client,

View File

@@ -3,9 +3,9 @@ import random
import re
import string
import time
import uuid
from typing import Any
from typing import cast
from typing import Optional
from retry import retry
from slack_sdk import WebClient
@@ -216,13 +216,6 @@ def build_feedback_id(
return unique_prefix + ID_SEPARATOR + feedback_id
def build_continue_in_web_ui_id(
message_id: int,
) -> str:
unique_prefix = str(uuid.uuid4())[:10]
return unique_prefix + ID_SEPARATOR + str(message_id)
def decompose_action_id(feedback_id: str) -> tuple[int, str | None, int | None]:
"""Decompose into query_id, document_id, document_rank, see above function"""
try:
@@ -320,7 +313,7 @@ def get_channel_name_from_id(
raise e
def fetch_slack_user_ids_from_emails(
def fetch_user_ids_from_emails(
user_emails: list[str], client: WebClient
) -> tuple[list[str], list[str]]:
user_ids: list[str] = []
@@ -529,7 +522,7 @@ class SlackRateLimiter:
self.last_reset_time = time.time()
def notify(
self, client: WebClient, channel: str, position: int, thread_ts: str | None
self, client: WebClient, channel: str, position: int, thread_ts: Optional[str]
) -> None:
respond_in_thread(
client=client,

View File

@@ -3,7 +3,6 @@ from datetime import datetime
from datetime import timedelta
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import func
@@ -31,7 +30,6 @@ from danswer.db.models import SearchDoc
from danswer.db.models import SearchDoc as DBSearchDoc
from danswer.db.models import ToolCall
from danswer.db.models import User
from danswer.db.persona import get_best_persona_id_for_user
from danswer.db.pg_file_store import delete_lobj_by_name
from danswer.file_store.models import FileDescriptor
from danswer.llm.override_models import LLMOverride
@@ -252,50 +250,6 @@ def create_chat_session(
return chat_session
def duplicate_chat_session_for_user_from_slack(
db_session: Session,
user: User | None,
chat_session_id: UUID,
) -> ChatSession:
"""
This takes a chat session id for a session in Slack and:
- Creates a new chat session in the DB
- Tries to copy the persona from the original chat session
(if it is available to the user clicking the button)
- Sets the user to the given user (if provided)
"""
chat_session = get_chat_session_by_id(
chat_session_id=chat_session_id,
user_id=None, # Ignore user permissions for this
db_session=db_session,
)
if not chat_session:
raise HTTPException(status_code=400, detail="Invalid Chat Session ID provided")
# This enforces permissions and sets a default
new_persona_id = get_best_persona_id_for_user(
db_session=db_session,
user=user,
persona_id=chat_session.persona_id,
)
return create_chat_session(
db_session=db_session,
user_id=user.id if user else None,
persona_id=new_persona_id,
# Set this to empty string so the frontend will force a rename
description="",
llm_override=chat_session.llm_override,
prompt_override=chat_session.prompt_override,
# Chat sessions from Slack should put people in the chat UI, not the search
one_shot=False,
# Chat is in UI now so this is false
danswerbot_flow=False,
# Maybe we want this in the future to track if it was created from Slack
slack_thread_id=None,
)
def update_chat_session(
db_session: Session,
user_id: UUID | None,
@@ -382,28 +336,6 @@ def get_chat_message(
return chat_message
def get_chat_session_by_message_id(
db_session: Session,
message_id: int,
) -> ChatSession:
"""
Should only be used for Slack
Get the chat session associated with a specific message ID
Note: this ignores permission checks.
"""
stmt = select(ChatMessage).where(ChatMessage.id == message_id)
result = db_session.execute(stmt)
chat_message = result.scalar_one_or_none()
if chat_message is None:
raise ValueError(
f"Unable to find chat session associated with message ID: {message_id}"
)
return chat_message.chat_session
def get_chat_messages_by_sessions(
chat_session_ids: list[UUID],
user_id: UUID | None,
@@ -423,44 +355,6 @@ def get_chat_messages_by_sessions(
return db_session.execute(stmt).scalars().all()
def add_chats_to_session_from_slack_thread(
db_session: Session,
slack_chat_session_id: UUID,
new_chat_session_id: UUID,
) -> None:
new_root_message = get_or_create_root_message(
chat_session_id=new_chat_session_id,
db_session=db_session,
)
for chat_message in get_chat_messages_by_sessions(
chat_session_ids=[slack_chat_session_id],
user_id=None, # Ignore user permissions for this
db_session=db_session,
skip_permission_check=True,
):
if chat_message.message_type == MessageType.SYSTEM:
continue
# Duplicate the message
new_root_message = create_new_chat_message(
db_session=db_session,
chat_session_id=new_chat_session_id,
parent_message=new_root_message,
message=chat_message.message,
files=chat_message.files,
rephrased_query=chat_message.rephrased_query,
error=chat_message.error,
citations=chat_message.citations,
reference_docs=chat_message.search_docs,
tool_call=chat_message.tool_call,
prompt_id=chat_message.prompt_id,
token_count=chat_message.token_count,
message_type=chat_message.message_type,
alternate_assistant_id=chat_message.alternate_assistant_id,
overridden_model=chat_message.overridden_model,
)
def get_search_docs_for_chat_message(
chat_message_id: int, db_session: Session
) -> list[SearchDoc]:

View File

@@ -12,7 +12,6 @@ from sqlalchemy.orm import Session
from danswer.configs.app_configs import DEFAULT_PRUNING_FREQ
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import InputType
from danswer.db.enums import IndexingMode
from danswer.db.models import Connector
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexAttempt
@@ -312,25 +311,3 @@ def mark_cc_pair_as_external_group_synced(db_session: Session, cc_pair_id: int)
# If this changes, we need to update this function.
cc_pair.last_time_external_group_sync = datetime.now(timezone.utc)
db_session.commit()
def mark_ccpair_with_indexing_trigger(
cc_pair_id: int, indexing_mode: IndexingMode | None, db_session: Session
) -> None:
"""indexing_mode sets a field which will be picked up by a background task
to trigger indexing. Set to None to disable the trigger."""
try:
cc_pair = db_session.execute(
select(ConnectorCredentialPair)
.where(ConnectorCredentialPair.id == cc_pair_id)
.with_for_update()
).scalar_one()
if cc_pair is None:
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")
cc_pair.indexing_trigger = indexing_mode
db_session.commit()
except Exception:
db_session.rollback()
raise

View File

@@ -324,11 +324,8 @@ def associate_default_cc_pair(db_session: Session) -> None:
def _relate_groups_to_cc_pair__no_commit(
db_session: Session,
cc_pair_id: int,
user_group_ids: list[int] | None = None,
user_group_ids: list[int],
) -> None:
if not user_group_ids:
return
for group_id in user_group_ids:
db_session.add(
UserGroup__ConnectorCredentialPair(
@@ -405,11 +402,12 @@ def add_credential_to_connector(
db_session.flush() # make sure the association has an id
db_session.refresh(association)
_relate_groups_to_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=association.id,
user_group_ids=groups,
)
if groups and access_type != AccessType.SYNC:
_relate_groups_to_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=association.id,
user_group_ids=groups,
)
db_session.commit()

View File

@@ -19,11 +19,6 @@ class IndexingStatus(str, PyEnum):
return self in terminal_states
class IndexingMode(str, PyEnum):
UPDATE = "update"
REINDEX = "reindex"
# these may differ in the future, which is why we're okay with this duplication
class DeletionStatus(str, PyEnum):
NOT_STARTED = "not_started"

View File

@@ -42,7 +42,7 @@ from danswer.configs.constants import DEFAULT_BOOST
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import FileOrigin
from danswer.configs.constants import MessageType
from danswer.db.enums import AccessType, IndexingMode
from danswer.db.enums import AccessType
from danswer.configs.constants import NotificationType
from danswer.configs.constants import SearchFeedbackType
from danswer.configs.constants import TokenRateLimitScope
@@ -126,7 +126,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
# if specified, controls the assistants that are shown to the user + their order
# if not specified, all assistants are shown
auto_scroll: Mapped[bool] = mapped_column(Boolean, default=True)
chosen_assistants: Mapped[list[int] | None] = mapped_column(
postgresql.JSONB(), nullable=True, default=None
)
@@ -439,10 +438,6 @@ class ConnectorCredentialPair(Base):
total_docs_indexed: Mapped[int] = mapped_column(Integer, default=0)
indexing_trigger: Mapped[IndexingMode | None] = mapped_column(
Enum(IndexingMode, native_enum=False), nullable=True
)
connector: Mapped["Connector"] = relationship(
"Connector", back_populates="credentials"
)
@@ -1485,7 +1480,6 @@ class ChannelConfig(TypedDict):
# If None then no follow up
# If empty list, follow up with no tags
follow_up_tags: NotRequired[list[str]]
show_continue_in_web_ui: NotRequired[bool] # defaults to False
class SlackBotResponseType(str, PyEnum):

View File

@@ -113,31 +113,6 @@ def fetch_persona_by_id(
return persona
def get_best_persona_id_for_user(
db_session: Session, user: User | None, persona_id: int | None = None
) -> int | None:
if persona_id is not None:
stmt = select(Persona).where(Persona.id == persona_id).distinct()
stmt = _add_user_filters(
stmt=stmt,
user=user,
# We don't want to filter by editable here, we just want to see if the
# persona is usable by the user
get_editable=False,
)
persona = db_session.scalars(stmt).one_or_none()
if persona:
return persona.id
# If the persona is not found, or the slack bot is using doc sets instead of personas,
# we need to find the best persona for the user
# This is the persona with the highest display priority that the user has access to
stmt = select(Persona).order_by(Persona.display_priority.desc()).distinct()
stmt = _add_user_filters(stmt=stmt, user=user, get_editable=True)
persona = db_session.scalars(stmt).one_or_none()
return persona.id if persona else None
def _get_persona_by_name(
persona_name: str, user: User | None, db_session: Session
) -> Persona | None:
@@ -185,7 +160,7 @@ def create_update_persona(
"persona_id": persona_id,
"user": user,
"db_session": db_session,
**create_persona_request.model_dump(exclude={"users", "groups"}),
**create_persona_request.dict(exclude={"users", "groups"}),
}
persona = upsert_persona(**persona_data)
@@ -758,8 +733,6 @@ def get_prompt_by_name(
if user and user.role != UserRole.ADMIN:
stmt = stmt.where(Prompt.user_id == user.id)
# Order by ID to ensure consistent result when multiple prompts exist
stmt = stmt.order_by(Prompt.id).limit(1)
result = db_session.execute(stmt).scalar_one_or_none()
return result

View File

@@ -143,25 +143,6 @@ def get_secondary_search_settings(db_session: Session) -> SearchSettings | None:
return latest_settings
def get_active_search_settings(db_session: Session) -> list[SearchSettings]:
"""Returns active search settings. The first entry will always be the current search
settings. If there are new search settings that are being migrated to, those will be
the second entry."""
search_settings_list: list[SearchSettings] = []
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings_list.append(primary_search_settings)
# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings_list.append(secondary_search_settings)
return search_settings_list
def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
query = select(SearchSettings).order_by(SearchSettings.id.desc())
result = db_session.execute(query)

View File

@@ -295,7 +295,7 @@ def pptx_to_text(file: IO[Any]) -> str:
def xlsx_to_text(file: IO[Any]) -> str:
workbook = openpyxl.load_workbook(file, read_only=True)
workbook = openpyxl.load_workbook(file)
text_content = []
for sheet in workbook.worksheets:
sheet_string = "\n".join(

View File

@@ -59,12 +59,6 @@ class FileStore(ABC):
Contents of the file and metadata dict
"""
@abstractmethod
def read_file_record(self, file_name: str) -> PGFileStore:
"""
Read the file record by the name
"""
@abstractmethod
def delete_file(self, file_name: str) -> None:
"""

View File

@@ -67,9 +67,9 @@ class CitationProcessor:
if piece_that_comes_after == "\n" and in_code_block(self.llm_out):
self.curr_segment = self.curr_segment.replace("```", "```plaintext")
citation_pattern = r"\[(\d+)\]|\[\[(\d+)\]\]" # [1], [[1]], etc.
citation_pattern = r"\[(\d+)\]"
citations_found = list(re.finditer(citation_pattern, self.curr_segment))
possible_citation_pattern = r"(\[+\d*$)" # [1, [, [[, [[2, etc.
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
possible_citation_found = re.search(
possible_citation_pattern, self.curr_segment
)
@@ -77,15 +77,13 @@ class CitationProcessor:
if len(citations_found) == 0 and len(self.llm_out) - self.past_cite_count > 5:
self.current_citations = []
result = ""
result = "" # Initialize result here
if citations_found and not in_code_block(self.llm_out):
last_citation_end = 0
length_to_add = 0
while len(citations_found) > 0:
citation = citations_found.pop(0)
numerical_value = int(
next(group for group in citation.groups() if group is not None)
)
numerical_value = int(citation.group(1))
if 1 <= numerical_value <= self.max_citation_num:
context_llm_doc = self.context_docs[numerical_value - 1]
@@ -133,6 +131,14 @@ class CitationProcessor:
link = context_llm_doc.link
# Replace the citation in the current segment
start, end = citation.span()
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[{target_citation_num}]"
+ self.curr_segment[end + length_to_add :]
)
self.past_cite_count = len(self.llm_out)
self.current_citations.append(target_citation_num)
@@ -143,7 +149,6 @@ class CitationProcessor:
document_id=context_llm_doc.document_id,
)
start, end = citation.span()
if link:
prev_length = len(self.curr_segment)
self.curr_segment = (

View File

@@ -26,9 +26,7 @@ from langchain_core.messages.tool import ToolMessage
from langchain_core.prompt_values import PromptValue
from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
from danswer.configs.model_configs import (
DISABLE_LITELLM_STREAMING,
)
from danswer.configs.model_configs import DISABLE_LITELLM_STREAMING
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.configs.model_configs import LITELLM_EXTRA_BODY
from danswer.llm.interfaces import LLM
@@ -163,9 +161,7 @@ def _convert_delta_to_message_chunk(
if role == "user":
return HumanMessageChunk(content=content)
# NOTE: if tool calls are present, then it's an assistant.
# In Ollama, the role will be None for tool-calls
elif role == "assistant" or tool_calls:
elif role == "assistant":
if tool_calls:
tool_call = tool_calls[0]
tool_name = tool_call.function.name or (curr_msg and curr_msg.name) or ""
@@ -240,7 +236,6 @@ class DefaultMultiLLM(LLM):
custom_config: dict[str, str] | None = None,
extra_headers: dict[str, str] | None = None,
extra_body: dict | None = LITELLM_EXTRA_BODY,
model_kwargs: dict[str, Any] | None = None,
long_term_logger: LongTermLogger | None = None,
):
self._timeout = timeout
@@ -273,7 +268,7 @@ class DefaultMultiLLM(LLM):
for k, v in custom_config.items():
os.environ[k] = v
model_kwargs = model_kwargs or {}
model_kwargs: dict[str, Any] = {}
if extra_headers:
model_kwargs.update({"extra_headers": extra_headers})
if extra_body:

View File

@@ -1,8 +1,5 @@
from typing import Any
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_default_provider
@@ -16,15 +13,6 @@ from danswer.utils.headers import build_llm_extra_headers
from danswer.utils.long_term_log import LongTermLogger
def _build_extra_model_kwargs(provider: str) -> dict[str, Any]:
"""Ollama requires us to specify the max context window.
For now, just using the GEN_AI_MODEL_FALLBACK_MAX_TOKENS value.
TODO: allow model-specific values to be configured via the UI.
"""
return {"num_ctx": GEN_AI_MODEL_FALLBACK_MAX_TOKENS} if provider == "ollama" else {}
def get_main_llm_from_tuple(
llms: tuple[LLM, LLM],
) -> LLM:
@@ -144,6 +132,5 @@ def get_llm(
temperature=temperature,
custom_config=custom_config,
extra_headers=build_llm_extra_headers(additional_headers),
model_kwargs=_build_extra_model_kwargs(provider),
long_term_logger=long_term_logger,
)

View File

@@ -1,4 +1,3 @@
import copy
import io
import json
from collections.abc import Callable
@@ -386,62 +385,6 @@ def test_llm(llm: LLM) -> str | None:
return error_msg
def get_model_map() -> dict:
starting_map = copy.deepcopy(cast(dict, litellm.model_cost))
# NOTE: we could add additional models here in the future,
# but for now there is no point. Ollama allows the user to
# to specify their desired max context window, and it's
# unlikely to be standard across users even for the same model
# (it heavily depends on their hardware). For now, we'll just
# rely on GEN_AI_MODEL_FALLBACK_MAX_TOKENS to cover this.
# for model_name in [
# "llama3.2",
# "llama3.2:1b",
# "llama3.2:3b",
# "llama3.2:11b",
# "llama3.2:90b",
# ]:
# starting_map[f"ollama/{model_name}"] = {
# "max_tokens": 128000,
# "max_input_tokens": 128000,
# "max_output_tokens": 128000,
# }
return starting_map
def _strip_extra_provider_from_model_name(model_name: str) -> str:
return model_name.split("/")[1] if "/" in model_name else model_name
def _strip_colon_from_model_name(model_name: str) -> str:
return ":".join(model_name.split(":")[:-1]) if ":" in model_name else model_name
def _find_model_obj(
model_map: dict, provider: str, model_names: list[str | None]
) -> dict | None:
# Filter out None values and deduplicate model names
filtered_model_names = [name for name in model_names if name]
# First try all model names with provider prefix
for model_name in filtered_model_names:
model_obj = model_map.get(f"{provider}/{model_name}")
if model_obj:
logger.debug(f"Using model object for {provider}/{model_name}")
return model_obj
# Then try all model names without provider prefix
for model_name in filtered_model_names:
model_obj = model_map.get(model_name)
if model_obj:
logger.debug(f"Using model object for {model_name}")
return model_obj
return None
def get_llm_max_tokens(
model_map: dict,
model_name: str,
@@ -454,22 +397,22 @@ def get_llm_max_tokens(
return GEN_AI_MAX_TOKENS
try:
extra_provider_stripped_model_name = _strip_extra_provider_from_model_name(
model_name
)
model_obj = _find_model_obj(
model_map,
model_provider,
[
model_name,
# Remove leading extra provider. Usually for cases where user has a
# customer model proxy which appends another prefix
extra_provider_stripped_model_name,
# remove :XXXX from the end, if present. Needed for ollama.
_strip_colon_from_model_name(model_name),
_strip_colon_from_model_name(extra_provider_stripped_model_name),
],
)
model_obj = model_map.get(f"{model_provider}/{model_name}")
if model_obj:
logger.debug(f"Using model object for {model_provider}/{model_name}")
if not model_obj:
model_obj = model_map.get(model_name)
if model_obj:
logger.debug(f"Using model object for {model_name}")
if not model_obj:
model_name_split = model_name.split("/")
if len(model_name_split) > 1:
model_obj = model_map.get(model_name_split[1])
if model_obj:
logger.debug(f"Using model object for {model_name_split[1]}")
if not model_obj:
raise RuntimeError(
f"No litellm entry found for {model_provider}/{model_name}"
@@ -545,7 +488,7 @@ def get_max_input_tokens(
# `model_cost` dict is a named public interface:
# https://litellm.vercel.app/docs/completion/token_usage#7-model_cost
# model_map is litellm.model_cost
litellm_model_map = get_model_map()
litellm_model_map = litellm.model_cost
input_toks = (
get_llm_max_tokens(

View File

@@ -26,7 +26,6 @@ from danswer.auth.schemas import UserRead
from danswer.auth.schemas import UserUpdate
from danswer.auth.users import auth_backend
from danswer.auth.users import BasicAuthenticationError
from danswer.auth.users import create_danswer_oauth_router
from danswer.auth.users import fastapi_users
from danswer.configs.app_configs import APP_API_PREFIX
from danswer.configs.app_configs import APP_HOST
@@ -45,7 +44,6 @@ from danswer.configs.constants import AuthType
from danswer.configs.constants import POSTGRES_WEB_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.db.engine import warm_up_connections
from danswer.server.api_key.api import router as api_key_router
from danswer.server.auth_check import check_router_auth
from danswer.server.danswer_api.ingestion import router as danswer_api_router
from danswer.server.documents.cc_pair import router as cc_pair_router
@@ -282,7 +280,6 @@ def get_application() -> FastAPI:
application, get_full_openai_assistants_api_router()
)
include_router_with_global_prefix_prepended(application, long_term_logs_router)
include_router_with_global_prefix_prepended(application, api_key_router)
if AUTH_TYPE == AuthType.DISABLED:
# Server logs this during auth setup verification step
@@ -326,7 +323,7 @@ def get_application() -> FastAPI:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_router_with_global_prefix_prepended(
application,
create_danswer_oauth_router(
fastapi_users.get_oauth_router(
oauth_client,
auth_backend,
USER_AUTH_SECRET,

View File

@@ -1,4 +0,0 @@
class ModelServerRateLimitError(Exception):
"""
Exception raised for rate limiting errors from the model server.
"""

View File

@@ -6,9 +6,6 @@ from typing import Any
import requests
from httpx import HTTPError
from requests import JSONDecodeError
from requests import RequestException
from requests import Response
from retry import retry
from danswer.configs.app_configs import LARGE_CHUNK_RATIO
@@ -19,9 +16,6 @@ from danswer.configs.model_configs import (
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.models import SearchSettings
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.natural_language_processing.exceptions import (
ModelServerRateLimitError,
)
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.utils.logger import setup_logger
@@ -105,43 +99,28 @@ class EmbeddingModel:
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse:
def _make_request() -> Response:
def _make_request() -> EmbedResponse:
response = requests.post(
self.embed_server_endpoint, json=embed_request.model_dump()
)
# signify that this is a rate limit error
if response.status_code == 429:
raise ModelServerRateLimitError(response.text)
response.raise_for_status()
return response
final_make_request_func = _make_request
# if the text type is a passage, add some default
# retries + handling for rate limiting
if embed_request.text_type == EmbedTextType.PASSAGE:
final_make_request_func = retry(
tries=3,
delay=5,
exceptions=(RequestException, ValueError, JSONDecodeError),
)(final_make_request_func)
# use 10 second delay as per Azure suggestion
final_make_request_func = retry(
tries=10, delay=10, exceptions=ModelServerRateLimitError
)(final_make_request_func)
try:
response = final_make_request_func()
return EmbedResponse(**response.json())
except requests.HTTPError as e:
try:
error_detail = response.json().get("detail", str(e))
except Exception:
error_detail = response.text
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
except requests.RequestException as e:
raise HTTPError(f"Request failed: {str(e)}") from e
response.raise_for_status()
except requests.HTTPError as e:
try:
error_detail = response.json().get("detail", str(e))
except Exception:
error_detail = response.text
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
except requests.RequestException as e:
raise HTTPError(f"Request failed: {str(e)}") from e
return EmbedResponse(**response.json())
# only perform retries for the non-realtime embedding of passages (e.g. for indexing)
if embed_request.text_type == EmbedTextType.PASSAGE:
return retry(tries=3, delay=5)(_make_request)()
else:
return _make_request()
def _batch_encode_texts(
self,

View File

@@ -131,7 +131,7 @@ def _try_initialize_tokenizer(
return tokenizer
except Exception as hf_error:
logger.warning(
f"Failed to initialize HuggingFaceTokenizer for {model_name}: {hf_error}"
f"Error initializing HuggingFaceTokenizer for {model_name}: {hf_error}"
)
# If both initializations fail, return None

View File

@@ -47,7 +47,6 @@ from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.one_shot_answer.models import QueryRephrase
from danswer.one_shot_answer.qa_utils import combine_message_thread
from danswer.one_shot_answer.qa_utils import slackify_message_thread
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
from danswer.server.query_and_chat.models import ChatMessageDetail
@@ -195,22 +194,13 @@ def stream_answer_objects(
)
prompt = persona.prompts[0]
user_message_str = query_msg.message
# For this endpoint, we only save one user message to the chat session
# However, for slackbot, we want to include the history of the entire thread
if danswerbot_flow:
# Right now, we only support bringing over citations and search docs
# from the last message in the thread, not the entire thread
# in the future, we may want to retrieve the entire thread
user_message_str = slackify_message_thread(query_req.messages)
# Create the first User query message
new_user_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=root_message,
prompt_id=query_req.prompt_id,
message=user_message_str,
token_count=len(llm_tokenizer.encode(user_message_str)),
message=query_msg.message,
token_count=len(llm_tokenizer.encode(query_msg.message)),
message_type=MessageType.USER,
db_session=db_session,
commit=True,

View File

@@ -51,31 +51,3 @@ def combine_message_thread(
total_token_count += message_token_count
return "\n\n".join(message_strs)
def slackify_message(message: ThreadMessage) -> str:
if message.role != MessageType.USER:
return message.message
return f"{message.sender or 'Unknown User'} said in Slack:\n{message.message}"
def slackify_message_thread(messages: list[ThreadMessage]) -> str:
if not messages:
return ""
message_strs: list[str] = []
for message in messages:
if message.role == MessageType.USER:
message_text = (
f"{message.sender or 'Unknown User'} said in Slack:\n{message.message}"
)
elif message.role == MessageType.ASSISTANT:
message_text = f"DanswerBot said in Slack:\n{message.message}"
else:
message_text = (
f"{message.role.value.upper()} said in Slack:\n{message.message}"
)
message_strs.append(message_text)
return "\n\n".join(message_strs)

View File

@@ -1,8 +1,5 @@
import time
import redis
from danswer.db.models import SearchSettings
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
@@ -34,44 +31,6 @@ class RedisConnector:
self.tenant_id, self.id, search_settings_id, self.redis
)
def wait_for_indexing_termination(
self,
search_settings_list: list[SearchSettings],
timeout: float = 15.0,
) -> bool:
"""
Returns True if all indexing for the given redis connector is finished within the given timeout.
Returns False if the timeout is exceeded
This check does not guarantee that current indexings being terminated
won't get restarted midflight
"""
finished = False
start = time.monotonic()
while True:
still_indexing = False
for search_settings in search_settings_list:
redis_connector_index = self.new_index(search_settings.id)
if redis_connector_index.fenced:
still_indexing = True
break
if not still_indexing:
finished = True
break
now = time.monotonic()
if now - start > timeout:
break
time.sleep(1)
continue
return finished
@staticmethod
def get_id_from_fence_key(key: str) -> str | None:
"""

View File

@@ -14,9 +14,8 @@ from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
class RedisConnectorPermissionSyncPayload(BaseModel):
class RedisConnectorPermissionSyncData(BaseModel):
started: datetime | None
celery_task_id: str | None
class RedisConnectorPermissionSync:
@@ -79,14 +78,14 @@ class RedisConnectorPermissionSync:
return False
@property
def payload(self) -> RedisConnectorPermissionSyncPayload | None:
def payload(self) -> RedisConnectorPermissionSyncData | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
if fence_bytes is None:
return None
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorPermissionSyncPayload.model_validate_json(
payload = RedisConnectorPermissionSyncData.model_validate_json(
cast(str, fence_str)
)
@@ -94,7 +93,7 @@ class RedisConnectorPermissionSync:
def set_fence(
self,
payload: RedisConnectorPermissionSyncPayload | None,
payload: RedisConnectorPermissionSyncData | None,
) -> None:
if not payload:
self.redis.delete(self.fence_key)
@@ -163,12 +162,6 @@ class RedisConnectorPermissionSync:
return len(async_results)
def reset(self) -> None:
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
self.redis.delete(self.taskset_key)
self.redis.delete(self.fence_key)
@staticmethod
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
taskset_key = f"{RedisConnectorPermissionSync.TASKSET_PREFIX}_{id}"

View File

@@ -1,18 +1,11 @@
from datetime import datetime
from typing import cast
import redis
from celery import Celery
from pydantic import BaseModel
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
class RedisConnectorExternalGroupSyncPayload(BaseModel):
started: datetime | None
celery_task_id: str | None
class RedisConnectorExternalGroupSync:
"""Manages interactions with redis for external group syncing tasks. Should only be accessed
through RedisConnector."""
@@ -75,29 +68,12 @@ class RedisConnectorExternalGroupSync:
return False
@property
def payload(self) -> RedisConnectorExternalGroupSyncPayload | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
if fence_bytes is None:
return None
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorExternalGroupSyncPayload.model_validate_json(
cast(str, fence_str)
)
return payload
def set_fence(
self,
payload: RedisConnectorExternalGroupSyncPayload | None,
) -> None:
if not payload:
def set_fence(self, value: bool) -> None:
if not value:
self.redis.delete(self.fence_key)
return
self.redis.set(self.fence_key, payload.model_dump_json())
self.redis.set(self.fence_key, 0)
@property
def generator_complete(self) -> int | None:

View File

@@ -29,8 +29,6 @@ class RedisConnectorIndex:
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate
def __init__(
self,
tenant_id: str | None,
@@ -53,7 +51,6 @@ class RedisConnectorIndex:
self.generator_lock_key = (
f"{self.GENERATOR_LOCK_PREFIX}_{id}/{search_settings_id}"
)
self.terminate_key = f"{self.TERMINATE_PREFIX}_{id}/{search_settings_id}"
@classmethod
def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str:
@@ -95,18 +92,6 @@ class RedisConnectorIndex:
self.redis.set(self.fence_key, payload.model_dump_json())
def terminating(self, celery_task_id: str) -> bool:
if self.redis.exists(f"{self.terminate_key}_{celery_task_id}"):
return True
return False
def set_terminate(self, celery_task_id: str) -> None:
"""This sets a signal. It does not block!"""
# We shouldn't need very long to terminate the spawned task.
# 10 minute TTL is good.
self.redis.set(f"{self.terminate_key}_{celery_task_id}", 0, ex=600)
def set_generator_complete(self, payload: int | None) -> None:
if not payload:
self.redis.delete(self.generator_complete_key)

View File

@@ -5,7 +5,7 @@ personas:
# this is for DanswerBot to use when tagged in a non-configured channel
# Careful setting specific IDs, this won't autoincrement the next ID value for postgres
- id: 0
name: "Search"
name: "Knowledge"
description: >
Assistant with access to documents from your Connected Sources.
# Default Prompt objects attached to the persona, see prompts.yaml

View File

@@ -6,7 +6,6 @@ from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi.responses import JSONResponse
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
@@ -38,9 +37,7 @@ from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import count_index_attempts_for_connector
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
from danswer.db.models import SearchSettings
from danswer.db.models import User
from danswer.db.search_settings import get_active_search_settings
from danswer.db.search_settings import get_current_search_settings
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_pool import get_redis_client
@@ -161,19 +158,7 @@ def update_cc_pair_status(
status_update_request: CCStatusUpdateRequest,
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""This method may wait up to 30 seconds if pausing the connector due to the need to
terminate tasks in progress. Tasks are not guaranteed to terminate within the
timeout.
Returns HTTPStatus.OK if everything finished.
Returns HTTPStatus.ACCEPTED if the connector is being paused, but background tasks
did not finish within the timeout.
"""
WAIT_TIMEOUT = 15.0
still_terminating = False
) -> None:
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
@@ -188,76 +173,10 @@ def update_cc_pair_status(
)
if status_update_request.status == ConnectorCredentialPairStatus.PAUSED:
search_settings_list: list[SearchSettings] = get_active_search_settings(
db_session
)
cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session)
cancel_indexing_attempts_past_model(db_session)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
try:
redis_connector.stop.set_fence(True)
while True:
logger.debug(
f"Wait for indexing soft termination starting: cc_pair={cc_pair_id}"
)
wait_succeeded = redis_connector.wait_for_indexing_termination(
search_settings_list, WAIT_TIMEOUT
)
if wait_succeeded:
logger.debug(
f"Wait for indexing soft termination succeeded: cc_pair={cc_pair_id}"
)
break
logger.debug(
"Wait for indexing soft termination timed out. "
f"Moving to hard termination: cc_pair={cc_pair_id} timeout={WAIT_TIMEOUT:.2f}"
)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(
search_settings.id
)
if not redis_connector_index.fenced:
continue
index_payload = redis_connector_index.payload
if not index_payload:
continue
if not index_payload.celery_task_id:
continue
# Revoke the task to prevent it from running
primary_app.control.revoke(index_payload.celery_task_id)
# If it is running, then signaling for termination will get the
# watchdog thread to kill the spawned task
redis_connector_index.set_terminate(index_payload.celery_task_id)
logger.debug(
f"Wait for indexing hard termination starting: cc_pair={cc_pair_id}"
)
wait_succeeded = redis_connector.wait_for_indexing_termination(
search_settings_list, WAIT_TIMEOUT
)
if wait_succeeded:
logger.debug(
f"Wait for indexing hard termination succeeded: cc_pair={cc_pair_id}"
)
break
logger.debug(
f"Wait for indexing hard termination timed out: cc_pair={cc_pair_id}"
)
still_terminating = True
break
finally:
redis_connector.stop.set_fence(False)
update_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
@@ -266,18 +185,6 @@ def update_cc_pair_status(
db_session.commit()
if still_terminating:
return JSONResponse(
status_code=HTTPStatus.ACCEPTED,
content={
"message": "Request accepted, background task termination still in progress"
},
)
return JSONResponse(
status_code=HTTPStatus.OK, content={"message": str(HTTPStatus.OK)}
)
@router.put("/admin/cc-pair/{cc_pair_id}/name")
def update_cc_pair_name(
@@ -360,9 +267,9 @@ def prune_cc_pair(
)
logger.info(
f"Pruning cc_pair: cc_pair={cc_pair_id} "
f"connector={cc_pair.connector_id} "
f"credential={cc_pair.credential_id} "
f"Pruning cc_pair: cc_pair_id={cc_pair_id} "
f"connector_id={cc_pair.connector_id} "
f"credential_id={cc_pair.credential_id} "
f"{cc_pair.connector.name} connector."
)
tasks_created = try_creating_prune_generator_task(

View File

@@ -17,9 +17,9 @@ from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot
from danswer.background.celery.tasks.indexing.tasks import try_creating_indexing_task
from danswer.background.celery.versioned_apps.primary import app as primary_app
from danswer.configs.app_configs import ENABLED_CONNECTOR_TYPES
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import FileOrigin
from danswer.connectors.google_utils.google_auth import (
@@ -59,7 +59,6 @@ from danswer.db.connector import delete_connector
from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector import fetch_connectors
from danswer.db.connector import get_connector_credential_ids
from danswer.db.connector import mark_ccpair_with_indexing_trigger
from danswer.db.connector import update_connector
from danswer.db.connector_credential_pair import add_credential_to_connector
from danswer.db.connector_credential_pair import get_cc_pair_groups_for_ids
@@ -75,7 +74,6 @@ from danswer.db.document import get_document_counts_for_cc_pairs
from danswer.db.engine import get_current_tenant_id
from danswer.db.engine import get_session
from danswer.db.enums import AccessType
from danswer.db.enums import IndexingMode
from danswer.db.index_attempt import get_index_attempts_for_cc_pair
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from danswer.db.index_attempt import get_latest_index_attempts
@@ -88,6 +86,7 @@ from danswer.db.search_settings import get_secondary_search_settings
from danswer.file_store.file_store import get_default_file_store
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_pool import get_redis_client
from danswer.server.documents.models import AuthStatus
from danswer.server.documents.models import AuthUrl
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
@@ -793,10 +792,12 @@ def connector_run_once(
_: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> StatusResponse[int]:
) -> StatusResponse[list[int]]:
"""Used to trigger indexing on a set of cc_pairs associated with a
single connector."""
r = get_redis_client(tenant_id=tenant_id)
connector_id = run_info.connector_id
specified_credential_ids = run_info.credential_ids
@@ -842,41 +843,54 @@ def connector_run_once(
)
]
search_settings = get_current_search_settings(db_session)
connector_credential_pairs = [
get_connector_credential_pair(connector_id, credential_id, db_session)
for credential_id in credential_ids
if credential_id not in skipped_credentials
]
num_triggers = 0
index_attempt_ids = []
for cc_pair in connector_credential_pairs:
if cc_pair is not None:
indexing_mode = IndexingMode.UPDATE
if run_info.from_beginning:
indexing_mode = IndexingMode.REINDEX
mark_ccpair_with_indexing_trigger(cc_pair.id, indexing_mode, db_session)
num_triggers += 1
logger.info(
f"connector_run_once - marking cc_pair with indexing trigger: "
f"connector={run_info.connector_id} "
f"cc_pair={cc_pair.id} "
f"indexing_trigger={indexing_mode}"
attempt_id = try_creating_indexing_task(
primary_app,
cc_pair,
search_settings,
run_info.from_beginning,
db_session,
r,
tenant_id,
)
if attempt_id:
logger.info(
f"connector_run_once - try_creating_indexing_task succeeded: "
f"connector={run_info.connector_id} "
f"cc_pair={cc_pair.id} "
f"attempt={attempt_id} "
)
index_attempt_ids.append(attempt_id)
else:
logger.info(
f"connector_run_once - try_creating_indexing_task failed: "
f"connector={run_info.connector_id} "
f"cc_pair={cc_pair.id}"
)
# run the beat task to pick up the triggers immediately
primary_app.send_task(
"check_for_indexing",
priority=DanswerCeleryPriority.HIGH,
kwargs={"tenant_id": tenant_id},
)
if not index_attempt_ids:
msg = "No new indexing attempts created, indexing jobs are queued or running."
logger.info(msg)
raise HTTPException(
status_code=400,
detail=msg,
)
msg = f"Marked {num_triggers} index attempts with indexing triggers."
msg = f"Successfully created {len(index_attempt_ids)} index attempts. {index_attempt_ids}"
return StatusResponse(
success=True,
message=msg,
data=num_triggers,
data=index_attempt_ids,
)

View File

@@ -45,7 +45,6 @@ class UserPreferences(BaseModel):
visible_assistants: list[int] = []
recent_assistants: list[int] | None = None
default_model: str | None = None
auto_scroll: bool | None = None
class UserInfo(BaseModel):
@@ -80,7 +79,6 @@ class UserInfo(BaseModel):
role=user.role,
preferences=(
UserPreferences(
auto_scroll=user.auto_scroll,
chosen_assistants=user.chosen_assistants,
default_model=user.default_model,
hidden_assistants=user.hidden_assistants,
@@ -130,10 +128,6 @@ class HiddenUpdateRequest(BaseModel):
hidden: bool
class AutoScrollRequest(BaseModel):
auto_scroll: bool | None
class SlackBotCreationRequest(BaseModel):
name: str
enabled: bool
@@ -162,7 +156,6 @@ class SlackChannelConfigCreationRequest(BaseModel):
channel_name: str
respond_tag_only: bool = False
respond_to_bots: bool = False
show_continue_in_web_ui: bool = False
enable_auto_filters: bool = False
# If no team members, assume respond in the channel to everyone
respond_member_group_list: list[str] = Field(default_factory=list)

View File

@@ -80,10 +80,6 @@ def _form_channel_config(
if follow_up_tags is not None:
channel_config["follow_up_tags"] = follow_up_tags
channel_config[
"show_continue_in_web_ui"
] = slack_channel_config_creation_request.show_continue_in_web_ui
channel_config[
"respond_to_bots"
] = slack_channel_config_creation_request.respond_to_bots

View File

@@ -34,6 +34,7 @@ from danswer.auth.users import optional_user
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import ENABLE_EMAIL_INVITES
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.configs.app_configs import SUPER_USERS
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
from danswer.configs.constants import AuthType
from danswer.db.api_key import is_api_key_email_address
@@ -51,7 +52,6 @@ from danswer.db.users import list_users
from danswer.db.users import validate_user_role_update
from danswer.key_value_store.factory import get_kv_store
from danswer.server.manage.models import AllUsersResponse
from danswer.server.manage.models import AutoScrollRequest
from danswer.server.manage.models import UserByEmail
from danswer.server.manage.models import UserInfo
from danswer.server.manage.models import UserPreferences
@@ -63,7 +63,6 @@ from danswer.server.models import MinimalUserSnapshot
from danswer.server.utils import send_user_email_invite
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
from ee.danswer.configs.app_configs import SUPER_USERS
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -498,6 +497,7 @@ def verify_user_logged_in(
return fetch_no_auth_user(store)
raise BasicAuthenticationError(detail="User Not Authenticated")
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
raise BasicAuthenticationError(
detail="Access denied. User's OIDC token has expired.",
@@ -581,30 +581,6 @@ def update_user_recent_assistants(
db_session.commit()
@router.patch("/auto-scroll")
def update_user_auto_scroll(
request: AutoScrollRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
if user is None:
if AUTH_TYPE == AuthType.DISABLED:
store = get_kv_store()
no_auth_user = fetch_no_auth_user(store)
no_auth_user.preferences.auto_scroll = request.auto_scroll
set_no_auth_user_preferences(store, no_auth_user.preferences)
return
else:
raise RuntimeError("This should never happen")
db_session.execute(
update(User)
.where(User.id == user.id) # type: ignore
.values(auto_scroll=request.auto_scroll)
)
db_session.commit()
@router.patch("/user/default-model")
def update_user_default_model(
request: ChosenDefaultModelRequest,

View File

@@ -27,11 +27,9 @@ from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import FileOrigin
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
from danswer.db.chat import add_chats_to_session_from_slack_thread
from danswer.db.chat import create_chat_session
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import delete_chat_session
from danswer.db.chat import duplicate_chat_session_for_user_from_slack
from danswer.db.chat import get_chat_message
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.chat import get_chat_session_by_id
@@ -534,38 +532,6 @@ def seed_chat(
)
class SeedChatFromSlackRequest(BaseModel):
chat_session_id: UUID
class SeedChatFromSlackResponse(BaseModel):
redirect_url: str
@router.post("/seed-chat-session-from-slack")
def seed_chat_from_slack(
chat_seed_request: SeedChatFromSlackRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> SeedChatFromSlackResponse:
slack_chat_session_id = chat_seed_request.chat_session_id
new_chat_session = duplicate_chat_session_for_user_from_slack(
db_session=db_session,
user=user,
chat_session_id=slack_chat_session_id,
)
add_chats_to_session_from_slack_thread(
db_session=db_session,
slack_chat_session_id=slack_chat_session_id,
new_chat_session_id=new_chat_session.id,
)
return SeedChatFromSlackResponse(
redirect_url=f"{WEB_DOMAIN}/chat?chatId={new_chat_session.id}"
)
"""File upload"""
@@ -707,18 +673,14 @@ def upload_files_for_chat(
}
@router.get("/file/{file_id:path}")
@router.get("/file/{file_id}")
def fetch_chat_file(
file_id: str,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_user),
) -> Response:
file_store = get_default_file_store(db_session)
file_record = file_store.read_file_record(file_id)
if not file_record:
raise HTTPException(status_code=404, detail="File not found")
media_type = file_record.file_type
file_io = file_store.read_file(file_id, mode="b")
return StreamingResponse(file_io, media_type=media_type)
# NOTE: specifying "image/jpeg" here, but it still works for pngs
# TODO: do this properly
return Response(content=file_io.read(), media_type="image/jpeg")

View File

@@ -79,7 +79,6 @@ class CreateChatMessageRequest(ChunkContext):
message: str
# Files that we should attach to this message
file_descriptors: list[FileDescriptor]
# If no prompt provided, uses the largest prompt of the chat session
# but really this should be explicitly specified, only in the simplified APIs is this inferred
# Use prompt_id 0 to use the system default prompt which is Answer-Question

View File

@@ -2,6 +2,7 @@ from typing import cast
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
@@ -37,6 +38,10 @@ basic_router = APIRouter(prefix="/settings")
def put_settings(
settings: Settings, _: User | None = Depends(current_admin_user)
) -> None:
try:
settings.check_validity()
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
store_settings(settings)

View File

@@ -41,10 +41,33 @@ class Notification(BaseModel):
class Settings(BaseModel):
"""General settings"""
chat_page_enabled: bool = True
search_page_enabled: bool = True
default_page: PageType = PageType.SEARCH
maximum_chat_retention_days: int | None = None
gpu_enabled: bool | None = None
product_gating: GatingType = GatingType.NONE
def check_validity(self) -> None:
chat_page_enabled = self.chat_page_enabled
search_page_enabled = self.search_page_enabled
default_page = self.default_page
if chat_page_enabled is False and search_page_enabled is False:
raise ValueError(
"One of `search_page_enabled` and `chat_page_enabled` must be True."
)
if default_page == PageType.CHAT and chat_page_enabled is False:
raise ValueError(
"The default page cannot be 'chat' if the chat page is disabled."
)
if default_page == PageType.SEARCH and search_page_enabled is False:
raise ValueError(
"The default page cannot be 'search' if the search page is disabled."
)
class UserSettings(Settings):
notifications: list[Notification]

View File

@@ -1,72 +1,23 @@
from functools import lru_cache
import requests
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from fastapi import status
from jwt import decode as jwt_decode
from jwt import InvalidTokenError
from jwt import PyJWTError
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from danswer.auth.users import current_admin_user
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import SUPER_CLOUD_API_KEY
from danswer.configs.app_configs import SUPER_USERS
from danswer.configs.constants import AuthType
from danswer.db.models import User
from danswer.utils.logger import setup_logger
from ee.danswer.configs.app_configs import JWT_PUBLIC_KEY_URL
from ee.danswer.configs.app_configs import SUPER_CLOUD_API_KEY
from ee.danswer.configs.app_configs import SUPER_USERS
from ee.danswer.db.saml import get_saml_account
from ee.danswer.server.seeding import get_seed_config
from ee.danswer.utils.secrets import extract_hashed_cookie
logger = setup_logger()
@lru_cache()
def get_public_key() -> str | None:
if JWT_PUBLIC_KEY_URL is None:
logger.error("JWT_PUBLIC_KEY_URL is not set")
return None
response = requests.get(JWT_PUBLIC_KEY_URL)
response.raise_for_status()
return response.text
async def verify_jwt_token(token: str, async_db_session: AsyncSession) -> User | None:
try:
public_key_pem = get_public_key()
if public_key_pem is None:
logger.error("Failed to retrieve public key")
return None
payload = jwt_decode(
token,
public_key_pem,
algorithms=["RS256"],
audience=None,
)
email = payload.get("email")
if email:
result = await async_db_session.execute(
select(User).where(func.lower(User.email) == func.lower(email))
)
return result.scalars().first()
except InvalidTokenError:
logger.error("Invalid JWT token")
get_public_key.cache_clear()
except PyJWTError as e:
logger.error(f"JWT decoding error: {str(e)}")
get_public_key.cache_clear()
return None
def verify_auth_setting() -> None:
# All the Auth flows are valid for EE version
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
@@ -87,13 +38,6 @@ async def optional_user_(
)
user = saml_account.user if saml_account else None
# If user is still None, check for JWT in Authorization header
if user is None and JWT_PUBLIC_KEY_URL is not None:
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header[len("Bearer ") :].strip()
user = await verify_jwt_token(token, async_db_session)
return user

View File

@@ -1,4 +1,3 @@
import json
import os
# Applicable for OIDC Auth
@@ -20,11 +19,3 @@ STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")
OPENAI_DEFAULT_API_KEY = os.environ.get("OPENAI_DEFAULT_API_KEY")
ANTHROPIC_DEFAULT_API_KEY = os.environ.get("ANTHROPIC_DEFAULT_API_KEY")
COHERE_DEFAULT_API_KEY = os.environ.get("COHERE_DEFAULT_API_KEY")
# JWT Public Key URL
JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
# Super Users
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")

View File

@@ -11,7 +11,6 @@ from sqlalchemy import update
from sqlalchemy.orm import Session
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential__UserGroup
@@ -299,11 +298,6 @@ def fetch_user_groups_for_documents(
db_session: Session,
document_ids: list[str],
) -> Sequence[tuple[str, list[str]]]:
"""
Fetches all user groups that have access to the given documents.
NOTE: this doesn't include groups if the cc_pair is access type SYNC
"""
stmt = (
select(Document.id, func.array_agg(UserGroup.name))
.join(
@@ -312,11 +306,7 @@ def fetch_user_groups_for_documents(
)
.join(
ConnectorCredentialPair,
and_(
ConnectorCredentialPair.id
== UserGroup__ConnectorCredentialPair.cc_pair_id,
ConnectorCredentialPair.access_type != AccessType.SYNC,
),
ConnectorCredentialPair.id == UserGroup__ConnectorCredentialPair.cc_pair_id,
)
.join(
DocumentByConnectorCredentialPair,

View File

@@ -16,7 +16,7 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
_VIEWSPACE_PERMISSION_TYPE = "VIEWSPACE"
_REQUEST_PAGINATION_LIMIT = 5000
_REQUEST_PAGINATION_LIMIT = 100
def _get_server_space_permissions(
@@ -97,7 +97,6 @@ def _get_space_permissions(
confluence_client: OnyxConfluence,
is_cloud: bool,
) -> dict[str, ExternalAccess]:
logger.debug("Getting space permissions")
# Gets all the spaces in the Confluence instance
all_space_keys = []
start = 0
@@ -114,7 +113,6 @@ def _get_space_permissions(
start += len(spaces_batch.get("results", []))
# Gets the permissions for each space
logger.debug(f"Got {len(all_space_keys)} spaces from confluence")
space_permissions_by_space_key: dict[str, ExternalAccess] = {}
for space_key in all_space_keys:
if is_cloud:
@@ -244,7 +242,6 @@ def _fetch_all_page_restrictions_for_space(
logger.warning(f"No permissions found for document {slim_doc.id}")
logger.debug("Finished fetching all page restrictions for space")
return document_restrictions
@@ -257,28 +254,27 @@ def confluence_doc_sync(
it in postgres so that when it gets created later, the permissions are
already populated
"""
logger.debug("Starting confluence doc sync")
confluence_connector = ConfluenceConnector(
**cc_pair.connector.connector_specific_config
)
confluence_connector.load_credentials(cc_pair.credential.credential_json)
if confluence_connector.confluence_client is None:
raise ValueError("Failed to load credentials")
confluence_client = confluence_connector.confluence_client
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
space_permissions_by_space_key = _get_space_permissions(
confluence_client=confluence_connector.confluence_client,
confluence_client=confluence_client,
is_cloud=is_cloud,
)
slim_docs = []
logger.debug("Fetching all slim documents from confluence")
for doc_batch in confluence_connector.retrieve_all_slim_documents():
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
slim_docs.extend(doc_batch)
logger.debug("Fetching all page restrictions for space")
return _fetch_all_page_restrictions_for_space(
confluence_client=confluence_connector.confluence_client,
confluence_client=confluence_client,
slim_docs=slim_docs,
space_permissions_by_space_key=space_permissions_by_space_key,
)

View File

@@ -14,10 +14,7 @@ def _build_group_member_email_map(
) -> dict[str, set[str]]:
group_member_emails: dict[str, set[str]] = {}
for user_result in confluence_client.paginated_cql_user_retrieval():
user = user_result.get("user", {})
if not user:
logger.warning(f"user result missing user field: {user_result}")
continue
user = user_result["user"]
email = user.get("email")
if not email:
# This field is only present in Confluence Server

View File

@@ -57,9 +57,9 @@ DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
# If nothing is specified here, we run the doc_sync every time the celery beat runs
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all group permissions every 5 minutes
DocumentSource.GOOGLE_DRIVE: 5 * 60,
DocumentSource.CONFLUENCE: 5 * 60,
# Polling is not supported so we fetch all group permissions every 60 seconds
DocumentSource.GOOGLE_DRIVE: 60,
DocumentSource.CONFLUENCE: 60,
}

View File

@@ -13,6 +13,7 @@ from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import AuthType
from danswer.main import get_application as get_application_base
from danswer.main import include_router_with_global_prefix_prepended
from danswer.server.api_key.api import router as api_key_router
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from ee.danswer.configs.app_configs import OPENID_CONFIG_URL
@@ -115,6 +116,8 @@ def get_application() -> FastAPI:
# Analytics endpoints
include_router_with_global_prefix_prepended(application, analytics_router)
include_router_with_global_prefix_prepended(application, query_history_router)
# Api key management
include_router_with_global_prefix_prepended(application, api_key_router)
# EE only backend APIs
include_router_with_global_prefix_prepended(application, query_router)
include_router_with_global_prefix_prepended(application, chat_router)

View File

@@ -113,6 +113,10 @@ async def refresh_access_token(
def put_settings(
settings: EnterpriseSettings, _: User | None = Depends(current_admin_user)
) -> None:
try:
settings.check_validity()
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
store_settings(settings)

View File

@@ -157,6 +157,7 @@ def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) ->
def _seed_settings(settings: Settings) -> None:
logger.notice("Seeding Settings")
try:
settings.check_validity()
store_base_settings(settings)
logger.notice("Successfully seeded Settings")
except ValueError as e:

View File

@@ -11,7 +11,6 @@ from fastapi import APIRouter
from fastapi import HTTPException
from google.oauth2 import service_account # type: ignore
from litellm import embedding
from litellm.exceptions import RateLimitError
from retry import retry
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
@@ -206,22 +205,28 @@ class CloudEmbedding:
model_name: str | None = None,
deployment_name: str | None = None,
) -> list[Embedding]:
if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(texts, model_name)
elif self.provider == EmbeddingProvider.AZURE:
return self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return self._embed_litellm_proxy(texts, model_name)
try:
if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(texts, model_name)
elif self.provider == EmbeddingProvider.AZURE:
return self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return self._embed_litellm_proxy(texts, model_name)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error embedding text with {self.provider}: {str(e)}",
)
@staticmethod
def create(
@@ -425,11 +430,6 @@ async def process_embed_request(
prefix=prefix,
)
return EmbedResponse(embeddings=embeddings)
except RateLimitError as e:
raise HTTPException(
status_code=429,
detail=str(e),
)
except Exception as e:
exception_detail = f"Error during embedding process:\n{str(e)}"
logger.exception(exception_detail)

View File

@@ -29,7 +29,7 @@ trafilatura==1.12.2
langchain==0.1.17
langchain-core==0.1.50
langchain-text-splitters==0.0.1
litellm==1.53.1
litellm==1.50.2
lxml==5.3.0
lxml_html_clean==0.2.2
llama-index==0.9.45
@@ -38,7 +38,7 @@ msal==1.28.0
nltk==3.8.1
Office365-REST-Python-Client==2.5.9
oauthlib==3.2.2
openai==1.55.3
openai==1.52.2
openpyxl==3.1.2
playwright==1.41.2
psutil==5.9.5

View File

@@ -163,92 +163,47 @@ SUPPORTED_EMBEDDING_MODELS = [
dim=1024,
index_name="danswer_chunk_cohere_embed_english_v3_0",
),
SupportedEmbeddingModel(
name="cohere/embed-english-v3.0",
dim=1024,
index_name="danswer_chunk_embed_english_v3_0",
),
SupportedEmbeddingModel(
name="cohere/embed-english-light-v3.0",
dim=384,
index_name="danswer_chunk_cohere_embed_english_light_v3_0",
),
SupportedEmbeddingModel(
name="cohere/embed-english-light-v3.0",
dim=384,
index_name="danswer_chunk_embed_english_light_v3_0",
),
SupportedEmbeddingModel(
name="openai/text-embedding-3-large",
dim=3072,
index_name="danswer_chunk_openai_text_embedding_3_large",
),
SupportedEmbeddingModel(
name="openai/text-embedding-3-large",
dim=3072,
index_name="danswer_chunk_text_embedding_3_large",
),
SupportedEmbeddingModel(
name="openai/text-embedding-3-small",
dim=1536,
index_name="danswer_chunk_openai_text_embedding_3_small",
),
SupportedEmbeddingModel(
name="openai/text-embedding-3-small",
dim=1536,
index_name="danswer_chunk_text_embedding_3_small",
),
SupportedEmbeddingModel(
name="google/text-embedding-004",
dim=768,
index_name="danswer_chunk_google_text_embedding_004",
),
SupportedEmbeddingModel(
name="google/text-embedding-004",
dim=768,
index_name="danswer_chunk_text_embedding_004",
),
SupportedEmbeddingModel(
name="google/textembedding-gecko@003",
dim=768,
index_name="danswer_chunk_google_textembedding_gecko_003",
),
SupportedEmbeddingModel(
name="google/textembedding-gecko@003",
dim=768,
index_name="danswer_chunk_textembedding_gecko_003",
),
SupportedEmbeddingModel(
name="voyage/voyage-large-2-instruct",
dim=1024,
index_name="danswer_chunk_voyage_large_2_instruct",
),
SupportedEmbeddingModel(
name="voyage/voyage-large-2-instruct",
dim=1024,
index_name="danswer_chunk_large_2_instruct",
),
SupportedEmbeddingModel(
name="voyage/voyage-light-2-instruct",
dim=384,
index_name="danswer_chunk_voyage_light_2_instruct",
),
SupportedEmbeddingModel(
name="voyage/voyage-light-2-instruct",
dim=384,
index_name="danswer_chunk_light_2_instruct",
),
# Self-hosted models
SupportedEmbeddingModel(
name="nomic-ai/nomic-embed-text-v1",
dim=768,
index_name="danswer_chunk_nomic_ai_nomic_embed_text_v1",
),
SupportedEmbeddingModel(
name="nomic-ai/nomic-embed-text-v1",
dim=768,
index_name="danswer_chunk_nomic_embed_text_v1",
),
SupportedEmbeddingModel(
name="intfloat/e5-base-v2",
dim=768,

View File

@@ -1,88 +0,0 @@
import json
import os
import time
from pathlib import Path
import pytest
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import Document
from danswer.connectors.slab.connector import SlabConnector
def load_test_data(file_name: str = "test_slab_data.json") -> dict[str, str]:
current_dir = Path(__file__).parent
with open(current_dir / file_name, "r") as f:
return json.load(f)
@pytest.fixture
def slab_connector() -> SlabConnector:
connector = SlabConnector(
base_url="https://onyx-test.slab.com/",
)
connector.load_credentials(
{
"slab_bot_token": os.environ["SLAB_BOT_TOKEN"],
}
)
return connector
@pytest.mark.xfail(
reason=(
"Need a test account with a slab subscription to run this test."
"Trial only lasts 14 days."
)
)
def test_slab_connector_basic(slab_connector: SlabConnector) -> None:
all_docs: list[Document] = []
target_test_doc_id = "jcp6cohu"
target_test_doc: Document | None = None
for doc_batch in slab_connector.poll_source(0, time.time()):
for doc in doc_batch:
all_docs.append(doc)
if doc.id == target_test_doc_id:
target_test_doc = doc
assert len(all_docs) == 6
assert target_test_doc is not None
desired_test_data = load_test_data()
assert (
target_test_doc.semantic_identifier == desired_test_data["semantic_identifier"]
)
assert target_test_doc.source == DocumentSource.SLAB
assert target_test_doc.metadata == {}
assert target_test_doc.primary_owners is None
assert target_test_doc.secondary_owners is None
assert target_test_doc.title is None
assert target_test_doc.from_ingestion_api is False
assert target_test_doc.additional_info is None
assert len(target_test_doc.sections) == 1
section = target_test_doc.sections[0]
# Need to replace the weird apostrophe with a normal one
assert section.text.replace("\u2019", "'") == desired_test_data["section_text"]
assert section.link == desired_test_data["link"]
@pytest.mark.xfail(
reason=(
"Need a test account with a slab subscription to run this test."
"Trial only lasts 14 days."
)
)
def test_slab_connector_slim(slab_connector: SlabConnector) -> None:
# Get all doc IDs from the full connector
all_full_doc_ids = set()
for doc_batch in slab_connector.load_from_state():
all_full_doc_ids.update([doc.id for doc in doc_batch])
# Get all doc IDs from the slim connector
all_slim_doc_ids = set()
for slim_doc_batch in slab_connector.retrieve_all_slim_documents():
all_slim_doc_ids.update([doc.id for doc in slim_doc_batch])
# The set of full doc IDs should be always be a subset of the slim doc IDs
assert all_full_doc_ids.issubset(all_slim_doc_ids)

View File

@@ -1,5 +0,0 @@
{
"section_text": "Learn about Posts\nWelcome\nThis is a post, where you can edit, share, and collaborate in real time with your team. We'd love to show you how it works!\nReading and editing\nClick the mode button to toggle between read and edit modes. You can only make changes to a post when editing.\nOrganize your posts\nWhen in edit mode, you can add topics to a post, which will keep it organized for the right 👀 to see.\nSmart mentions\nMentions are references to users, posts, topics and third party tools that show details on hover. Paste in a link for automatic conversion.\nLook back in time\nYou are ready to begin writing. You can always bring back this tour in the help menu.\nGreat job!\nYou are ready to begin writing. You can always bring back this tour in the help menu.\n\n",
"link": "https://onyx-test.slab.com/posts/learn-about-posts-jcp6cohu",
"semantic_identifier": "Learn about Posts"
}

View File

@@ -7,7 +7,6 @@ from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import EmbeddingProvider
VALID_SAMPLE = ["hi", "hello my name is bob", "woah there!!!. 😃"]
VALID_LONG_SAMPLE = ["hi " * 999]
# openai limit is 2048, cohere is supposed to be 96 but in practice that doesn't
# seem to be true
TOO_LONG_SAMPLE = ["a"] * 2500
@@ -100,42 +99,3 @@ def local_nomic_embedding_model() -> EmbeddingModel:
def test_local_nomic_embedding(local_nomic_embedding_model: EmbeddingModel) -> None:
_run_embeddings(VALID_SAMPLE, local_nomic_embedding_model, 768)
_run_embeddings(TOO_LONG_SAMPLE, local_nomic_embedding_model, 768)
@pytest.fixture
def azure_embedding_model() -> EmbeddingModel:
return EmbeddingModel(
server_host="localhost",
server_port=9000,
model_name="text-embedding-3-large",
normalize=True,
query_prefix=None,
passage_prefix=None,
api_key=os.getenv("AZURE_API_KEY"),
provider_type=EmbeddingProvider.AZURE,
api_url=os.getenv("AZURE_API_URL"),
)
# NOTE (chris): this test doesn't work, and I do not know why
# def test_azure_embedding_model_rate_limit(azure_embedding_model: EmbeddingModel):
# """NOTE: this test relies on a very low rate limit for the Azure API +
# this test only being run once in a 1 minute window"""
# # VALID_LONG_SAMPLE is 999 tokens, so the second call should run into rate
# # limits assuming the limit is 1000 tokens per minute
# result = azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
# assert len(result) == 1
# assert len(result[0]) == 1536
# # this should fail
# with pytest.raises(ModelServerRateLimitError):
# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
# # this should succeed, since passage requests retry up to 10 times
# start = time.time()
# result = azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.PASSAGE)
# assert len(result) == 1
# assert len(result[0]) == 1536
# assert time.time() - start > 30 # make sure we waited, even though we hit rate limits

View File

@@ -240,85 +240,7 @@ class CCPairManager:
result.raise_for_status()
@staticmethod
def wait_for_indexing_inactive(
cc_pair: DATestCCPair,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> None:
"""wait for the number of docs to be indexed on the connector.
This is used to test pausing a connector in the middle of indexing and
terminating that indexing."""
print(f"Indexing wait for inactive starting: cc_pair={cc_pair.id}")
start = time.monotonic()
while True:
fetched_cc_pairs = CCPairManager.get_indexing_statuses(
user_performing_action
)
for fetched_cc_pair in fetched_cc_pairs:
if fetched_cc_pair.cc_pair_id != cc_pair.id:
continue
if fetched_cc_pair.in_progress:
continue
print(f"Indexing is inactive: cc_pair={cc_pair.id}")
return
elapsed = time.monotonic() - start
if elapsed > timeout:
raise TimeoutError(
f"Indexing wait for inactive timed out: cc_pair={cc_pair.id} timeout={timeout}s"
)
print(
f"Indexing wait for inactive still waiting: cc_pair={cc_pair.id} elapsed={elapsed:.2f} timeout={timeout}s"
)
time.sleep(5)
@staticmethod
def wait_for_indexing_in_progress(
cc_pair: DATestCCPair,
timeout: float = MAX_DELAY,
num_docs: int = 16,
user_performing_action: DATestUser | None = None,
) -> None:
"""wait for the number of docs to be indexed on the connector.
This is used to test pausing a connector in the middle of indexing and
terminating that indexing."""
start = time.monotonic()
while True:
fetched_cc_pairs = CCPairManager.get_indexing_statuses(
user_performing_action
)
for fetched_cc_pair in fetched_cc_pairs:
if fetched_cc_pair.cc_pair_id != cc_pair.id:
continue
if not fetched_cc_pair.in_progress:
continue
if fetched_cc_pair.docs_indexed >= num_docs:
print(
"Indexed at least the requested number of docs: "
f"cc_pair={cc_pair.id} "
f"docs_indexed={fetched_cc_pair.docs_indexed} "
f"num_docs={num_docs}"
)
return
elapsed = time.monotonic() - start
if elapsed > timeout:
raise TimeoutError(
f"Indexing in progress wait timed out: cc_pair={cc_pair.id} timeout={timeout}s"
)
print(
f"Indexing in progress waiting: cc_pair={cc_pair.id} elapsed={elapsed:.2f} timeout={timeout}s"
)
time.sleep(5)
@staticmethod
def wait_for_indexing_completion(
def wait_for_indexing(
cc_pair: DATestCCPair,
after: datetime,
timeout: float = MAX_DELAY,

View File

@@ -1,62 +0,0 @@
import mimetypes
from typing import cast
from typing import IO
from typing import List
from typing import Tuple
import requests
from danswer.file_store.models import FileDescriptor
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
class FileManager:
@staticmethod
def upload_files(
files: List[Tuple[str, IO]],
user_performing_action: DATestUser | None = None,
) -> Tuple[List[FileDescriptor], str]:
headers = (
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
)
headers.pop("Content-Type", None)
files_param = []
for filename, file_obj in files:
mime_type, _ = mimetypes.guess_type(filename)
if mime_type is None:
mime_type = "application/octet-stream"
files_param.append(("files", (filename, file_obj, mime_type)))
response = requests.post(
f"{API_SERVER_URL}/chat/file",
files=files_param,
headers=headers,
)
if not response.ok:
return (
cast(List[FileDescriptor], []),
f"Failed to upload files - {response.json().get('detail', 'Unknown error')}",
)
response_json = response.json()
return response_json.get("files", cast(List[FileDescriptor], [])), ""
@staticmethod
def fetch_uploaded_file(
file_id: str,
user_performing_action: DATestUser | None = None,
) -> bytes:
response = requests.get(
f"{API_SERVER_URL}/chat/file/{file_id}",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return response.content

View File

@@ -14,7 +14,6 @@ from tests.integration.common_utils.managers.document_search import (
)
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestConnector
from tests.integration.common_utils.test_models import DATestCredential
@@ -78,7 +77,7 @@ def test_slack_permission_sync(
access_type=AccessType.SYNC,
user_performing_action=admin_user,
)
CCPairManager.wait_for_indexing_completion(
CCPairManager.wait_for_indexing(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,
@@ -113,7 +112,7 @@ def test_slack_permission_sync(
# Run indexing
before = datetime.now(timezone.utc)
CCPairManager.run_once(cc_pair, admin_user)
CCPairManager.wait_for_indexing_completion(
CCPairManager.wait_for_indexing(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,
@@ -216,124 +215,3 @@ def test_slack_permission_sync(
# Ensure test_user_1 can only see messages from the public channel
assert public_message in danswer_doc_message_strings
assert private_message not in danswer_doc_message_strings
def test_slack_group_permission_sync(
reset: None,
vespa_client: vespa_fixture,
slack_test_setup: tuple[dict[str, Any], dict[str, Any]],
) -> None:
"""
This test ensures that permission sync overrides danswer group access.
"""
public_channel, private_channel = slack_test_setup
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(
email="admin@onyx-test.com",
)
# Creating a non-admin user
test_user_1: DATestUser = UserManager.create(
email="test_user_1@onyx-test.com",
)
# Create a user group and adding the non-admin user to it
user_group = UserGroupManager.create(
name="test_group",
user_ids=[test_user_1.id],
cc_pair_ids=[],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group],
user_performing_action=admin_user,
)
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
email_id_map = SlackManager.build_slack_user_email_id_map(slack_client)
admin_user_id = email_id_map[admin_user.email]
LLMProviderManager.create(user_performing_action=admin_user)
# Add only admin to the private channel
SlackManager.set_channel_members(
slack_client=slack_client,
admin_user_id=admin_user_id,
channel=private_channel,
user_ids=[admin_user_id],
)
before = datetime.now(timezone.utc)
credential = CredentialManager.create(
source=DocumentSource.SLACK,
credential_json={
"slack_bot_token": os.environ["SLACK_BOT_TOKEN"],
},
user_performing_action=admin_user,
)
# Create connector with sync access and assign it to the user group
connector = ConnectorManager.create(
name="Slack",
input_type=InputType.POLL,
source=DocumentSource.SLACK,
connector_specific_config={
"workspace": "onyx-test-workspace",
"channels": [private_channel["name"]],
},
access_type=AccessType.SYNC,
groups=[user_group.id],
user_performing_action=admin_user,
)
cc_pair = CCPairManager.create(
credential_id=credential.id,
connector_id=connector.id,
access_type=AccessType.SYNC,
user_performing_action=admin_user,
groups=[user_group.id],
)
# Add a test message to the private channel
private_message = "This is a secret message: 987654"
SlackManager.add_message_to_channel(
slack_client=slack_client,
channel=private_channel,
message=private_message,
)
# Run indexing
CCPairManager.run_once(cc_pair, admin_user)
CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,
)
# Run permission sync
CCPairManager.sync(
cc_pair=cc_pair,
user_performing_action=admin_user,
)
CCPairManager.wait_for_sync(
cc_pair=cc_pair,
after=before,
number_of_updated_docs=1,
user_performing_action=admin_user,
)
# Verify admin can see the message
admin_docs = DocumentSearchManager.search_documents(
query="secret message",
user_performing_action=admin_user,
)
assert private_message in admin_docs
# Verify test_user_1 cannot see the message despite being in the group
# (Slack permissions should take precedence)
user_1_docs = DocumentSearchManager.search_documents(
query="secret message",
user_performing_action=test_user_1,
)
assert private_message not in user_1_docs

View File

@@ -74,7 +74,7 @@ def test_slack_prune(
access_type=AccessType.SYNC,
user_performing_action=admin_user,
)
CCPairManager.wait_for_indexing_completion(
CCPairManager.wait_for_indexing(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,
@@ -113,7 +113,7 @@ def test_slack_prune(
# Run indexing
before = datetime.now(timezone.utc)
CCPairManager.run_once(cc_pair, admin_user)
CCPairManager.wait_for_indexing_completion(
CCPairManager.wait_for_indexing(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,

View File

@@ -58,7 +58,7 @@ def test_overlapping_connector_creation(reset: None) -> None:
user_performing_action=admin_user,
)
CCPairManager.wait_for_indexing_completion(
CCPairManager.wait_for_indexing(
cc_pair_1, now, timeout=120, user_performing_action=admin_user
)
@@ -71,7 +71,7 @@ def test_overlapping_connector_creation(reset: None) -> None:
user_performing_action=admin_user,
)
CCPairManager.wait_for_indexing_completion(
CCPairManager.wait_for_indexing(
cc_pair_2, now, timeout=120, user_performing_action=admin_user
)
@@ -82,48 +82,3 @@ def test_overlapping_connector_creation(reset: None) -> None:
assert info_2
assert info_1.num_docs_indexed == info_2.num_docs_indexed
def test_connector_pause_while_indexing(reset: None) -> None:
"""Tests that we can pause a connector while indexing is in progress and that
tasks end early or abort as a result.
TODO: This does not specifically test for soft or hard termination code paths.
Design specific tests for those use cases.
"""
admin_user: DATestUser = UserManager.create(name="admin_user")
config = {
"wiki_base": os.environ["CONFLUENCE_TEST_SPACE_URL"],
"space": "",
"is_cloud": True,
"page_id": "",
}
credential = {
"confluence_username": os.environ["CONFLUENCE_USER_NAME"],
"confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"],
}
# store the time before we create the connector so that we know after
# when the indexing should have started
datetime.now(timezone.utc)
# create connector
cc_pair_1 = CCPairManager.create_from_scratch(
source=DocumentSource.CONFLUENCE,
connector_specific_config=config,
credential_json=credential,
user_performing_action=admin_user,
)
CCPairManager.wait_for_indexing_in_progress(
cc_pair_1, timeout=60, num_docs=16, user_performing_action=admin_user
)
CCPairManager.pause_cc_pair(cc_pair_1, user_performing_action=admin_user)
CCPairManager.wait_for_indexing_inactive(
cc_pair_1, timeout=60, user_performing_action=admin_user
)
return

View File

@@ -135,7 +135,7 @@ def test_web_pruning(reset: None, vespa_client: vespa_fixture) -> None:
user_performing_action=admin_user,
)
CCPairManager.wait_for_indexing_completion(
CCPairManager.wait_for_indexing(
cc_pair_1, now, timeout=60, user_performing_action=admin_user
)

View File

@@ -385,16 +385,6 @@ def process_text(
"Here is some text[[1]](https://0.com). Some other text",
["doc_0"],
),
# ['To', ' set', ' up', ' D', 'answer', ',', ' if', ' you', ' are', ' running', ' it', ' yourself', ' and',
# ' need', ' access', ' to', ' certain', ' features', ' like', ' auto', '-sync', 'ing', ' document',
# '-level', ' access', ' permissions', ',', ' you', ' should', ' reach', ' out', ' to', ' the', ' D',
# 'answer', ' team', ' to', ' receive', ' access', ' [[', '4', ']].', '']
(
"Unique tokens with double brackets and a single token that ends the citation and has characters after it.",
["... to receive access", " [[", "1", "]].", ""],
"... to receive access [[1]](https://0.com).",
["doc_0"],
),
],
)
def test_citation_extraction(

View File

@@ -130,7 +130,6 @@ services:
restart: always
environment:
- ENCRYPTION_KEY_SECRET=${ENCRYPTION_KEY_SECRET:-}
- JWT_PUBLIC_KEY_URL=${JWT_PUBLIC_KEY_URL:-} # used for JWT authentication of users via API
# Gen AI Settings (Needed by DanswerBot)
- GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-}
- QA_TIMEOUT=${QA_TIMEOUT:-}

6
node_modules/.package-lock.json generated vendored
View File

@@ -1,6 +0,0 @@
{
"name": "danswer",
"lockfileVersion": 3,
"requires": true,
"packages": {}
}

View File

@@ -1,9 +0,0 @@
declare module "favicon-fetch" {
interface FaviconFetchOptions {
uri: string;
}
function faviconFetch(options: FaviconFetchOptions): string | null;
export default faviconFetch;
}

1011
web/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@@ -17,13 +17,11 @@
"@headlessui/react": "^2.2.0",
"@headlessui/tailwindcss": "^0.2.1",
"@phosphor-icons/react": "^2.0.8",
"@radix-ui/react-checkbox": "^1.1.2",
"@radix-ui/react-dialog": "^1.1.2",
"@radix-ui/react-dialog": "^1.0.5",
"@radix-ui/react-popover": "^1.1.2",
"@radix-ui/react-select": "^2.1.2",
"@radix-ui/react-separator": "^1.1.0",
"@radix-ui/react-slot": "^1.1.0",
"@radix-ui/react-switch": "^1.1.1",
"@radix-ui/react-tabs": "^1.1.1",
"@radix-ui/react-tooltip": "^1.1.3",
"@sentry/nextjs": "^8.34.0",
@@ -39,7 +37,6 @@
"class-variance-authority": "^0.7.0",
"clsx": "^2.1.1",
"date-fns": "^3.6.0",
"favicon-fetch": "^1.0.0",
"formik": "^2.2.9",
"js-cookie": "^3.0.5",
"lodash": "^4.17.21",
@@ -70,7 +67,6 @@
"tailwindcss-animate": "^1.0.7",
"typescript": "5.0.3",
"uuid": "^9.0.1",
"vaul": "^1.1.1",
"yup": "^1.4.0"
},
"devDependencies": {

View File

@@ -24,6 +24,13 @@ import {
TextFormField,
} from "@/components/admin/connectors/Field";
import {
Card,
CardHeader,
CardTitle,
CardContent,
CardFooter,
} from "@/components/ui/card";
import { usePopup } from "@/components/admin/connectors/Popup";
import { getDisplayNameForModel, useCategories } from "@/lib/hooks";
import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable";
@@ -398,7 +405,7 @@ export function AssistantEditor({
message: `"${assistant.name}" has been added to your list.`,
type: "success",
});
await refreshAssistants();
router.refresh();
} else {
setPopup({
message: `"${assistant.name}" could not be added to your list.`,

View File

@@ -90,7 +90,7 @@ export function PersonasTable() {
message: `Failed to update persona order - ${await response.text()}`,
});
setFinalPersonas(assistants);
await refreshAssistants();
router.refresh();
return;
}
@@ -151,7 +151,7 @@ export function PersonasTable() {
persona.is_visible
);
if (response.ok) {
await refreshAssistants();
router.refresh();
} else {
setPopup({
type: "error",
@@ -183,7 +183,7 @@ export function PersonasTable() {
onClick={async () => {
const response = await deletePersona(persona.id);
if (response.ok) {
await refreshAssistants();
router.refresh();
} else {
alert(
`Failed to delete persona - ${await response.text()}`

View File

@@ -259,8 +259,29 @@ export async function updatePersona(
): Promise<[Response, Response | null]> {
const { id, existingPromptId } = personaUpdateRequest;
let fileId = null;
if (personaUpdateRequest.uploaded_image) {
fileId = await uploadFile(personaUpdateRequest.uploaded_image);
if (!fileId) {
return [new Response(null, { status: 400 }), null];
}
}
const updatePersonaResponse = await fetch(`/api/persona/${id}`, {
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(
buildPersonaAPIBody(personaUpdateRequest, existingPromptId ?? 0, fileId)
),
});
if (!updatePersonaResponse.ok) {
return [updatePersonaResponse, null];
}
let promptResponse;
let promptId: number | null = null;
if (existingPromptId !== undefined) {
promptResponse = await updatePrompt({
promptId: existingPromptId,
@@ -269,7 +290,6 @@ export async function updatePersona(
taskPrompt: personaUpdateRequest.task_prompt,
includeCitations: personaUpdateRequest.include_citations,
});
promptId = existingPromptId;
} else {
promptResponse = await createPrompt({
personaName: personaUpdateRequest.name,
@@ -277,30 +297,7 @@ export async function updatePersona(
taskPrompt: personaUpdateRequest.task_prompt,
includeCitations: personaUpdateRequest.include_citations,
});
promptId = promptResponse.ok
? ((await promptResponse.json()).id as number)
: null;
}
let fileId = null;
if (personaUpdateRequest.uploaded_image) {
fileId = await uploadFile(personaUpdateRequest.uploaded_image);
if (!fileId) {
return [promptResponse, null];
}
}
const updatePersonaResponse =
promptResponse.ok && promptId !== null
? await fetch(`/api/persona/${id}`, {
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(
buildPersonaAPIBody(personaUpdateRequest, promptId, fileId)
),
})
: null;
return [promptResponse, updatePersonaResponse];
}

View File

@@ -60,24 +60,21 @@ export function SlackChannelConfigsTable({
.slice(numToDisplay * (page - 1), numToDisplay * page)
.map((slackChannelConfig) => {
return (
<TableRow
key={slackChannelConfig.id}
className="cursor-pointer hover:bg-gray-100 transition-colors"
onClick={() => {
window.location.href = `/admin/bots/${slackBotId}/channels/${slackChannelConfig.id}`;
}}
>
<TableRow key={slackChannelConfig.id}>
<TableCell>
<div className="flex gap-x-2">
<div className="my-auto">
<Link
className="cursor-pointer my-auto"
href={`/admin/bots/${slackBotId}/channels/${slackChannelConfig.id}`}
>
<EditIcon />
</div>
</Link>
<div className="my-auto">
{"#" + slackChannelConfig.channel_config.channel_name}
</div>
</div>
</TableCell>
<TableCell onClick={(e) => e.stopPropagation()}>
<TableCell>
{slackChannelConfig.persona &&
!isPersonaASlackBotPersona(slackChannelConfig.persona) ? (
<Link
@@ -101,11 +98,10 @@ export function SlackChannelConfigsTable({
: "-"}
</div>
</TableCell>
<TableCell onClick={(e) => e.stopPropagation()}>
<TableCell>
<div
className="cursor-pointer hover:text-destructive"
onClick={async (e) => {
e.stopPropagation();
onClick={async () => {
const response = await deleteSlackChannelConfig(
slackChannelConfig.id
);

View File

@@ -81,11 +81,6 @@ export const SlackChannelConfigCreationForm = ({
respond_to_bots:
existingSlackChannelConfig?.channel_config?.respond_to_bots ||
false,
show_continue_in_web_ui:
// If we're updating, we want to keep the existing value
// Otherwise, we want to default to true
existingSlackChannelConfig?.channel_config
?.show_continue_in_web_ui ?? !isUpdate,
enable_auto_filters:
existingSlackChannelConfig?.enable_auto_filters || false,
respond_member_group_list:
@@ -124,7 +119,6 @@ export const SlackChannelConfigCreationForm = ({
questionmark_prefilter_enabled: Yup.boolean().required(),
respond_tag_only: Yup.boolean().required(),
respond_to_bots: Yup.boolean().required(),
show_continue_in_web_ui: Yup.boolean().required(),
enable_auto_filters: Yup.boolean().required(),
respond_member_group_list: Yup.array().of(Yup.string()).required(),
still_need_help_enabled: Yup.boolean().required(),
@@ -276,13 +270,7 @@ export const SlackChannelConfigCreationForm = ({
{showAdvancedOptions && (
<div className="mt-4">
<BooleanFormField
name="show_continue_in_web_ui"
removeIndent
label="Show Continue in Web UI button"
tooltip="If set, will show a button at the bottom of the response that allows the user to continue the conversation in the Danswer Web UI"
/>
<div className="w-64 mb-4 mt-4">
<div className="w-64 mb-4">
<SelectorFormField
name="response_type"
label="Answer Type"

View File

@@ -15,7 +15,6 @@ interface SlackChannelConfigCreationRequest {
questionmark_prefilter_enabled: boolean;
respond_tag_only: boolean;
respond_to_bots: boolean;
show_continue_in_web_ui: boolean;
respond_member_group_list: string[];
follow_up_tags?: string[];
usePersona: boolean;
@@ -44,7 +43,6 @@ const buildRequestBodyFromCreationRequest = (
channel_name: creationRequest.channel_name,
respond_tag_only: creationRequest.respond_tag_only,
respond_to_bots: creationRequest.respond_to_bots,
show_continue_in_web_ui: creationRequest.show_continue_in_web_ui,
enable_auto_filters: creationRequest.enable_auto_filters,
respond_member_group_list: creationRequest.respond_member_group_list,
answer_filters: buildFiltersFromCreationRequest(creationRequest),

View File

@@ -22,6 +22,7 @@ function SlackBotEditPage({
const unwrappedParams = use(params);
const { popup, setPopup } = usePopup();
console.log("unwrappedParams", unwrappedParams);
const {
data: slackBot,
isLoading: isSlackBotLoading,

View File

@@ -161,7 +161,7 @@ export default function UpgradingPage({
reindexingProgress={sortedReindexingProgress}
/>
) : (
<ErrorCallout errorTitle="Failed to fetch reindexing progress" />
<ErrorCallout errorTitle="Failed to fetch re-indexing progress" />
)}
</>
) : (
@@ -171,7 +171,7 @@ export default function UpgradingPage({
</h3>
<p className="mb-4 text-text-800">
You&apos;re currently switching embedding models, but there
are no connectors to reindex. This means the transition will
are no connectors to re-index. This means the transition will
be quick and seamless!
</p>
<p className="text-text-600">

View File

@@ -6,8 +6,6 @@ import { usePopup } from "@/components/admin/connectors/Popup";
import { mutate } from "swr";
import { buildCCPairInfoUrl } from "./lib";
import { setCCPairStatus } from "@/lib/ccPair";
import { useState } from "react";
import { LoadingAnimation } from "@/components/Loading";
export function ModifyStatusButtonCluster({
ccPair,
@@ -15,72 +13,44 @@ export function ModifyStatusButtonCluster({
ccPair: CCPairFullInfo;
}) {
const { popup, setPopup } = usePopup();
const [isUpdating, setIsUpdating] = useState(false);
const handleStatusChange = async (
newStatus: ConnectorCredentialPairStatus
) => {
if (isUpdating) return; // Prevent double-clicks or multiple requests
setIsUpdating(true);
try {
// Call the backend to update the status
await setCCPairStatus(ccPair.id, newStatus, setPopup);
// Use mutate to revalidate the status on the backend
await mutate(buildCCPairInfoUrl(ccPair.id));
} catch (error) {
console.error("Failed to update status", error);
} finally {
// Reset local updating state and button text after mutation
setIsUpdating(false);
}
};
// Compute the button text based on current state and backend status
const buttonText =
ccPair.status === ConnectorCredentialPairStatus.PAUSED
? "Re-Enable"
: "Pause";
const tooltip =
ccPair.status === ConnectorCredentialPairStatus.PAUSED
? "Click to start indexing again!"
: "When paused, the connector's documents will still be visible. However, no new documents will be indexed.";
return (
<>
{popup}
<Button
className="flex items-center justify-center w-auto min-w-[100px] px-4 py-2"
variant={
ccPair.status === ConnectorCredentialPairStatus.PAUSED
? "success-reverse"
: "default"
}
disabled={isUpdating}
onClick={() =>
handleStatusChange(
ccPair.status === ConnectorCredentialPairStatus.PAUSED
? ConnectorCredentialPairStatus.ACTIVE
: ConnectorCredentialPairStatus.PAUSED
)
}
tooltip={tooltip}
>
{isUpdating ? (
<LoadingAnimation
text={
ccPair.status === ConnectorCredentialPairStatus.PAUSED
? "Resuming"
: "Pausing"
}
size="text-md"
/>
) : (
buttonText
)}
</Button>
{ccPair.status === ConnectorCredentialPairStatus.PAUSED ? (
<Button
variant="success-reverse"
onClick={() =>
setCCPairStatus(
ccPair.id,
ConnectorCredentialPairStatus.ACTIVE,
setPopup,
() => mutate(buildCCPairInfoUrl(ccPair.id))
)
}
tooltip="Click to start indexing again!"
>
Re-Enable
</Button>
) : (
<Button
variant="default"
onClick={() =>
setCCPairStatus(
ccPair.id,
ConnectorCredentialPairStatus.PAUSED,
setPopup,
() => mutate(buildCCPairInfoUrl(ccPair.id))
)
}
tooltip={
"When paused, the connectors documents will still" +
" be visible. However, no new documents will be indexed."
}
>
Pause
</Button>
)}
</>
);
}

View File

@@ -121,7 +121,7 @@ export function ReIndexButton({
{popup}
<Button
variant="success-reverse"
className="ml-auto min-w-[100px]"
className="ml-auto"
onClick={() => {
setReIndexPopupVisible(true);
}}

View File

@@ -25,7 +25,6 @@ import { ReIndexButton } from "./ReIndexButton";
import { buildCCPairInfoUrl } from "./lib";
import { CCPairFullInfo, ConnectorCredentialPairStatus } from "./types";
import { EditableStringFieldDisplay } from "@/components/EditableStringFieldDisplay";
import { Button } from "@/components/ui/button";
// since the uploaded files are cleaned up after some period of time
// re-indexing will not work for the file connector. Also, it would not

View File

@@ -83,7 +83,7 @@ const EditRow = ({
</div>
</TooltipTrigger>
{!documentSet.is_up_to_date && (
<TooltipContent width="max-w-sm">
<TooltipContent maxWidth="max-w-sm">
<div className="flex break-words break-keep whitespace-pre-wrap items-start">
<InfoIcon className="mr-2 mt-0.5" />
Cannot update while syncing! Wait for the sync to finish, then

Some files were not shown because too many files have changed in this diff Show More