Compare commits

..

5 Commits

Author SHA1 Message Date
hagen-danswer
e2aaa60e77 done 2024-11-21 17:26:18 -08:00
hagen-danswer
31f4a68bee less passing around is_cloud 2024-11-21 16:58:31 -08:00
hagen-danswer
4fc196fc39 properly escaped the user query 2024-11-21 16:37:48 -08:00
hagen-danswer
95ab63b6bc reworked it 2024-11-21 16:35:08 -08:00
hagen-danswer
6d26d0b929 replace deprecated confluence group api endpoint 2024-11-21 12:37:16 -08:00
38 changed files with 314 additions and 398 deletions

View File

View File

@@ -1,59 +0,0 @@
"""display custom llm models
Revision ID: 177de57c21c9
Revises: 4ee1287bd26a
Create Date: 2024-11-21 11:49:04.488677
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy import and_
revision = "177de57c21c9"
down_revision = "4ee1287bd26a"
branch_labels = None
depends_on = None
depends_on = None
def upgrade() -> None:
conn = op.get_bind()
llm_provider = sa.table(
"llm_provider",
sa.column("id", sa.Integer),
sa.column("provider", sa.String),
sa.column("model_names", postgresql.ARRAY(sa.String)),
sa.column("display_model_names", postgresql.ARRAY(sa.String)),
)
excluded_providers = ["openai", "bedrock", "anthropic", "azure"]
providers_to_update = sa.select(
llm_provider.c.id,
llm_provider.c.model_names,
llm_provider.c.display_model_names,
).where(
and_(
~llm_provider.c.provider.in_(excluded_providers),
llm_provider.c.model_names.isnot(None),
)
)
results = conn.execute(providers_to_update).fetchall()
for provider_id, model_names, display_model_names in results:
if display_model_names is None:
display_model_names = []
combined_model_names = list(set(display_model_names + model_names))
update_stmt = (
llm_provider.update()
.where(llm_provider.c.id == provider_id)
.values(display_model_names=combined_model_names)
)
conn.execute(update_stmt)
def downgrade() -> None:
pass

View File

@@ -1,6 +1,5 @@
import multiprocessing
from typing import Any
from typing import cast
from celery import bootsteps # type: ignore
from celery import Celery
@@ -15,9 +14,7 @@ from celery.signals import worker_shutdown
import danswer.background.celery.apps.app_base as app_base
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.background.celery.tasks.indexing.tasks import (
get_unfenced_index_attempt_ids,
)
from danswer.background.celery.tasks.vespa.tasks import get_unfenced_index_attempt_ids
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
@@ -98,15 +95,6 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
# by the primary worker. This is unnecessary in the multi tenant scenario
r = get_redis_client(tenant_id=None)
# Log the role and slave count - being connected to a slave or slave count > 0 could be problematic
info: dict[str, Any] = cast(dict, r.info("replication"))
role: str = cast(str, info.get("role"))
connected_slaves: int = info.get("connected_slaves", 0)
logger.info(
f"Redis INFO REPLICATION: role={role} connected_slaves={connected_slaves}"
)
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it

View File

@@ -4,6 +4,7 @@ from typing import Any
from sqlalchemy.orm import Session
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
@@ -16,7 +17,6 @@ from danswer.connectors.models import Document
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.enums import TaskStatus
from danswer.db.models import TaskQueueState
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.redis.redis_connector import RedisConnector
from danswer.server.documents.models import DeletionAttemptSnapshot
from danswer.utils.logger import setup_logger
@@ -78,7 +78,7 @@ def document_batch_to_ids(
def extract_ids_from_runnable_connector(
runnable_connector: BaseConnector,
callback: IndexingHeartbeatInterface | None = None,
callback: RunIndexingCallbackInterface | None = None,
) -> set[str]:
"""
If the SlimConnector hasnt been implemented for the given connector, just pull
@@ -111,15 +111,10 @@ def extract_ids_from_runnable_connector(
for doc_batch in doc_batch_generator:
if callback:
if callback.should_stop():
raise RuntimeError(
"extract_ids_from_runnable_connector: Stop signal detected"
)
raise RuntimeError("Stop signal received")
callback.progress(len(doc_batch))
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
if callback:
callback.progress("extract_ids_from_runnable_connector", len(doc_batch))
return all_connector_doc_ids

View File

@@ -19,7 +19,7 @@ from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.search_settings import get_all_search_settings
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_delete import RedisConnectorDeletePayload
from danswer.redis.redis_connector_delete import RedisConnectorDeletionFenceData
from danswer.redis.redis_pool import get_redis_client
@@ -118,7 +118,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
return None
# set a basic fence to start
fence_payload = RedisConnectorDeletePayload(
fence_payload = RedisConnectorDeletionFenceData(
num_tasks=None,
submitted=datetime.now(timezone.utc),
)

View File

@@ -3,7 +3,6 @@ from datetime import timezone
from http import HTTPStatus
from time import sleep
import redis
import sentry_sdk
from celery import Celery
from celery import shared_task
@@ -17,6 +16,7 @@ from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.indexing.job_client import SimpleJobClient
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
@@ -33,8 +33,6 @@ from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.enums import IndexingStatus
from danswer.db.enums import IndexModelStatus
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import delete_index_attempt
from danswer.db.index_attempt import get_all_index_attempts_by_status
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
from danswer.db.index_attempt import mark_attempt_failed
@@ -44,11 +42,9 @@ from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_index import RedisConnectorIndexPayload
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
@@ -61,7 +57,7 @@ from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
class IndexingCallback(IndexingHeartbeatInterface):
class RunIndexingCallback(RunIndexingCallbackInterface):
def __init__(
self,
stop_key: str,
@@ -77,7 +73,6 @@ class IndexingCallback(IndexingHeartbeatInterface):
self.started: datetime = datetime.now(timezone.utc)
self.redis_lock.reacquire()
self.last_tag: str = ""
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
def should_stop(self) -> bool:
@@ -85,17 +80,15 @@ class IndexingCallback(IndexingHeartbeatInterface):
return True
return False
def progress(self, tag: str, amount: int) -> None:
def progress(self, amount: int) -> None:
try:
self.redis_lock.reacquire()
self.last_tag = tag
self.last_lock_reacquire = datetime.now(timezone.utc)
except LockError:
logger.exception(
f"IndexingCallback - lock.reacquire exceptioned. "
f"RunIndexingCallback - lock.reacquire exceptioned. "
f"lock_timeout={self.redis_lock.timeout} "
f"start={self.started} "
f"last_tag={self.last_tag} "
f"last_reacquired={self.last_lock_reacquire} "
f"now={datetime.now(timezone.utc)}"
)
@@ -104,54 +97,6 @@ class IndexingCallback(IndexingHeartbeatInterface):
self.redis_client.incrby(self.generator_progress_key, amount)
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
want to clean them up.
Unfenced = attempt not in terminal state and fence does not exist.
"""
unfenced_attempts: list[int] = []
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# inner = index_attempt in non terminal state
# outer = r.fence_key down
# check the db for index attempts in a non terminal state
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for attempt in attempts:
fence_key = RedisConnectorIndex.fence_key_with_ids(
attempt.connector_credential_pair_id, attempt.search_settings_id
)
# if the fence is down / doesn't exist, possible error but not confirmed
if r.exists(fence_key):
continue
# Between the time the attempts are first looked up and the time we see the fence down,
# the attempt may have completed and taken down the fence normally.
# We need to double check that the index attempt is still in a non terminal state
# and matches the original state, which confirms we are really in a bad state.
attempt_2 = get_index_attempt(db_session, attempt.id)
if not attempt_2:
continue
if attempt.status != attempt_2.status:
continue
unfenced_attempts.append(attempt.id)
return unfenced_attempts
@shared_task(
name="check_for_indexing",
soft_time_limit=300,
@@ -162,7 +107,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
lock_beat = r.lock(
DanswerRedisLocks.CHECK_INDEXING_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -172,7 +117,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
if not lock_beat.acquire(blocking=False):
return None
# check for search settings swap
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
old_search_settings = check_index_swap(db_session=db_session)
current_search_settings = get_current_search_settings(db_session)
@@ -191,18 +135,13 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
embedding_model=embedding_model,
)
# gather cc_pair_ids
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
lock_beat.reacquire()
cc_pairs = fetch_connector_credential_pairs(db_session)
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
# kick off index attempts
for cc_pair_id in cc_pair_ids:
lock_beat.reacquire()
redis_connector = RedisConnector(tenant_id, cc_pair_id)
with get_session_with_tenant(tenant_id) as db_session:
# Get the primary search settings
@@ -259,29 +198,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
f"search_settings={search_settings_instance.id} "
)
tasks_created += 1
# Fail any index attempts in the DB that don't have fences
# This shouldn't ever happen!
with get_session_with_tenant(tenant_id) as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
lock_beat.reacquire()
attempt = get_index_attempt(db_session, attempt_id)
if not attempt:
continue
failure_reason = (
f"Unfenced index attempt found in DB: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
task_logger.error(failure_reason)
mark_attempt_failed(
attempt.id, db_session, failure_reason=failure_reason
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -291,11 +207,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
finally:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error(
"check_for_indexing - Lock not owned on completion: "
f"tenant={tenant_id}"
)
return tasks_created
@@ -400,11 +311,10 @@ def try_creating_indexing_task(
"""
LOCK_TIMEOUT = 30
index_attempt_id: int | None = None
# we need to serialize any attempt to trigger indexing since it can be triggered
# either via celery beat or manually (API call)
lock: RedisLock = r.lock(
lock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
timeout=LOCK_TIMEOUT,
)
@@ -455,8 +365,6 @@ def try_creating_indexing_task(
custom_task_id = redis_connector_index.generate_generator_task_id()
# when the task is sent, we have yet to finish setting up the fence
# therefore, the task must contain code that blocks until the fence is ready
result = celery_app.send_task(
"connector_indexing_proxy_task",
kwargs=dict(
@@ -477,16 +385,13 @@ def try_creating_indexing_task(
payload.celery_task_id = result.id
redis_connector_index.set_fence(payload)
except Exception:
redis_connector_index.set_fence(None)
task_logger.exception(
f"try_creating_indexing_task - Unexpected exception: "
f"Unexpected exception: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings.id}"
)
if index_attempt_id is not None:
delete_index_attempt(db_session, index_attempt_id)
redis_connector_index.set_fence(None)
return None
finally:
if lock.owned():
@@ -504,7 +409,7 @@ def connector_indexing_proxy_task(
) -> None:
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
task_logger.info(
f"Indexing watchdog - starting: attempt={index_attempt_id} "
f"Indexing proxy - starting: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
@@ -512,7 +417,7 @@ def connector_indexing_proxy_task(
client = SimpleJobClient()
job = client.submit(
connector_indexing_task_wrapper,
connector_indexing_task,
index_attempt_id,
cc_pair_id,
search_settings_id,
@@ -523,7 +428,7 @@ def connector_indexing_proxy_task(
if not job:
task_logger.info(
f"Indexing watchdog - spawn failed: attempt={index_attempt_id} "
f"Indexing proxy - spawn failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
@@ -531,7 +436,7 @@ def connector_indexing_proxy_task(
return
task_logger.info(
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
f"Indexing proxy - spawn succeeded: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
@@ -555,7 +460,7 @@ def connector_indexing_proxy_task(
if job.status == "error":
task_logger.error(
f"Indexing watchdog - spawned task exceptioned: "
f"Indexing proxy - spawned task exceptioned: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
@@ -567,7 +472,7 @@ def connector_indexing_proxy_task(
break
task_logger.info(
f"Indexing watchdog - finished: attempt={index_attempt_id} "
f"Indexing proxy - finished: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
@@ -575,38 +480,6 @@ def connector_indexing_proxy_task(
return
def connector_indexing_task_wrapper(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
is_ee: bool,
) -> int | None:
"""Just wraps connector_indexing_task so we can log any exceptions before
re-raising it."""
result: int | None = None
try:
result = connector_indexing_task(
index_attempt_id,
cc_pair_id,
search_settings_id,
tenant_id,
is_ee,
)
except:
logger.exception(
f"connector_indexing_task exceptioned: "
f"tenant={tenant_id} "
f"index_attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
raise
return result
def connector_indexing_task(
index_attempt_id: int,
cc_pair_id: int,
@@ -661,7 +534,6 @@ def connector_indexing_task(
if redis_connector.delete.fenced:
raise RuntimeError(
f"Indexing will not start because connector deletion is in progress: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.delete.fence_key}"
)
@@ -669,18 +541,18 @@ def connector_indexing_task(
if redis_connector.stop.fenced:
raise RuntimeError(
f"Indexing will not start because a connector stop signal was detected: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.stop.fence_key}"
)
while True:
if not redis_connector_index.fenced: # The fence must exist
# wait for the fence to come up
if not redis_connector_index.fenced:
raise ValueError(
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
)
payload = redis_connector_index.payload # The payload must exist
payload = redis_connector_index.payload
if not payload:
raise ValueError("connector_indexing_task: payload invalid or not found")
@@ -703,7 +575,7 @@ def connector_indexing_task(
)
break
lock: RedisLock = r.lock(
lock = r.lock(
redis_connector_index.generator_lock_key,
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
)
@@ -712,7 +584,7 @@ def connector_indexing_task(
if not acquired:
logger.warning(
f"Indexing task already running, exiting...: "
f"index_attempt={index_attempt_id} cc_pair={cc_pair_id} search_settings={search_settings_id}"
f"cc_pair={cc_pair_id} search_settings={search_settings_id}"
)
return None
@@ -747,7 +619,7 @@ def connector_indexing_task(
)
# define a callback class
callback = IndexingCallback(
callback = RunIndexingCallback(
redis_connector.stop.fence_key,
redis_connector_index.generator_progress_key,
lock,

View File

@@ -12,7 +12,7 @@ from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.background.celery.tasks.indexing.tasks import IndexingCallback
from danswer.background.celery.tasks.indexing.tasks import RunIndexingCallback
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
@@ -277,7 +277,7 @@ def connector_pruning_generator_task(
cc_pair.credential,
)
callback = IndexingCallback(
callback = RunIndexingCallback(
redis_connector.stop.fence_key,
redis_connector.prune.generator_progress_key,
lock,

View File

@@ -5,6 +5,7 @@ from http import HTTPStatus
from typing import cast
import httpx
import redis
from celery import Celery
from celery import shared_task
from celery import Task
@@ -46,10 +47,13 @@ from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import IndexingStatus
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.index_attempt import get_all_index_attempts_by_status
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import DocumentSet
from danswer.db.models import IndexAttempt
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaDocumentFields
@@ -645,26 +649,20 @@ def monitor_ccpair_indexing_taskset(
# the task is still setting up
return
# Read result state BEFORE generator_complete_key to avoid a race condition
# never use any blocking methods on the result from inside a task!
result: AsyncResult = AsyncResult(payload.celery_task_id)
result_state = result.state
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# inner = get_completion / generator_complete not signaled
# outer = result.state in READY state
status_int = redis_connector_index.get_completion()
if status_int is None: # inner signal not set ... possible error
result_state = result.state
if (
result_state in READY_STATES
): # outer signal in terminal state ... possible error
# Now double check!
if status_int is None: # completion signal not set ... check for errors
# If we get here, and then the task both sets the completion signal and finishes,
# we will incorrectly abort the task. We must check result state, then check
# get_completion again to avoid the race condition.
if result_state in READY_STATES:
if redis_connector_index.get_completion() is None:
# inner signal still not set (and cannot change when outer result_state is READY)
# Task is finished but generator complete isn't set.
# We have a problem! Worker may have crashed.
# IF the task state is READY, THEN generator_complete should be set
# if it isn't, then the worker crashed
msg = (
f"Connector indexing aborted or exceptioned: "
f"attempt={payload.index_attempt_id} "
@@ -699,6 +697,37 @@ def monitor_ccpair_indexing_taskset(
redis_connector_index.reset()
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
want to clean them up.
Unfenced = attempt not in terminal state and fence does not exist.
"""
unfenced_attempts: list[int] = []
# do some cleanup before clearing fences
# check the db for any outstanding index attempts
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for attempt in attempts:
# if attempts exist in the db but we don't detect them in redis, mark them as failed
fence_key = RedisConnectorIndex.fence_key_with_ids(
attempt.connector_credential_pair_id, attempt.search_settings_id
)
if r.exists(fence_key):
continue
unfenced_attempts.append(attempt.id)
return unfenced_attempts
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
@@ -750,6 +779,25 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
f"permissions_sync={n_permissions_sync} "
)
# Fail any index attempts in the DB that don't have fences
with get_session_with_tenant(tenant_id) as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
attempt = get_index_attempt(db_session, attempt_id)
if not attempt:
continue
failure_reason = (
f"Unfenced index attempt found in DB: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
task_logger.warning(failure_reason)
mark_attempt_failed(
attempt.id, db_session, failure_reason=failure_reason
)
lock_beat.reacquire()
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)

View File

@@ -1,5 +1,7 @@
import time
import traceback
from abc import ABC
from abc import abstractmethod
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -29,7 +31,7 @@ from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.embedder import DefaultIndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logger import setup_logger
from danswer.utils.logger import TaskAttemptSingleton
@@ -40,6 +42,19 @@ logger = setup_logger()
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
class RunIndexingCallbackInterface(ABC):
"""Defines a callback interface to be passed to
to run_indexing_entrypoint."""
@abstractmethod
def should_stop(self) -> bool:
"""Signal to stop the looping function in flight."""
@abstractmethod
def progress(self, amount: int) -> None:
"""Send progress updates to the caller."""
def _get_connector_runner(
db_session: Session,
attempt: IndexAttempt,
@@ -91,7 +106,7 @@ def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
tenant_id: str | None,
callback: IndexingHeartbeatInterface | None = None,
callback: RunIndexingCallbackInterface | None = None,
) -> None:
"""
1. Get documents which are either new or updated from specified application
@@ -123,7 +138,13 @@ def _run_indexing(
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings,
callback=callback,
heartbeat=IndexingHeartbeat(
index_attempt_id=index_attempt.id,
db_session=db_session,
# let the world know we're still making progress after
# every 10 batches
freq=10,
),
)
indexing_pipeline = build_indexing_pipeline(
@@ -136,7 +157,6 @@ def _run_indexing(
),
db_session=db_session,
tenant_id=tenant_id,
callback=callback,
)
db_cc_pair = index_attempt.connector_credential_pair
@@ -208,9 +228,7 @@ def _run_indexing(
# contents still need to be initially pulled.
if callback:
if callback.should_stop():
raise RuntimeError(
"_run_indexing: Connector stop signal detected"
)
raise RuntimeError("Connector stop signal detected")
# TODO: should we move this into the above callback instead?
db_session.refresh(db_cc_pair)
@@ -271,7 +289,7 @@ def _run_indexing(
db_session.commit()
if callback:
callback.progress("_run_indexing", len(doc_batch))
callback.progress(len(doc_batch))
# This new value is updated every batch, so UI can refresh per batch update
update_docs_indexed(
@@ -401,7 +419,7 @@ def run_indexing_entrypoint(
tenant_id: str | None,
connector_credential_pair_id: int,
is_ee: bool = False,
callback: IndexingHeartbeatInterface | None = None,
callback: RunIndexingCallbackInterface | None = None,
) -> None:
try:
if is_ee:

View File

@@ -67,13 +67,6 @@ def create_index_attempt(
return new_attempt.id
def delete_index_attempt(db_session: Session, index_attempt_id: int) -> None:
index_attempt = get_index_attempt(db_session, index_attempt_id)
if index_attempt:
db_session.delete(index_attempt)
db_session.commit()
def mock_successful_index_attempt(
connector_credential_pair_id: int,
search_settings_id: int,

View File

@@ -1181,7 +1181,7 @@ class LLMProvider(Base):
default_model_name: Mapped[str] = mapped_column(String)
fast_default_model_name: Mapped[str | None] = mapped_column(String, nullable=True)
# Models to actually display to users
# Models to actually disp;aly to users
# If nulled out, we assume in the application logic we should present all
display_model_names: Mapped[list[str] | None] = mapped_column(
postgresql.ARRAY(String), nullable=True

View File

@@ -259,6 +259,7 @@ def get_personas(
) -> Sequence[Persona]:
stmt = select(Persona).distinct()
stmt = _add_user_filters(stmt=stmt, user=user, get_editable=get_editable)
if not include_default:
stmt = stmt.where(Persona.builtin_persona.is_(False))
if not include_slack_bot_personas:

View File

@@ -10,7 +10,7 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
get_metadata_keys_to_ignore,
)
from danswer.connectors.models import Document
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.indexing.models import DocAwareChunk
from danswer.natural_language_processing.utils import BaseTokenizer
from danswer.utils.logger import setup_logger
@@ -125,7 +125,7 @@ class Chunker:
chunk_token_limit: int = DOC_EMBEDDING_CONTEXT_SIZE,
chunk_overlap: int = CHUNK_OVERLAP,
mini_chunk_size: int = MINI_CHUNK_SIZE,
callback: IndexingHeartbeatInterface | None = None,
heartbeat: Heartbeat | None = None,
) -> None:
from llama_index.text_splitter import SentenceSplitter
@@ -134,7 +134,7 @@ class Chunker:
self.enable_multipass = enable_multipass
self.enable_large_chunks = enable_large_chunks
self.tokenizer = tokenizer
self.callback = callback
self.heartbeat = heartbeat
self.blurb_splitter = SentenceSplitter(
tokenizer=tokenizer.tokenize,
@@ -356,14 +356,9 @@ class Chunker:
def chunk(self, documents: list[Document]) -> list[DocAwareChunk]:
final_chunks: list[DocAwareChunk] = []
for document in documents:
if self.callback:
if self.callback.should_stop():
raise RuntimeError("Chunker.chunk: Stop signal detected")
final_chunks.extend(self._handle_single_document(document))
chunks = self._handle_single_document(document)
final_chunks.extend(chunks)
if self.callback:
self.callback.progress("Chunker.chunk", len(chunks))
if self.heartbeat:
self.heartbeat.heartbeat()
return final_chunks

View File

@@ -2,7 +2,7 @@ from abc import ABC
from abc import abstractmethod
from danswer.db.models import SearchSettings
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.indexing.models import ChunkEmbedding
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk
@@ -34,7 +34,7 @@ class IndexingEmbedder(ABC):
api_url: str | None,
api_version: str | None,
deployment_name: str | None,
callback: IndexingHeartbeatInterface | None,
heartbeat: Heartbeat | None,
):
self.model_name = model_name
self.normalize = normalize
@@ -60,7 +60,7 @@ class IndexingEmbedder(ABC):
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
retrim_content=True,
callback=callback,
heartbeat=heartbeat,
)
@abstractmethod
@@ -83,7 +83,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
api_url: str | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
callback: IndexingHeartbeatInterface | None = None,
heartbeat: Heartbeat | None = None,
):
super().__init__(
model_name,
@@ -95,7 +95,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
api_url,
api_version,
deployment_name,
callback,
heartbeat,
)
@log_function_time()
@@ -201,9 +201,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
@classmethod
def from_db_search_settings(
cls,
search_settings: SearchSettings,
callback: IndexingHeartbeatInterface | None = None,
cls, search_settings: SearchSettings, heartbeat: Heartbeat | None = None
) -> "DefaultIndexingEmbedder":
return cls(
model_name=search_settings.model_name,
@@ -215,5 +213,5 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
api_url=search_settings.api_url,
api_version=search_settings.api_version,
deployment_name=search_settings.deployment_name,
callback=callback,
heartbeat=heartbeat,
)

View File

@@ -1,15 +1,41 @@
from abc import ABC
from abc import abstractmethod
import abc
from typing import Any
from sqlalchemy import func
from sqlalchemy.orm import Session
from danswer.db.index_attempt import get_index_attempt
from danswer.utils.logger import setup_logger
logger = setup_logger()
class IndexingHeartbeatInterface(ABC):
"""Defines a callback interface to be passed to
to run_indexing_entrypoint."""
class Heartbeat(abc.ABC):
"""Useful for any long-running work that goes through a bunch of items
and needs to occasionally give updates on progress.
e.g. chunking, embedding, updating vespa, etc."""
@abstractmethod
def should_stop(self) -> bool:
"""Signal to stop the looping function in flight."""
@abc.abstractmethod
def heartbeat(self, metadata: Any = None) -> None:
raise NotImplementedError
@abstractmethod
def progress(self, tag: str, amount: int) -> None:
"""Send progress updates to the caller."""
class IndexingHeartbeat(Heartbeat):
def __init__(self, index_attempt_id: int, db_session: Session, freq: int):
self.cnt = 0
self.index_attempt_id = index_attempt_id
self.db_session = db_session
self.freq = freq
def heartbeat(self, metadata: Any = None) -> None:
self.cnt += 1
if self.cnt % self.freq == 0:
index_attempt = get_index_attempt(
db_session=self.db_session, index_attempt_id=self.index_attempt_id
)
if index_attempt:
index_attempt.time_updated = func.now()
self.db_session.commit()
else:
logger.error("Index attempt not found, this should not happen!")

View File

@@ -34,7 +34,7 @@ from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import DocumentMetadata
from danswer.indexing.chunker import Chunker
from danswer.indexing.embedder import IndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.utils.logger import setup_logger
@@ -414,7 +414,6 @@ def build_indexing_pipeline(
ignore_time_skip: bool = False,
attempt_id: int | None = None,
tenant_id: str | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> IndexingPipelineProtocol:
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
search_settings = get_current_search_settings(db_session)
@@ -441,8 +440,13 @@ def build_indexing_pipeline(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=multipass,
enable_large_chunks=enable_large_chunks,
# after every doc, update status in case there are a bunch of really long docs
callback=callback,
# after every doc, update status in case there are a bunch of
# really long docs
heartbeat=IndexingHeartbeat(
index_attempt_id=attempt_id, db_session=db_session, freq=1
)
if attempt_id
else None,
)
return partial(

View File

@@ -16,7 +16,7 @@ from danswer.configs.model_configs import (
)
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.models import SearchSettings
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.utils.logger import setup_logger
@@ -99,7 +99,7 @@ class EmbeddingModel:
api_url: str | None,
provider_type: EmbeddingProvider | None,
retrim_content: bool = False,
callback: IndexingHeartbeatInterface | None = None,
heartbeat: Heartbeat | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
) -> None:
@@ -116,7 +116,7 @@ class EmbeddingModel:
self.tokenizer = get_tokenizer(
model_name=model_name, provider_type=provider_type
)
self.callback = callback
self.heartbeat = heartbeat
model_server_url = build_model_server_url(server_host, server_port)
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
@@ -160,10 +160,6 @@ class EmbeddingModel:
embeddings: list[Embedding] = []
for idx, text_batch in enumerate(text_batches, start=1):
if self.callback:
if self.callback.should_stop():
raise RuntimeError("_batch_encode_texts detected stop signal")
logger.debug(f"Encoding batch {idx} of {len(text_batches)}")
embed_request = EmbedRequest(
model_name=self.model_name,
@@ -183,8 +179,8 @@ class EmbeddingModel:
response = self._make_model_server_request(embed_request)
embeddings.extend(response.embeddings)
if self.callback:
self.callback.progress("_batch_encode_texts", 1)
if self.heartbeat:
self.heartbeat.heartbeat()
return embeddings
def encode(

View File

@@ -17,7 +17,7 @@ from danswer.db.document import construct_document_select_for_connector_credenti
from danswer.db.models import Document as DbDocument
class RedisConnectorDeletePayload(BaseModel):
class RedisConnectorDeletionFenceData(BaseModel):
num_tasks: int | None
submitted: datetime
@@ -54,18 +54,20 @@ class RedisConnectorDelete:
return False
@property
def payload(self) -> RedisConnectorDeletePayload | None:
def payload(self) -> RedisConnectorDeletionFenceData | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
if fence_bytes is None:
return None
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorDeletePayload.model_validate_json(cast(str, fence_str))
payload = RedisConnectorDeletionFenceData.model_validate_json(
cast(str, fence_str)
)
return payload
def set_fence(self, payload: RedisConnectorDeletePayload | None) -> None:
def set_fence(self, payload: RedisConnectorDeletionFenceData | None) -> None:
if not payload:
self.redis.delete(self.fence_key)
return

View File

@@ -30,6 +30,7 @@ from danswer.utils.threadpool_concurrency import run_functions_tuples_in_paralle
logger = setup_logger()
admin_router = APIRouter(prefix="/admin/llm")
basic_router = APIRouter(prefix="/llm")

View File

@@ -411,8 +411,6 @@ def _validate_curator_status__no_commit(
.all()
)
# if the user is a curator in any of their groups, set their role to CURATOR
# otherwise, set their role to BASIC
if curator_relationships:
user.role = UserRole.CURATOR
elif user.role == UserRole.CURATOR:
@@ -438,15 +436,6 @@ def update_user_curator_relationship(
user = fetch_user_by_id(db_session, set_curator_request.user_id)
if not user:
raise ValueError(f"User with id '{set_curator_request.user_id}' not found")
if user.role == UserRole.ADMIN:
raise ValueError(
f"User '{user.email}' is an admin and therefore has all permissions "
"of a curator. If you'd like this user to only have curator permissions, "
"you must update their role to BASIC then assign them to be CURATOR in the "
"appropriate groups."
)
requested_user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=set_curator_request.user_id,

View File

@@ -1,16 +1,15 @@
from typing import Any
import pytest
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.indexing.indexing_heartbeat import Heartbeat
class MockHeartbeat(IndexingHeartbeatInterface):
class MockHeartbeat(Heartbeat):
def __init__(self) -> None:
self.call_count = 0
def should_stop(self) -> bool:
return False
def progress(self, tag: str, amount: int) -> None:
def heartbeat(self, metadata: Any = None) -> None:
self.call_count += 1

View File

@@ -74,7 +74,7 @@ def test_chunker_heartbeat(
chunker = Chunker(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=False,
callback=mock_heartbeat,
heartbeat=mock_heartbeat,
)
chunks = chunker.chunk([document])

View File

@@ -0,0 +1,80 @@
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from sqlalchemy.orm import Session
from danswer.db.index_attempt import IndexAttempt
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
@pytest.fixture
def mock_db_session() -> MagicMock:
return MagicMock(spec=Session)
@pytest.fixture
def mock_index_attempt() -> MagicMock:
return MagicMock(spec=IndexAttempt)
def test_indexing_heartbeat(
mock_db_session: MagicMock, mock_index_attempt: MagicMock
) -> None:
with patch(
"danswer.indexing.indexing_heartbeat.get_index_attempt"
) as mock_get_index_attempt:
mock_get_index_attempt.return_value = mock_index_attempt
heartbeat = IndexingHeartbeat(
index_attempt_id=1, db_session=mock_db_session, freq=5
)
# Test that heartbeat doesn't update before freq is reached
for _ in range(4):
heartbeat.heartbeat()
mock_db_session.commit.assert_not_called()
# Test that heartbeat updates when freq is reached
heartbeat.heartbeat()
mock_get_index_attempt.assert_called_once_with(
db_session=mock_db_session, index_attempt_id=1
)
assert mock_index_attempt.time_updated is not None
mock_db_session.commit.assert_called_once()
# Reset mock calls
mock_db_session.reset_mock()
mock_get_index_attempt.reset_mock()
# Test that heartbeat updates again after freq more calls
for _ in range(5):
heartbeat.heartbeat()
mock_get_index_attempt.assert_called_once()
mock_db_session.commit.assert_called_once()
def test_indexing_heartbeat_not_found(mock_db_session: MagicMock) -> None:
with patch(
"danswer.indexing.indexing_heartbeat.get_index_attempt"
) as mock_get_index_attempt, patch(
"danswer.indexing.indexing_heartbeat.logger"
) as mock_logger:
mock_get_index_attempt.return_value = None
heartbeat = IndexingHeartbeat(
index_attempt_id=1, db_session=mock_db_session, freq=1
)
heartbeat.heartbeat()
mock_get_index_attempt.assert_called_once_with(
db_session=mock_db_session, index_attempt_id=1
)
mock_logger.error.assert_called_once_with(
"Index attempt not found, this should not happen!"
)
mock_db_session.commit.assert_not_called()

View File

@@ -69,9 +69,6 @@ ENV NEXT_PUBLIC_POSTHOG_HOST=${NEXT_PUBLIC_POSTHOG_HOST}
ARG NEXT_PUBLIC_SENTRY_DSN
ENV NEXT_PUBLIC_SENTRY_DSN=${NEXT_PUBLIC_SENTRY_DSN}
ARG NEXT_PUBLIC_GTM_ENABLED
ENV NEXT_PUBLIC_GTM_ENABLED=${NEXT_PUBLIC_GTM_ENABLED}
RUN npx next build
# Step 2. Production image, copy all the files and run next
@@ -137,12 +134,9 @@ ARG NEXT_PUBLIC_POSTHOG_KEY
ARG NEXT_PUBLIC_POSTHOG_HOST
ENV NEXT_PUBLIC_POSTHOG_KEY=${NEXT_PUBLIC_POSTHOG_KEY}
ENV NEXT_PUBLIC_POSTHOG_HOST=${NEXT_PUBLIC_POSTHOG_HOST}
ARG NEXT_PUBLIC_SENTRY_DSN
ENV NEXT_PUBLIC_SENTRY_DSN=${NEXT_PUBLIC_SENTRY_DSN}
ARG NEXT_PUBLIC_GTM_ENABLED
ENV NEXT_PUBLIC_GTM_ENABLED=${NEXT_PUBLIC_GTM_ENABLED}
# Note: Don't expose ports here, Compose will handle that for us if necessary.
# If you want to run this without compose, specify the ports to

View File

@@ -142,8 +142,6 @@ export function CustomLLMProviderUpdateForm({
},
body: JSON.stringify({
...values,
// For custom llm providers, all model names are displayed
display_model_names: values.model_names,
custom_config: customConfigProcessing(values.custom_config_list),
}),
});

View File

@@ -52,7 +52,6 @@ import {
useLayoutEffect,
useRef,
useState,
useMemo,
} from "react";
import { usePopup } from "@/components/admin/connectors/Popup";
import { SEARCH_PARAM_NAMES, shouldSubmitOnLoad } from "./searchParams";
@@ -267,6 +266,7 @@ export function ChatPage({
availableAssistants[0];
const noAssistants = liveAssistant == null || liveAssistant == undefined;
// always set the model override for the chat session, when an assistant, llm provider, or user preference exists
useEffect(() => {
const personaDefault = getLLMProviderOverrideForPersona(
@@ -282,7 +282,7 @@ export function ChatPage({
);
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [liveAssistant, user?.preferences.default_model]);
}, [liveAssistant, llmProviders, user?.preferences.default_model]);
const stopGenerating = () => {
const currentSession = currentSessionId();
@@ -2007,7 +2007,7 @@ export function ChatPage({
{...getRootProps()}
>
<div
className={`w-full h-full flex flex-col default-scrollbar overflow-y-auto overflow-x-hidden relative`}
className={`w-full h-full flex flex-col overflow-y-auto include-scrollbar overflow-x-hidden relative`}
ref={scrollableDivRef}
>
{/* ChatBanner is a custom banner that displays a admin-specified message at

View File

@@ -76,7 +76,7 @@ export function AssistantsTab({
items={assistants.map((a) => a.id.toString())}
strategy={verticalListSortingStrategy}
>
<div className="px-4 pb-2 max-h-[500px] default-scrollbar overflow-y-scroll overflow-x-hidden my-3 grid grid-cols-1 gap-4">
<div className="px-4 pb-2 max-h-[500px] include-scrollbar overflow-y-scroll my-3 grid grid-cols-1 gap-4">
{assistants.map((assistant) => (
<DraggableAssistantCard
key={assistant.id.toString()}

View File

@@ -260,29 +260,28 @@
}
}
.default-scrollbar::-webkit-scrollbar {
.include-scrollbar::-webkit-scrollbar {
width: 6px;
}
.default-scrollbar::-webkit-scrollbar-track {
.include-scrollbar::-webkit-scrollbar-track {
background: #f1f1f1;
}
.default-scrollbar::-webkit-scrollbar-thumb {
.include-scrollbar::-webkit-scrollbar-thumb {
background: #888;
border-radius: 4px;
}
.default-scrollbar::-webkit-scrollbar-thumb:hover {
.include-scrollbar::-webkit-scrollbar-thumb:hover {
background: #555;
}
.default-scrollbar {
.include-scrollbar {
scrollbar-width: thin;
scrollbar-color: #888 transparent;
overflow: overlay;
overflow-y: scroll;
overflow-x: hidden;
}
.inputscroll::-webkit-scrollbar-track {

View File

@@ -57,7 +57,7 @@ const DropdownOption: React.FC<DropdownOptionProps> = ({
};
export function UserDropdown({ page }: { page?: pageType }) {
const { user, isCurator } = useUser();
const { user } = useUser();
const [userInfoVisible, setUserInfoVisible] = useState(false);
const userInfoRef = useRef<HTMLDivElement>(null);
const router = useRouter();
@@ -95,9 +95,7 @@ export function UserDropdown({ page }: { page?: pageType }) {
}
// Construct the current URL
const currentUrl = `${pathname}${
searchParams.toString() ? `?${searchParams.toString()}` : ""
}`;
const currentUrl = `${pathname}${searchParams.toString() ? `?${searchParams.toString()}` : ""}`;
// Encode the current URL to use as a redirect parameter
const encodedRedirect = encodeURIComponent(currentUrl);
@@ -108,7 +106,9 @@ export function UserDropdown({ page }: { page?: pageType }) {
};
const showAdminPanel = !user || user.role === UserRole.ADMIN;
const showCuratorPanel = user && isCurator;
const showCuratorPanel =
user &&
(user.role === UserRole.CURATOR || user.role === UserRole.GLOBAL_CURATOR);
const showLogout =
user && !checkUserIsNoAuthUser(user.id) && !LOGOUT_DISABLED;
@@ -244,11 +244,7 @@ export function UserDropdown({ page }: { page?: pageType }) {
setShowNotifications(true);
}}
icon={<BellIcon className="h-5 w-5 my-auto mr-2" />}
label={`Notifications ${
notifications && notifications.length > 0
? `(${notifications.length})`
: ""
}`}
label={`Notifications ${notifications && notifications.length > 0 ? `(${notifications.length})` : ""}`}
/>
{showLogout &&

View File

@@ -47,7 +47,7 @@ export const AssistantsProvider: React.FC<{
const [assistants, setAssistants] = useState<Persona[]>(
initialAssistants || []
);
const { user, isLoadingUser, isAdmin, isCurator } = useUser();
const { user, isLoadingUser, isAdmin } = useUser();
const [editablePersonas, setEditablePersonas] = useState<Persona[]>([]);
const [allAssistants, setAllAssistants] = useState<Persona[]>([]);
@@ -83,7 +83,7 @@ export const AssistantsProvider: React.FC<{
useEffect(() => {
const fetchPersonas = async () => {
if (!isAdmin && !isCurator) {
if (!isAdmin) {
return;
}
@@ -101,8 +101,6 @@ export const AssistantsProvider: React.FC<{
if (allResponse.ok) {
const allPersonas = await allResponse.json();
setAllAssistants(allPersonas);
} else {
console.error("Error fetching personas:", allResponse);
}
} catch (error) {
console.error("Error fetching personas:", error);
@@ -110,7 +108,7 @@ export const AssistantsProvider: React.FC<{
};
fetchPersonas();
}, [isAdmin, isCurator]);
}, [isAdmin]);
const refreshRecentAssistants = async (currentAssistant: number) => {
const response = await fetch("/api/user/recent-assistants", {

View File

@@ -63,10 +63,8 @@ export const LlmList: React.FC<LlmListProps> = ({
return (
<div
className={`${
scrollable
? "max-h-[200px] default-scrollbar overflow-x-hidden"
: "max-h-[300px]"
} bg-background-175 flex flex-col gap-y-1 overflow-y-scroll`}
scrollable ? "max-h-[200px] include-scrollbar" : "max-h-[300px]"
} bg-background-175 flex flex-col gap-y-1 overflow-y-scroll`}
>
{userDefault && (
<button

View File

@@ -18,7 +18,7 @@ export default function ExceptionTraceModal({
title="Full Exception Trace"
onOutsideClick={onOutsideClick}
>
<div className="overflow-y-auto default-scrollbar overflow-x-hidden pr-3 h-full mb-6">
<div className="overflow-y-auto include-scrollbar pr-3 h-full mb-6">
<div className="mb-6">
{!copyClicked ? (
<div

View File

@@ -698,7 +698,7 @@ export const SearchSection = ({
</div>
</div>
<div className="absolute default-scrollbar h-screen overflow-y-auto overflow-x-hidden left-0 w-full top-0">
<div className="absolute include-scrollbar h-screen overflow-y-auto left-0 w-full top-0">
<FunctionalHeader
sidebarToggled={toggledSidebar}
reset={() => setQuery("")}

View File

@@ -67,10 +67,7 @@ export function UserProvider({
isLoadingUser,
refreshUser,
isAdmin: upToDateUser?.role === UserRole.ADMIN,
// Curator status applies for either global or basic curator
isCurator:
upToDateUser?.role === UserRole.CURATOR ||
upToDateUser?.role === UserRole.GLOBAL_CURATOR,
isCurator: upToDateUser?.role === UserRole.CURATOR,
isCloudSuperuser: upToDateUser?.is_cloud_superuser ?? false,
}}
>

View File

@@ -36,10 +36,8 @@ export const SIDEBAR_WIDTH = `w-[350px]`;
export const LOGOUT_DISABLED =
process.env.NEXT_PUBLIC_DISABLE_LOGOUT?.toLowerCase() === "true";
// Default sidebar open is true if the environment variable is not set
export const NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN =
process.env.NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN?.toLowerCase() === "true" ??
true;
process.env.NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN?.toLowerCase() === "true";
export const TOGGLED_CONNECTORS_COOKIE_NAME = "toggled_connectors";

View File

@@ -174,6 +174,7 @@ export function useLlmOverride(
modelName: "",
}
);
const [llmOverride, setLlmOverride] = useState<LlmOverride>(
currentChatSession && currentChatSession.current_alternate_model
? destructureValue(currentChatSession.current_alternate_model)

View File

@@ -1,22 +1,14 @@
import { test, expect } from "@chromatic-com/playwright";
test.describe("Admin Performance Query History", () => {
// Ignores the diff for elements targeted by the specified list of selectors
// exclude button since they change based on the date
test.use({ ignoreSelectors: ["button"] });
test(
"Admin - Performance - Query History",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
// Test simple loading
await page.goto("http://localhost:3000/admin/performance/query-history");
await expect(page.locator("h1.text-3xl")).toHaveText("Query History");
await expect(page.locator("p.text-sm").nth(0)).toHaveText(
"Feedback Type"
);
}
);
});
test(
"Admin - Performance - Query History",
{
tag: "@admin",
},
async ({ page }, testInfo) => {
// Test simple loading
await page.goto("http://localhost:3000/admin/performance/query-history");
await expect(page.locator("h1.text-3xl")).toHaveText("Query History");
await expect(page.locator("p.text-sm").nth(0)).toHaveText("Feedback Type");
}
);

View File

@@ -2,7 +2,6 @@ import { test, expect } from "@chromatic-com/playwright";
test.describe("Admin Performance Usage", () => {
// Ignores the diff for elements targeted by the specified list of selectors
// exclude button and svg since they change based on the date
test.use({ ignoreSelectors: ["button", "svg"] });
test(