Compare commits

..

63 Commits

Author SHA1 Message Date
pablodanswer
d97e96b3f0 build org 2024-12-01 17:20:43 -08:00
pablodanswer
911fbfa5a6 k 2024-12-01 17:14:09 -08:00
pablodanswer
d02305671a k 2024-12-01 17:12:54 -08:00
pablodanswer
bdfa29dcb5 slack chat 2024-12-01 15:06:32 -08:00
pablodanswer
897ed03c19 fix memoization 2024-12-01 15:02:46 -08:00
pablodanswer
49f0c4f1f8 fix memoization 2024-12-01 15:02:28 -08:00
pablodanswer
338c02171b rm shs 2024-12-01 12:46:07 -08:00
pablodanswer
ef1ade84b6 k 2024-12-01 12:46:07 -08:00
pablodanswer
7c81566c54 k 2024-12-01 12:46:07 -08:00
pablodanswer
c9df0aea47 k 2024-12-01 12:46:07 -08:00
pablodanswer
92e0aeecba k 2024-12-01 12:46:07 -08:00
pablodanswer
30c7e07783 update for all screen sizes 2024-12-01 12:46:07 -08:00
pablodanswer
e99704e9bd update sidebar line 2024-12-01 12:46:07 -08:00
pablodanswer
7f36387f7f k 2024-12-01 12:46:07 -08:00
pablodanswer
407592445b minor nit 2024-12-01 12:46:07 -08:00
pablodanswer
2e533d8188 minor date range clarity 2024-12-01 12:46:07 -08:00
pablodanswer
5b56869937 quick unification of icons 2024-12-01 12:46:07 -08:00
pablodanswer
7baeab54e2 address comments 2024-12-01 12:46:07 -08:00
pablodanswer
aefcfb75ef k 2024-12-01 12:46:07 -08:00
pablodanswer
e5adcb457d k 2024-12-01 12:46:07 -08:00
pablodanswer
db6463644a small nit 2024-12-01 12:46:07 -08:00
pablodanswer
e26ba70cc6 update filters 2024-12-01 12:46:07 -08:00
pablodanswer
66ff723c94 badge up 2024-12-01 12:46:07 -08:00
pablodanswer
dda66f2178 finalize changes 2024-12-01 12:46:07 -08:00
pablodanswer
0a27f72d20 cleanup complete 2024-12-01 12:46:07 -08:00
pablodanswer
fe397601ed minor cleanup 2024-12-01 12:46:07 -08:00
pablodanswer
3bc187c1d1 clean up unused components 2024-12-01 12:46:07 -08:00
pablodanswer
9a0b9eecf0 source types update 2024-12-01 12:46:07 -08:00
pablodanswer
e08db414c0 viewport height update 2024-12-01 12:46:07 -08:00
pablodanswer
b5734057b7 various updates 2024-12-01 12:46:07 -08:00
pablodanswer
56beb3ec82 k 2024-12-01 12:46:07 -08:00
pablodanswer
9f2c8118d7 updates 2024-12-01 12:46:07 -08:00
pablodanswer
6e4a3d5d57 finalize tags 2024-12-01 12:46:07 -08:00
pablodanswer
5b3dcf718f scroll nit 2024-12-01 12:46:07 -08:00
pablodanswer
07bd20b5b9 push fade 2024-12-01 12:46:07 -08:00
pablodanswer
eb01b175ae update logs 2024-12-01 12:46:06 -08:00
pablodanswer
6f55e5fe56 default 2024-12-01 12:46:06 -08:00
pablodanswer
18e7609bfc update scroll 2024-12-01 12:46:06 -08:00
pablodanswer
dd69ec6cdb cleanup 2024-12-01 12:46:06 -08:00
pablodanswer
e961fa2820 fix mystery reorg 2024-12-01 12:46:06 -08:00
pablodanswer
d41bf9a3ff clean up 2024-12-01 12:46:06 -08:00
pablodanswer
e3a6c76d51 k 2024-12-01 12:46:06 -08:00
pablodanswer
719c2aa0df update 2024-12-01 12:46:06 -08:00
pablodanswer
09f487e402 updates 2024-12-01 12:46:06 -08:00
pablodanswer
33a1548fc1 k 2024-12-01 12:46:06 -08:00
pablodanswer
e87c93226a updated chat flow 2024-12-01 12:46:06 -08:00
pablodanswer
5e11a79593 proper no assistant typing + no assistant modal 2024-12-01 12:46:06 -08:00
Chris Weaver
f12eb4a5cf Fix assistant prompt zero-ing (#3293) 2024-11-30 04:45:40 +00:00
Chris Weaver
16863de0aa Improve model token limit detection (#3292)
* Properly find context window for ollama llama

* Better ollama support + upgrade litellm

* Ugprade OpenAI as well

* Fix mypy
2024-11-30 04:42:56 +00:00
Weves
63d1eefee5 Add read_only=True for xlsx parsing 2024-11-28 16:02:02 -08:00
pablodanswer
e338677896 order seeding 2024-11-28 15:41:10 -08:00
hagen-danswer
7be80c4af9 increased the pagination limit for confluence spaces (#3288) 2024-11-28 19:04:38 +00:00
rkuo-danswer
7f1e4a02bf Feature/kill indexing (#3213)
* checkpoint

* add celery termination of the task

* rename to RedisConnectorPermissionSyncPayload, add RedisLock to more places, add get_active_search_settings

* rename payload

* pretty sure these weren't named correctly

* testing in progress

* cleanup

* remove space

* merge fix

* three dots animation on Pausing

* improve messaging when connector is stopped or killed and animate buttons

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-11-28 05:32:45 +00:00
rkuo-danswer
5be7d27285 use indexing flag in db for manually triggering indexing (#3264)
* use indexing flag in db for manually trigger indexing

* add comment.

* only try to release the lock if we actually succeeded with the lock

* ensure we don't trigger manual indexing on anything but the primary search settings

* comment usage of primary search settings

* run check for indexing immediately after indexing triggers are set

* reorder fix
2024-11-28 01:34:34 +00:00
Weves
fd84b7a768 Remove duplicate API key router 2024-11-27 16:30:59 -08:00
Subash-Mohan
36941ae663 fix: Cannot configure API keys #3191 2024-11-27 16:25:00 -08:00
Matthew Holland
212353ed4a Fixed default feedback options 2024-11-27 16:23:52 -08:00
Richard Kuo (Danswer)
eb8708f770 the word "error" might be throwing off sentry 2024-11-27 14:31:21 -08:00
Chris Weaver
ac448956e9 Add handling for rate limiting (#3280) 2024-11-27 14:22:15 -08:00
pablodanswer
634a0b9398 no stack by default (#3278) 2024-11-27 20:58:21 +00:00
hagen-danswer
09d3e47c03 Perm sync behavior change (#3262)
* Change external permissions behavior

* fixed behavior

* added error handling

* LLM the goat

* comment

* simplify

* fixed

* done

* limits increased

* added a ton of logging

* uhhhh
2024-11-27 20:04:15 +00:00
pablodanswer
9c0cc94f15 refresh router -> refresh assistants (#3271) 2024-11-27 19:11:58 +00:00
hagen-danswer
07dfde2209 add continue in danswer button to slack bot responses (#3239)
* all done except routing

* fixed initial changes

* added backend endpoint for duplicating a chat session from Slack

* got chat duplication routing done

* got login routing working

* improved answer handling

* finished all checks

* finished all!

* made sure it works with google oauth

* dont remove that lol

* fixed weird thing

* bad comments
2024-11-27 18:25:38 +00:00
153 changed files with 5077 additions and 3866 deletions

View File

@@ -0,0 +1,35 @@
"""add web ui option to slack config
Revision ID: 93560ba1b118
Revises: 6d562f86c78b
Create Date: 2024-11-24 06:36:17.490612
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "93560ba1b118"
down_revision = "6d562f86c78b"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add show_continue_in_web_ui with default False to all existing channel_configs
op.execute(
"""
UPDATE slack_channel_config
SET channel_config = channel_config || '{"show_continue_in_web_ui": false}'::jsonb
WHERE NOT channel_config ? 'show_continue_in_web_ui'
"""
)
def downgrade() -> None:
# Remove show_continue_in_web_ui from all channel_configs
op.execute(
"""
UPDATE slack_channel_config
SET channel_config = channel_config - 'show_continue_in_web_ui'
"""
)

View File

@@ -0,0 +1,27 @@
"""add auto scroll to user model
Revision ID: a8c2065484e6
Revises: abe7378b8217
Create Date: 2024-11-22 17:34:09.690295
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "a8c2065484e6"
down_revision = "abe7378b8217"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column("auto_scroll", sa.Boolean(), nullable=True, server_default=None),
)
def downgrade() -> None:
op.drop_column("user", "auto_scroll")

View File

@@ -0,0 +1,30 @@
"""add indexing trigger to cc_pair
Revision ID: abe7378b8217
Revises: 6d562f86c78b
Create Date: 2024-11-26 19:09:53.481171
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "abe7378b8217"
down_revision = "93560ba1b118"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column(
"indexing_trigger",
sa.Enum("UPDATE", "REINDEX", name="indexingmode", native_enum=False),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("connector_credential_pair", "indexing_trigger")

View File

@@ -23,7 +23,9 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
)
return UserPreferences(**preferences_data)
except KvKeyNotFoundError:
return UserPreferences(chosen_assistants=None, default_model=None)
return UserPreferences(
chosen_assistants=None, default_model=None, auto_scroll=True
)
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:

View File

@@ -5,7 +5,6 @@ from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
@@ -37,7 +36,7 @@ class TaskDependencyError(RuntimeError):
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
lock_beat: RedisLock = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -60,7 +59,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
redis_connector = RedisConnector(tenant_id, cc_pair_id)
try:
try_generate_document_cc_pair_cleanup_tasks(
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
self.app, cc_pair_id, db_session, lock_beat, tenant_id
)
except TaskDependencyError as e:
# this means we wanted to start deleting but dependent tasks were running
@@ -86,7 +85,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
app: Celery,
cc_pair_id: int,
db_session: Session,
r: Redis,
lock_beat: RedisLock,
tenant_id: str | None,
) -> int | None:

View File

@@ -8,6 +8,7 @@ from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from danswer.access.models import DocExternalAccess
from danswer.background.celery.apps.app_base import task_logger
@@ -27,7 +28,7 @@ from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_ext_perm_user_if_not_exists
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncData,
RedisConnectorPermissionSyncPayload,
)
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import doc_permission_sync_ctx
@@ -138,7 +139,7 @@ def try_creating_permissions_sync_task(
LOCK_TIMEOUT = 30
lock = r.lock(
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
timeout=LOCK_TIMEOUT,
)
@@ -162,7 +163,7 @@ def try_creating_permissions_sync_task(
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
app.send_task(
result = app.send_task(
"connector_permission_sync_generator_task",
kwargs=dict(
cc_pair_id=cc_pair_id,
@@ -174,8 +175,8 @@ def try_creating_permissions_sync_task(
)
# set a basic fence to start
payload = RedisConnectorPermissionSyncData(
started=None,
payload = RedisConnectorPermissionSyncPayload(
started=None, celery_task_id=result.id
)
redis_connector.permissions.set_fence(payload)
@@ -241,13 +242,17 @@ def connector_permission_sync_generator_task(
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
if doc_sync_func is None:
raise ValueError(f"No doc sync func found for {source_type}")
raise ValueError(
f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}"
)
logger.info(f"Syncing docs for {source_type}")
logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}")
payload = RedisConnectorPermissionSyncData(
started=datetime.now(timezone.utc),
)
payload = redis_connector.permissions.payload
if not payload:
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
payload.started = datetime.now(timezone.utc)
redis_connector.permissions.set_fence(payload)
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)

View File

@@ -8,6 +8,7 @@ from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
@@ -24,6 +25,9 @@ from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_ext_group_sync import (
RedisConnectorExternalGroupSyncPayload,
)
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
@@ -49,7 +53,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
if cc_pair.access_type != AccessType.SYNC:
return False
# skip pruning if not active
# skip external group sync if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
@@ -107,7 +111,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
cc_pair_ids_to_sync.append(cc_pair.id)
for cc_pair_id in cc_pair_ids_to_sync:
tasks_created = try_creating_permissions_sync_task(
tasks_created = try_creating_external_group_sync_task(
self.app, cc_pair_id, r, tenant_id
)
if not tasks_created:
@@ -125,7 +129,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
lock_beat.release()
def try_creating_permissions_sync_task(
def try_creating_external_group_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
@@ -156,7 +160,7 @@ def try_creating_permissions_sync_task(
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
_ = app.send_task(
result = app.send_task(
"connector_external_group_sync_generator_task",
kwargs=dict(
cc_pair_id=cc_pair_id,
@@ -166,8 +170,13 @@ def try_creating_permissions_sync_task(
task_id=custom_task_id,
priority=DanswerCeleryPriority.HIGH,
)
# set a basic fence to start
redis_connector.external_group_sync.set_fence(True)
payload = RedisConnectorExternalGroupSyncPayload(
started=datetime.now(timezone.utc),
celery_task_id=result.id,
)
redis_connector.external_group_sync.set_fence(payload)
except Exception:
task_logger.exception(
@@ -195,7 +204,7 @@ def connector_external_group_sync_generator_task(
tenant_id: str | None,
) -> None:
"""
Permission sync task that handles document permission syncing for a given connector credential pair
Permission sync task that handles external group syncing for a given connector credential pair
This task assumes that the task has already been properly fenced
"""
@@ -203,7 +212,7 @@ def connector_external_group_sync_generator_task(
r = get_redis_client(tenant_id=tenant_id)
lock = r.lock(
lock: RedisLock = r.lock(
DanswerRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
@@ -228,9 +237,13 @@ def connector_external_group_sync_generator_task(
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
if ext_group_sync_func is None:
raise ValueError(f"No external group sync func found for {source_type}")
raise ValueError(
f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
)
logger.info(f"Syncing docs for {source_type}")
logger.info(
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
)
external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)
@@ -249,7 +262,6 @@ def connector_external_group_sync_generator_task(
)
mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)
except Exception as e:
task_logger.exception(
f"Failed to run external group sync: cc_pair={cc_pair_id}"
@@ -260,6 +272,6 @@ def connector_external_group_sync_generator_task(
raise e
finally:
# we always want to clear the fence after the task is done or failed so it doesn't get stuck
redis_connector.external_group_sync.set_fence(False)
redis_connector.external_group_sync.set_fence(None)
if lock.owned():
lock.release()

View File

@@ -25,11 +25,13 @@ from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import DocumentSource
from danswer.db.connector import mark_ccpair_with_indexing_trigger
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.enums import IndexingMode
from danswer.db.enums import IndexingStatus
from danswer.db.enums import IndexModelStatus
from danswer.db.index_attempt import create_index_attempt
@@ -37,12 +39,13 @@ from danswer.db.index_attempt import delete_index_attempt
from danswer.db.index_attempt import get_all_index_attempts_by_status
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
from danswer.db.index_attempt import mark_attempt_canceled
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexAttempt
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_active_search_settings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
@@ -159,7 +162,7 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
)
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
tasks_created = 0
locked = False
r = get_redis_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
@@ -172,6 +175,8 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
if not lock_beat.acquire(blocking=False):
return None
locked = True
# check for search settings swap
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
old_search_settings = check_index_swap(db_session=db_session)
@@ -205,17 +210,10 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
redis_connector = RedisConnector(tenant_id, cc_pair_id)
with get_session_with_tenant(tenant_id) as db_session:
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings = [primary_search_settings]
# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings.append(secondary_search_settings)
for search_settings_instance in search_settings:
search_settings_list: list[SearchSettings] = get_active_search_settings(
db_session
)
for search_settings_instance in search_settings_list:
redis_connector_index = redis_connector.new_index(
search_settings_instance.id
)
@@ -231,22 +229,46 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
last_attempt = get_last_attempt_for_cc_pair(
cc_pair.id, search_settings_instance.id, db_session
)
search_settings_primary = False
if search_settings_instance.id == search_settings_list[0].id:
search_settings_primary = True
if not _should_index(
cc_pair=cc_pair,
last_index=last_attempt,
search_settings_instance=search_settings_instance,
secondary_index_building=len(search_settings) > 1,
search_settings_primary=search_settings_primary,
secondary_index_building=len(search_settings_list) > 1,
db_session=db_session,
):
continue
reindex = False
if search_settings_instance.id == search_settings_list[0].id:
# the indexing trigger is only checked and cleared with the primary search settings
if cc_pair.indexing_trigger is not None:
if cc_pair.indexing_trigger == IndexingMode.REINDEX:
reindex = True
task_logger.info(
f"Connector indexing manual trigger detected: "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings_instance.id} "
f"indexing_mode={cc_pair.indexing_trigger}"
)
mark_ccpair_with_indexing_trigger(
cc_pair.id, None, db_session
)
# using a task queue and only allowing one task per cc_pair/search_setting
# prevents us from starving out certain attempts
attempt_id = try_creating_indexing_task(
self.app,
cc_pair,
search_settings_instance,
False,
reindex,
db_session,
r,
tenant_id,
@@ -256,7 +278,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
f"Connector indexing queued: "
f"index_attempt={attempt_id} "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings_instance.id} "
f"search_settings={search_settings_instance.id}"
)
tasks_created += 1
@@ -281,7 +303,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
mark_attempt_failed(
attempt.id, db_session, failure_reason=failure_reason
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -289,13 +310,14 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error(
"check_for_indexing - Lock not owned on completion: "
f"tenant={tenant_id}"
)
if locked:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error(
"check_for_indexing - Lock not owned on completion: "
f"tenant={tenant_id}"
)
return tasks_created
@@ -304,6 +326,7 @@ def _should_index(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
search_settings_instance: SearchSettings,
search_settings_primary: bool,
secondary_index_building: bool,
db_session: Session,
) -> bool:
@@ -368,6 +391,11 @@ def _should_index(
):
return False
if search_settings_primary:
if cc_pair.indexing_trigger is not None:
# if a manual indexing trigger is on the cc pair, honor it for primary search settings
return True
# if no attempt has ever occurred, we should index regardless of refresh_freq
if not last_index:
return True
@@ -495,8 +523,11 @@ def try_creating_indexing_task(
return index_attempt_id
@shared_task(name="connector_indexing_proxy_task", acks_late=False, track_started=True)
@shared_task(
name="connector_indexing_proxy_task", bind=True, acks_late=False, track_started=True
)
def connector_indexing_proxy_task(
self: Task,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
@@ -509,6 +540,10 @@ def connector_indexing_proxy_task(
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
if not self.request.id:
task_logger.error("self.request.id is None!")
client = SimpleJobClient()
job = client.submit(
@@ -537,8 +572,30 @@ def connector_indexing_proxy_task(
f"search_settings={search_settings_id}"
)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
while True:
sleep(10)
sleep(5)
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
"Indexing proxy - termination signal detected: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_canceled(
index_attempt_id,
db_session,
"Connector termination signal detected",
)
job.cancel()
break
# do nothing for ongoing jobs that haven't been stopped
if not job.done():

View File

@@ -46,6 +46,7 @@ from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import IndexingStatus
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
@@ -58,7 +59,7 @@ from danswer.redis.redis_connector_credential_pair import RedisConnectorCredenti
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncData,
RedisConnectorPermissionSyncPayload,
)
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
@@ -588,7 +589,7 @@ def monitor_ccpair_permissions_taskset(
if remaining > 0:
return
payload: RedisConnectorPermissionSyncData | None = (
payload: RedisConnectorPermissionSyncPayload | None = (
redis_connector.permissions.payload
)
start_time: datetime | None = payload.started if payload else None
@@ -596,9 +597,7 @@ def monitor_ccpair_permissions_taskset(
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
redis_connector.permissions.taskset_clear()
redis_connector.permissions.generator_clear()
redis_connector.permissions.set_fence(None)
redis_connector.permissions.reset()
def monitor_ccpair_indexing_taskset(
@@ -678,11 +677,15 @@ def monitor_ccpair_indexing_taskset(
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
if index_attempt:
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason=msg,
)
if (
index_attempt.status != IndexingStatus.CANCELED
and index_attempt.status != IndexingStatus.FAILED
):
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason=msg,
)
redis_connector_index.reset()
return
@@ -692,6 +695,7 @@ def monitor_ccpair_indexing_taskset(
task_logger.info(
f"Connector indexing finished: cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"status={status_enum.name} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
@@ -724,7 +728,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
# print current queue lengths
r_celery = self.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r)
n_celery = celery_get_queue_length("celery", r_celery)
n_indexing = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_INDEXING, r_celery
)

View File

@@ -1,6 +1,8 @@
"""Factory stub for running celery worker / celery beat."""
from celery import Celery
from danswer.background.celery.apps.beat import celery_app
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
app = celery_app
app: Celery = celery_app

View File

@@ -1,8 +1,10 @@
"""Factory stub for running celery worker / celery beat."""
from celery import Celery
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
app = fetch_versioned_implementation(
app: Celery = fetch_versioned_implementation(
"danswer.background.celery.apps.primary", "celery_app"
)

View File

@@ -19,6 +19,7 @@ from danswer.db.connector_credential_pair import get_last_successful_attempt_tim
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.index_attempt import mark_attempt_canceled
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_partially_succeeded
from danswer.db.index_attempt import mark_attempt_succeeded
@@ -87,6 +88,10 @@ def _get_connector_runner(
)
class ConnectorStopSignal(Exception):
"""A custom exception used to signal a stop in processing."""
def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
@@ -208,9 +213,7 @@ def _run_indexing(
# contents still need to be initially pulled.
if callback:
if callback.should_stop():
raise RuntimeError(
"_run_indexing: Connector stop signal detected"
)
raise ConnectorStopSignal("Connector stop signal detected")
# TODO: should we move this into the above callback instead?
db_session.refresh(db_cc_pair)
@@ -304,26 +307,16 @@ def _run_indexing(
)
except Exception as e:
logger.exception(
f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds"
f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds"
)
# Only mark the attempt as a complete failure if this is the first indexing window.
# Otherwise, some progress was made - the next run will not start from the beginning.
# In this case, it is not accurate to mark it as a failure. When the next run begins,
# if that fails immediately, it will be marked as a failure.
#
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
# to give better clarity in the UI, as the next run will never happen.
if (
ind == 0
or not db_cc_pair.status.is_active()
or index_attempt.status != IndexingStatus.IN_PROGRESS
):
mark_attempt_failed(
if isinstance(e, ConnectorStopSignal):
mark_attempt_canceled(
index_attempt.id,
db_session,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
reason=str(e),
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
@@ -335,6 +328,37 @@ def _run_indexing(
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
else:
# Only mark the attempt as a complete failure if this is the first indexing window.
# Otherwise, some progress was made - the next run will not start from the beginning.
# In this case, it is not accurate to mark it as a failure. When the next run begins,
# if that fails immediately, it will be marked as a failure.
#
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
# to give better clarity in the UI, as the next run will never happen.
if (
ind == 0
or not db_cc_pair.status.is_active()
or index_attempt.status != IndexingStatus.IN_PROGRESS
):
mark_attempt_failed(
index_attempt.id,
db_session,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
net_docs=net_doc_change,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
# break => similar to success case. As mentioned above, if the next run fails for the same
# reason it will then be marked as a failure

View File

@@ -605,6 +605,7 @@ def stream_chat_message_objects(
additional_headers=custom_tool_additional_headers,
),
)
tools: list[Tool] = []
for tool_list in tool_dict.values():
tools.extend(tool_list)

View File

@@ -70,7 +70,9 @@ GEN_AI_NUM_RESERVED_OUTPUT_TOKENS = int(
)
# Typically, GenAI models nowadays are at least 4K tokens
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = 4096
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = int(
os.environ.get("GEN_AI_MODEL_FALLBACK_MAX_TOKENS") or 4096
)
# Number of tokens from chat history to include at maximum
# 3000 should be enough context regardless of use, no need to include as much as possible

View File

@@ -51,7 +51,7 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
"restrictions.read.restrictions.group",
]
_SLIM_DOC_BATCH_SIZE = 1000
_SLIM_DOC_BATCH_SIZE = 5000
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
@@ -301,5 +301,8 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
perm_sync_data=perm_sync_data,
)
)
yield doc_metadata_list
doc_metadata_list = []
if len(doc_metadata_list) > _SLIM_DOC_BATCH_SIZE:
yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE]
doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:]
yield doc_metadata_list

View File

@@ -120,7 +120,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
return cast(F, wrapped_call)
_DEFAULT_PAGINATION_LIMIT = 100
_DEFAULT_PAGINATION_LIMIT = 1000
class OnyxConfluence(Confluence):

View File

@@ -18,20 +18,30 @@ from slack_sdk.models.blocks.block_elements import ImageElement
from danswer.chat.models import DanswerQuote
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import SearchFeedbackType
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY
from danswer.context.search.models import SavedSearchDoc
from danswer.danswerbot.slack.constants import CONTINUE_IN_WEB_UI_ACTION_ID
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID
from danswer.danswerbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.formatting import format_slack_message
from danswer.danswerbot.slack.icons import source_to_github_img_link
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import build_continue_in_web_ui_id
from danswer.danswerbot.slack.utils import build_feedback_id
from danswer.danswerbot.slack.utils import remove_slack_text_interactions
from danswer.danswerbot.slack.utils import translate_vespa_highlight_to_slack
from danswer.db.chat import get_chat_session_by_message_id
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import ChannelConfig
from danswer.db.models import Persona
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.utils.text_processing import decode_escapes
from danswer.utils.text_processing import replace_whitespaces_w_space
@@ -101,12 +111,12 @@ def _split_text(text: str, limit: int = 3000) -> list[str]:
return chunks
def clean_markdown_link_text(text: str) -> str:
def _clean_markdown_link_text(text: str) -> str:
# Remove any newlines within the text
return text.replace("\n", " ").strip()
def build_qa_feedback_block(
def _build_qa_feedback_block(
message_id: int, feedback_reminder_id: str | None = None
) -> Block:
return ActionsBlock(
@@ -115,7 +125,6 @@ def build_qa_feedback_block(
ButtonElement(
action_id=LIKE_BLOCK_ACTION_ID,
text="👍 Helpful",
style="primary",
value=feedback_reminder_id,
),
ButtonElement(
@@ -155,7 +164,7 @@ def get_document_feedback_blocks() -> Block:
)
def build_doc_feedback_block(
def _build_doc_feedback_block(
message_id: int,
document_id: str,
document_rank: int,
@@ -182,7 +191,7 @@ def get_restate_blocks(
]
def build_documents_blocks(
def _build_documents_blocks(
documents: list[SavedSearchDoc],
message_id: int | None,
num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY,
@@ -223,7 +232,7 @@ def build_documents_blocks(
feedback: ButtonElement | dict = {}
if message_id is not None:
feedback = build_doc_feedback_block(
feedback = _build_doc_feedback_block(
message_id=message_id,
document_id=d.document_id,
document_rank=rank,
@@ -241,7 +250,7 @@ def build_documents_blocks(
return section_blocks
def build_sources_blocks(
def _build_sources_blocks(
cited_documents: list[tuple[int, SavedSearchDoc]],
num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY,
) -> list[Block]:
@@ -286,7 +295,7 @@ def build_sources_blocks(
+ ([days_ago_str] if days_ago_str else [])
)
document_title = clean_markdown_link_text(doc_sem_id)
document_title = _clean_markdown_link_text(doc_sem_id)
img_link = source_to_github_img_link(d.source_type)
section_blocks.append(
@@ -317,7 +326,50 @@ def build_sources_blocks(
return section_blocks
def build_quotes_block(
def _priority_ordered_documents_blocks(
answer: OneShotQAResponse,
) -> list[Block]:
docs_response = answer.docs if answer.docs else None
top_docs = docs_response.top_documents if docs_response else []
llm_doc_inds = answer.llm_selected_doc_indices or []
llm_docs = [top_docs[i] for i in llm_doc_inds]
remaining_docs = [
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
]
priority_ordered_docs = llm_docs + remaining_docs
if not priority_ordered_docs:
return []
document_blocks = _build_documents_blocks(
documents=priority_ordered_docs,
message_id=answer.chat_message_id,
)
if document_blocks:
document_blocks = [DividerBlock()] + document_blocks
return document_blocks
def _build_citations_blocks(
answer: OneShotQAResponse,
) -> list[Block]:
docs_response = answer.docs if answer.docs else None
top_docs = docs_response.top_documents if docs_response else []
citations = answer.citations or []
cited_docs = []
for citation in citations:
matching_doc = next(
(d for d in top_docs if d.document_id == citation.document_id),
None,
)
if matching_doc:
cited_docs.append((citation.citation_num, matching_doc))
cited_docs.sort()
citations_block = _build_sources_blocks(cited_documents=cited_docs)
return citations_block
def _build_quotes_block(
quotes: list[DanswerQuote],
) -> list[Block]:
quote_lines: list[str] = []
@@ -359,58 +411,70 @@ def build_quotes_block(
return [SectionBlock(text="*Relevant Snippets*\n" + "\n".join(quote_lines))]
def build_qa_response_blocks(
message_id: int | None,
answer: str | None,
quotes: list[DanswerQuote] | None,
source_filters: list[DocumentSource] | None,
time_cutoff: datetime | None,
favor_recent: bool,
def _build_qa_response_blocks(
answer: OneShotQAResponse,
skip_quotes: bool = False,
process_message_for_citations: bool = False,
skip_ai_feedback: bool = False,
feedback_reminder_id: str | None = None,
) -> list[Block]:
retrieval_info = answer.docs
if not retrieval_info:
# This should not happen, even with no docs retrieved, there is still info returned
raise RuntimeError("Failed to retrieve docs, cannot answer question.")
formatted_answer = format_slack_message(answer.answer) if answer.answer else None
quotes = answer.quotes.quotes if answer.quotes else None
if DISABLE_GENERATIVE_AI:
return []
quotes_blocks: list[Block] = []
filter_block: Block | None = None
if time_cutoff or favor_recent or source_filters:
if (
retrieval_info.applied_time_cutoff
or retrieval_info.recency_bias_multiplier > 1
or retrieval_info.applied_source_filters
):
filter_text = "Filters: "
if source_filters:
sources_str = ", ".join([s.value for s in source_filters])
if retrieval_info.applied_source_filters:
sources_str = ", ".join(
[s.value for s in retrieval_info.applied_source_filters]
)
filter_text += f"`Sources in [{sources_str}]`"
if time_cutoff or favor_recent:
if (
retrieval_info.applied_time_cutoff
or retrieval_info.recency_bias_multiplier > 1
):
filter_text += " and "
if time_cutoff is not None:
time_str = time_cutoff.strftime("%b %d, %Y")
if retrieval_info.applied_time_cutoff is not None:
time_str = retrieval_info.applied_time_cutoff.strftime("%b %d, %Y")
filter_text += f"`Docs Updated >= {time_str}` "
if favor_recent:
if time_cutoff is not None:
if retrieval_info.recency_bias_multiplier > 1:
if retrieval_info.applied_time_cutoff is not None:
filter_text += "+ "
filter_text += "`Prioritize Recently Updated Docs`"
filter_block = SectionBlock(text=f"_{filter_text}_")
if not answer:
if not formatted_answer:
answer_blocks = [
SectionBlock(
text="Sorry, I was unable to find an answer, but I did find some potentially relevant docs 🤓"
)
]
else:
answer_processed = decode_escapes(remove_slack_text_interactions(answer))
answer_processed = decode_escapes(
remove_slack_text_interactions(formatted_answer)
)
if process_message_for_citations:
answer_processed = _process_citations_for_slack(answer_processed)
answer_blocks = [
SectionBlock(text=text) for text in _split_text(answer_processed)
]
if quotes:
quotes_blocks = build_quotes_block(quotes)
quotes_blocks = _build_quotes_block(quotes)
# if no quotes OR `build_quotes_block()` did not give back any blocks
# if no quotes OR `_build_quotes_block()` did not give back any blocks
if not quotes_blocks:
quotes_blocks = [
SectionBlock(
@@ -425,20 +489,37 @@ def build_qa_response_blocks(
response_blocks.extend(answer_blocks)
if message_id is not None and not skip_ai_feedback:
response_blocks.append(
build_qa_feedback_block(
message_id=message_id, feedback_reminder_id=feedback_reminder_id
)
)
if not skip_quotes:
response_blocks.extend(quotes_blocks)
return response_blocks
def build_follow_up_block(message_id: int | None) -> ActionsBlock:
def _build_continue_in_web_ui_block(
tenant_id: str | None,
message_id: int | None,
) -> Block:
if message_id is None:
raise ValueError("No message id provided to build continue in web ui block")
with get_session_with_tenant(tenant_id) as db_session:
chat_session = get_chat_session_by_message_id(
db_session=db_session,
message_id=message_id,
)
return ActionsBlock(
block_id=build_continue_in_web_ui_id(message_id),
elements=[
ButtonElement(
action_id=CONTINUE_IN_WEB_UI_ACTION_ID,
text="Continue Chat in Danswer!",
style="primary",
url=f"{WEB_DOMAIN}/chat?slackChatId={chat_session.id}",
),
],
)
def _build_follow_up_block(message_id: int | None) -> ActionsBlock:
return ActionsBlock(
block_id=build_feedback_id(message_id) if message_id is not None else None,
elements=[
@@ -483,3 +564,77 @@ def build_follow_up_resolved_blocks(
]
)
return [text_block, button_block]
def build_slack_response_blocks(
tenant_id: str | None,
message_info: SlackMessageInfo,
answer: OneShotQAResponse,
persona: Persona | None,
channel_conf: ChannelConfig | None,
use_citations: bool,
feedback_reminder_id: str | None,
skip_ai_feedback: bool = False,
) -> list[Block]:
"""
This function is a top level function that builds all the blocks for the Slack response.
It also handles combining all the blocks together.
"""
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
restate_question_block = get_restate_blocks(
message_info.thread_messages[-1].message, message_info.is_bot_msg
)
answer_blocks = _build_qa_response_blocks(
answer=answer,
skip_quotes=persona is not None or use_citations,
process_message_for_citations=use_citations,
)
web_follow_up_block = []
if channel_conf and channel_conf.get("show_continue_in_web_ui"):
web_follow_up_block.append(
_build_continue_in_web_ui_block(
tenant_id=tenant_id,
message_id=answer.chat_message_id,
)
)
follow_up_block = []
if channel_conf and channel_conf.get("follow_up_tags") is not None:
follow_up_block.append(
_build_follow_up_block(message_id=answer.chat_message_id)
)
ai_feedback_block = []
if answer.chat_message_id is not None and not skip_ai_feedback:
ai_feedback_block.append(
_build_qa_feedback_block(
message_id=answer.chat_message_id,
feedback_reminder_id=feedback_reminder_id,
)
)
citations_blocks = []
document_blocks = []
if use_citations:
# if citations are enabled, only show cited documents
citations_blocks = _build_citations_blocks(answer)
else:
document_blocks = _priority_ordered_documents_blocks(answer)
citations_divider = [DividerBlock()] if citations_blocks else []
buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else []
all_blocks = (
restate_question_block
+ answer_blocks
+ ai_feedback_block
+ citations_divider
+ citations_blocks
+ document_blocks
+ buttons_divider
+ web_follow_up_block
+ follow_up_block
)
return all_blocks

View File

@@ -2,6 +2,7 @@ from enum import Enum
LIKE_BLOCK_ACTION_ID = "feedback-like"
DISLIKE_BLOCK_ACTION_ID = "feedback-dislike"
CONTINUE_IN_WEB_UI_ACTION_ID = "continue-in-web-ui"
FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID = "feedback-doc-button"
IMMEDIATE_RESOLVED_BUTTON_ACTION_ID = "immediate-resolved-button"
FOLLOWUP_BUTTON_ACTION_ID = "followup-button"

View File

@@ -28,7 +28,7 @@ from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import build_feedback_id
from danswer.danswerbot.slack.utils import decompose_action_id
from danswer.danswerbot.slack.utils import fetch_group_ids_from_names
from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails
from danswer.danswerbot.slack.utils import fetch_slack_user_ids_from_emails
from danswer.danswerbot.slack.utils import get_channel_name_from_id
from danswer.danswerbot.slack.utils import get_feedback_visibility
from danswer.danswerbot.slack.utils import read_slack_thread
@@ -267,7 +267,7 @@ def handle_followup_button(
tag_names = slack_channel_config.channel_config.get("follow_up_tags")
remaining = None
if tag_names:
tag_ids, remaining = fetch_user_ids_from_emails(
tag_ids, remaining = fetch_slack_user_ids_from_emails(
tag_names, client.web_client
)
if remaining:

View File

@@ -13,7 +13,7 @@ from danswer.danswerbot.slack.handlers.handle_standard_answers import (
handle_standard_answers,
)
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails
from danswer.danswerbot.slack.utils import fetch_slack_user_ids_from_emails
from danswer.danswerbot.slack.utils import fetch_user_ids_from_groups
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import slack_usage_report
@@ -184,7 +184,7 @@ def handle_message(
send_to: list[str] | None = None
missing_users: list[str] | None = None
if respond_member_group_list:
send_to, missing_ids = fetch_user_ids_from_emails(
send_to, missing_ids = fetch_slack_user_ids_from_emails(
respond_member_group_list, client
)

View File

@@ -7,7 +7,6 @@ from typing import TypeVar
from retry import retry
from slack_sdk import WebClient
from slack_sdk.models.blocks import DividerBlock
from slack_sdk.models.blocks import SectionBlock
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
@@ -25,12 +24,7 @@ from danswer.context.search.enums import OptionalSearchSetting
from danswer.context.search.models import BaseFilters
from danswer.context.search.models import RerankingDetails
from danswer.context.search.models import RetrievalDetails
from danswer.danswerbot.slack.blocks import build_documents_blocks
from danswer.danswerbot.slack.blocks import build_follow_up_block
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
from danswer.danswerbot.slack.blocks import build_sources_blocks
from danswer.danswerbot.slack.blocks import get_restate_blocks
from danswer.danswerbot.slack.formatting import format_slack_message
from danswer.danswerbot.slack.blocks import build_slack_response_blocks
from danswer.danswerbot.slack.handlers.utils import send_team_member_message
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import respond_in_thread
@@ -411,62 +405,16 @@ def handle_regular_answer(
)
return True
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg)
formatted_answer = format_slack_message(answer.answer) if answer.answer else None
answer_blocks = build_qa_response_blocks(
message_id=answer.chat_message_id,
answer=formatted_answer,
quotes=answer.quotes.quotes if answer.quotes else None,
source_filters=retrieval_info.applied_source_filters,
time_cutoff=retrieval_info.applied_time_cutoff,
favor_recent=retrieval_info.recency_bias_multiplier > 1,
# currently Personas don't support quotes
# if citations are enabled, also don't use quotes
skip_quotes=persona is not None or use_citations,
process_message_for_citations=use_citations,
all_blocks = build_slack_response_blocks(
tenant_id=tenant_id,
message_info=message_info,
answer=answer,
persona=persona,
channel_conf=channel_conf,
use_citations=use_citations,
feedback_reminder_id=feedback_reminder_id,
)
# Get the chunks fed to the LLM only, then fill with other docs
llm_doc_inds = answer.llm_selected_doc_indices or []
llm_docs = [top_docs[i] for i in llm_doc_inds]
remaining_docs = [
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
]
priority_ordered_docs = llm_docs + remaining_docs
document_blocks = []
citations_block = []
# if citations are enabled, only show cited documents
if use_citations:
citations = answer.citations or []
cited_docs = []
for citation in citations:
matching_doc = next(
(d for d in top_docs if d.document_id == citation.document_id),
None,
)
if matching_doc:
cited_docs.append((citation.citation_num, matching_doc))
cited_docs.sort()
citations_block = build_sources_blocks(cited_documents=cited_docs)
elif priority_ordered_docs:
document_blocks = build_documents_blocks(
documents=priority_ordered_docs,
message_id=answer.chat_message_id,
)
document_blocks = [DividerBlock()] + document_blocks
all_blocks = (
restate_question_block + answer_blocks + citations_block + document_blocks
)
if channel_conf and channel_conf.get("follow_up_tags") is not None:
all_blocks.append(build_follow_up_block(message_id=answer.chat_message_id))
try:
respond_in_thread(
client=client,

View File

@@ -3,9 +3,9 @@ import random
import re
import string
import time
import uuid
from typing import Any
from typing import cast
from typing import Optional
from retry import retry
from slack_sdk import WebClient
@@ -216,6 +216,13 @@ def build_feedback_id(
return unique_prefix + ID_SEPARATOR + feedback_id
def build_continue_in_web_ui_id(
message_id: int,
) -> str:
unique_prefix = str(uuid.uuid4())[:10]
return unique_prefix + ID_SEPARATOR + str(message_id)
def decompose_action_id(feedback_id: str) -> tuple[int, str | None, int | None]:
"""Decompose into query_id, document_id, document_rank, see above function"""
try:
@@ -313,7 +320,7 @@ def get_channel_name_from_id(
raise e
def fetch_user_ids_from_emails(
def fetch_slack_user_ids_from_emails(
user_emails: list[str], client: WebClient
) -> tuple[list[str], list[str]]:
user_ids: list[str] = []
@@ -522,7 +529,7 @@ class SlackRateLimiter:
self.last_reset_time = time.time()
def notify(
self, client: WebClient, channel: str, position: int, thread_ts: Optional[str]
self, client: WebClient, channel: str, position: int, thread_ts: str | None
) -> None:
respond_in_thread(
client=client,

View File

@@ -3,6 +3,7 @@ from datetime import datetime
from datetime import timedelta
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import func
@@ -30,6 +31,7 @@ from danswer.db.models import SearchDoc
from danswer.db.models import SearchDoc as DBSearchDoc
from danswer.db.models import ToolCall
from danswer.db.models import User
from danswer.db.persona import get_best_persona_id_for_user
from danswer.db.pg_file_store import delete_lobj_by_name
from danswer.file_store.models import FileDescriptor
from danswer.llm.override_models import LLMOverride
@@ -250,6 +252,50 @@ def create_chat_session(
return chat_session
def duplicate_chat_session_for_user_from_slack(
db_session: Session,
user: User | None,
chat_session_id: UUID,
) -> ChatSession:
"""
This takes a chat session id for a session in Slack and:
- Creates a new chat session in the DB
- Tries to copy the persona from the original chat session
(if it is available to the user clicking the button)
- Sets the user to the given user (if provided)
"""
chat_session = get_chat_session_by_id(
chat_session_id=chat_session_id,
user_id=None, # Ignore user permissions for this
db_session=db_session,
)
if not chat_session:
raise HTTPException(status_code=400, detail="Invalid Chat Session ID provided")
# This enforces permissions and sets a default
new_persona_id = get_best_persona_id_for_user(
db_session=db_session,
user=user,
persona_id=chat_session.persona_id,
)
return create_chat_session(
db_session=db_session,
user_id=user.id if user else None,
persona_id=new_persona_id,
# Set this to empty string so the frontend will force a rename
description="",
llm_override=chat_session.llm_override,
prompt_override=chat_session.prompt_override,
# Chat sessions from Slack should put people in the chat UI, not the search
one_shot=False,
# Chat is in UI now so this is false
danswerbot_flow=False,
# Maybe we want this in the future to track if it was created from Slack
slack_thread_id=None,
)
def update_chat_session(
db_session: Session,
user_id: UUID | None,
@@ -336,6 +382,28 @@ def get_chat_message(
return chat_message
def get_chat_session_by_message_id(
db_session: Session,
message_id: int,
) -> ChatSession:
"""
Should only be used for Slack
Get the chat session associated with a specific message ID
Note: this ignores permission checks.
"""
stmt = select(ChatMessage).where(ChatMessage.id == message_id)
result = db_session.execute(stmt)
chat_message = result.scalar_one_or_none()
if chat_message is None:
raise ValueError(
f"Unable to find chat session associated with message ID: {message_id}"
)
return chat_message.chat_session
def get_chat_messages_by_sessions(
chat_session_ids: list[UUID],
user_id: UUID | None,
@@ -355,6 +423,44 @@ def get_chat_messages_by_sessions(
return db_session.execute(stmt).scalars().all()
def add_chats_to_session_from_slack_thread(
db_session: Session,
slack_chat_session_id: UUID,
new_chat_session_id: UUID,
) -> None:
new_root_message = get_or_create_root_message(
chat_session_id=new_chat_session_id,
db_session=db_session,
)
for chat_message in get_chat_messages_by_sessions(
chat_session_ids=[slack_chat_session_id],
user_id=None, # Ignore user permissions for this
db_session=db_session,
skip_permission_check=True,
):
if chat_message.message_type == MessageType.SYSTEM:
continue
# Duplicate the message
new_root_message = create_new_chat_message(
db_session=db_session,
chat_session_id=new_chat_session_id,
parent_message=new_root_message,
message=chat_message.message,
files=chat_message.files,
rephrased_query=chat_message.rephrased_query,
error=chat_message.error,
citations=chat_message.citations,
reference_docs=chat_message.search_docs,
tool_call=chat_message.tool_call,
prompt_id=chat_message.prompt_id,
token_count=chat_message.token_count,
message_type=chat_message.message_type,
alternate_assistant_id=chat_message.alternate_assistant_id,
overridden_model=chat_message.overridden_model,
)
def get_search_docs_for_chat_message(
chat_message_id: int, db_session: Session
) -> list[SearchDoc]:

View File

@@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
from danswer.configs.app_configs import DEFAULT_PRUNING_FREQ
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import InputType
from danswer.db.enums import IndexingMode
from danswer.db.models import Connector
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexAttempt
@@ -311,3 +312,25 @@ def mark_cc_pair_as_external_group_synced(db_session: Session, cc_pair_id: int)
# If this changes, we need to update this function.
cc_pair.last_time_external_group_sync = datetime.now(timezone.utc)
db_session.commit()
def mark_ccpair_with_indexing_trigger(
cc_pair_id: int, indexing_mode: IndexingMode | None, db_session: Session
) -> None:
"""indexing_mode sets a field which will be picked up by a background task
to trigger indexing. Set to None to disable the trigger."""
try:
cc_pair = db_session.execute(
select(ConnectorCredentialPair)
.where(ConnectorCredentialPair.id == cc_pair_id)
.with_for_update()
).scalar_one()
if cc_pair is None:
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")
cc_pair.indexing_trigger = indexing_mode
db_session.commit()
except Exception:
db_session.rollback()
raise

View File

@@ -324,8 +324,11 @@ def associate_default_cc_pair(db_session: Session) -> None:
def _relate_groups_to_cc_pair__no_commit(
db_session: Session,
cc_pair_id: int,
user_group_ids: list[int],
user_group_ids: list[int] | None = None,
) -> None:
if not user_group_ids:
return
for group_id in user_group_ids:
db_session.add(
UserGroup__ConnectorCredentialPair(
@@ -402,12 +405,11 @@ def add_credential_to_connector(
db_session.flush() # make sure the association has an id
db_session.refresh(association)
if groups and access_type != AccessType.SYNC:
_relate_groups_to_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=association.id,
user_group_ids=groups,
)
_relate_groups_to_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=association.id,
user_group_ids=groups,
)
db_session.commit()

View File

@@ -19,6 +19,11 @@ class IndexingStatus(str, PyEnum):
return self in terminal_states
class IndexingMode(str, PyEnum):
UPDATE = "update"
REINDEX = "reindex"
# these may differ in the future, which is why we're okay with this duplication
class DeletionStatus(str, PyEnum):
NOT_STARTED = "not_started"

View File

@@ -42,7 +42,7 @@ from danswer.configs.constants import DEFAULT_BOOST
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import FileOrigin
from danswer.configs.constants import MessageType
from danswer.db.enums import AccessType
from danswer.db.enums import AccessType, IndexingMode
from danswer.configs.constants import NotificationType
from danswer.configs.constants import SearchFeedbackType
from danswer.configs.constants import TokenRateLimitScope
@@ -126,6 +126,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
# if specified, controls the assistants that are shown to the user + their order
# if not specified, all assistants are shown
auto_scroll: Mapped[bool] = mapped_column(Boolean, default=True)
chosen_assistants: Mapped[list[int] | None] = mapped_column(
postgresql.JSONB(), nullable=True, default=None
)
@@ -438,6 +439,10 @@ class ConnectorCredentialPair(Base):
total_docs_indexed: Mapped[int] = mapped_column(Integer, default=0)
indexing_trigger: Mapped[IndexingMode | None] = mapped_column(
Enum(IndexingMode, native_enum=False), nullable=True
)
connector: Mapped["Connector"] = relationship(
"Connector", back_populates="credentials"
)
@@ -1480,6 +1485,7 @@ class ChannelConfig(TypedDict):
# If None then no follow up
# If empty list, follow up with no tags
follow_up_tags: NotRequired[list[str]]
show_continue_in_web_ui: NotRequired[bool] # defaults to False
class SlackBotResponseType(str, PyEnum):

View File

@@ -113,6 +113,31 @@ def fetch_persona_by_id(
return persona
def get_best_persona_id_for_user(
db_session: Session, user: User | None, persona_id: int | None = None
) -> int | None:
if persona_id is not None:
stmt = select(Persona).where(Persona.id == persona_id).distinct()
stmt = _add_user_filters(
stmt=stmt,
user=user,
# We don't want to filter by editable here, we just want to see if the
# persona is usable by the user
get_editable=False,
)
persona = db_session.scalars(stmt).one_or_none()
if persona:
return persona.id
# If the persona is not found, or the slack bot is using doc sets instead of personas,
# we need to find the best persona for the user
# This is the persona with the highest display priority that the user has access to
stmt = select(Persona).order_by(Persona.display_priority.desc()).distinct()
stmt = _add_user_filters(stmt=stmt, user=user, get_editable=True)
persona = db_session.scalars(stmt).one_or_none()
return persona.id if persona else None
def _get_persona_by_name(
persona_name: str, user: User | None, db_session: Session
) -> Persona | None:
@@ -160,7 +185,7 @@ def create_update_persona(
"persona_id": persona_id,
"user": user,
"db_session": db_session,
**create_persona_request.dict(exclude={"users", "groups"}),
**create_persona_request.model_dump(exclude={"users", "groups"}),
}
persona = upsert_persona(**persona_data)
@@ -733,6 +758,8 @@ def get_prompt_by_name(
if user and user.role != UserRole.ADMIN:
stmt = stmt.where(Prompt.user_id == user.id)
# Order by ID to ensure consistent result when multiple prompts exist
stmt = stmt.order_by(Prompt.id).limit(1)
result = db_session.execute(stmt).scalar_one_or_none()
return result

View File

@@ -143,6 +143,25 @@ def get_secondary_search_settings(db_session: Session) -> SearchSettings | None:
return latest_settings
def get_active_search_settings(db_session: Session) -> list[SearchSettings]:
"""Returns active search settings. The first entry will always be the current search
settings. If there are new search settings that are being migrated to, those will be
the second entry."""
search_settings_list: list[SearchSettings] = []
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings_list.append(primary_search_settings)
# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings_list.append(secondary_search_settings)
return search_settings_list
def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
query = select(SearchSettings).order_by(SearchSettings.id.desc())
result = db_session.execute(query)

View File

@@ -295,7 +295,7 @@ def pptx_to_text(file: IO[Any]) -> str:
def xlsx_to_text(file: IO[Any]) -> str:
workbook = openpyxl.load_workbook(file)
workbook = openpyxl.load_workbook(file, read_only=True)
text_content = []
for sheet in workbook.worksheets:
sheet_string = "\n".join(

View File

@@ -26,7 +26,9 @@ from langchain_core.messages.tool import ToolMessage
from langchain_core.prompt_values import PromptValue
from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
from danswer.configs.model_configs import DISABLE_LITELLM_STREAMING
from danswer.configs.model_configs import (
DISABLE_LITELLM_STREAMING,
)
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.configs.model_configs import LITELLM_EXTRA_BODY
from danswer.llm.interfaces import LLM
@@ -161,7 +163,9 @@ def _convert_delta_to_message_chunk(
if role == "user":
return HumanMessageChunk(content=content)
elif role == "assistant":
# NOTE: if tool calls are present, then it's an assistant.
# In Ollama, the role will be None for tool-calls
elif role == "assistant" or tool_calls:
if tool_calls:
tool_call = tool_calls[0]
tool_name = tool_call.function.name or (curr_msg and curr_msg.name) or ""
@@ -236,6 +240,7 @@ class DefaultMultiLLM(LLM):
custom_config: dict[str, str] | None = None,
extra_headers: dict[str, str] | None = None,
extra_body: dict | None = LITELLM_EXTRA_BODY,
model_kwargs: dict[str, Any] | None = None,
long_term_logger: LongTermLogger | None = None,
):
self._timeout = timeout
@@ -268,7 +273,7 @@ class DefaultMultiLLM(LLM):
for k, v in custom_config.items():
os.environ[k] = v
model_kwargs: dict[str, Any] = {}
model_kwargs = model_kwargs or {}
if extra_headers:
model_kwargs.update({"extra_headers": extra_headers})
if extra_body:

View File

@@ -1,5 +1,8 @@
from typing import Any
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_default_provider
@@ -13,6 +16,15 @@ from danswer.utils.headers import build_llm_extra_headers
from danswer.utils.long_term_log import LongTermLogger
def _build_extra_model_kwargs(provider: str) -> dict[str, Any]:
"""Ollama requires us to specify the max context window.
For now, just using the GEN_AI_MODEL_FALLBACK_MAX_TOKENS value.
TODO: allow model-specific values to be configured via the UI.
"""
return {"num_ctx": GEN_AI_MODEL_FALLBACK_MAX_TOKENS} if provider == "ollama" else {}
def get_main_llm_from_tuple(
llms: tuple[LLM, LLM],
) -> LLM:
@@ -132,5 +144,6 @@ def get_llm(
temperature=temperature,
custom_config=custom_config,
extra_headers=build_llm_extra_headers(additional_headers),
model_kwargs=_build_extra_model_kwargs(provider),
long_term_logger=long_term_logger,
)

View File

@@ -1,3 +1,4 @@
import copy
import io
import json
from collections.abc import Callable
@@ -385,6 +386,62 @@ def test_llm(llm: LLM) -> str | None:
return error_msg
def get_model_map() -> dict:
starting_map = copy.deepcopy(cast(dict, litellm.model_cost))
# NOTE: we could add additional models here in the future,
# but for now there is no point. Ollama allows the user to
# to specify their desired max context window, and it's
# unlikely to be standard across users even for the same model
# (it heavily depends on their hardware). For now, we'll just
# rely on GEN_AI_MODEL_FALLBACK_MAX_TOKENS to cover this.
# for model_name in [
# "llama3.2",
# "llama3.2:1b",
# "llama3.2:3b",
# "llama3.2:11b",
# "llama3.2:90b",
# ]:
# starting_map[f"ollama/{model_name}"] = {
# "max_tokens": 128000,
# "max_input_tokens": 128000,
# "max_output_tokens": 128000,
# }
return starting_map
def _strip_extra_provider_from_model_name(model_name: str) -> str:
return model_name.split("/")[1] if "/" in model_name else model_name
def _strip_colon_from_model_name(model_name: str) -> str:
return ":".join(model_name.split(":")[:-1]) if ":" in model_name else model_name
def _find_model_obj(
model_map: dict, provider: str, model_names: list[str | None]
) -> dict | None:
# Filter out None values and deduplicate model names
filtered_model_names = [name for name in model_names if name]
# First try all model names with provider prefix
for model_name in filtered_model_names:
model_obj = model_map.get(f"{provider}/{model_name}")
if model_obj:
logger.debug(f"Using model object for {provider}/{model_name}")
return model_obj
# Then try all model names without provider prefix
for model_name in filtered_model_names:
model_obj = model_map.get(model_name)
if model_obj:
logger.debug(f"Using model object for {model_name}")
return model_obj
return None
def get_llm_max_tokens(
model_map: dict,
model_name: str,
@@ -397,22 +454,22 @@ def get_llm_max_tokens(
return GEN_AI_MAX_TOKENS
try:
model_obj = model_map.get(f"{model_provider}/{model_name}")
if model_obj:
logger.debug(f"Using model object for {model_provider}/{model_name}")
if not model_obj:
model_obj = model_map.get(model_name)
if model_obj:
logger.debug(f"Using model object for {model_name}")
if not model_obj:
model_name_split = model_name.split("/")
if len(model_name_split) > 1:
model_obj = model_map.get(model_name_split[1])
if model_obj:
logger.debug(f"Using model object for {model_name_split[1]}")
extra_provider_stripped_model_name = _strip_extra_provider_from_model_name(
model_name
)
model_obj = _find_model_obj(
model_map,
model_provider,
[
model_name,
# Remove leading extra provider. Usually for cases where user has a
# customer model proxy which appends another prefix
extra_provider_stripped_model_name,
# remove :XXXX from the end, if present. Needed for ollama.
_strip_colon_from_model_name(model_name),
_strip_colon_from_model_name(extra_provider_stripped_model_name),
],
)
if not model_obj:
raise RuntimeError(
f"No litellm entry found for {model_provider}/{model_name}"
@@ -488,7 +545,7 @@ def get_max_input_tokens(
# `model_cost` dict is a named public interface:
# https://litellm.vercel.app/docs/completion/token_usage#7-model_cost
# model_map is litellm.model_cost
litellm_model_map = litellm.model_cost
litellm_model_map = get_model_map()
input_toks = (
get_llm_max_tokens(

View File

@@ -26,6 +26,7 @@ from danswer.auth.schemas import UserRead
from danswer.auth.schemas import UserUpdate
from danswer.auth.users import auth_backend
from danswer.auth.users import BasicAuthenticationError
from danswer.auth.users import create_danswer_oauth_router
from danswer.auth.users import fastapi_users
from danswer.configs.app_configs import APP_API_PREFIX
from danswer.configs.app_configs import APP_HOST
@@ -44,6 +45,7 @@ from danswer.configs.constants import AuthType
from danswer.configs.constants import POSTGRES_WEB_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.db.engine import warm_up_connections
from danswer.server.api_key.api import router as api_key_router
from danswer.server.auth_check import check_router_auth
from danswer.server.danswer_api.ingestion import router as danswer_api_router
from danswer.server.documents.cc_pair import router as cc_pair_router
@@ -280,6 +282,7 @@ def get_application() -> FastAPI:
application, get_full_openai_assistants_api_router()
)
include_router_with_global_prefix_prepended(application, long_term_logs_router)
include_router_with_global_prefix_prepended(application, api_key_router)
if AUTH_TYPE == AuthType.DISABLED:
# Server logs this during auth setup verification step
@@ -323,7 +326,7 @@ def get_application() -> FastAPI:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_oauth_router(
create_danswer_oauth_router(
oauth_client,
auth_backend,
USER_AUTH_SECRET,

View File

@@ -0,0 +1,4 @@
class ModelServerRateLimitError(Exception):
"""
Exception raised for rate limiting errors from the model server.
"""

View File

@@ -6,6 +6,9 @@ from typing import Any
import requests
from httpx import HTTPError
from requests import JSONDecodeError
from requests import RequestException
from requests import Response
from retry import retry
from danswer.configs.app_configs import LARGE_CHUNK_RATIO
@@ -16,6 +19,9 @@ from danswer.configs.model_configs import (
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.models import SearchSettings
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.natural_language_processing.exceptions import (
ModelServerRateLimitError,
)
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.utils.logger import setup_logger
@@ -99,28 +105,43 @@ class EmbeddingModel:
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse:
def _make_request() -> EmbedResponse:
def _make_request() -> Response:
response = requests.post(
self.embed_server_endpoint, json=embed_request.model_dump()
)
try:
response.raise_for_status()
except requests.HTTPError as e:
try:
error_detail = response.json().get("detail", str(e))
except Exception:
error_detail = response.text
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
except requests.RequestException as e:
raise HTTPError(f"Request failed: {str(e)}") from e
# signify that this is a rate limit error
if response.status_code == 429:
raise ModelServerRateLimitError(response.text)
return EmbedResponse(**response.json())
response.raise_for_status()
return response
# only perform retries for the non-realtime embedding of passages (e.g. for indexing)
final_make_request_func = _make_request
# if the text type is a passage, add some default
# retries + handling for rate limiting
if embed_request.text_type == EmbedTextType.PASSAGE:
return retry(tries=3, delay=5)(_make_request)()
else:
return _make_request()
final_make_request_func = retry(
tries=3,
delay=5,
exceptions=(RequestException, ValueError, JSONDecodeError),
)(final_make_request_func)
# use 10 second delay as per Azure suggestion
final_make_request_func = retry(
tries=10, delay=10, exceptions=ModelServerRateLimitError
)(final_make_request_func)
try:
response = final_make_request_func()
return EmbedResponse(**response.json())
except requests.HTTPError as e:
try:
error_detail = response.json().get("detail", str(e))
except Exception:
error_detail = response.text
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
except requests.RequestException as e:
raise HTTPError(f"Request failed: {str(e)}") from e
def _batch_encode_texts(
self,

View File

@@ -131,7 +131,7 @@ def _try_initialize_tokenizer(
return tokenizer
except Exception as hf_error:
logger.warning(
f"Error initializing HuggingFaceTokenizer for {model_name}: {hf_error}"
f"Failed to initialize HuggingFaceTokenizer for {model_name}: {hf_error}"
)
# If both initializations fail, return None

View File

@@ -47,6 +47,7 @@ from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.one_shot_answer.models import QueryRephrase
from danswer.one_shot_answer.qa_utils import combine_message_thread
from danswer.one_shot_answer.qa_utils import slackify_message_thread
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
from danswer.server.query_and_chat.models import ChatMessageDetail
@@ -194,13 +195,22 @@ def stream_answer_objects(
)
prompt = persona.prompts[0]
user_message_str = query_msg.message
# For this endpoint, we only save one user message to the chat session
# However, for slackbot, we want to include the history of the entire thread
if danswerbot_flow:
# Right now, we only support bringing over citations and search docs
# from the last message in the thread, not the entire thread
# in the future, we may want to retrieve the entire thread
user_message_str = slackify_message_thread(query_req.messages)
# Create the first User query message
new_user_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=root_message,
prompt_id=query_req.prompt_id,
message=query_msg.message,
token_count=len(llm_tokenizer.encode(query_msg.message)),
message=user_message_str,
token_count=len(llm_tokenizer.encode(user_message_str)),
message_type=MessageType.USER,
db_session=db_session,
commit=True,

View File

@@ -51,3 +51,31 @@ def combine_message_thread(
total_token_count += message_token_count
return "\n\n".join(message_strs)
def slackify_message(message: ThreadMessage) -> str:
if message.role != MessageType.USER:
return message.message
return f"{message.sender or 'Unknown User'} said in Slack:\n{message.message}"
def slackify_message_thread(messages: list[ThreadMessage]) -> str:
if not messages:
return ""
message_strs: list[str] = []
for message in messages:
if message.role == MessageType.USER:
message_text = (
f"{message.sender or 'Unknown User'} said in Slack:\n{message.message}"
)
elif message.role == MessageType.ASSISTANT:
message_text = f"DanswerBot said in Slack:\n{message.message}"
else:
message_text = (
f"{message.role.value.upper()} said in Slack:\n{message.message}"
)
message_strs.append(message_text)
return "\n\n".join(message_strs)

View File

@@ -1,5 +1,8 @@
import time
import redis
from danswer.db.models import SearchSettings
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
@@ -31,6 +34,44 @@ class RedisConnector:
self.tenant_id, self.id, search_settings_id, self.redis
)
def wait_for_indexing_termination(
self,
search_settings_list: list[SearchSettings],
timeout: float = 15.0,
) -> bool:
"""
Returns True if all indexing for the given redis connector is finished within the given timeout.
Returns False if the timeout is exceeded
This check does not guarantee that current indexings being terminated
won't get restarted midflight
"""
finished = False
start = time.monotonic()
while True:
still_indexing = False
for search_settings in search_settings_list:
redis_connector_index = self.new_index(search_settings.id)
if redis_connector_index.fenced:
still_indexing = True
break
if not still_indexing:
finished = True
break
now = time.monotonic()
if now - start > timeout:
break
time.sleep(1)
continue
return finished
@staticmethod
def get_id_from_fence_key(key: str) -> str | None:
"""

View File

@@ -14,8 +14,9 @@ from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
class RedisConnectorPermissionSyncData(BaseModel):
class RedisConnectorPermissionSyncPayload(BaseModel):
started: datetime | None
celery_task_id: str | None
class RedisConnectorPermissionSync:
@@ -78,14 +79,14 @@ class RedisConnectorPermissionSync:
return False
@property
def payload(self) -> RedisConnectorPermissionSyncData | None:
def payload(self) -> RedisConnectorPermissionSyncPayload | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
if fence_bytes is None:
return None
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorPermissionSyncData.model_validate_json(
payload = RedisConnectorPermissionSyncPayload.model_validate_json(
cast(str, fence_str)
)
@@ -93,7 +94,7 @@ class RedisConnectorPermissionSync:
def set_fence(
self,
payload: RedisConnectorPermissionSyncData | None,
payload: RedisConnectorPermissionSyncPayload | None,
) -> None:
if not payload:
self.redis.delete(self.fence_key)
@@ -162,6 +163,12 @@ class RedisConnectorPermissionSync:
return len(async_results)
def reset(self) -> None:
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
self.redis.delete(self.taskset_key)
self.redis.delete(self.fence_key)
@staticmethod
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
taskset_key = f"{RedisConnectorPermissionSync.TASKSET_PREFIX}_{id}"

View File

@@ -1,11 +1,18 @@
from datetime import datetime
from typing import cast
import redis
from celery import Celery
from pydantic import BaseModel
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
class RedisConnectorExternalGroupSyncPayload(BaseModel):
started: datetime | None
celery_task_id: str | None
class RedisConnectorExternalGroupSync:
"""Manages interactions with redis for external group syncing tasks. Should only be accessed
through RedisConnector."""
@@ -68,12 +75,29 @@ class RedisConnectorExternalGroupSync:
return False
def set_fence(self, value: bool) -> None:
if not value:
@property
def payload(self) -> RedisConnectorExternalGroupSyncPayload | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
if fence_bytes is None:
return None
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorExternalGroupSyncPayload.model_validate_json(
cast(str, fence_str)
)
return payload
def set_fence(
self,
payload: RedisConnectorExternalGroupSyncPayload | None,
) -> None:
if not payload:
self.redis.delete(self.fence_key)
return
self.redis.set(self.fence_key, 0)
self.redis.set(self.fence_key, payload.model_dump_json())
@property
def generator_complete(self) -> int | None:

View File

@@ -29,6 +29,8 @@ class RedisConnectorIndex:
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate
def __init__(
self,
tenant_id: str | None,
@@ -51,6 +53,7 @@ class RedisConnectorIndex:
self.generator_lock_key = (
f"{self.GENERATOR_LOCK_PREFIX}_{id}/{search_settings_id}"
)
self.terminate_key = f"{self.TERMINATE_PREFIX}_{id}/{search_settings_id}"
@classmethod
def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str:
@@ -92,6 +95,18 @@ class RedisConnectorIndex:
self.redis.set(self.fence_key, payload.model_dump_json())
def terminating(self, celery_task_id: str) -> bool:
if self.redis.exists(f"{self.terminate_key}_{celery_task_id}"):
return True
return False
def set_terminate(self, celery_task_id: str) -> None:
"""This sets a signal. It does not block!"""
# We shouldn't need very long to terminate the spawned task.
# 10 minute TTL is good.
self.redis.set(f"{self.terminate_key}_{celery_task_id}", 0, ex=600)
def set_generator_complete(self, payload: int | None) -> None:
if not payload:
self.redis.delete(self.generator_complete_key)

View File

@@ -5,7 +5,7 @@ personas:
# this is for DanswerBot to use when tagged in a non-configured channel
# Careful setting specific IDs, this won't autoincrement the next ID value for postgres
- id: 0
name: "Knowledge"
name: "Search"
description: >
Assistant with access to documents from your Connected Sources.
# Default Prompt objects attached to the persona, see prompts.yaml

View File

@@ -6,6 +6,7 @@ from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi.responses import JSONResponse
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
@@ -37,7 +38,9 @@ from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import count_index_attempts_for_connector
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
from danswer.db.models import SearchSettings
from danswer.db.models import User
from danswer.db.search_settings import get_active_search_settings
from danswer.db.search_settings import get_current_search_settings
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_pool import get_redis_client
@@ -158,7 +161,19 @@ def update_cc_pair_status(
status_update_request: CCStatusUpdateRequest,
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> None:
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""This method may wait up to 30 seconds if pausing the connector due to the need to
terminate tasks in progress. Tasks are not guaranteed to terminate within the
timeout.
Returns HTTPStatus.OK if everything finished.
Returns HTTPStatus.ACCEPTED if the connector is being paused, but background tasks
did not finish within the timeout.
"""
WAIT_TIMEOUT = 15.0
still_terminating = False
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
@@ -173,10 +188,76 @@ def update_cc_pair_status(
)
if status_update_request.status == ConnectorCredentialPairStatus.PAUSED:
cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session)
search_settings_list: list[SearchSettings] = get_active_search_settings(
db_session
)
cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session)
cancel_indexing_attempts_past_model(db_session)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
try:
redis_connector.stop.set_fence(True)
while True:
logger.debug(
f"Wait for indexing soft termination starting: cc_pair={cc_pair_id}"
)
wait_succeeded = redis_connector.wait_for_indexing_termination(
search_settings_list, WAIT_TIMEOUT
)
if wait_succeeded:
logger.debug(
f"Wait for indexing soft termination succeeded: cc_pair={cc_pair_id}"
)
break
logger.debug(
"Wait for indexing soft termination timed out. "
f"Moving to hard termination: cc_pair={cc_pair_id} timeout={WAIT_TIMEOUT:.2f}"
)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(
search_settings.id
)
if not redis_connector_index.fenced:
continue
index_payload = redis_connector_index.payload
if not index_payload:
continue
if not index_payload.celery_task_id:
continue
# Revoke the task to prevent it from running
primary_app.control.revoke(index_payload.celery_task_id)
# If it is running, then signaling for termination will get the
# watchdog thread to kill the spawned task
redis_connector_index.set_terminate(index_payload.celery_task_id)
logger.debug(
f"Wait for indexing hard termination starting: cc_pair={cc_pair_id}"
)
wait_succeeded = redis_connector.wait_for_indexing_termination(
search_settings_list, WAIT_TIMEOUT
)
if wait_succeeded:
logger.debug(
f"Wait for indexing hard termination succeeded: cc_pair={cc_pair_id}"
)
break
logger.debug(
f"Wait for indexing hard termination timed out: cc_pair={cc_pair_id}"
)
still_terminating = True
break
finally:
redis_connector.stop.set_fence(False)
update_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
@@ -185,6 +266,18 @@ def update_cc_pair_status(
db_session.commit()
if still_terminating:
return JSONResponse(
status_code=HTTPStatus.ACCEPTED,
content={
"message": "Request accepted, background task termination still in progress"
},
)
return JSONResponse(
status_code=HTTPStatus.OK, content={"message": str(HTTPStatus.OK)}
)
@router.put("/admin/cc-pair/{cc_pair_id}/name")
def update_cc_pair_name(
@@ -267,9 +360,9 @@ def prune_cc_pair(
)
logger.info(
f"Pruning cc_pair: cc_pair_id={cc_pair_id} "
f"connector_id={cc_pair.connector_id} "
f"credential_id={cc_pair.credential_id} "
f"Pruning cc_pair: cc_pair={cc_pair_id} "
f"connector={cc_pair.connector_id} "
f"credential={cc_pair.credential_id} "
f"{cc_pair.connector.name} connector."
)
tasks_created = try_creating_prune_generator_task(

View File

@@ -17,9 +17,9 @@ from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot
from danswer.background.celery.tasks.indexing.tasks import try_creating_indexing_task
from danswer.background.celery.versioned_apps.primary import app as primary_app
from danswer.configs.app_configs import ENABLED_CONNECTOR_TYPES
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import FileOrigin
from danswer.connectors.google_utils.google_auth import (
@@ -59,6 +59,7 @@ from danswer.db.connector import delete_connector
from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector import fetch_connectors
from danswer.db.connector import get_connector_credential_ids
from danswer.db.connector import mark_ccpair_with_indexing_trigger
from danswer.db.connector import update_connector
from danswer.db.connector_credential_pair import add_credential_to_connector
from danswer.db.connector_credential_pair import get_cc_pair_groups_for_ids
@@ -74,6 +75,7 @@ from danswer.db.document import get_document_counts_for_cc_pairs
from danswer.db.engine import get_current_tenant_id
from danswer.db.engine import get_session
from danswer.db.enums import AccessType
from danswer.db.enums import IndexingMode
from danswer.db.index_attempt import get_index_attempts_for_cc_pair
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from danswer.db.index_attempt import get_latest_index_attempts
@@ -86,7 +88,6 @@ from danswer.db.search_settings import get_secondary_search_settings
from danswer.file_store.file_store import get_default_file_store
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_pool import get_redis_client
from danswer.server.documents.models import AuthStatus
from danswer.server.documents.models import AuthUrl
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
@@ -792,12 +793,10 @@ def connector_run_once(
_: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> StatusResponse[list[int]]:
) -> StatusResponse[int]:
"""Used to trigger indexing on a set of cc_pairs associated with a
single connector."""
r = get_redis_client(tenant_id=tenant_id)
connector_id = run_info.connector_id
specified_credential_ids = run_info.credential_ids
@@ -843,54 +842,41 @@ def connector_run_once(
)
]
search_settings = get_current_search_settings(db_session)
connector_credential_pairs = [
get_connector_credential_pair(connector_id, credential_id, db_session)
for credential_id in credential_ids
if credential_id not in skipped_credentials
]
index_attempt_ids = []
num_triggers = 0
for cc_pair in connector_credential_pairs:
if cc_pair is not None:
attempt_id = try_creating_indexing_task(
primary_app,
cc_pair,
search_settings,
run_info.from_beginning,
db_session,
r,
tenant_id,
indexing_mode = IndexingMode.UPDATE
if run_info.from_beginning:
indexing_mode = IndexingMode.REINDEX
mark_ccpair_with_indexing_trigger(cc_pair.id, indexing_mode, db_session)
num_triggers += 1
logger.info(
f"connector_run_once - marking cc_pair with indexing trigger: "
f"connector={run_info.connector_id} "
f"cc_pair={cc_pair.id} "
f"indexing_trigger={indexing_mode}"
)
if attempt_id:
logger.info(
f"connector_run_once - try_creating_indexing_task succeeded: "
f"connector={run_info.connector_id} "
f"cc_pair={cc_pair.id} "
f"attempt={attempt_id} "
)
index_attempt_ids.append(attempt_id)
else:
logger.info(
f"connector_run_once - try_creating_indexing_task failed: "
f"connector={run_info.connector_id} "
f"cc_pair={cc_pair.id}"
)
if not index_attempt_ids:
msg = "No new indexing attempts created, indexing jobs are queued or running."
logger.info(msg)
raise HTTPException(
status_code=400,
detail=msg,
)
# run the beat task to pick up the triggers immediately
primary_app.send_task(
"check_for_indexing",
priority=DanswerCeleryPriority.HIGH,
kwargs={"tenant_id": tenant_id},
)
msg = f"Successfully created {len(index_attempt_ids)} index attempts. {index_attempt_ids}"
msg = f"Marked {num_triggers} index attempts with indexing triggers."
return StatusResponse(
success=True,
message=msg,
data=index_attempt_ids,
data=num_triggers,
)

View File

@@ -45,6 +45,7 @@ class UserPreferences(BaseModel):
visible_assistants: list[int] = []
recent_assistants: list[int] | None = None
default_model: str | None = None
auto_scroll: bool | None = None
class UserInfo(BaseModel):
@@ -79,6 +80,7 @@ class UserInfo(BaseModel):
role=user.role,
preferences=(
UserPreferences(
auto_scroll=user.auto_scroll,
chosen_assistants=user.chosen_assistants,
default_model=user.default_model,
hidden_assistants=user.hidden_assistants,
@@ -128,6 +130,10 @@ class HiddenUpdateRequest(BaseModel):
hidden: bool
class AutoScrollRequest(BaseModel):
auto_scroll: bool | None
class SlackBotCreationRequest(BaseModel):
name: str
enabled: bool
@@ -156,6 +162,7 @@ class SlackChannelConfigCreationRequest(BaseModel):
channel_name: str
respond_tag_only: bool = False
respond_to_bots: bool = False
show_continue_in_web_ui: bool = False
enable_auto_filters: bool = False
# If no team members, assume respond in the channel to everyone
respond_member_group_list: list[str] = Field(default_factory=list)

View File

@@ -80,6 +80,10 @@ def _form_channel_config(
if follow_up_tags is not None:
channel_config["follow_up_tags"] = follow_up_tags
channel_config[
"show_continue_in_web_ui"
] = slack_channel_config_creation_request.show_continue_in_web_ui
channel_config[
"respond_to_bots"
] = slack_channel_config_creation_request.respond_to_bots

View File

@@ -52,6 +52,7 @@ from danswer.db.users import list_users
from danswer.db.users import validate_user_role_update
from danswer.key_value_store.factory import get_kv_store
from danswer.server.manage.models import AllUsersResponse
from danswer.server.manage.models import AutoScrollRequest
from danswer.server.manage.models import UserByEmail
from danswer.server.manage.models import UserInfo
from danswer.server.manage.models import UserPreferences
@@ -497,7 +498,6 @@ def verify_user_logged_in(
return fetch_no_auth_user(store)
raise BasicAuthenticationError(detail="User Not Authenticated")
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
raise BasicAuthenticationError(
detail="Access denied. User's OIDC token has expired.",
@@ -581,6 +581,30 @@ def update_user_recent_assistants(
db_session.commit()
@router.patch("/auto-scroll")
def update_user_auto_scroll(
request: AutoScrollRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
if user is None:
if AUTH_TYPE == AuthType.DISABLED:
store = get_kv_store()
no_auth_user = fetch_no_auth_user(store)
no_auth_user.preferences.auto_scroll = request.auto_scroll
set_no_auth_user_preferences(store, no_auth_user.preferences)
return
else:
raise RuntimeError("This should never happen")
db_session.execute(
update(User)
.where(User.id == user.id) # type: ignore
.values(auto_scroll=request.auto_scroll)
)
db_session.commit()
@router.patch("/user/default-model")
def update_user_default_model(
request: ChosenDefaultModelRequest,

View File

@@ -27,9 +27,11 @@ from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import FileOrigin
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
from danswer.db.chat import add_chats_to_session_from_slack_thread
from danswer.db.chat import create_chat_session
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import delete_chat_session
from danswer.db.chat import duplicate_chat_session_for_user_from_slack
from danswer.db.chat import get_chat_message
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.chat import get_chat_session_by_id
@@ -532,6 +534,38 @@ def seed_chat(
)
class SeedChatFromSlackRequest(BaseModel):
chat_session_id: UUID
class SeedChatFromSlackResponse(BaseModel):
redirect_url: str
@router.post("/seed-chat-session-from-slack")
def seed_chat_from_slack(
chat_seed_request: SeedChatFromSlackRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> SeedChatFromSlackResponse:
slack_chat_session_id = chat_seed_request.chat_session_id
new_chat_session = duplicate_chat_session_for_user_from_slack(
db_session=db_session,
user=user,
chat_session_id=slack_chat_session_id,
)
add_chats_to_session_from_slack_thread(
db_session=db_session,
slack_chat_session_id=slack_chat_session_id,
new_chat_session_id=new_chat_session.id,
)
return SeedChatFromSlackResponse(
redirect_url=f"{WEB_DOMAIN}/chat?chatId={new_chat_session.id}"
)
"""File upload"""

View File

@@ -79,6 +79,7 @@ class CreateChatMessageRequest(ChunkContext):
message: str
# Files that we should attach to this message
file_descriptors: list[FileDescriptor]
# If no prompt provided, uses the largest prompt of the chat session
# but really this should be explicitly specified, only in the simplified APIs is this inferred
# Use prompt_id 0 to use the system default prompt which is Answer-Question

View File

@@ -2,7 +2,6 @@ from typing import cast
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
@@ -38,10 +37,6 @@ basic_router = APIRouter(prefix="/settings")
def put_settings(
settings: Settings, _: User | None = Depends(current_admin_user)
) -> None:
try:
settings.check_validity()
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
store_settings(settings)

View File

@@ -41,33 +41,10 @@ class Notification(BaseModel):
class Settings(BaseModel):
"""General settings"""
chat_page_enabled: bool = True
search_page_enabled: bool = True
default_page: PageType = PageType.SEARCH
maximum_chat_retention_days: int | None = None
gpu_enabled: bool | None = None
product_gating: GatingType = GatingType.NONE
def check_validity(self) -> None:
chat_page_enabled = self.chat_page_enabled
search_page_enabled = self.search_page_enabled
default_page = self.default_page
if chat_page_enabled is False and search_page_enabled is False:
raise ValueError(
"One of `search_page_enabled` and `chat_page_enabled` must be True."
)
if default_page == PageType.CHAT and chat_page_enabled is False:
raise ValueError(
"The default page cannot be 'chat' if the chat page is disabled."
)
if default_page == PageType.SEARCH and search_page_enabled is False:
raise ValueError(
"The default page cannot be 'search' if the search page is disabled."
)
class UserSettings(Settings):
notifications: list[Notification]

View File

@@ -11,6 +11,7 @@ from sqlalchemy import update
from sqlalchemy.orm import Session
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential__UserGroup
@@ -298,6 +299,11 @@ def fetch_user_groups_for_documents(
db_session: Session,
document_ids: list[str],
) -> Sequence[tuple[str, list[str]]]:
"""
Fetches all user groups that have access to the given documents.
NOTE: this doesn't include groups if the cc_pair is access type SYNC
"""
stmt = (
select(Document.id, func.array_agg(UserGroup.name))
.join(
@@ -306,7 +312,11 @@ def fetch_user_groups_for_documents(
)
.join(
ConnectorCredentialPair,
ConnectorCredentialPair.id == UserGroup__ConnectorCredentialPair.cc_pair_id,
and_(
ConnectorCredentialPair.id
== UserGroup__ConnectorCredentialPair.cc_pair_id,
ConnectorCredentialPair.access_type != AccessType.SYNC,
),
)
.join(
DocumentByConnectorCredentialPair,

View File

@@ -16,7 +16,7 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
_VIEWSPACE_PERMISSION_TYPE = "VIEWSPACE"
_REQUEST_PAGINATION_LIMIT = 100
_REQUEST_PAGINATION_LIMIT = 5000
def _get_server_space_permissions(
@@ -97,6 +97,7 @@ def _get_space_permissions(
confluence_client: OnyxConfluence,
is_cloud: bool,
) -> dict[str, ExternalAccess]:
logger.debug("Getting space permissions")
# Gets all the spaces in the Confluence instance
all_space_keys = []
start = 0
@@ -113,6 +114,7 @@ def _get_space_permissions(
start += len(spaces_batch.get("results", []))
# Gets the permissions for each space
logger.debug(f"Got {len(all_space_keys)} spaces from confluence")
space_permissions_by_space_key: dict[str, ExternalAccess] = {}
for space_key in all_space_keys:
if is_cloud:
@@ -242,6 +244,7 @@ def _fetch_all_page_restrictions_for_space(
logger.warning(f"No permissions found for document {slim_doc.id}")
logger.debug("Finished fetching all page restrictions for space")
return document_restrictions
@@ -254,27 +257,28 @@ def confluence_doc_sync(
it in postgres so that when it gets created later, the permissions are
already populated
"""
logger.debug("Starting confluence doc sync")
confluence_connector = ConfluenceConnector(
**cc_pair.connector.connector_specific_config
)
confluence_connector.load_credentials(cc_pair.credential.credential_json)
if confluence_connector.confluence_client is None:
raise ValueError("Failed to load credentials")
confluence_client = confluence_connector.confluence_client
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
space_permissions_by_space_key = _get_space_permissions(
confluence_client=confluence_client,
confluence_client=confluence_connector.confluence_client,
is_cloud=is_cloud,
)
slim_docs = []
logger.debug("Fetching all slim documents from confluence")
for doc_batch in confluence_connector.retrieve_all_slim_documents():
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
slim_docs.extend(doc_batch)
logger.debug("Fetching all page restrictions for space")
return _fetch_all_page_restrictions_for_space(
confluence_client=confluence_client,
confluence_client=confluence_connector.confluence_client,
slim_docs=slim_docs,
space_permissions_by_space_key=space_permissions_by_space_key,
)

View File

@@ -14,7 +14,10 @@ def _build_group_member_email_map(
) -> dict[str, set[str]]:
group_member_emails: dict[str, set[str]] = {}
for user_result in confluence_client.paginated_cql_user_retrieval():
user = user_result["user"]
user = user_result.get("user", {})
if not user:
logger.warning(f"user result missing user field: {user_result}")
continue
email = user.get("email")
if not email:
# This field is only present in Confluence Server

View File

@@ -57,9 +57,9 @@ DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
# If nothing is specified here, we run the doc_sync every time the celery beat runs
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all group permissions every 60 seconds
DocumentSource.GOOGLE_DRIVE: 60,
DocumentSource.CONFLUENCE: 60,
# Polling is not supported so we fetch all group permissions every 5 minutes
DocumentSource.GOOGLE_DRIVE: 5 * 60,
DocumentSource.CONFLUENCE: 5 * 60,
}

View File

@@ -13,7 +13,6 @@ from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import AuthType
from danswer.main import get_application as get_application_base
from danswer.main import include_router_with_global_prefix_prepended
from danswer.server.api_key.api import router as api_key_router
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from ee.danswer.configs.app_configs import OPENID_CONFIG_URL
@@ -116,8 +115,6 @@ def get_application() -> FastAPI:
# Analytics endpoints
include_router_with_global_prefix_prepended(application, analytics_router)
include_router_with_global_prefix_prepended(application, query_history_router)
# Api key management
include_router_with_global_prefix_prepended(application, api_key_router)
# EE only backend APIs
include_router_with_global_prefix_prepended(application, query_router)
include_router_with_global_prefix_prepended(application, chat_router)

View File

@@ -113,10 +113,6 @@ async def refresh_access_token(
def put_settings(
settings: EnterpriseSettings, _: User | None = Depends(current_admin_user)
) -> None:
try:
settings.check_validity()
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
store_settings(settings)

View File

@@ -157,7 +157,6 @@ def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) ->
def _seed_settings(settings: Settings) -> None:
logger.notice("Seeding Settings")
try:
settings.check_validity()
store_base_settings(settings)
logger.notice("Successfully seeded Settings")
except ValueError as e:

View File

@@ -11,6 +11,7 @@ from fastapi import APIRouter
from fastapi import HTTPException
from google.oauth2 import service_account # type: ignore
from litellm import embedding
from litellm.exceptions import RateLimitError
from retry import retry
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
@@ -205,28 +206,22 @@ class CloudEmbedding:
model_name: str | None = None,
deployment_name: str | None = None,
) -> list[Embedding]:
try:
if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(texts, model_name)
elif self.provider == EmbeddingProvider.AZURE:
return self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return self._embed_litellm_proxy(texts, model_name)
if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(texts, model_name)
elif self.provider == EmbeddingProvider.AZURE:
return self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return self._embed_litellm_proxy(texts, model_name)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error embedding text with {self.provider}: {str(e)}",
)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
@staticmethod
def create(
@@ -430,6 +425,11 @@ async def process_embed_request(
prefix=prefix,
)
return EmbedResponse(embeddings=embeddings)
except RateLimitError as e:
raise HTTPException(
status_code=429,
detail=str(e),
)
except Exception as e:
exception_detail = f"Error during embedding process:\n{str(e)}"
logger.exception(exception_detail)

View File

@@ -29,7 +29,7 @@ trafilatura==1.12.2
langchain==0.1.17
langchain-core==0.1.50
langchain-text-splitters==0.0.1
litellm==1.50.2
litellm==1.53.1
lxml==5.3.0
lxml_html_clean==0.2.2
llama-index==0.9.45
@@ -38,7 +38,7 @@ msal==1.28.0
nltk==3.8.1
Office365-REST-Python-Client==2.5.9
oauthlib==3.2.2
openai==1.52.2
openai==1.55.3
openpyxl==3.1.2
playwright==1.41.2
psutil==5.9.5

View File

@@ -7,6 +7,7 @@ from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import EmbeddingProvider
VALID_SAMPLE = ["hi", "hello my name is bob", "woah there!!!. 😃"]
VALID_LONG_SAMPLE = ["hi " * 999]
# openai limit is 2048, cohere is supposed to be 96 but in practice that doesn't
# seem to be true
TOO_LONG_SAMPLE = ["a"] * 2500
@@ -99,3 +100,42 @@ def local_nomic_embedding_model() -> EmbeddingModel:
def test_local_nomic_embedding(local_nomic_embedding_model: EmbeddingModel) -> None:
_run_embeddings(VALID_SAMPLE, local_nomic_embedding_model, 768)
_run_embeddings(TOO_LONG_SAMPLE, local_nomic_embedding_model, 768)
@pytest.fixture
def azure_embedding_model() -> EmbeddingModel:
return EmbeddingModel(
server_host="localhost",
server_port=9000,
model_name="text-embedding-3-large",
normalize=True,
query_prefix=None,
passage_prefix=None,
api_key=os.getenv("AZURE_API_KEY"),
provider_type=EmbeddingProvider.AZURE,
api_url=os.getenv("AZURE_API_URL"),
)
# NOTE (chris): this test doesn't work, and I do not know why
# def test_azure_embedding_model_rate_limit(azure_embedding_model: EmbeddingModel):
# """NOTE: this test relies on a very low rate limit for the Azure API +
# this test only being run once in a 1 minute window"""
# # VALID_LONG_SAMPLE is 999 tokens, so the second call should run into rate
# # limits assuming the limit is 1000 tokens per minute
# result = azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
# assert len(result) == 1
# assert len(result[0]) == 1536
# # this should fail
# with pytest.raises(ModelServerRateLimitError):
# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
# # this should succeed, since passage requests retry up to 10 times
# start = time.time()
# result = azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.PASSAGE)
# assert len(result) == 1
# assert len(result[0]) == 1536
# assert time.time() - start > 30 # make sure we waited, even though we hit rate limits

View File

@@ -240,7 +240,85 @@ class CCPairManager:
result.raise_for_status()
@staticmethod
def wait_for_indexing(
def wait_for_indexing_inactive(
cc_pair: DATestCCPair,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> None:
"""wait for the number of docs to be indexed on the connector.
This is used to test pausing a connector in the middle of indexing and
terminating that indexing."""
print(f"Indexing wait for inactive starting: cc_pair={cc_pair.id}")
start = time.monotonic()
while True:
fetched_cc_pairs = CCPairManager.get_indexing_statuses(
user_performing_action
)
for fetched_cc_pair in fetched_cc_pairs:
if fetched_cc_pair.cc_pair_id != cc_pair.id:
continue
if fetched_cc_pair.in_progress:
continue
print(f"Indexing is inactive: cc_pair={cc_pair.id}")
return
elapsed = time.monotonic() - start
if elapsed > timeout:
raise TimeoutError(
f"Indexing wait for inactive timed out: cc_pair={cc_pair.id} timeout={timeout}s"
)
print(
f"Indexing wait for inactive still waiting: cc_pair={cc_pair.id} elapsed={elapsed:.2f} timeout={timeout}s"
)
time.sleep(5)
@staticmethod
def wait_for_indexing_in_progress(
cc_pair: DATestCCPair,
timeout: float = MAX_DELAY,
num_docs: int = 16,
user_performing_action: DATestUser | None = None,
) -> None:
"""wait for the number of docs to be indexed on the connector.
This is used to test pausing a connector in the middle of indexing and
terminating that indexing."""
start = time.monotonic()
while True:
fetched_cc_pairs = CCPairManager.get_indexing_statuses(
user_performing_action
)
for fetched_cc_pair in fetched_cc_pairs:
if fetched_cc_pair.cc_pair_id != cc_pair.id:
continue
if not fetched_cc_pair.in_progress:
continue
if fetched_cc_pair.docs_indexed >= num_docs:
print(
"Indexed at least the requested number of docs: "
f"cc_pair={cc_pair.id} "
f"docs_indexed={fetched_cc_pair.docs_indexed} "
f"num_docs={num_docs}"
)
return
elapsed = time.monotonic() - start
if elapsed > timeout:
raise TimeoutError(
f"Indexing in progress wait timed out: cc_pair={cc_pair.id} timeout={timeout}s"
)
print(
f"Indexing in progress waiting: cc_pair={cc_pair.id} elapsed={elapsed:.2f} timeout={timeout}s"
)
time.sleep(5)
@staticmethod
def wait_for_indexing_completion(
cc_pair: DATestCCPair,
after: datetime,
timeout: float = MAX_DELAY,

View File

@@ -14,6 +14,7 @@ from tests.integration.common_utils.managers.document_search import (
)
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestConnector
from tests.integration.common_utils.test_models import DATestCredential
@@ -77,7 +78,7 @@ def test_slack_permission_sync(
access_type=AccessType.SYNC,
user_performing_action=admin_user,
)
CCPairManager.wait_for_indexing(
CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,
@@ -112,7 +113,7 @@ def test_slack_permission_sync(
# Run indexing
before = datetime.now(timezone.utc)
CCPairManager.run_once(cc_pair, admin_user)
CCPairManager.wait_for_indexing(
CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,
@@ -215,3 +216,124 @@ def test_slack_permission_sync(
# Ensure test_user_1 can only see messages from the public channel
assert public_message in danswer_doc_message_strings
assert private_message not in danswer_doc_message_strings
def test_slack_group_permission_sync(
reset: None,
vespa_client: vespa_fixture,
slack_test_setup: tuple[dict[str, Any], dict[str, Any]],
) -> None:
"""
This test ensures that permission sync overrides danswer group access.
"""
public_channel, private_channel = slack_test_setup
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(
email="admin@onyx-test.com",
)
# Creating a non-admin user
test_user_1: DATestUser = UserManager.create(
email="test_user_1@onyx-test.com",
)
# Create a user group and adding the non-admin user to it
user_group = UserGroupManager.create(
name="test_group",
user_ids=[test_user_1.id],
cc_pair_ids=[],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group],
user_performing_action=admin_user,
)
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
email_id_map = SlackManager.build_slack_user_email_id_map(slack_client)
admin_user_id = email_id_map[admin_user.email]
LLMProviderManager.create(user_performing_action=admin_user)
# Add only admin to the private channel
SlackManager.set_channel_members(
slack_client=slack_client,
admin_user_id=admin_user_id,
channel=private_channel,
user_ids=[admin_user_id],
)
before = datetime.now(timezone.utc)
credential = CredentialManager.create(
source=DocumentSource.SLACK,
credential_json={
"slack_bot_token": os.environ["SLACK_BOT_TOKEN"],
},
user_performing_action=admin_user,
)
# Create connector with sync access and assign it to the user group
connector = ConnectorManager.create(
name="Slack",
input_type=InputType.POLL,
source=DocumentSource.SLACK,
connector_specific_config={
"workspace": "onyx-test-workspace",
"channels": [private_channel["name"]],
},
access_type=AccessType.SYNC,
groups=[user_group.id],
user_performing_action=admin_user,
)
cc_pair = CCPairManager.create(
credential_id=credential.id,
connector_id=connector.id,
access_type=AccessType.SYNC,
user_performing_action=admin_user,
groups=[user_group.id],
)
# Add a test message to the private channel
private_message = "This is a secret message: 987654"
SlackManager.add_message_to_channel(
slack_client=slack_client,
channel=private_channel,
message=private_message,
)
# Run indexing
CCPairManager.run_once(cc_pair, admin_user)
CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,
)
# Run permission sync
CCPairManager.sync(
cc_pair=cc_pair,
user_performing_action=admin_user,
)
CCPairManager.wait_for_sync(
cc_pair=cc_pair,
after=before,
number_of_updated_docs=1,
user_performing_action=admin_user,
)
# Verify admin can see the message
admin_docs = DocumentSearchManager.search_documents(
query="secret message",
user_performing_action=admin_user,
)
assert private_message in admin_docs
# Verify test_user_1 cannot see the message despite being in the group
# (Slack permissions should take precedence)
user_1_docs = DocumentSearchManager.search_documents(
query="secret message",
user_performing_action=test_user_1,
)
assert private_message not in user_1_docs

View File

@@ -74,7 +74,7 @@ def test_slack_prune(
access_type=AccessType.SYNC,
user_performing_action=admin_user,
)
CCPairManager.wait_for_indexing(
CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,
@@ -113,7 +113,7 @@ def test_slack_prune(
# Run indexing
before = datetime.now(timezone.utc)
CCPairManager.run_once(cc_pair, admin_user)
CCPairManager.wait_for_indexing(
CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,

View File

@@ -58,7 +58,7 @@ def test_overlapping_connector_creation(reset: None) -> None:
user_performing_action=admin_user,
)
CCPairManager.wait_for_indexing(
CCPairManager.wait_for_indexing_completion(
cc_pair_1, now, timeout=120, user_performing_action=admin_user
)
@@ -71,7 +71,7 @@ def test_overlapping_connector_creation(reset: None) -> None:
user_performing_action=admin_user,
)
CCPairManager.wait_for_indexing(
CCPairManager.wait_for_indexing_completion(
cc_pair_2, now, timeout=120, user_performing_action=admin_user
)
@@ -82,3 +82,48 @@ def test_overlapping_connector_creation(reset: None) -> None:
assert info_2
assert info_1.num_docs_indexed == info_2.num_docs_indexed
def test_connector_pause_while_indexing(reset: None) -> None:
"""Tests that we can pause a connector while indexing is in progress and that
tasks end early or abort as a result.
TODO: This does not specifically test for soft or hard termination code paths.
Design specific tests for those use cases.
"""
admin_user: DATestUser = UserManager.create(name="admin_user")
config = {
"wiki_base": os.environ["CONFLUENCE_TEST_SPACE_URL"],
"space": "",
"is_cloud": True,
"page_id": "",
}
credential = {
"confluence_username": os.environ["CONFLUENCE_USER_NAME"],
"confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"],
}
# store the time before we create the connector so that we know after
# when the indexing should have started
datetime.now(timezone.utc)
# create connector
cc_pair_1 = CCPairManager.create_from_scratch(
source=DocumentSource.CONFLUENCE,
connector_specific_config=config,
credential_json=credential,
user_performing_action=admin_user,
)
CCPairManager.wait_for_indexing_in_progress(
cc_pair_1, timeout=60, num_docs=16, user_performing_action=admin_user
)
CCPairManager.pause_cc_pair(cc_pair_1, user_performing_action=admin_user)
CCPairManager.wait_for_indexing_inactive(
cc_pair_1, timeout=60, user_performing_action=admin_user
)
return

View File

@@ -135,7 +135,7 @@ def test_web_pruning(reset: None, vespa_client: vespa_fixture) -> None:
user_performing_action=admin_user,
)
CCPairManager.wait_for_indexing(
CCPairManager.wait_for_indexing_completion(
cc_pair_1, now, timeout=60, user_performing_action=admin_user
)

9
web/@types/favicon-fetch.d.ts vendored Normal file
View File

@@ -0,0 +1,9 @@
declare module "favicon-fetch" {
interface FaviconFetchOptions {
uri: string;
}
function faviconFetch(options: FaviconFetchOptions): string | null;
export default faviconFetch;
}

1007
web/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@@ -17,11 +17,13 @@
"@headlessui/react": "^2.2.0",
"@headlessui/tailwindcss": "^0.2.1",
"@phosphor-icons/react": "^2.0.8",
"@radix-ui/react-dialog": "^1.0.5",
"@radix-ui/react-checkbox": "^1.1.2",
"@radix-ui/react-dialog": "^1.1.2",
"@radix-ui/react-popover": "^1.1.2",
"@radix-ui/react-select": "^2.1.2",
"@radix-ui/react-separator": "^1.1.0",
"@radix-ui/react-slot": "^1.1.0",
"@radix-ui/react-switch": "^1.1.1",
"@radix-ui/react-tabs": "^1.1.1",
"@radix-ui/react-tooltip": "^1.1.3",
"@sentry/nextjs": "^8.34.0",
@@ -37,6 +39,7 @@
"class-variance-authority": "^0.7.0",
"clsx": "^2.1.1",
"date-fns": "^3.6.0",
"favicon-fetch": "^1.0.0",
"formik": "^2.2.9",
"js-cookie": "^3.0.5",
"lodash": "^4.17.21",
@@ -67,6 +70,7 @@
"tailwindcss-animate": "^1.0.7",
"typescript": "5.0.3",
"uuid": "^9.0.1",
"vaul": "^1.1.1",
"yup": "^1.4.0"
},
"devDependencies": {

View File

@@ -405,7 +405,7 @@ export function AssistantEditor({
message: `"${assistant.name}" has been added to your list.`,
type: "success",
});
router.refresh();
await refreshAssistants();
} else {
setPopup({
message: `"${assistant.name}" could not be added to your list.`,

View File

@@ -90,7 +90,7 @@ export function PersonasTable() {
message: `Failed to update persona order - ${await response.text()}`,
});
setFinalPersonas(assistants);
router.refresh();
await refreshAssistants();
return;
}
@@ -151,7 +151,7 @@ export function PersonasTable() {
persona.is_visible
);
if (response.ok) {
router.refresh();
await refreshAssistants();
} else {
setPopup({
type: "error",
@@ -183,7 +183,7 @@ export function PersonasTable() {
onClick={async () => {
const response = await deletePersona(persona.id);
if (response.ok) {
router.refresh();
await refreshAssistants();
} else {
alert(
`Failed to delete persona - ${await response.text()}`

View File

@@ -259,29 +259,8 @@ export async function updatePersona(
): Promise<[Response, Response | null]> {
const { id, existingPromptId } = personaUpdateRequest;
let fileId = null;
if (personaUpdateRequest.uploaded_image) {
fileId = await uploadFile(personaUpdateRequest.uploaded_image);
if (!fileId) {
return [new Response(null, { status: 400 }), null];
}
}
const updatePersonaResponse = await fetch(`/api/persona/${id}`, {
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(
buildPersonaAPIBody(personaUpdateRequest, existingPromptId ?? 0, fileId)
),
});
if (!updatePersonaResponse.ok) {
return [updatePersonaResponse, null];
}
let promptResponse;
let promptId: number | null = null;
if (existingPromptId !== undefined) {
promptResponse = await updatePrompt({
promptId: existingPromptId,
@@ -290,6 +269,7 @@ export async function updatePersona(
taskPrompt: personaUpdateRequest.task_prompt,
includeCitations: personaUpdateRequest.include_citations,
});
promptId = existingPromptId;
} else {
promptResponse = await createPrompt({
personaName: personaUpdateRequest.name,
@@ -297,7 +277,30 @@ export async function updatePersona(
taskPrompt: personaUpdateRequest.task_prompt,
includeCitations: personaUpdateRequest.include_citations,
});
promptId = promptResponse.ok
? ((await promptResponse.json()).id as number)
: null;
}
let fileId = null;
if (personaUpdateRequest.uploaded_image) {
fileId = await uploadFile(personaUpdateRequest.uploaded_image);
if (!fileId) {
return [promptResponse, null];
}
}
const updatePersonaResponse =
promptResponse.ok && promptId !== null
? await fetch(`/api/persona/${id}`, {
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(
buildPersonaAPIBody(personaUpdateRequest, promptId, fileId)
),
})
: null;
return [promptResponse, updatePersonaResponse];
}

View File

@@ -60,21 +60,24 @@ export function SlackChannelConfigsTable({
.slice(numToDisplay * (page - 1), numToDisplay * page)
.map((slackChannelConfig) => {
return (
<TableRow key={slackChannelConfig.id}>
<TableRow
key={slackChannelConfig.id}
className="cursor-pointer hover:bg-gray-100 transition-colors"
onClick={() => {
window.location.href = `/admin/bots/${slackBotId}/channels/${slackChannelConfig.id}`;
}}
>
<TableCell>
<div className="flex gap-x-2">
<Link
className="cursor-pointer my-auto"
href={`/admin/bots/${slackBotId}/channels/${slackChannelConfig.id}`}
>
<div className="my-auto">
<EditIcon />
</Link>
</div>
<div className="my-auto">
{"#" + slackChannelConfig.channel_config.channel_name}
</div>
</div>
</TableCell>
<TableCell>
<TableCell onClick={(e) => e.stopPropagation()}>
{slackChannelConfig.persona &&
!isPersonaASlackBotPersona(slackChannelConfig.persona) ? (
<Link
@@ -98,10 +101,11 @@ export function SlackChannelConfigsTable({
: "-"}
</div>
</TableCell>
<TableCell>
<TableCell onClick={(e) => e.stopPropagation()}>
<div
className="cursor-pointer hover:text-destructive"
onClick={async () => {
onClick={async (e) => {
e.stopPropagation();
const response = await deleteSlackChannelConfig(
slackChannelConfig.id
);

View File

@@ -81,6 +81,11 @@ export const SlackChannelConfigCreationForm = ({
respond_to_bots:
existingSlackChannelConfig?.channel_config?.respond_to_bots ||
false,
show_continue_in_web_ui:
// If we're updating, we want to keep the existing value
// Otherwise, we want to default to true
existingSlackChannelConfig?.channel_config
?.show_continue_in_web_ui ?? !isUpdate,
enable_auto_filters:
existingSlackChannelConfig?.enable_auto_filters || false,
respond_member_group_list:
@@ -119,6 +124,7 @@ export const SlackChannelConfigCreationForm = ({
questionmark_prefilter_enabled: Yup.boolean().required(),
respond_tag_only: Yup.boolean().required(),
respond_to_bots: Yup.boolean().required(),
show_continue_in_web_ui: Yup.boolean().required(),
enable_auto_filters: Yup.boolean().required(),
respond_member_group_list: Yup.array().of(Yup.string()).required(),
still_need_help_enabled: Yup.boolean().required(),
@@ -270,7 +276,13 @@ export const SlackChannelConfigCreationForm = ({
{showAdvancedOptions && (
<div className="mt-4">
<div className="w-64 mb-4">
<BooleanFormField
name="show_continue_in_web_ui"
removeIndent
label="Show Continue in Web UI button"
tooltip="If set, will show a button at the bottom of the response that allows the user to continue the conversation in the Danswer Web UI"
/>
<div className="w-64 mb-4 mt-4">
<SelectorFormField
name="response_type"
label="Answer Type"

View File

@@ -15,6 +15,7 @@ interface SlackChannelConfigCreationRequest {
questionmark_prefilter_enabled: boolean;
respond_tag_only: boolean;
respond_to_bots: boolean;
show_continue_in_web_ui: boolean;
respond_member_group_list: string[];
follow_up_tags?: string[];
usePersona: boolean;
@@ -43,6 +44,7 @@ const buildRequestBodyFromCreationRequest = (
channel_name: creationRequest.channel_name,
respond_tag_only: creationRequest.respond_tag_only,
respond_to_bots: creationRequest.respond_to_bots,
show_continue_in_web_ui: creationRequest.show_continue_in_web_ui,
enable_auto_filters: creationRequest.enable_auto_filters,
respond_member_group_list: creationRequest.respond_member_group_list,
answer_filters: buildFiltersFromCreationRequest(creationRequest),

View File

@@ -22,7 +22,6 @@ function SlackBotEditPage({
const unwrappedParams = use(params);
const { popup, setPopup } = usePopup();
console.log("unwrappedParams", unwrappedParams);
const {
data: slackBot,
isLoading: isSlackBotLoading,

View File

@@ -161,7 +161,7 @@ export default function UpgradingPage({
reindexingProgress={sortedReindexingProgress}
/>
) : (
<ErrorCallout errorTitle="Failed to fetch re-indexing progress" />
<ErrorCallout errorTitle="Failed to fetch reindexing progress" />
)}
</>
) : (
@@ -171,7 +171,7 @@ export default function UpgradingPage({
</h3>
<p className="mb-4 text-text-800">
You&apos;re currently switching embedding models, but there
are no connectors to re-index. This means the transition will
are no connectors to reindex. This means the transition will
be quick and seamless!
</p>
<p className="text-text-600">

View File

@@ -6,6 +6,8 @@ import { usePopup } from "@/components/admin/connectors/Popup";
import { mutate } from "swr";
import { buildCCPairInfoUrl } from "./lib";
import { setCCPairStatus } from "@/lib/ccPair";
import { useState } from "react";
import { LoadingAnimation } from "@/components/Loading";
export function ModifyStatusButtonCluster({
ccPair,
@@ -13,44 +15,72 @@ export function ModifyStatusButtonCluster({
ccPair: CCPairFullInfo;
}) {
const { popup, setPopup } = usePopup();
const [isUpdating, setIsUpdating] = useState(false);
const handleStatusChange = async (
newStatus: ConnectorCredentialPairStatus
) => {
if (isUpdating) return; // Prevent double-clicks or multiple requests
setIsUpdating(true);
try {
// Call the backend to update the status
await setCCPairStatus(ccPair.id, newStatus, setPopup);
// Use mutate to revalidate the status on the backend
await mutate(buildCCPairInfoUrl(ccPair.id));
} catch (error) {
console.error("Failed to update status", error);
} finally {
// Reset local updating state and button text after mutation
setIsUpdating(false);
}
};
// Compute the button text based on current state and backend status
const buttonText =
ccPair.status === ConnectorCredentialPairStatus.PAUSED
? "Re-Enable"
: "Pause";
const tooltip =
ccPair.status === ConnectorCredentialPairStatus.PAUSED
? "Click to start indexing again!"
: "When paused, the connector's documents will still be visible. However, no new documents will be indexed.";
return (
<>
{popup}
{ccPair.status === ConnectorCredentialPairStatus.PAUSED ? (
<Button
variant="success-reverse"
onClick={() =>
setCCPairStatus(
ccPair.id,
ConnectorCredentialPairStatus.ACTIVE,
setPopup,
() => mutate(buildCCPairInfoUrl(ccPair.id))
)
}
tooltip="Click to start indexing again!"
>
Re-Enable
</Button>
) : (
<Button
variant="default"
onClick={() =>
setCCPairStatus(
ccPair.id,
ConnectorCredentialPairStatus.PAUSED,
setPopup,
() => mutate(buildCCPairInfoUrl(ccPair.id))
)
}
tooltip={
"When paused, the connectors documents will still" +
" be visible. However, no new documents will be indexed."
}
>
Pause
</Button>
)}
<Button
className="flex items-center justify-center w-auto min-w-[100px] px-4 py-2"
variant={
ccPair.status === ConnectorCredentialPairStatus.PAUSED
? "success-reverse"
: "default"
}
disabled={isUpdating}
onClick={() =>
handleStatusChange(
ccPair.status === ConnectorCredentialPairStatus.PAUSED
? ConnectorCredentialPairStatus.ACTIVE
: ConnectorCredentialPairStatus.PAUSED
)
}
tooltip={tooltip}
>
{isUpdating ? (
<LoadingAnimation
text={
ccPair.status === ConnectorCredentialPairStatus.PAUSED
? "Resuming"
: "Pausing"
}
size="text-md"
/>
) : (
buttonText
)}
</Button>
</>
);
}

View File

@@ -121,7 +121,7 @@ export function ReIndexButton({
{popup}
<Button
variant="success-reverse"
className="ml-auto"
className="ml-auto min-w-[100px]"
onClick={() => {
setReIndexPopupVisible(true);
}}

View File

@@ -25,6 +25,7 @@ import { ReIndexButton } from "./ReIndexButton";
import { buildCCPairInfoUrl } from "./lib";
import { CCPairFullInfo, ConnectorCredentialPairStatus } from "./types";
import { EditableStringFieldDisplay } from "@/components/EditableStringFieldDisplay";
import { Button } from "@/components/ui/button";
// since the uploaded files are cleaned up after some period of time
// re-indexing will not work for the file connector. Also, it would not

View File

@@ -83,7 +83,7 @@ const EditRow = ({
</div>
</TooltipTrigger>
{!documentSet.is_up_to_date && (
<TooltipContent maxWidth="max-w-sm">
<TooltipContent width="max-w-sm">
<div className="flex break-words break-keep whitespace-pre-wrap items-start">
<InfoIcon className="mr-2 mt-0.5" />
Cannot update while syncing! Wait for the sync to finish, then

View File

@@ -175,29 +175,6 @@ export function SettingsForm() {
{ fieldName, newValue: checked },
];
// If we're disabling a page, check if we need to update the default page
if (
!checked &&
(fieldName === "search_page_enabled" || fieldName === "chat_page_enabled")
) {
const otherPageField =
fieldName === "search_page_enabled"
? "chat_page_enabled"
: "search_page_enabled";
const otherPageEnabled = settings && settings[otherPageField];
if (
otherPageEnabled &&
settings?.default_page ===
(fieldName === "search_page_enabled" ? "search" : "chat")
) {
updates.push({
fieldName: "default_page",
newValue: fieldName === "search_page_enabled" ? "chat" : "search",
});
}
}
updateSettingField(updates);
}
@@ -218,42 +195,17 @@ export function SettingsForm() {
return (
<div>
{popup}
<Title className="mb-4">Page Visibility</Title>
<Title className="mb-4">Workspace Settings</Title>
<Checkbox
label="Search Page Enabled?"
sublabel="If set, then the 'Search' page will be accessible to all users and will show up as an option on the top navbar. If unset, then this page will not be available."
checked={settings.search_page_enabled}
label="Auto-scroll"
sublabel="If set, the chat window will automatically scroll to the bottom as new lines of text are generated by the AI model."
checked={settings.auto_scroll}
onChange={(e) =>
handleToggleSettingsField("search_page_enabled", e.target.checked)
handleToggleSettingsField("auto_scroll", e.target.checked)
}
/>
<Checkbox
label="Chat Page Enabled?"
sublabel="If set, then the 'Chat' page will be accessible to all users and will show up as an option on the top navbar. If unset, then this page will not be available."
checked={settings.chat_page_enabled}
onChange={(e) =>
handleToggleSettingsField("chat_page_enabled", e.target.checked)
}
/>
<Selector
label="Default Page"
subtext="The page that users will be redirected to after logging in. Can only be set to a page that is enabled."
options={[
{ value: "search", name: "Search" },
{ value: "chat", name: "Chat" },
]}
selected={settings.default_page}
onSelect={(value) => {
value &&
updateSettingField([
{ fieldName: "default_page", newValue: value },
]);
}}
/>
{isEnterpriseEnabled && (
<>
<Title className="mb-4">Chat Settings</Title>

View File

@@ -5,14 +5,12 @@ export enum GatingType {
}
export interface Settings {
chat_page_enabled: boolean;
search_page_enabled: boolean;
default_page: "search" | "chat";
maximum_chat_retention_days: number | null;
notifications: Notification[];
needs_reindexing: boolean;
gpu_enabled: boolean;
product_gating: GatingType;
auto_scroll: boolean;
}
export enum NotificationType {
@@ -54,6 +52,7 @@ export interface EnterpriseSettings {
custom_popup_header: string | null;
custom_popup_content: string | null;
enable_consent_screen: boolean | null;
auto_scroll: boolean;
}
export interface CombinedSettings {

View File

@@ -15,10 +15,12 @@ export function EmailPasswordForm({
isSignup = false,
shouldVerify,
referralSource,
nextUrl,
}: {
isSignup?: boolean;
shouldVerify?: boolean;
referralSource?: string;
nextUrl?: string | null;
}) {
const router = useRouter();
const { popup, setPopup } = usePopup();
@@ -69,7 +71,7 @@ export function EmailPasswordForm({
await requestEmailVerification(values.email);
router.push("/auth/waiting-on-verification");
} else {
router.push("/");
router.push(nextUrl ? encodeURI(nextUrl) : "/");
}
} else {
setIsWorking(false);

View File

@@ -22,6 +22,9 @@ const Page = async (props: {
}) => {
const searchParams = await props.searchParams;
const autoRedirectDisabled = searchParams?.disableAutoRedirect === "true";
const nextUrl = Array.isArray(searchParams?.next)
? searchParams?.next[0]
: searchParams?.next || null;
// catch cases where the backend is completely unreachable here
// without try / catch, will just raise an exception and the page
@@ -37,10 +40,6 @@ const Page = async (props: {
console.log(`Some fetch failed for the login page - ${e}`);
}
const nextUrl = Array.isArray(searchParams?.next)
? searchParams?.next[0]
: searchParams?.next || null;
// simply take the user to the home page if Auth is disabled
if (authTypeMetadata?.authType === "disabled") {
return redirect("/");
@@ -100,12 +99,15 @@ const Page = async (props: {
<span className="px-4 text-gray-500">or</span>
<div className="flex-grow border-t border-gray-300"></div>
</div>
<EmailPasswordForm shouldVerify={true} />
<EmailPasswordForm shouldVerify={true} nextUrl={nextUrl} />
<div className="flex">
<Text className="mt-4 mx-auto">
Don&apos;t have an account?{" "}
<Link href="/auth/signup" className="text-link font-medium">
<Link
href={`/auth/signup${searchParams?.next ? `?next=${searchParams.next}` : ""}`}
className="text-link font-medium"
>
Create an account
</Link>
</Text>
@@ -120,11 +122,14 @@ const Page = async (props: {
<LoginText />
</Title>
</div>
<EmailPasswordForm />
<EmailPasswordForm nextUrl={nextUrl} />
<div className="flex">
<Text className="mt-4 mx-auto">
Don&apos;t have an account?{" "}
<Link href="/auth/signup" className="text-link font-medium">
<Link
href={`/auth/signup${searchParams?.next ? `?next=${searchParams.next}` : ""}`}
className="text-link font-medium"
>
Create an account
</Link>
</Text>

View File

@@ -15,7 +15,14 @@ import AuthFlowContainer from "@/components/auth/AuthFlowContainer";
import ReferralSourceSelector from "./ReferralSourceSelector";
import { Separator } from "@/components/ui/separator";
const Page = async () => {
const Page = async (props: {
searchParams?: Promise<{ [key: string]: string | string[] | undefined }>;
}) => {
const searchParams = await props.searchParams;
const nextUrl = Array.isArray(searchParams?.next)
? searchParams?.next[0]
: searchParams?.next || null;
// catch cases where the backend is completely unreachable here
// without try / catch, will just raise an exception and the page
// will not render
@@ -86,12 +93,19 @@ const Page = async () => {
<EmailPasswordForm
isSignup
shouldVerify={authTypeMetadata?.requiresVerification}
nextUrl={nextUrl}
/>
<div className="flex">
<Text className="mt-4 mx-auto">
Already have an account?{" "}
<Link href="/auth/login" className="text-link font-medium">
<Link
href={{
pathname: "/auth/login",
query: { ...searchParams },
}}
className="text-link font-medium"
>
Log In
</Link>
</Text>

View File

@@ -8,7 +8,6 @@ import {
ChatFileType,
ChatSession,
ChatSessionSharedStatus,
DocumentsResponse,
FileDescriptor,
FileChatDisplay,
Message,
@@ -60,7 +59,7 @@ import { useDocumentSelection } from "./useDocumentSelection";
import { LlmOverride, useFilters, useLlmOverride } from "@/lib/hooks";
import { computeAvailableFilters } from "@/lib/filters";
import { ChatState, FeedbackType, RegenerationState } from "./types";
import { DocumentSidebar } from "./documentSidebar/DocumentSidebar";
import { ChatFilters } from "./documentSidebar/ChatFilters";
import { DanswerInitializingLoader } from "@/components/DanswerInitializingLoader";
import { FeedbackModal } from "./modal/FeedbackModal";
import { ShareChatSessionModal } from "./modal/ShareChatSessionModal";
@@ -71,6 +70,7 @@ import { StarterMessages } from "../../components/assistants/StarterMessage";
import {
AnswerPiecePacket,
DanswerDocument,
FinalContextDocs,
StreamStopInfo,
StreamStopReason,
} from "@/lib/search/interfaces";
@@ -105,14 +105,9 @@ import BlurBackground from "./shared_chat_search/BlurBackground";
import { NoAssistantModal } from "@/components/modals/NoAssistantModal";
import { useAssistants } from "@/components/context/AssistantsContext";
import { Separator } from "@/components/ui/separator";
import {
Card,
CardContent,
CardDescription,
CardHeader,
} from "@/components/ui/card";
import { AssistantIcon } from "@/components/assistants/AssistantIcon";
import AssistantBanner from "../../components/assistants/AssistantBanner";
import AssistantSelector from "@/components/chat_search/AssistantSelector";
import { Modal } from "@/components/Modal";
const TEMP_USER_MESSAGE_ID = -1;
const TEMP_ASSISTANT_MESSAGE_ID = -2;
@@ -132,8 +127,9 @@ export function ChatPage({
const {
chatSessions,
availableSources,
availableDocumentSets,
ccPairs,
tags,
documentSets,
llmProviders,
folders,
openedFolders,
@@ -142,6 +138,36 @@ export function ChatPage({
shouldShowWelcomeModal,
refreshChatSessions,
} = useChatContext();
function useScreenSize() {
const [screenSize, setScreenSize] = useState({
width: typeof window !== "undefined" ? window.innerWidth : 0,
height: typeof window !== "undefined" ? window.innerHeight : 0,
});
useEffect(() => {
const handleResize = () => {
setScreenSize({
width: window.innerWidth,
height: window.innerHeight,
});
};
window.addEventListener("resize", handleResize);
return () => window.removeEventListener("resize", handleResize);
}, []);
return screenSize;
}
const { height: screenHeight } = useScreenSize();
const getContainerHeight = () => {
if (autoScrollEnabled) return undefined;
if (screenHeight < 600) return "20vh";
if (screenHeight < 1200) return "30vh";
return "40vh";
};
// handle redirect if chat page is disabled
// NOTE: this must be done here, in a client component since
@@ -149,9 +175,11 @@ export function ChatPage({
// available in server-side components
const settings = useContext(SettingsContext);
const enterpriseSettings = settings?.enterpriseSettings;
if (settings?.settings?.chat_page_enabled === false) {
router.push("/search");
}
const [documentSidebarToggled, setDocumentSidebarToggled] = useState(false);
const [filtersToggled, setFiltersToggled] = useState(false);
const [userSettingsToggled, setUserSettingsToggled] = useState(false);
const { assistants: availableAssistants, finalAssistants } = useAssistants();
@@ -159,14 +187,13 @@ export function ChatPage({
!shouldShowWelcomeModal
);
const { user, isAdmin, isLoadingUser, refreshUser } = useUser();
const { user, isAdmin, isLoadingUser } = useUser();
const slackChatId = searchParams.get("slackChatId");
const existingChatIdRaw = searchParams.get("chatId");
const [sendOnLoad, setSendOnLoad] = useState<string | null>(
searchParams.get(SEARCH_PARAM_NAMES.SEND_ON_LOAD)
);
const currentPersonaId = searchParams.get(SEARCH_PARAM_NAMES.PERSONA_ID);
const modelVersionFromSearchParams = searchParams.get(
SEARCH_PARAM_NAMES.STRUCTURED_MODEL
);
@@ -267,6 +294,17 @@ export function ChatPage({
availableAssistants[0];
const noAssistants = liveAssistant == null || liveAssistant == undefined;
const availableSources = ccPairs.map((ccPair) => ccPair.source);
const [finalAvailableSources, finalAvailableDocumentSets] =
computeAvailableFilters({
selectedPersona: availableAssistants.find(
(assistant) => assistant.id === liveAssistant?.id
),
availableSources: availableSources,
availableDocumentSets: documentSets,
});
// always set the model override for the chat session, when an assistant, llm provider, or user preference exists
useEffect(() => {
if (noAssistants) return;
@@ -356,9 +394,7 @@ export function ChatPage({
textAreaRef.current?.focus();
// only clear things if we're going from one chat session to another
const isChatSessionSwitch =
chatSessionIdRef.current !== null &&
existingChatSessionId !== priorChatSessionId;
const isChatSessionSwitch = existingChatSessionId !== priorChatSessionId;
if (isChatSessionSwitch) {
// de-select documents
clearSelectedDocuments();
@@ -404,6 +440,7 @@ export function ChatPage({
}
return;
}
setIsReady(true);
const shouldScrollToBottom =
visibleRange.get(existingChatSessionId) === undefined ||
visibleRange.get(existingChatSessionId)?.end == 0;
@@ -447,9 +484,9 @@ export function ChatPage({
}
if (shouldScrollToBottom) {
if (!hasPerformedInitialScroll) {
if (!hasPerformedInitialScroll && autoScrollEnabled) {
clientScrollToBottom();
} else if (isChatSessionSwitch) {
} else if (isChatSessionSwitch && autoScrollEnabled) {
clientScrollToBottom(true);
}
}
@@ -469,9 +506,12 @@ export function ChatPage({
});
// force re-name if the chat session doesn't have one
if (!chatSession.description) {
await nameChatSession(existingChatSessionId, seededMessage);
await nameChatSession(existingChatSessionId);
refreshChatSessions();
}
} else if (newMessageHistory.length === 2 && !chatSession.description) {
await nameChatSession(existingChatSessionId);
refreshChatSessions();
}
}
@@ -828,11 +868,13 @@ export function ChatPage({
0
)}px`;
scrollableDivRef?.current.scrollBy({
left: 0,
top: Math.max(heightDifference, 0),
behavior: "smooth",
});
if (autoScrollEnabled) {
scrollableDivRef?.current.scrollBy({
left: 0,
top: Math.max(heightDifference, 0),
behavior: "smooth",
});
}
}
previousHeight.current = newHeight;
}
@@ -879,6 +921,7 @@ export function ChatPage({
endDivRef.current.scrollIntoView({
behavior: fast ? "auto" : "smooth",
});
setHasPerformedInitialScroll(true);
}
}, 50);
@@ -1030,7 +1073,9 @@ export function ChatPage({
}
setAlternativeGeneratingAssistant(alternativeAssistantOverride);
clientScrollToBottom();
let currChatSessionId: string;
const isNewSession = chatSessionIdRef.current === null;
const searchParamBasedChatSessionName =
@@ -1276,8 +1321,8 @@ export function ChatPage({
if (Object.hasOwn(packet, "answer_piece")) {
answer += (packet as AnswerPiecePacket).answer_piece;
} else if (Object.hasOwn(packet, "top_documents")) {
documents = (packet as DocumentsResponse).top_documents;
} else if (Object.hasOwn(packet, "final_context_docs")) {
documents = (packet as FinalContextDocs).final_context_docs;
retrievalType = RetrievalType.Search;
if (documents && documents.length > 0) {
// point to the latest message (we don't know the messageId yet, which is why
@@ -1374,8 +1419,7 @@ export function ChatPage({
type: error ? "error" : "assistant",
retrievalType,
query: finalMessage?.rephrased_query || query,
documents:
finalMessage?.context_docs?.top_documents || documents,
documents: documents,
citations: finalMessage?.citations || {},
files: finalMessage?.files || aiMessageImages || [],
toolCall: finalMessage?.tool_call || toolCall,
@@ -1429,7 +1473,7 @@ export function ChatPage({
if (!searchParamBasedChatSessionName) {
await new Promise((resolve) => setTimeout(resolve, 200));
await nameChatSession(currChatSessionId, currMessage);
await nameChatSession(currChatSessionId);
refreshChatSessions();
}
@@ -1594,6 +1638,11 @@ export function ChatPage({
mobile: settings?.isMobile,
});
const autoScrollEnabled =
user?.preferences?.auto_scroll == null
? settings?.enterpriseSettings?.auto_scroll || false
: user?.preferences?.auto_scroll!;
useScrollonStream({
chatState: currentSessionChatState,
scrollableDivRef,
@@ -1602,6 +1651,7 @@ export function ChatPage({
debounceNumber,
waitForScrollRef,
mobile: settings?.isMobile,
enableAutoScroll: autoScrollEnabled,
});
// Virtualization + Scrolling related effects and functions
@@ -1751,6 +1801,13 @@ export function ChatPage({
liveAssistant
);
});
useEffect(() => {
if (!retrievalEnabled) {
setDocumentSidebarToggled(false);
}
}, [retrievalEnabled]);
const [stackTraceModalContent, setStackTraceModalContent] = useState<
string | null
>(null);
@@ -1759,7 +1816,41 @@ export function ChatPage({
const [settingsToggled, setSettingsToggled] = useState(false);
const currentPersona = alternativeAssistant || liveAssistant;
useEffect(() => {
const handleSlackChatRedirect = async () => {
if (!slackChatId) return;
// Set isReady to false before starting retrieval to display loading text
setIsReady(false);
try {
const response = await fetch("/api/chat/seed-chat-session-from-slack", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
chat_session_id: slackChatId,
}),
});
if (!response.ok) {
throw new Error("Failed to seed chat from Slack");
}
const data = await response.json();
router.push(data.redirect_url);
} catch (error) {
console.error("Error seeding chat from Slack:", error);
setPopup({
message: "Failed to load chat from Slack",
type: "error",
});
}
};
handleSlackChatRedirect();
}, [searchParams, router]);
useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => {
if (event.metaKey || event.ctrlKey) {
@@ -1790,9 +1881,30 @@ export function ChatPage({
setSharedChatSession(chatSession);
};
const [documentSelection, setDocumentSelection] = useState(false);
const toggleDocumentSelectionAspects = () => {
setDocumentSelection((documentSelection) => !documentSelection);
setShowDocSidebar(false);
// const toggleDocumentSelectionAspects = () => {
// setDocumentSelection((documentSelection) => !documentSelection);
// setShowDocSidebar(false);
// };
const toggleDocumentSidebar = () => {
if (!documentSidebarToggled) {
setFiltersToggled(false);
setDocumentSidebarToggled(true);
} else if (!filtersToggled) {
setDocumentSidebarToggled(false);
} else {
setFiltersToggled(false);
}
};
const toggleFilters = () => {
if (!documentSidebarToggled) {
setFiltersToggled(true);
setDocumentSidebarToggled(true);
} else if (filtersToggled) {
setDocumentSidebarToggled(false);
} else {
setFiltersToggled(true);
}
};
interface RegenerationRequest {
@@ -1850,16 +1962,46 @@ export function ChatPage({
/>
)}
{settingsToggled && (
{(settingsToggled || userSettingsToggled) && (
<SetDefaultModelModal
setPopup={setPopup}
setLlmOverride={llmOverrideManager.setGlobalDefault}
defaultModel={user?.preferences.default_model!}
llmProviders={llmProviders}
onClose={() => setSettingsToggled(false)}
onClose={() => {
setUserSettingsToggled(false);
setSettingsToggled(false);
}}
/>
)}
{retrievalEnabled && documentSidebarToggled && settings?.isMobile && (
<div className="md:hidden">
<Modal noPadding noScroll>
<ChatFilters
modal={true}
filterManager={filterManager}
ccPairs={ccPairs}
tags={tags}
documentSets={documentSets}
ref={innerSidebarElementRef}
showFilters={filtersToggled}
closeSidebar={() => {
setDocumentSidebarToggled(false);
}}
selectedMessage={aiMessage}
selectedDocuments={selectedDocuments}
toggleDocumentSelection={toggleDocumentSelection}
clearSelectedDocuments={clearSelectedDocuments}
selectedDocumentTokens={selectedDocumentTokens}
maxTokens={maxTokens}
initialWidth={400}
isOpen={true}
/>
</Modal>
</div>
)}
{deletingChatSession && (
<DeleteEntityModal
entityType="chat"
@@ -1960,6 +2102,50 @@ export function ChatPage({
</div>
</div>
</div>
{!settings?.isMobile && retrievalEnabled && (
<div
style={{ transition: "width 0.30s ease-out" }}
className={`
flex-none
fixed
right-0
z-[1000]
bg-background
h-screen
transition-all
bg-opacity-80
duration-300
ease-in-out
bg-transparent
transition-all
bg-opacity-80
duration-300
ease-in-out
h-full
${documentSidebarToggled ? "w-[400px]" : "w-[0px]"}
`}
>
<ChatFilters
modal={false}
filterManager={filterManager}
ccPairs={ccPairs}
tags={tags}
documentSets={documentSets}
ref={innerSidebarElementRef}
showFilters={filtersToggled}
closeSidebar={() => setDocumentSidebarToggled(false)}
selectedMessage={aiMessage}
selectedDocuments={selectedDocuments}
toggleDocumentSelection={toggleDocumentSelection}
clearSelectedDocuments={clearSelectedDocuments}
selectedDocumentTokens={selectedDocumentTokens}
maxTokens={maxTokens}
initialWidth={400}
isOpen={documentSidebarToggled}
/>
</div>
)}
<BlurBackground
visible={!untoggled && (showDocSidebar || toggledSidebar)}
@@ -1969,9 +2155,12 @@ export function ChatPage({
ref={masterFlexboxRef}
className="flex h-full w-full overflow-x-hidden"
>
<div className="flex h-full flex-col w-full">
<div className="flex h-full relative px-2 flex-col w-full">
{liveAssistant && (
<FunctionalHeader
toggleUserSettings={() => setUserSettingsToggled(true)}
liveAssistant={liveAssistant}
onAssistantChange={onAssistantChange}
sidebarToggled={toggledSidebar}
reset={() => setMessage("")}
page="chat"
@@ -1982,6 +2171,8 @@ export function ChatPage({
}
toggleSidebar={toggleSidebar}
currentChatSession={selectedChatSession}
documentSidebarToggled={documentSidebarToggled}
llmOverrideManager={llmOverrideManager}
/>
)}
@@ -2003,7 +2194,7 @@ export function ChatPage({
duration-300
ease-in-out
h-full
${toggledSidebar ? "w-[250px]" : "w-[0px]"}
${toggledSidebar ? "w-[200px]" : "w-[0px]"}
`}
></div>
)}
@@ -2013,9 +2204,55 @@ export function ChatPage({
{...getRootProps()}
>
<div
className={`w-full h-full flex flex-col default-scrollbar overflow-y-auto overflow-x-hidden relative`}
className={`w-full h-[calc(100vh-160px)] flex flex-col default-scrollbar overflow-y-auto overflow-x-hidden relative`}
ref={scrollableDivRef}
>
{liveAssistant && onAssistantChange && (
<div className="z-20 fixed top-4 pointer-events-none left-0 w-full flex justify-center overflow-visible">
{!settings?.isMobile && (
<div
style={{ transition: "width 0.30s ease-out" }}
className={`
flex-none
overflow-y-hidden
transition-all
pointer-events-none
duration-300
ease-in-out
h-full
${toggledSidebar ? "w-[200px]" : "w-[0px]"}
`}
></div>
)}
<AssistantSelector
isMobile={settings?.isMobile!}
liveAssistant={liveAssistant}
onAssistantChange={onAssistantChange}
llmOverrideManager={llmOverrideManager}
/>
{!settings?.isMobile && (
<div
style={{ transition: "width 0.30s ease-out" }}
className={`
flex-none
overflow-y-hidden
transition-all
duration-300
ease-in-out
h-full
pointer-events-none
${
documentSidebarToggled && retrievalEnabled
? "w-[400px]"
: "w-[0px]"
}
`}
></div>
)}
</div>
)}
{/* ChatBanner is a custom banner that displays a admin-specified message at
the top of the chat page. Oly used in the EE version of the app. */}
@@ -2023,7 +2260,7 @@ export function ChatPage({
!isFetchingChatMessages &&
currentSessionChatState == "input" &&
!loadingError && (
<div className="h-full mt-12 flex flex-col justify-center items-center">
<div className="h-full w-[95%] mx-auto mt-12 flex flex-col justify-center items-center">
<ChatIntro selectedPersona={liveAssistant} />
<StarterMessages
@@ -2045,6 +2282,7 @@ export function ChatPage({
Recent Assistants
</div>
<AssistantBanner
mobile={settings?.isMobile}
recentAssistants={recentAssistants}
liveAssistant={liveAssistant}
allAssistants={allAssistants}
@@ -2186,6 +2424,14 @@ export function ChatPage({
}
>
<AIMessage
index={i}
selectedMessageForDocDisplay={
selectedMessageForDocDisplay
}
documentSelectionToggled={
documentSidebarToggled &&
!filtersToggled
}
continueGenerating={
i == messageHistory.length - 1 &&
currentCanContinue()
@@ -2222,9 +2468,19 @@ export function ChatPage({
}}
isActive={messageHistory.length - 1 == i}
selectedDocuments={selectedDocuments}
toggleDocumentSelection={
toggleDocumentSelectionAspects
}
toggleDocumentSelection={() => {
if (
!documentSidebarToggled ||
(documentSidebarToggled &&
selectedMessageForDocDisplay ===
message.messageId)
) {
toggleDocumentSidebar();
}
setSelectedMessageForDocDisplay(
message.messageId
);
}}
docs={message.documents}
currentPersona={liveAssistant}
alternativeAssistant={
@@ -2232,7 +2488,6 @@ export function ChatPage({
}
messageId={message.messageId}
content={message.message}
// content={message.message}
files={message.files}
query={
messageHistory[i]?.query || undefined
@@ -2418,6 +2673,15 @@ export function ChatPage({
/>
</div>
)}
{messageHistory.length > 0 && (
<div
style={{
height: !autoScrollEnabled
? getContainerHeight()
: undefined,
}}
/>
)}
{/* Some padding at the bottom so the search bar has space at the bottom to not cover the last message*/}
<div ref={endPaddingRef} className="h-[95px]" />
@@ -2441,6 +2705,15 @@ export function ChatPage({
</div>
)}
<ChatInputBar
removeDocs={() => {
clearSelectedDocuments();
}}
removeFilters={() => {
filterManager.setSelectedSources([]);
filterManager.setSelectedTags([]);
filterManager.setSelectedDocumentSets([]);
setDocumentSidebarToggled(false);
}}
showConfigureAPIKey={() =>
setShowApiKeyModal(true)
}
@@ -2463,6 +2736,9 @@ export function ChatPage({
llmOverrideManager={llmOverrideManager}
files={currentMessageFiles}
setFiles={setCurrentMessageFiles}
toggleFilters={
retrievalEnabled ? toggleFilters : undefined
}
handleFileUpload={handleImageUpload}
textAreaRef={textAreaRef}
chatSessionId={chatSessionIdRef.current!}
@@ -2493,6 +2769,23 @@ export function ChatPage({
</div>
</div>
</div>
{!settings?.isMobile && (
<div
style={{ transition: "width 0.30s ease-out" }}
className={`
flex-none
overflow-y-hidden
transition-all
duration-300
ease-in-out
${
documentSidebarToggled && retrievalEnabled
? "w-[400px]"
: "w-[0px]"
}
`}
></div>
)}
</div>
)}
</Dropzone>
@@ -2501,7 +2794,11 @@ export function ChatPage({
<div
style={{ transition: "width 0.30s ease-out" }}
className={`flex-none bg-transparent transition-all bg-opacity-80 duration-300 epase-in-out h-full
${toggledSidebar ? "w-[250px] " : "w-[0px]"}`}
${
toggledSidebar && !settings?.isMobile
? "w-[250px] "
: "w-[0px]"
}`}
/>
<div className="my-auto">
<DanswerInitializingLoader />
@@ -2512,20 +2809,8 @@ export function ChatPage({
</div>
<FixedLogo backgroundToggled={toggledSidebar || showDocSidebar} />
</div>
{/* Right Sidebar - DocumentSidebar */}
</div>
<DocumentSidebar
initialWidth={350}
ref={innerSidebarElementRef}
closeSidebar={() => setDocumentSelection(false)}
selectedMessage={aiMessage}
selectedDocuments={selectedDocuments}
toggleDocumentSelection={toggleDocumentSelection}
clearSelectedDocuments={clearSelectedDocuments}
selectedDocumentTokens={selectedDocumentTokens}
maxTokens={maxTokens}
isLoading={isFetchingChatMessages}
isOpen={documentSelection}
/>
</>
);
}

View File

@@ -1,133 +1,117 @@
import { HoverPopup } from "@/components/HoverPopup";
import { SourceIcon } from "@/components/SourceIcon";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { DanswerDocument } from "@/lib/search/interfaces";
import { FiInfo, FiRadio } from "react-icons/fi";
import { FiTag } from "react-icons/fi";
import { DocumentSelector } from "./DocumentSelector";
import {
DocumentMetadataBlock,
buildDocumentSummaryDisplay,
} from "@/components/search/DocumentDisplay";
import { InternetSearchIcon } from "@/components/InternetSearchIcon";
import { buildDocumentSummaryDisplay } from "@/components/search/DocumentDisplay";
import { DocumentUpdatedAtBadge } from "@/components/search/DocumentUpdatedAtBadge";
import { MetadataBadge } from "@/components/MetadataBadge";
import { WebResultIcon } from "@/components/WebResultIcon";
interface DocumentDisplayProps {
document: DanswerDocument;
queryEventId: number | null;
isAIPick: boolean;
modal?: boolean;
isSelected: boolean;
handleSelect: (documentId: string) => void;
setPopup: (popupSpec: PopupSpec | null) => void;
tokenLimitReached: boolean;
}
export function DocumentMetadataBlock({
modal,
document,
}: {
modal?: boolean;
document: DanswerDocument;
}) {
const MAX_METADATA_ITEMS = 3;
const metadataEntries = Object.entries(document.metadata);
return (
<div className="flex items-center overflow-hidden">
{document.updated_at && (
<DocumentUpdatedAtBadge updatedAt={document.updated_at} modal={modal} />
)}
{metadataEntries.length > 0 && (
<>
<div className="mx-1 h-4 border-l border-border" />
<div className="flex items-center overflow-hidden">
{metadataEntries
.slice(0, MAX_METADATA_ITEMS)
.map(([key, value], index) => (
<MetadataBadge
key={index}
icon={FiTag}
value={`${key}=${value}`}
/>
))}
{metadataEntries.length > MAX_METADATA_ITEMS && (
<span className="ml-1 text-xs text-gray-500">...</span>
)}
</div>
</>
)}
</div>
);
}
export function ChatDocumentDisplay({
document,
queryEventId,
isAIPick,
modal,
isSelected,
handleSelect,
setPopup,
tokenLimitReached,
}: DocumentDisplayProps) {
const isInternet = document.is_internet;
// Consider reintroducing null scored docs in the future
if (document.score === null) {
return null;
}
return (
<div
key={document.semantic_identifier}
className={`p-2 w-[325px] justify-start rounded-md ${
isSelected ? "bg-background-200" : "bg-background-125"
} text-sm mx-3`}
>
<div className="flex relative justify-start overflow-y-visible">
<div className={`opacity-100 ${modal ? "w-[90vw]" : "w-full"}`}>
<div
className={`flex relative flex-col gap-0.5 rounded-xl mx-2 my-1 ${
isSelected ? "bg-gray-200" : "hover:bg-background-125"
}`}
>
<a
href={document.link}
target="_blank"
className={
"rounded-lg flex font-bold flex-shrink truncate" +
(document.link ? "" : "pointer-events-none")
}
rel="noreferrer"
rel="noopener noreferrer"
className="cursor-pointer flex flex-col px-2 py-1.5"
>
{isInternet ? (
<InternetSearchIcon url={document.link} />
) : (
<SourceIcon sourceType={document.source_type} iconSize={18} />
)}
<p className="overflow-hidden text-left text-ellipsis mx-2 my-auto text-sm">
{document.semantic_identifier || document.document_id}
</p>
</a>
{document.score !== null && (
<div className="my-auto">
{isAIPick && (
<div className="w-4 h-4 my-auto mr-1 flex flex-col">
<HoverPopup
mainContent={<FiRadio className="text-gray-500 my-auto" />}
popupContent={
<div className="text-xs text-gray-300 w-36 flex">
<div className="flex mx-auto">
<div className="w-3 h-3 flex flex-col my-auto mr-1">
<FiInfo className="my-auto" />
</div>
<div className="my-auto">The AI liked this doc!</div>
</div>
</div>
}
direction="bottom"
style="dark"
/>
</div>
<div className="line-clamp-1 mb-1 flex h-6 items-center gap-2 text-xs">
{document.is_internet || document.source_type === "web" ? (
<WebResultIcon url={document.link} />
) : (
<SourceIcon sourceType={document.source_type} iconSize={18} />
)}
<div
className={`
text-xs
text-emphasis
bg-hover
rounded
p-0.5
w-fit
my-auto
select-none
my-auto
mr-2`}
>
{Math.abs(document.score).toFixed(2)}
<div className="line-clamp-1 text-text-900 text-sm font-semibold">
{(document.semantic_identifier || document.document_id).length >
(modal ? 30 : 40)
? `${(document.semantic_identifier || document.document_id)
.slice(0, modal ? 30 : 40)
.trim()}...`
: document.semantic_identifier || document.document_id}
</div>
</div>
)}
{!isInternet && (
<DocumentSelector
isSelected={isSelected}
handleSelect={() => handleSelect(document.document_id)}
isDisabled={tokenLimitReached && !isSelected}
/>
)}
</div>
<div>
<div className="mt-1">
<DocumentMetadataBlock document={document} />
</div>
</div>
<p className="line-clamp-3 pl-1 pt-2 mb-1 text-start break-words">
{buildDocumentSummaryDisplay(document.match_highlights, document.blurb)}
test
</p>
<div className="mb-2">
{/*
// TODO: find a way to include this
{queryEventId && (
<DocumentFeedbackBlock
documentId={document.document_id}
queryId={queryEventId}
setPopup={setPopup}
/>
)} */}
<DocumentMetadataBlock modal={modal} document={document} />
<div className="line-clamp-3 pt-2 text-sm font-normal leading-snug text-gray-600">
{buildDocumentSummaryDisplay(
document.match_highlights,
document.blurb
)}
</div>
<div className="absolute top-2 right-2">
{!isInternet && (
<DocumentSelector
isSelected={isSelected}
handleSelect={() => handleSelect(document.document_id)}
isDisabled={tokenLimitReached && !isSelected}
/>
)}
</div>
</a>
</div>
</div>
);

View File

@@ -0,0 +1,186 @@
import { DanswerDocument } from "@/lib/search/interfaces";
import { ChatDocumentDisplay } from "./ChatDocumentDisplay";
import { usePopup } from "@/components/admin/connectors/Popup";
import { removeDuplicateDocs } from "@/lib/documentUtils";
import { Message } from "../interfaces";
import { ForwardedRef, forwardRef, useEffect, useState } from "react";
import { FilterManager } from "@/lib/hooks";
import { CCPairBasicInfo, DocumentSet, Tag } from "@/lib/types";
import { SourceSelector } from "../shared_chat_search/SearchFilters";
import { XIcon } from "@/components/icons/icons";
interface ChatFiltersProps {
filterManager: FilterManager;
closeSidebar: () => void;
selectedMessage: Message | null;
selectedDocuments: DanswerDocument[] | null;
toggleDocumentSelection: (document: DanswerDocument) => void;
clearSelectedDocuments: () => void;
selectedDocumentTokens: number;
maxTokens: number;
initialWidth: number;
isOpen: boolean;
modal: boolean;
ccPairs: CCPairBasicInfo[];
tags: Tag[];
documentSets: DocumentSet[];
showFilters: boolean;
}
export const ChatFilters = forwardRef<HTMLDivElement, ChatFiltersProps>(
(
{
closeSidebar,
modal,
selectedMessage,
selectedDocuments,
filterManager,
toggleDocumentSelection,
clearSelectedDocuments,
selectedDocumentTokens,
maxTokens,
initialWidth,
isOpen,
ccPairs,
tags,
documentSets,
showFilters,
},
ref: ForwardedRef<HTMLDivElement>
) => {
const { popup, setPopup } = usePopup();
const [delayedSelectedDocumentCount, setDelayedSelectedDocumentCount] =
useState(0);
useEffect(() => {
const timer = setTimeout(
() => {
setDelayedSelectedDocumentCount(selectedDocuments?.length || 0);
},
selectedDocuments?.length == 0 ? 1000 : 0
);
return () => clearTimeout(timer);
}, [selectedDocuments]);
const selectedDocumentIds =
selectedDocuments?.map((document) => document.document_id) || [];
const currentDocuments = selectedMessage?.documents || null;
const dedupedDocuments = removeDuplicateDocs(currentDocuments || []);
const tokenLimitReached = selectedDocumentTokens > maxTokens - 75;
const hasSelectedDocuments = selectedDocumentIds.length > 0;
return (
<div
id="danswer-chat-sidebar"
className={`relative py-2 max-w-full ${
!modal ? "border-l h-full border-sidebar-border" : ""
}`}
onClick={(e) => {
if (e.target === e.currentTarget) {
closeSidebar();
}
}}
>
<div
className={`ml-auto h-full relative sidebar transition-all duration-300
${
isOpen
? "opacity-100 translate-x-0"
: "opacity-0 translate-x-[10%]"
}`}
style={{
width: modal ? undefined : initialWidth,
}}
>
<div className="flex flex-col h-full">
{popup}
<div className="p-4 flex justify-between items-center">
<h2 className="text-xl font-bold text-text-900">
{showFilters ? "Filters" : "Sources"}
</h2>
<button
onClick={closeSidebar}
className="text-sm text-primary-600 mr-2 hover:text-primary-800 transition-colors duration-200 ease-in-out"
>
<XIcon className="w-4 h-4" />
</button>
</div>
<div className="border-b border-divider-history-sidebar-bar mx-3" />
<div className="overflow-y-auto -mx-1 sm:mx-0 flex-grow gap-y-0 default-scrollbar dark-scrollbar flex flex-col">
{showFilters ? (
<SourceSelector
modal={modal}
tagsOnLeft={true}
filtersUntoggled={false}
{...filterManager}
availableDocumentSets={documentSets}
existingSources={ccPairs.map((ccPair) => ccPair.source)}
availableTags={tags}
/>
) : (
<>
{dedupedDocuments.length > 0 ? (
dedupedDocuments.map((document, ind) => (
<div
key={document.document_id}
className={`${
ind === dedupedDocuments.length - 1
? ""
: "border-b border-border-light w-full"
}`}
>
<ChatDocumentDisplay
modal={modal}
document={document}
isSelected={selectedDocumentIds.includes(
document.document_id
)}
handleSelect={(documentId) => {
toggleDocumentSelection(
dedupedDocuments.find(
(doc) => doc.document_id === documentId
)!
);
}}
tokenLimitReached={tokenLimitReached}
/>
</div>
))
) : (
<div className="mx-3" />
)}
</>
)}
</div>
</div>
{!showFilters && (
<div
className={`sticky bottom-4 w-full left-0 flex justify-center transition-opacity duration-300 ${
hasSelectedDocuments
? "opacity-100"
: "opacity-0 pointer-events-none"
}`}
>
<button
className="text-sm font-medium py-2 px-4 rounded-full transition-colors bg-gray-900 text-white"
onClick={clearSelectedDocuments}
>
{`Remove ${
delayedSelectedDocumentCount > 0
? delayedSelectedDocumentCount
: ""
} Source${delayedSelectedDocumentCount > 1 ? "s" : ""}`}
</button>
</div>
)}
</div>
</div>
);
}
);
ChatFilters.displayName = "ChatFilters";

View File

@@ -1,168 +0,0 @@
import { DanswerDocument } from "@/lib/search/interfaces";
import Text from "@/components/ui/text";
import { ChatDocumentDisplay } from "./ChatDocumentDisplay";
import { usePopup } from "@/components/admin/connectors/Popup";
import { removeDuplicateDocs } from "@/lib/documentUtils";
import { Message } from "../interfaces";
import { ForwardedRef, forwardRef } from "react";
import { Separator } from "@/components/ui/separator";
interface DocumentSidebarProps {
closeSidebar: () => void;
selectedMessage: Message | null;
selectedDocuments: DanswerDocument[] | null;
toggleDocumentSelection: (document: DanswerDocument) => void;
clearSelectedDocuments: () => void;
selectedDocumentTokens: number;
maxTokens: number;
isLoading: boolean;
initialWidth: number;
isOpen: boolean;
}
export const DocumentSidebar = forwardRef<HTMLDivElement, DocumentSidebarProps>(
(
{
closeSidebar,
selectedMessage,
selectedDocuments,
toggleDocumentSelection,
clearSelectedDocuments,
selectedDocumentTokens,
maxTokens,
isLoading,
initialWidth,
isOpen,
},
ref: ForwardedRef<HTMLDivElement>
) => {
const { popup, setPopup } = usePopup();
const selectedDocumentIds =
selectedDocuments?.map((document) => document.document_id) || [];
const currentDocuments = selectedMessage?.documents || null;
const dedupedDocuments = removeDuplicateDocs(currentDocuments || []);
// NOTE: do not allow selection if less than 75 tokens are left
// this is to prevent the case where they are able to select the doc
// but it basically is unused since it's truncated right at the very
// start of the document (since title + metadata + misc overhead) takes up
// space
const tokenLimitReached = selectedDocumentTokens > maxTokens - 75;
return (
<div
id="danswer-chat-sidebar"
className={`fixed inset-0 transition-opacity duration-300 z-50 bg-black/80 ${
isOpen ? "opacity-100" : "opacity-0 pointer-events-none"
}`}
onClick={(e) => {
if (e.target === e.currentTarget) {
closeSidebar();
}
}}
>
<div
className={`ml-auto rounded-l-lg relative border-l bg-text-100 sidebar z-50 absolute right-0 h-screen transition-all duration-300 ${
isOpen ? "opacity-100 translate-x-0" : "opacity-0 translate-x-[10%]"
}`}
ref={ref}
style={{
width: initialWidth,
}}
>
<div className="pb-6 flex-initial overflow-y-hidden flex flex-col h-screen">
{popup}
<div className="pl-3 mx-2 pr-6 mt-3 flex text-text-800 flex-col text-2xl text-emphasis flex font-semibold">
{dedupedDocuments.length} Document
{dedupedDocuments.length > 1 ? "s" : ""}
<p className="text-sm font-semibold flex flex-wrap gap-x-2 text-text-600 mt-1">
Select to add to continuous context
<a
href="https://docs.danswer.dev/introduction"
className="underline cursor-pointer hover:text-strong"
>
Learn more
</a>
</p>
</div>
<Separator className="mb-0 mt-4 pb-2" />
{currentDocuments ? (
<div className="overflow-y-auto flex-grow dark-scrollbar flex relative flex-col">
{dedupedDocuments.length > 0 ? (
dedupedDocuments.map((document, ind) => (
<div
key={document.document_id}
className={`${
ind === dedupedDocuments.length - 1
? "mb-5"
: "border-b border-border-light mb-3"
}`}
>
<ChatDocumentDisplay
document={document}
setPopup={setPopup}
queryEventId={null}
isAIPick={false}
isSelected={selectedDocumentIds.includes(
document.document_id
)}
handleSelect={(documentId) => {
toggleDocumentSelection(
dedupedDocuments.find(
(document) => document.document_id === documentId
)!
);
}}
tokenLimitReached={tokenLimitReached}
/>
</div>
))
) : (
<div className="mx-3">
<Text>No documents found for the query.</Text>
</div>
)}
</div>
) : (
!isLoading && (
<div className="ml-4 mr-3">
<Text>
When you run ask a question, the retrieved documents will
show up here!
</Text>
</div>
)
)}
</div>
<div className="absolute left-0 bottom-0 w-full bg-gradient-to-b from-neutral-100/0 via-neutral-100/40 backdrop-blur-xs to-neutral-100 h-[100px]" />
<div className="sticky bottom-4 w-full left-0 justify-center flex gap-x-4">
<button
className="bg-[#84e49e] text-xs p-2 rounded text-text-800"
onClick={() => closeSidebar()}
>
Save Changes
</button>
<button
className="bg-error text-xs p-2 rounded text-text-200"
onClick={() => {
clearSelectedDocuments();
closeSidebar();
}}
>
Delete Context
</button>
</div>
</div>
</div>
);
}
);
DocumentSidebar.displayName = "DocumentSidebar";

View File

@@ -1,13 +1,9 @@
import React, { useContext, useEffect, useRef, useState } from "react";
import { FiPlusCircle, FiPlus, FiInfo, FiX } from "react-icons/fi";
import { FiPlusCircle, FiPlus, FiInfo, FiX, FiSearch } from "react-icons/fi";
import { ChatInputOption } from "./ChatInputOption";
import { Persona } from "@/app/admin/assistants/interfaces";
import { InputPrompt } from "@/app/admin/prompt-library/interfaces";
import {
FilterManager,
getDisplayNameForModel,
LlmOverrideManager,
} from "@/lib/hooks";
import { FilterManager, LlmOverrideManager } from "@/lib/hooks";
import { SelectedFilterDisplay } from "./SelectedFilterDisplay";
import { useChatContext } from "@/components/context/ChatContext";
import { getFinalLLM } from "@/lib/llm/utils";
@@ -18,15 +14,10 @@ import {
} from "../files/InputBarPreview";
import {
AssistantsIconSkeleton,
CpuIconSkeleton,
FileIcon,
SendIcon,
StopGeneratingIcon,
} from "@/components/icons/icons";
import { IconType } from "react-icons";
import Popup from "../../../components/popup/Popup";
import { LlmTab } from "../modal/configuration/LlmTab";
import { AssistantsTab } from "../modal/configuration/AssistantsTab";
import { DanswerDocument } from "@/lib/search/interfaces";
import { AssistantIcon } from "@/components/assistants/AssistantIcon";
import {
@@ -40,10 +31,18 @@ import { SettingsContext } from "@/components/settings/SettingsProvider";
import { ChatState } from "../types";
import UnconfiguredProviderText from "@/components/chat_search/UnconfiguredProviderText";
import { useAssistants } from "@/components/context/AssistantsContext";
import AnimatedToggle from "@/components/search/SearchBar";
import { Popup } from "@/components/admin/connectors/Popup";
import { AssistantsTab } from "../modal/configuration/AssistantsTab";
import { IconType } from "react-icons";
import { LlmTab } from "../modal/configuration/LlmTab";
import { XIcon } from "lucide-react";
const MAX_INPUT_HEIGHT = 200;
export function ChatInputBar({
removeFilters,
removeDocs,
openModelSettings,
showDocs,
showConfigureAPIKey,
@@ -68,7 +67,10 @@ export function ChatInputBar({
alternativeAssistant,
chatSessionId,
inputPrompts,
toggleFilters,
}: {
removeFilters: () => void;
removeDocs: () => void;
showConfigureAPIKey: () => void;
openModelSettings: () => void;
chatState: ChatState;
@@ -90,6 +92,7 @@ export function ChatInputBar({
handleFileUpload: (files: File[]) => void;
textAreaRef: React.RefObject<HTMLTextAreaElement>;
chatSessionId?: string;
toggleFilters?: () => void;
}) {
useEffect(() => {
const textarea = textAreaRef.current;
@@ -370,9 +373,9 @@ export function ChatInputBar({
</div>
)}
<div>
{/* <div>
<SelectedFilterDisplay filterManager={filterManager} />
</div>
</div> */}
<UnconfiguredProviderText showConfigureAPIKey={showConfigureAPIKey} />
@@ -429,16 +432,21 @@ export function ChatInputBar({
)}
{(selectedDocuments.length > 0 || files.length > 0) && (
<div className="flex gap-x-2 px-2 pt-2">
<div className="flex gap-x-1 px-2 overflow-y-auto overflow-x-scroll items-end miniscroll">
<div className="flex gap-x-1 px-2 overflow-visible overflow-x-scroll items-end miniscroll">
{selectedDocuments.length > 0 && (
<button
onClick={showDocs}
className="flex-none flex cursor-pointer hover:bg-background-200 transition-colors duration-300 h-10 p-1 items-center gap-x-1 rounded-lg bg-background-150 max-w-[100px]"
className="flex-none relative overflow-visible flex items-center gap-x-2 h-10 px-3 rounded-lg bg-background-150 hover:bg-background-200 transition-colors duration-300 cursor-pointer max-w-[150px]"
>
<FileIcon size={24} />
<p className="text-xs">
<FileIcon size={20} />
<span className="text-sm whitespace-nowrap overflow-hidden text-ellipsis">
{selectedDocuments.length} selected
</p>
</span>
<XIcon
onClick={removeDocs}
size={16}
className="text-text-400 hover:text-text-600 ml-auto"
/>
</button>
)}
{files.map((file) => (
@@ -529,72 +537,6 @@ export function ChatInputBar({
suppressContentEditableWarning={true}
/>
<div className="flex items-center space-x-3 mr-12 px-4 pb-2">
<Popup
removePadding
content={(close) => (
<AssistantsTab
llmProviders={llmProviders}
selectedAssistant={selectedAssistant}
onSelect={(assistant) => {
setSelectedAssistant(assistant);
close();
}}
/>
)}
flexPriority="shrink"
position="top"
mobilePosition="top-right"
>
<ChatInputOption
toggle
flexPriority="shrink"
name={
selectedAssistant ? selectedAssistant.name : "Assistants"
}
Icon={AssistantsIconSkeleton as IconType}
/>
</Popup>
<Popup
tab
content={(close, ref) => (
<LlmTab
currentAssistant={alternativeAssistant || selectedAssistant}
openModelSettings={openModelSettings}
currentLlm={
llmOverrideManager.llmOverride.modelName ||
(selectedAssistant
? selectedAssistant.llm_model_version_override ||
llmOverrideManager.globalDefault.modelName ||
llmName
: llmName)
}
close={close}
ref={ref}
llmOverrideManager={llmOverrideManager}
chatSessionId={chatSessionId}
/>
)}
position="top"
>
<ChatInputOption
flexPriority="second"
toggle
name={
settings?.isMobile
? undefined
: getDisplayNameForModel(
llmOverrideManager.llmOverride.modelName ||
(selectedAssistant
? selectedAssistant.llm_model_version_override ||
llmOverrideManager.globalDefault.modelName ||
llmName
: llmName)
)
}
Icon={CpuIconSkeleton}
/>
</Popup>
<ChatInputOption
flexPriority="stiff"
name="File"
@@ -614,6 +556,14 @@ export function ChatInputBar({
input.click();
}}
/>
{toggleFilters && (
<ChatInputOption
flexPriority="stiff"
name="Filters"
Icon={FiSearch}
onClick={toggleFilters}
/>
)}
</div>
<div className="absolute bottom-2.5 mobile:right-4 desktop:right-10">

View File

@@ -2,6 +2,7 @@ import {
AnswerPiecePacket,
DanswerDocument,
Filters,
FinalContextDocs,
StreamStopInfo,
} from "@/lib/search/interfaces";
import { handleSSEStream } from "@/lib/search/streamingUtils";
@@ -102,6 +103,7 @@ export type PacketType =
| ToolCallMetadata
| BackendMessage
| AnswerPiecePacket
| FinalContextDocs
| DocumentsResponse
| FileChatDisplay
| StreamingError
@@ -147,7 +149,6 @@ export async function* sendMessage({
}): AsyncGenerator<PacketType, void, unknown> {
const documentsAreSelected =
selectedDocumentIds && selectedDocumentIds.length > 0;
const body = JSON.stringify({
alternate_assistant_id: alternateAssistantId,
chat_session_id: chatSessionId,
@@ -203,7 +204,7 @@ export async function* sendMessage({
yield* handleSSEStream<PacketType>(response);
}
export async function nameChatSession(chatSessionId: string, message: string) {
export async function nameChatSession(chatSessionId: string) {
const response = await fetch("/api/chat/rename-chat-session", {
method: "PUT",
headers: {
@@ -212,7 +213,6 @@ export async function nameChatSession(chatSessionId: string, message: string) {
body: JSON.stringify({
chat_session_id: chatSessionId,
name: null,
first_message: message,
}),
});
return response;
@@ -263,7 +263,6 @@ export async function renameChatSession(
body: JSON.stringify({
chat_session_id: chatSessionId,
name: newName,
first_message: null,
}),
});
return response;
@@ -641,6 +640,7 @@ export async function useScrollonStream({
endDivRef,
debounceNumber,
mobile,
enableAutoScroll,
}: {
chatState: ChatState;
scrollableDivRef: RefObject<HTMLDivElement>;
@@ -649,6 +649,7 @@ export async function useScrollonStream({
endDivRef: RefObject<HTMLDivElement>;
debounceNumber: number;
mobile?: boolean;
enableAutoScroll?: boolean;
}) {
const mobileDistance = 900; // distance that should "engage" the scroll
const desktopDistance = 500; // distance that should "engage" the scroll
@@ -661,6 +662,10 @@ export async function useScrollonStream({
const previousScroll = useRef<number>(0);
useEffect(() => {
if (!enableAutoScroll) {
return;
}
if (chatState != "input" && scrollableDivRef && scrollableDivRef.current) {
const newHeight: number = scrollableDivRef.current?.scrollTop!;
const heightDifference = newHeight - previousScroll.current;
@@ -718,7 +723,7 @@ export async function useScrollonStream({
// scroll on end of stream if within distance
useEffect(() => {
if (scrollableDivRef?.current && chatState == "input") {
if (scrollableDivRef?.current && chatState == "input" && enableAutoScroll) {
if (scrollDist.current < distance - 50) {
scrollableDivRef?.current?.scrollBy({
left: 0,

View File

@@ -1,8 +1,50 @@
import { Citation } from "@/components/search/results/Citation";
import { WebResultIcon } from "@/components/WebResultIcon";
import { LoadedDanswerDocument } from "@/lib/search/interfaces";
import { getSourceMetadata } from "@/lib/sources";
import { ValidSources } from "@/lib/types";
import React, { memo } from "react";
import isEqual from "lodash/isEqual";
export const MemoizedAnchor = memo(({ docs, children }: any) => {
console.log(children);
const value = children?.toString();
if (value?.startsWith("[") && value?.endsWith("]")) {
const match = value.match(/\[(\d+)\]/);
if (match) {
const index = parseInt(match[1], 10) - 1;
const associatedDoc = docs && docs[index];
const url = associatedDoc?.link
? new URL(associatedDoc.link).origin + "/favicon.ico"
: "";
const getIcon = (sourceType: ValidSources, link: string) => {
return getSourceMetadata(sourceType).icon({ size: 18 });
};
const icon =
associatedDoc?.source_type === "web" ? (
<WebResultIcon url={associatedDoc.link} />
) : (
getIcon(
associatedDoc?.source_type || "web",
associatedDoc?.link || ""
)
);
return (
<MemoizedLink document={{ ...associatedDoc, icon, url }}>
{children}
</MemoizedLink>
);
}
}
return <MemoizedLink>{children}</MemoizedLink>;
});
export const MemoizedLink = memo((props: any) => {
const { node, ...rest } = props;
const { node, document, ...rest } = props;
const value = rest.children;
if (value?.toString().startsWith("*")) {
@@ -10,7 +52,16 @@ export const MemoizedLink = memo((props: any) => {
<div className="flex-none bg-background-800 inline-block rounded-full h-3 w-3 ml-2" />
);
} else if (value?.toString().startsWith("[")) {
return <Citation link={rest?.href}>{rest.children}</Citation>;
return (
<Citation
url={document?.url}
icon={document?.icon as React.ReactNode}
link={rest?.href}
document={document as LoadedDanswerDocument}
>
{rest.children}
</Citation>
);
} else {
return (
<a
@@ -25,9 +76,16 @@ export const MemoizedLink = memo((props: any) => {
}
});
export const MemoizedParagraph = memo(({ ...props }: any) => {
return <p {...props} className="text-default" />;
});
export const MemoizedParagraph = memo(
function MemoizedParagraph({ children }: any) {
return <p className="text-default">{children}</p>;
},
(prevProps, nextProps) => {
const areEqual = isEqual(prevProps.children, nextProps.children);
return areEqual;
}
);
MemoizedAnchor.displayName = "MemoizedAnchor";
MemoizedLink.displayName = "MemoizedLink";
MemoizedParagraph.displayName = "MemoizedParagraph";

View File

@@ -8,14 +8,22 @@ import {
FiGlobe,
} from "react-icons/fi";
import { FeedbackType } from "../types";
import React, { useContext, useEffect, useMemo, useRef, useState } from "react";
import React, {
memo,
useCallback,
useContext,
useEffect,
useMemo,
useRef,
useState,
} from "react";
import ReactMarkdown from "react-markdown";
import {
DanswerDocument,
FilteredDanswerDocument,
} from "@/lib/search/interfaces";
import { SearchSummary } from "./SearchSummary";
import { SourceIcon } from "@/components/SourceIcon";
import { SkippedSearch } from "./SkippedSearch";
import remarkGfm from "remark-gfm";
import { CopyButton } from "@/components/CopyButton";
@@ -36,8 +44,6 @@ import "prismjs/themes/prism-tomorrow.css";
import "./custom-code-styles.css";
import { Persona } from "@/app/admin/assistants/interfaces";
import { AssistantIcon } from "@/components/assistants/AssistantIcon";
import { Citation } from "@/components/search/results/Citation";
import { DocumentMetadataBlock } from "@/components/search/DocumentDisplay";
import { LikeFeedback, DislikeFeedback } from "@/components/icons/icons";
import {
@@ -52,16 +58,18 @@ import {
TooltipTrigger,
} from "@/components/ui/tooltip";
import { useMouseTracking } from "./hooks";
import { InternetSearchIcon } from "@/components/InternetSearchIcon";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import GeneratingImageDisplay from "../tools/GeneratingImageDisplay";
import RegenerateOption from "../RegenerateOption";
import { LlmOverride } from "@/lib/hooks";
import { ContinueGenerating } from "./ContinueMessage";
import { MemoizedLink, MemoizedParagraph } from "./MemoizedTextComponents";
import { MemoizedAnchor, MemoizedParagraph } from "./MemoizedTextComponents";
import { extractCodeText } from "./codeUtils";
import ToolResult from "../../../components/tools/ToolResult";
import CsvContent from "../../../components/tools/CSVContent";
import SourceCard, {
SeeMoreBlock,
} from "@/components/chat_search/sources/SourceCard";
const TOOLS_WITH_CUSTOM_HANDLING = [
SEARCH_TOOL_NAME,
@@ -155,6 +163,7 @@ function FileDisplay({
export const AIMessage = ({
regenerate,
overriddenModel,
selectedMessageForDocDisplay,
continueGenerating,
shared,
isActive,
@@ -162,6 +171,7 @@ export const AIMessage = ({
alternativeAssistant,
docs,
messageId,
documentSelectionToggled,
content,
files,
selectedDocuments,
@@ -178,7 +188,10 @@ export const AIMessage = ({
currentPersona,
otherMessagesCanSwitchTo,
onMessageSelection,
index,
}: {
index?: number;
selectedMessageForDocDisplay?: number | null;
shared?: boolean;
isActive?: boolean;
continueGenerating?: () => void;
@@ -191,6 +204,7 @@ export const AIMessage = ({
currentPersona: Persona;
messageId: number | null;
content: string | JSX.Element;
documentSelectionToggled?: boolean;
files?: FileDescriptor[];
query?: string;
citedDocuments?: [string, DanswerDocument][] | null;
@@ -287,18 +301,31 @@ export const AIMessage = ({
});
}
const paragraphCallback = useCallback(
(props: any) => <MemoizedParagraph>{props.children}</MemoizedParagraph>,
[]
);
const anchorCallback = useCallback(
(props: any) => (
<MemoizedAnchor docs={docs}>{props.children}</MemoizedAnchor>
),
[docs]
);
const currentMessageInd = messageId
? otherMessagesCanSwitchTo?.indexOf(messageId)
: undefined;
const uniqueSources: ValidSources[] = Array.from(
new Set((docs || []).map((doc) => doc.source_type))
).slice(0, 3);
const markdownComponents = useMemo(
() => ({
a: MemoizedLink,
p: MemoizedParagraph,
code: ({ node, className, children, ...props }: any) => {
a: anchorCallback,
p: paragraphCallback,
code: ({ node, className, children }: any) => {
const codeText = extractCodeText(
node,
finalContent as string,
@@ -312,7 +339,7 @@ export const AIMessage = ({
);
},
}),
[finalContent]
[anchorCallback, paragraphCallback, finalContent]
);
const renderedMarkdown = useMemo(() => {
@@ -333,12 +360,11 @@ export const AIMessage = ({
onMessageSelection &&
otherMessagesCanSwitchTo &&
otherMessagesCanSwitchTo.length > 1;
return (
<div
id="danswer-ai-message"
ref={trackedElementRef}
className={"py-5 ml-4 px-5 relative flex "}
className={`py-5 ml-4 px-5 relative flex `}
>
<div
className={`mx-auto ${
@@ -363,6 +389,7 @@ export const AIMessage = ({
!retrievalDisabled && (
<div className="mb-1">
<SearchSummary
index={index || 0}
query={query}
finished={toolCall?.tool_result != undefined}
hasDocs={hasDocs || false}
@@ -423,6 +450,31 @@ export const AIMessage = ({
/>
)}
{docs && docs.length > 0 && (
<div className="mt-2 -mx-8 w-full mb-4 flex relative">
<div className="w-full">
<div className="px-8 flex gap-x-2">
{!settings?.isMobile &&
docs.length > 0 &&
docs
.slice(0, 2)
.map((doc, ind) => (
<SourceCard doc={doc} key={ind} />
))}
<SeeMoreBlock
documentSelectionToggled={
(documentSelectionToggled &&
selectedMessageForDocDisplay === messageId) ||
false
}
toggleDocumentSelection={toggleDocumentSelection}
uniqueSources={uniqueSources}
/>
</div>
</div>
</div>
)}
{content || files ? (
<>
<FileDisplay files={files || []} />
@@ -438,81 +490,6 @@ export const AIMessage = ({
) : isComplete ? null : (
<></>
)}
{isComplete && docs && docs.length > 0 && (
<div className="mt-2 -mx-8 w-full mb-4 flex relative">
<div className="w-full">
<div className="px-8 flex gap-x-2">
{!settings?.isMobile &&
filteredDocs.length > 0 &&
filteredDocs.slice(0, 2).map((doc, ind) => (
<div
key={doc.document_id}
className={`w-[200px] rounded-lg flex-none transition-all duration-500 hover:bg-background-125 bg-text-100 px-4 pb-2 pt-1 border-b
`}
>
<a
href={doc.link || undefined}
target="_blank"
className="text-sm flex w-full pt-1 gap-x-1.5 overflow-hidden justify-between font-semibold text-text-700"
rel="noreferrer"
>
<Citation link={doc.link} index={ind + 1} />
<p className="shrink truncate ellipsis break-all">
{doc.semantic_identifier ||
doc.document_id}
</p>
<div className="ml-auto flex-none">
{doc.is_internet ? (
<InternetSearchIcon url={doc.link} />
) : (
<SourceIcon
sourceType={doc.source_type}
iconSize={18}
/>
)}
</div>
</a>
<div className="flex overscroll-x-scroll mt-.5">
<DocumentMetadataBlock document={doc} />
</div>
<div className="line-clamp-3 text-xs break-words pt-1">
{doc.blurb}
</div>
</div>
))}
<div
onClick={() => {
if (messageId) {
onMessageSelection?.(messageId);
}
toggleDocumentSelection?.();
}}
key={-1}
className="cursor-pointer w-[200px] rounded-lg flex-none transition-all duration-500 hover:bg-background-125 bg-text-100 px-4 py-2 border-b"
>
<div className="text-sm flex justify-between font-semibold text-text-700">
<p className="line-clamp-1">See context</p>
<div className="flex gap-x-1">
{uniqueSources.map((sourceType, ind) => {
return (
<div key={ind} className="flex-none">
<SourceIcon
sourceType={sourceType}
iconSize={18}
/>
</div>
);
})}
</div>
</div>
<div className="line-clamp-3 text-xs break-words pt-1">
See more
</div>
</div>
</div>
</div>
</div>
)}
</div>
{handleFeedback &&

View File

@@ -41,6 +41,7 @@ export function ShowHideDocsButton({
}
export function SearchSummary({
index,
query,
hasDocs,
finished,
@@ -48,6 +49,7 @@ export function SearchSummary({
handleShowRetrieved,
handleSearchQueryEdit,
}: {
index: number;
finished: boolean;
query: string;
hasDocs: boolean;
@@ -98,7 +100,14 @@ export function SearchSummary({
!text-sm !line-clamp-1 !break-all px-0.5`}
ref={searchingForRef}
>
{finished ? "Searched" : "Searching"} for: <i> {finalQuery}</i>
{finished ? "Searched" : "Searching"} for:{" "}
<i>
{index === 1
? finalQuery.length > 50
? `${finalQuery.slice(0, 50)}...`
: finalQuery
: finalQuery}
</i>
</div>
</div>
);

View File

@@ -5,15 +5,19 @@ import { FeedbackType } from "../types";
import { Modal } from "@/components/Modal";
import { FilledLikeIcon } from "@/components/icons/icons";
const predefinedPositiveFeedbackOptions =
process.env.NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS?.split(",") ||
[];
const predefinedNegativeFeedbackOptions =
process.env.NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS?.split(",") || [
"Retrieved documents were not relevant",
"AI misread the documents",
"Cited source had incorrect information",
];
const predefinedPositiveFeedbackOptions = process.env
.NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS
? process.env.NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS.split(",")
: [];
const predefinedNegativeFeedbackOptions = process.env
.NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS
? process.env.NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS.split(",")
: [
"Retrieved documents were not relevant",
"AI misread the documents",
"Cited source had incorrect information",
];
interface FeedbackModalProps {
feedbackType: FeedbackType;
@@ -49,7 +53,7 @@ export const FeedbackModal = ({
: predefinedNegativeFeedbackOptions;
return (
<Modal onOutsideClick={onClose} width="max-w-3xl">
<Modal onOutsideClick={onClose} width="w-full max-w-3xl">
<>
<h2 className="text-2xl text-emphasis font-bold mb-4 flex">
<div className="mr-1 my-auto">

View File

@@ -1,4 +1,4 @@
import { Dispatch, SetStateAction, useEffect, useRef } from "react";
import { Dispatch, SetStateAction, useContext, useEffect, useRef } from "react";
import { Modal } from "@/components/Modal";
import Text from "@/components/ui/text";
import { getDisplayNameForModel, LlmOverride } from "@/lib/hooks";
@@ -9,6 +9,10 @@ import { setUserDefaultModel } from "@/lib/users/UserSettings";
import { useRouter } from "next/navigation";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { useUser } from "@/components/user/UserProvider";
import { Separator } from "@/components/ui/separator";
import { Switch } from "@/components/ui/switch";
import { Label } from "@/components/admin/connectors/Field";
import { SettingsContext } from "@/components/settings/SettingsProvider";
export function SetDefaultModelModal({
setPopup,
@@ -23,7 +27,7 @@ export function SetDefaultModelModal({
onClose: () => void;
defaultModel: string | null;
}) {
const { refreshUser } = useUser();
const { refreshUser, user, updateUserAutoScroll } = useUser();
const containerRef = useRef<HTMLDivElement>(null);
const messageRef = useRef<HTMLDivElement>(null);
@@ -121,16 +125,41 @@ export function SetDefaultModelModal({
const defaultProvider = llmProviders.find(
(llmProvider) => llmProvider.is_default_provider
);
const settings = useContext(SettingsContext);
const autoScroll = settings?.enterpriseSettings?.auto_scroll;
const checked =
user?.preferences?.auto_scroll === null
? autoScroll
: user?.preferences?.auto_scroll;
return (
<Modal onOutsideClick={onClose} width="rounded-lg bg-white max-w-xl">
<>
<div className="flex mb-4">
<h2 className="text-2xl text-emphasis font-bold flex my-auto">
Set Default Model
User settings
</h2>
</div>
<div className="flex flex-col gap-y-2">
<div className="flex items-center gap-x-2">
<Switch
checked={checked}
onCheckedChange={(checked) => {
updateUserAutoScroll(checked);
}}
/>
<Label className="text-sm">Enable auto-scroll</Label>
</div>
</div>
<Separator />
<h3 className="text-lg text-emphasis font-bold">
Default model for assistants
</h3>
<Text className="mb-4">
Choose a Large Language Model (LLM) to serve as the default for
assistants that don&apos;t have a default model assigned.

View File

@@ -32,6 +32,7 @@ export default async function Page(props: {
defaultAssistantId,
shouldShowWelcomeModal,
userInputPrompts,
ccPairs,
} = data;
return (
@@ -44,6 +45,9 @@ export default async function Page(props: {
value={{
chatSessions,
availableSources,
ccPairs,
documentSets,
tags,
availableDocumentSets: documentSets,
availableTags: tags,
llmProviders,

Some files were not shown because too many files have changed in this diff Show More