mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-01 13:02:42 +00:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c673959714 | ||
|
|
cb36562802 | ||
|
|
efc424bf3e | ||
|
|
e0baaf85e5 | ||
|
|
a0ffd47e2c | ||
|
|
d0396a1337 |
@@ -28,6 +28,7 @@ from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ElementExternalAccess
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
@@ -187,7 +188,6 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
|
||||
@@ -227,6 +227,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
validate_permission_sync_fences(
|
||||
tenant_id, r, r_replica, r_celery, lock_beat
|
||||
)
|
||||
|
||||
@@ -29,6 +29,7 @@ from ee.onyx.external_permissions.sync_params import (
|
||||
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
from onyx.background.error_logging import emit_background_error
|
||||
@@ -162,7 +163,6 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
|
||||
@@ -221,6 +221,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
validate_external_group_sync_fences(
|
||||
tenant_id, self.app, r, r_replica, r_celery, lock_beat
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# These are helper objects for tracking the keys we need to write in redis
|
||||
import json
|
||||
import threading
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -7,7 +8,59 @@ from celery import Celery
|
||||
from redis import Redis
|
||||
|
||||
from onyx.background.celery.configs.base import CELERY_SEPARATOR
|
||||
from onyx.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
|
||||
|
||||
|
||||
_broker_client: Redis | None = None
|
||||
_broker_url: str | None = None
|
||||
_broker_client_lock = threading.Lock()
|
||||
|
||||
|
||||
def celery_get_broker_client(app: Celery) -> Redis:
|
||||
"""Return a shared Redis client connected to the Celery broker DB.
|
||||
|
||||
Uses a module-level singleton so all tasks on a worker share one
|
||||
connection instead of creating a new one per call. The client
|
||||
connects directly to the broker Redis DB (parsed from the broker URL).
|
||||
|
||||
Thread-safe via lock — safe for use in Celery thread-pool workers.
|
||||
|
||||
Usage:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
length = celery_get_queue_length(queue, r_celery)
|
||||
"""
|
||||
global _broker_client, _broker_url
|
||||
with _broker_client_lock:
|
||||
url = app.conf.broker_url
|
||||
if _broker_client is not None and _broker_url == url:
|
||||
try:
|
||||
_broker_client.ping()
|
||||
return _broker_client
|
||||
except Exception:
|
||||
try:
|
||||
_broker_client.close()
|
||||
except Exception:
|
||||
pass
|
||||
_broker_client = None
|
||||
elif _broker_client is not None:
|
||||
try:
|
||||
_broker_client.close()
|
||||
except Exception:
|
||||
pass
|
||||
_broker_client = None
|
||||
|
||||
_broker_url = url
|
||||
_broker_client = Redis.from_url(
|
||||
url,
|
||||
decode_responses=False,
|
||||
health_check_interval=REDIS_HEALTH_CHECK_INTERVAL,
|
||||
socket_keepalive=True,
|
||||
socket_keepalive_options=REDIS_SOCKET_KEEPALIVE_OPTIONS,
|
||||
retry_on_timeout=True,
|
||||
)
|
||||
return _broker_client
|
||||
|
||||
|
||||
def celery_get_unacked_length(r: Redis) -> int:
|
||||
|
||||
@@ -14,6 +14,7 @@ from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
@@ -132,7 +133,6 @@ def revoke_tasks_blocking_deletion(
|
||||
def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
@@ -149,6 +149,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | N
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES):
|
||||
# clear fences that don't have associated celery tasks in progress
|
||||
try:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
validate_connector_deletion_fences(
|
||||
tenant_id, r, r_replica, r_celery, lock_beat
|
||||
)
|
||||
|
||||
@@ -22,6 +22,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
@@ -449,7 +450,7 @@ def check_indexing_completion(
|
||||
):
|
||||
# Check if the task exists in the celery queue
|
||||
# This handles the case where Redis dies after task creation but before task execution
|
||||
redis_celery = task.app.broker_connection().channel().client # type: ignore
|
||||
redis_celery = celery_get_broker_client(task.app)
|
||||
task_exists = celery_find_task(
|
||||
attempt.celery_task_id,
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from itertools import islice
|
||||
from typing import Any
|
||||
@@ -19,6 +18,7 @@ from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
@@ -698,31 +698,27 @@ def monitor_background_processes(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Get Redis client for Celery broker
|
||||
redis_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
redis_std = get_redis_client()
|
||||
|
||||
# Define metric collection functions and their dependencies
|
||||
metric_functions: list[Callable[[], list[Metric]]] = [
|
||||
lambda: _collect_queue_metrics(redis_celery),
|
||||
lambda: _collect_connector_metrics(db_session, redis_std),
|
||||
lambda: _collect_sync_metrics(db_session, redis_std),
|
||||
]
|
||||
# Collect queue metrics with broker connection
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
queue_metrics = _collect_queue_metrics(r_celery)
|
||||
|
||||
# Collect and log each metric
|
||||
# Collect remaining metrics (no broker connection needed)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for metric_fn in metric_functions:
|
||||
metrics = metric_fn()
|
||||
for metric in metrics:
|
||||
# double check to make sure we aren't double-emitting metrics
|
||||
if metric.key is None or not _has_metric_been_emitted(
|
||||
redis_std, metric.key
|
||||
):
|
||||
metric.log()
|
||||
metric.emit(tenant_id)
|
||||
all_metrics: list[Metric] = queue_metrics
|
||||
all_metrics.extend(_collect_connector_metrics(db_session, redis_std))
|
||||
all_metrics.extend(_collect_sync_metrics(db_session, redis_std))
|
||||
|
||||
if metric.key is not None:
|
||||
_mark_metric_as_emitted(redis_std, metric.key)
|
||||
for metric in all_metrics:
|
||||
if metric.key is None or not _has_metric_been_emitted(
|
||||
redis_std, metric.key
|
||||
):
|
||||
metric.log()
|
||||
metric.emit(tenant_id)
|
||||
|
||||
if metric.key is not None:
|
||||
_mark_metric_as_emitted(redis_std, metric.key)
|
||||
|
||||
task_logger.info("Successfully collected background metrics")
|
||||
except SoftTimeLimitExceeded:
|
||||
@@ -890,7 +886,7 @@ def monitor_celery_queues_helper(
|
||||
) -> None:
|
||||
"""A task to monitor all celery queue lengths."""
|
||||
|
||||
r_celery = task.app.broker_connection().channel().client # type: ignore
|
||||
r_celery = celery_get_broker_client(task.app)
|
||||
n_celery = celery_get_queue_length(OnyxCeleryQueues.PRIMARY, r_celery)
|
||||
n_docfetching = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
|
||||
@@ -1080,7 +1076,7 @@ def cloud_monitor_celery_pidbox(
|
||||
num_deleted = 0
|
||||
|
||||
MAX_PIDBOX_IDLE = 24 * 3600 # 1 day in seconds
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
for key in r_celery.scan_iter("*.reply.celery.pidbox"):
|
||||
key_bytes = cast(bytes, key)
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
|
||||
@@ -17,6 +17,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
@@ -203,7 +204,6 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
|
||||
@@ -261,6 +261,7 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
validate_pruning_fences(tenant_id, r, r_replica, r_celery, lock_beat)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while validating pruning fences")
|
||||
|
||||
@@ -16,6 +16,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.access import build_access_for_user_files
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
@@ -105,7 +106,7 @@ def _user_file_delete_queued_key(user_file_id: str | UUID) -> str:
|
||||
|
||||
|
||||
def get_user_file_project_sync_queue_depth(celery_app: Celery) -> int:
|
||||
redis_celery: Redis = celery_app.broker_connection().channel().client # type: ignore
|
||||
redis_celery = celery_get_broker_client(celery_app)
|
||||
return celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, redis_celery
|
||||
)
|
||||
@@ -238,7 +239,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
skipped_guard = 0
|
||||
try:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
queue_len = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
|
||||
)
|
||||
@@ -591,7 +592,7 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
# NOTE: must use the broker's Redis client (not redis_client) because
|
||||
# Celery queues live on a separate Redis DB with CELERY_SEPARATOR keys.
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
queue_len = celery_get_queue_length(OnyxCeleryQueues.USER_FILE_DELETE, r_celery)
|
||||
if queue_len > USER_FILE_DELETE_MAX_QUEUE_DEPTH:
|
||||
task_logger.warning(
|
||||
|
||||
@@ -8,7 +8,6 @@ from uuid import UUID
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import exists
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import nullsfirst
|
||||
from sqlalchemy import or_
|
||||
@@ -132,32 +131,47 @@ def get_chat_sessions_by_user(
|
||||
if before is not None:
|
||||
stmt = stmt.where(ChatSession.time_updated < before)
|
||||
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
if project_id is not None:
|
||||
stmt = stmt.where(ChatSession.project_id == project_id)
|
||||
elif only_non_project_chats:
|
||||
stmt = stmt.where(ChatSession.project_id.is_(None))
|
||||
|
||||
if not include_failed_chats:
|
||||
non_system_message_exists_subq = (
|
||||
exists()
|
||||
.where(ChatMessage.chat_session_id == ChatSession.id)
|
||||
.where(ChatMessage.message_type != MessageType.SYSTEM)
|
||||
.correlate(ChatSession)
|
||||
)
|
||||
|
||||
# Leeway for newly created chats that don't have messages yet
|
||||
time = datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||
recently_created = ChatSession.time_created >= time
|
||||
|
||||
stmt = stmt.where(or_(non_system_message_exists_subq, recently_created))
|
||||
# When filtering out failed chats, we apply the limit in Python after
|
||||
# filtering rather than in SQL, since the post-filter may remove rows.
|
||||
if limit and include_failed_chats:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
chat_sessions = result.scalars().all()
|
||||
chat_sessions = list(result.scalars().all())
|
||||
|
||||
return list(chat_sessions)
|
||||
if not include_failed_chats and chat_sessions:
|
||||
# Filter out "failed" sessions (those with only SYSTEM messages)
|
||||
# using a separate efficient query instead of a correlated EXISTS
|
||||
# subquery, which causes full sequential scans of chat_message.
|
||||
leeway = datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||
session_ids = [cs.id for cs in chat_sessions if cs.time_created < leeway]
|
||||
|
||||
if session_ids:
|
||||
valid_session_ids_stmt = (
|
||||
select(ChatMessage.chat_session_id)
|
||||
.where(ChatMessage.chat_session_id.in_(session_ids))
|
||||
.where(ChatMessage.message_type != MessageType.SYSTEM)
|
||||
.distinct()
|
||||
)
|
||||
valid_session_ids = set(
|
||||
db_session.execute(valid_session_ids_stmt).scalars().all()
|
||||
)
|
||||
|
||||
chat_sessions = [
|
||||
cs
|
||||
for cs in chat_sessions
|
||||
if cs.time_created >= leeway or cs.id in valid_session_ids
|
||||
]
|
||||
|
||||
if limit:
|
||||
chat_sessions = chat_sessions[:limit]
|
||||
|
||||
return chat_sessions
|
||||
|
||||
|
||||
def delete_orphaned_search_docs(db_session: Session) -> None:
|
||||
|
||||
@@ -185,6 +185,21 @@ def _messages_contain_tool_content(messages: list[dict[str, Any]]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _prompt_contains_tool_call_history(prompt: LanguageModelInput) -> bool:
|
||||
"""Check if the prompt contains any assistant messages with tool_calls.
|
||||
|
||||
When Anthropic's extended thinking is enabled, the API requires every
|
||||
assistant message to start with a thinking block before any tool_use
|
||||
blocks. Since we don't preserve thinking_blocks (they carry
|
||||
cryptographic signatures that can't be reconstructed), we must skip
|
||||
the thinking param whenever history contains prior tool-calling turns.
|
||||
"""
|
||||
from onyx.llm.models import AssistantMessage
|
||||
|
||||
msgs = prompt if isinstance(prompt, list) else [prompt]
|
||||
return any(isinstance(msg, AssistantMessage) and msg.tool_calls for msg in msgs)
|
||||
|
||||
|
||||
def _is_vertex_model_rejecting_output_config(model_name: str) -> bool:
|
||||
normalized_model_name = model_name.lower()
|
||||
return any(
|
||||
@@ -466,7 +481,20 @@ class LitellmLLM(LLM):
|
||||
reasoning_effort
|
||||
)
|
||||
|
||||
if budget_tokens is not None:
|
||||
# Anthropic requires every assistant message with tool_use
|
||||
# blocks to start with a thinking block that carries a
|
||||
# cryptographic signature. We don't preserve those blocks
|
||||
# across turns, so skip thinking when the history already
|
||||
# contains tool-calling assistant messages. LiteLLM's
|
||||
# modify_params workaround doesn't cover all providers
|
||||
# (notably Bedrock).
|
||||
can_enable_thinking = (
|
||||
budget_tokens is not None
|
||||
and not _prompt_contains_tool_call_history(prompt)
|
||||
)
|
||||
|
||||
if can_enable_thinking:
|
||||
assert budget_tokens is not None # mypy
|
||||
if max_tokens is not None:
|
||||
# Anthropic has a weird rule where max token has to be at least as much as budget tokens if set
|
||||
# and the minimum budget tokens is 1024
|
||||
|
||||
@@ -12,7 +12,6 @@ stale, which is fine for monitoring dashboards.
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
@@ -104,25 +103,23 @@ class _CachedCollector(Collector):
|
||||
|
||||
|
||||
class QueueDepthCollector(_CachedCollector):
|
||||
"""Reads Celery queue lengths from the broker Redis on each scrape.
|
||||
|
||||
Uses a Redis client factory (callable) rather than a stored client
|
||||
reference so the connection is always fresh from Celery's pool.
|
||||
"""
|
||||
"""Reads Celery queue lengths from the broker Redis on each scrape."""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._get_redis: Callable[[], Redis] | None = None
|
||||
self._celery_app: Any | None = None
|
||||
|
||||
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
|
||||
"""Set a callable that returns a broker Redis client on demand."""
|
||||
self._get_redis = factory
|
||||
def set_celery_app(self, app: Any) -> None:
|
||||
"""Set the Celery app for broker Redis access."""
|
||||
self._celery_app = app
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if self._get_redis is None:
|
||||
if self._celery_app is None:
|
||||
return []
|
||||
|
||||
redis_client = self._get_redis()
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
|
||||
redis_client = celery_get_broker_client(self._celery_app)
|
||||
|
||||
depth = GaugeMetricFamily(
|
||||
"onyx_queue_depth",
|
||||
@@ -404,17 +401,19 @@ class RedisHealthCollector(_CachedCollector):
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._get_redis: Callable[[], Redis] | None = None
|
||||
self._celery_app: Any | None = None
|
||||
|
||||
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
|
||||
"""Set a callable that returns a broker Redis client on demand."""
|
||||
self._get_redis = factory
|
||||
def set_celery_app(self, app: Any) -> None:
|
||||
"""Set the Celery app for broker Redis access."""
|
||||
self._celery_app = app
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if self._get_redis is None:
|
||||
if self._celery_app is None:
|
||||
return []
|
||||
|
||||
redis_client = self._get_redis()
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
|
||||
redis_client = celery_get_broker_client(self._celery_app)
|
||||
|
||||
memory_used = GaugeMetricFamily(
|
||||
"onyx_redis_memory_used_bytes",
|
||||
|
||||
@@ -3,12 +3,8 @@
|
||||
Called once by the monitoring celery worker after Redis and DB are ready.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from prometheus_client.registry import REGISTRY
|
||||
from redis import Redis
|
||||
|
||||
from onyx.server.metrics.indexing_pipeline import ConnectorHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
|
||||
@@ -21,7 +17,7 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
# Module-level singletons — these are lightweight objects (no connections or DB
|
||||
# state) until configure() / set_redis_factory() is called. Keeping them at
|
||||
# state) until configure() / set_celery_app() is called. Keeping them at
|
||||
# module level ensures they survive the lifetime of the worker process and are
|
||||
# only registered with the Prometheus registry once.
|
||||
_queue_collector = QueueDepthCollector()
|
||||
@@ -32,72 +28,15 @@ _worker_health_collector = WorkerHealthCollector()
|
||||
_heartbeat_monitor: WorkerHeartbeatMonitor | None = None
|
||||
|
||||
|
||||
def _make_broker_redis_factory(celery_app: Celery) -> Callable[[], Redis]:
|
||||
"""Create a factory that returns a cached broker Redis client.
|
||||
|
||||
Reuses a single connection across scrapes to avoid leaking connections.
|
||||
Reconnects automatically if the cached connection becomes stale.
|
||||
"""
|
||||
_cached_client: list[Redis | None] = [None]
|
||||
# Keep a reference to the Kombu Connection so we can close it on
|
||||
# reconnect (the raw Redis client outlives the Kombu wrapper).
|
||||
_cached_kombu_conn: list[Any] = [None]
|
||||
|
||||
def _close_client(client: Redis) -> None:
|
||||
"""Best-effort close of a Redis client."""
|
||||
try:
|
||||
client.close()
|
||||
except Exception:
|
||||
logger.debug("Failed to close stale Redis client", exc_info=True)
|
||||
|
||||
def _close_kombu_conn() -> None:
|
||||
"""Best-effort close of the cached Kombu Connection."""
|
||||
conn = _cached_kombu_conn[0]
|
||||
if conn is not None:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
logger.debug("Failed to close Kombu connection", exc_info=True)
|
||||
_cached_kombu_conn[0] = None
|
||||
|
||||
def _get_broker_redis() -> Redis:
|
||||
client = _cached_client[0]
|
||||
if client is not None:
|
||||
try:
|
||||
client.ping()
|
||||
return client
|
||||
except Exception:
|
||||
logger.debug("Cached Redis client stale, reconnecting")
|
||||
_close_client(client)
|
||||
_cached_client[0] = None
|
||||
_close_kombu_conn()
|
||||
|
||||
# Get a fresh Redis client from the broker connection.
|
||||
# We hold this client long-term (cached above) rather than using a
|
||||
# context manager, because we need it to persist across scrapes.
|
||||
# The caching logic above ensures we only ever hold one connection,
|
||||
# and we close it explicitly on reconnect.
|
||||
conn = celery_app.broker_connection()
|
||||
# kombu's Channel exposes .client at runtime (the underlying Redis
|
||||
# client) but the type stubs don't declare it.
|
||||
new_client: Redis = conn.channel().client # type: ignore[attr-defined]
|
||||
_cached_client[0] = new_client
|
||||
_cached_kombu_conn[0] = conn
|
||||
return new_client
|
||||
|
||||
return _get_broker_redis
|
||||
|
||||
|
||||
def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
|
||||
"""Register all indexing pipeline collectors with the default registry.
|
||||
|
||||
Args:
|
||||
celery_app: The Celery application instance. Used to obtain a fresh
|
||||
celery_app: The Celery application instance. Used to obtain a
|
||||
broker Redis client on each scrape for queue depth metrics.
|
||||
"""
|
||||
redis_factory = _make_broker_redis_factory(celery_app)
|
||||
_queue_collector.set_redis_factory(redis_factory)
|
||||
_redis_health_collector.set_redis_factory(redis_factory)
|
||||
_queue_collector.set_celery_app(celery_app)
|
||||
_redis_health_collector.set_celery_app(celery_app)
|
||||
|
||||
# Start the heartbeat monitor daemon thread — uses a single persistent
|
||||
# connection to receive worker-heartbeat events.
|
||||
|
||||
@@ -129,6 +129,10 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
|
||||
return_value=mock_app,
|
||||
),
|
||||
patch(_PATCH_QUEUE_DEPTH, return_value=0),
|
||||
patch(
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@@ -88,10 +88,22 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
|
||||
the actual task instance. We patch ``app`` on that instance's class
|
||||
(a unique Celery-generated Task subclass) so the mock is scoped to this
|
||||
task only.
|
||||
|
||||
Also patches ``celery_get_broker_client`` so the mock app doesn't need
|
||||
a real broker URL.
|
||||
"""
|
||||
task_instance = task.run.__self__
|
||||
with patch.object(
|
||||
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
|
||||
with (
|
||||
patch.object(
|
||||
type(task_instance),
|
||||
"app",
|
||||
new_callable=PropertyMock,
|
||||
return_value=mock_app,
|
||||
),
|
||||
patch(
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@@ -90,8 +90,17 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
|
||||
task only.
|
||||
"""
|
||||
task_instance = task.run.__self__
|
||||
with patch.object(
|
||||
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
|
||||
with (
|
||||
patch.object(
|
||||
type(task_instance),
|
||||
"app",
|
||||
new_callable=PropertyMock,
|
||||
return_value=mock_app,
|
||||
),
|
||||
patch(
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
"""Tests for celery_get_broker_client singleton."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.background.celery import celery_redis
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_singleton() -> Iterator[None]:
|
||||
"""Reset the module-level singleton between tests."""
|
||||
celery_redis._broker_client = None
|
||||
celery_redis._broker_url = None
|
||||
yield
|
||||
celery_redis._broker_client = None
|
||||
celery_redis._broker_url = None
|
||||
|
||||
|
||||
def _make_mock_app(broker_url: str = "redis://localhost:6379/15") -> MagicMock:
|
||||
app = MagicMock()
|
||||
app.conf.broker_url = broker_url
|
||||
return app
|
||||
|
||||
|
||||
class TestCeleryGetBrokerClient:
|
||||
@patch("onyx.background.celery.celery_redis.Redis")
|
||||
def test_creates_client_on_first_call(self, mock_redis_cls: MagicMock) -> None:
|
||||
mock_client = MagicMock()
|
||||
mock_redis_cls.from_url.return_value = mock_client
|
||||
|
||||
app = _make_mock_app()
|
||||
result = celery_redis.celery_get_broker_client(app)
|
||||
|
||||
assert result is mock_client
|
||||
call_args = mock_redis_cls.from_url.call_args
|
||||
assert call_args[0][0] == "redis://localhost:6379/15"
|
||||
assert call_args[1]["decode_responses"] is False
|
||||
assert call_args[1]["socket_keepalive"] is True
|
||||
assert call_args[1]["retry_on_timeout"] is True
|
||||
|
||||
@patch("onyx.background.celery.celery_redis.Redis")
|
||||
def test_reuses_cached_client(self, mock_redis_cls: MagicMock) -> None:
|
||||
mock_client = MagicMock()
|
||||
mock_client.ping.return_value = True
|
||||
mock_redis_cls.from_url.return_value = mock_client
|
||||
|
||||
app = _make_mock_app()
|
||||
client1 = celery_redis.celery_get_broker_client(app)
|
||||
client2 = celery_redis.celery_get_broker_client(app)
|
||||
|
||||
assert client1 is client2
|
||||
# from_url called only once
|
||||
assert mock_redis_cls.from_url.call_count == 1
|
||||
|
||||
@patch("onyx.background.celery.celery_redis.Redis")
|
||||
def test_reconnects_on_ping_failure(self, mock_redis_cls: MagicMock) -> None:
|
||||
stale_client = MagicMock()
|
||||
stale_client.ping.side_effect = ConnectionError("disconnected")
|
||||
|
||||
fresh_client = MagicMock()
|
||||
fresh_client.ping.return_value = True
|
||||
|
||||
mock_redis_cls.from_url.side_effect = [stale_client, fresh_client]
|
||||
|
||||
app = _make_mock_app()
|
||||
|
||||
# First call creates stale_client
|
||||
client1 = celery_redis.celery_get_broker_client(app)
|
||||
assert client1 is stale_client
|
||||
|
||||
# Second call: ping fails, creates fresh_client
|
||||
client2 = celery_redis.celery_get_broker_client(app)
|
||||
assert client2 is fresh_client
|
||||
assert mock_redis_cls.from_url.call_count == 2
|
||||
|
||||
@patch("onyx.background.celery.celery_redis.Redis")
|
||||
def test_uses_broker_url_from_app_config(self, mock_redis_cls: MagicMock) -> None:
|
||||
mock_redis_cls.from_url.return_value = MagicMock()
|
||||
|
||||
app = _make_mock_app("redis://custom-host:6380/3")
|
||||
celery_redis.celery_get_broker_client(app)
|
||||
|
||||
call_args = mock_redis_cls.from_url.call_args
|
||||
assert call_args[0][0] == "redis://custom-host:6380/3"
|
||||
225
backend/tests/unit/onyx/db/test_chat_sessions.py
Normal file
225
backend/tests/unit/onyx/db/test_chat_sessions.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Tests for get_chat_sessions_by_user filtering behavior.
|
||||
|
||||
Verifies that failed chat sessions (those with only SYSTEM messages) are
|
||||
correctly filtered out while preserving recently created sessions, matching
|
||||
the behavior specified in PR #7233.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.models import ChatSession
|
||||
|
||||
|
||||
def _make_session(
|
||||
user_id: UUID,
|
||||
time_created: datetime | None = None,
|
||||
time_updated: datetime | None = None,
|
||||
description: str = "",
|
||||
) -> MagicMock:
|
||||
"""Create a mock ChatSession with the given attributes."""
|
||||
session = MagicMock(spec=ChatSession)
|
||||
session.id = uuid4()
|
||||
session.user_id = user_id
|
||||
session.time_created = time_created or datetime.now(timezone.utc)
|
||||
session.time_updated = time_updated or session.time_created
|
||||
session.description = description
|
||||
session.deleted = False
|
||||
session.onyxbot_flow = False
|
||||
session.project_id = None
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_id() -> UUID:
|
||||
return uuid4()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def old_time() -> datetime:
|
||||
"""A timestamp well outside the 5-minute leeway window."""
|
||||
return datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def recent_time() -> datetime:
|
||||
"""A timestamp within the 5-minute leeway window."""
|
||||
return datetime.now(timezone.utc) - timedelta(minutes=2)
|
||||
|
||||
|
||||
class TestGetChatSessionsByUser:
|
||||
"""Tests for the failed chat filtering logic in get_chat_sessions_by_user."""
|
||||
|
||||
def test_filters_out_failed_sessions(
|
||||
self, user_id: UUID, old_time: datetime
|
||||
) -> None:
|
||||
"""Sessions with only SYSTEM messages should be excluded."""
|
||||
valid_session = _make_session(user_id, time_created=old_time)
|
||||
failed_session = _make_session(user_id, time_created=old_time)
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
# First execute: returns all sessions
|
||||
# Second execute: returns only the valid session's ID (has non-system msgs)
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.scalars.return_value.all.return_value = [
|
||||
valid_session,
|
||||
failed_session,
|
||||
]
|
||||
|
||||
mock_result_2 = MagicMock()
|
||||
mock_result_2.scalars.return_value.all.return_value = [valid_session.id]
|
||||
|
||||
db_session.execute.side_effect = [mock_result_1, mock_result_2]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].id == valid_session.id
|
||||
|
||||
def test_keeps_recent_sessions_without_messages(
|
||||
self, user_id: UUID, recent_time: datetime
|
||||
) -> None:
|
||||
"""Recently created sessions should be kept even without messages."""
|
||||
recent_session = _make_session(user_id, time_created=recent_time)
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.scalars.return_value.all.return_value = [recent_session]
|
||||
|
||||
db_session.execute.side_effect = [mock_result_1]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].id == recent_session.id
|
||||
# Should only have been called once — no second query needed
|
||||
# because the recent session is within the leeway window
|
||||
assert db_session.execute.call_count == 1
|
||||
|
||||
def test_include_failed_chats_skips_filtering(
|
||||
self, user_id: UUID, old_time: datetime
|
||||
) -> None:
|
||||
"""When include_failed_chats=True, no filtering should occur."""
|
||||
session_a = _make_session(user_id, time_created=old_time)
|
||||
session_b = _make_session(user_id, time_created=old_time)
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [session_a, session_b]
|
||||
|
||||
db_session.execute.side_effect = [mock_result]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=True,
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
# Only one DB call — no second query for message validation
|
||||
assert db_session.execute.call_count == 1
|
||||
|
||||
def test_limit_applied_after_filtering(
|
||||
self, user_id: UUID, old_time: datetime
|
||||
) -> None:
|
||||
"""Limit should be applied after filtering, not before."""
|
||||
sessions = [_make_session(user_id, time_created=old_time) for _ in range(5)]
|
||||
valid_ids = [s.id for s in sessions[:3]]
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.scalars.return_value.all.return_value = sessions
|
||||
|
||||
mock_result_2 = MagicMock()
|
||||
mock_result_2.scalars.return_value.all.return_value = valid_ids
|
||||
|
||||
db_session.execute.side_effect = [mock_result_1, mock_result_2]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
limit=2,
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
# Should be the first 2 valid sessions (order preserved)
|
||||
assert result[0].id == sessions[0].id
|
||||
assert result[1].id == sessions[1].id
|
||||
|
||||
def test_mixed_recent_and_old_sessions(
|
||||
self, user_id: UUID, old_time: datetime, recent_time: datetime
|
||||
) -> None:
|
||||
"""Mix of recent and old sessions should filter correctly."""
|
||||
old_valid = _make_session(user_id, time_created=old_time)
|
||||
old_failed = _make_session(user_id, time_created=old_time)
|
||||
recent_no_msgs = _make_session(user_id, time_created=recent_time)
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.scalars.return_value.all.return_value = [
|
||||
old_valid,
|
||||
old_failed,
|
||||
recent_no_msgs,
|
||||
]
|
||||
|
||||
mock_result_2 = MagicMock()
|
||||
mock_result_2.scalars.return_value.all.return_value = [old_valid.id]
|
||||
|
||||
db_session.execute.side_effect = [mock_result_1, mock_result_2]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
)
|
||||
|
||||
result_ids = {cs.id for cs in result}
|
||||
assert old_valid.id in result_ids
|
||||
assert recent_no_msgs.id in result_ids
|
||||
assert old_failed.id not in result_ids
|
||||
|
||||
def test_empty_result(self, user_id: UUID) -> None:
|
||||
"""No sessions should return empty list without errors."""
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
|
||||
db_session.execute.side_effect = [mock_result]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
)
|
||||
|
||||
assert result == []
|
||||
assert db_session.execute.call_count == 1
|
||||
@@ -11,6 +11,7 @@ from litellm.types.utils import ChatCompletionDeltaToolCall
|
||||
from litellm.types.utils import Delta
|
||||
from litellm.types.utils import Function as LiteLLMFunction
|
||||
|
||||
import onyx.llm.models
|
||||
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
@@ -1479,6 +1480,147 @@ def test_bifrost_normalizes_api_base_in_model_kwargs() -> None:
|
||||
assert llm._model_kwargs["api_base"] == "https://bifrost.example.com/v1"
|
||||
|
||||
|
||||
def test_prompt_contains_tool_call_history_true() -> None:
|
||||
from onyx.llm.multi_llm import _prompt_contains_tool_call_history
|
||||
|
||||
messages: LanguageModelInput = [
|
||||
UserMessage(content="What's the weather?"),
|
||||
AssistantMessage(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id="tc_1",
|
||||
function=FunctionCall(name="get_weather", arguments="{}"),
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
assert _prompt_contains_tool_call_history(messages) is True
|
||||
|
||||
|
||||
def test_prompt_contains_tool_call_history_false_no_tools() -> None:
|
||||
from onyx.llm.multi_llm import _prompt_contains_tool_call_history
|
||||
|
||||
messages: LanguageModelInput = [
|
||||
UserMessage(content="Hello"),
|
||||
AssistantMessage(content="Hi there!"),
|
||||
]
|
||||
assert _prompt_contains_tool_call_history(messages) is False
|
||||
|
||||
|
||||
def test_prompt_contains_tool_call_history_false_user_only() -> None:
|
||||
from onyx.llm.multi_llm import _prompt_contains_tool_call_history
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hello")]
|
||||
assert _prompt_contains_tool_call_history(messages) is False
|
||||
|
||||
|
||||
def test_bedrock_claude_drops_thinking_when_thinking_blocks_missing() -> None:
|
||||
"""When thinking is enabled but assistant messages with tool_calls lack
|
||||
thinking_blocks, the thinking param must be dropped to avoid the Bedrock
|
||||
BadRequestError about missing thinking blocks."""
|
||||
llm = LitellmLLM(
|
||||
api_key=None,
|
||||
timeout=30,
|
||||
model_provider=LlmProviderNames.BEDROCK,
|
||||
model_name="anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
max_input_tokens=200000,
|
||||
)
|
||||
|
||||
messages: LanguageModelInput = [
|
||||
UserMessage(content="What's the weather?"),
|
||||
AssistantMessage(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id="tc_1",
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "Paris"}',
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
onyx.llm.models.ToolMessage(
|
||||
content="22°C sunny",
|
||||
tool_call_id="tc_1",
|
||||
),
|
||||
]
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with (
|
||||
patch("litellm.completion") as mock_completion,
|
||||
patch("onyx.llm.multi_llm.model_is_reasoning_model", return_value=True),
|
||||
):
|
||||
mock_completion.return_value = []
|
||||
|
||||
list(llm.stream(messages, tools=tools, reasoning_effort=ReasoningEffort.HIGH))
|
||||
|
||||
kwargs = mock_completion.call_args.kwargs
|
||||
assert "thinking" not in kwargs, (
|
||||
"thinking param should be dropped when thinking_blocks are missing "
|
||||
"from assistant messages with tool_calls"
|
||||
)
|
||||
|
||||
|
||||
def test_bedrock_claude_keeps_thinking_when_no_tool_history() -> None:
|
||||
"""When thinking is enabled and there are no historical assistant messages
|
||||
with tool_calls, the thinking param should be preserved."""
|
||||
llm = LitellmLLM(
|
||||
api_key=None,
|
||||
timeout=30,
|
||||
model_provider=LlmProviderNames.BEDROCK,
|
||||
model_name="anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
max_input_tokens=200000,
|
||||
)
|
||||
|
||||
messages: LanguageModelInput = [
|
||||
UserMessage(content="What's the weather?"),
|
||||
]
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with (
|
||||
patch("litellm.completion") as mock_completion,
|
||||
patch("onyx.llm.multi_llm.model_is_reasoning_model", return_value=True),
|
||||
):
|
||||
mock_completion.return_value = []
|
||||
|
||||
list(llm.stream(messages, tools=tools, reasoning_effort=ReasoningEffort.HIGH))
|
||||
|
||||
kwargs = mock_completion.call_args.kwargs
|
||||
assert "thinking" in kwargs, (
|
||||
"thinking param should be preserved when no assistant messages "
|
||||
"with tool_calls exist in history"
|
||||
)
|
||||
assert kwargs["thinking"]["type"] == "enabled"
|
||||
|
||||
|
||||
def test_bifrost_claude_includes_allowed_openai_params() -> None:
|
||||
llm = LitellmLLM(
|
||||
api_key="test_key",
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Tests for indexing pipeline Prometheus collectors."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@@ -13,6 +14,16 @@ from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
|
||||
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_broker_client() -> Iterator[None]:
|
||||
"""Patch celery_get_broker_client for all collector tests."""
|
||||
with patch(
|
||||
"onyx.background.celery.celery_redis.celery_get_broker_client",
|
||||
return_value=MagicMock(),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class TestQueueDepthCollector:
|
||||
def test_returns_empty_when_factory_not_set(self) -> None:
|
||||
collector = QueueDepthCollector()
|
||||
@@ -24,8 +35,7 @@ class TestQueueDepthCollector:
|
||||
|
||||
def test_collects_queue_depths(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
collector.set_celery_app(MagicMock())
|
||||
|
||||
with (
|
||||
patch(
|
||||
@@ -60,8 +70,8 @@ class TestQueueDepthCollector:
|
||||
|
||||
def test_handles_redis_error_gracefully(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
MagicMock()
|
||||
collector.set_celery_app(MagicMock())
|
||||
|
||||
with patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
@@ -74,8 +84,8 @@ class TestQueueDepthCollector:
|
||||
|
||||
def test_caching_returns_stale_within_ttl(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=60)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
MagicMock()
|
||||
collector.set_celery_app(MagicMock())
|
||||
|
||||
with (
|
||||
patch(
|
||||
@@ -98,31 +108,10 @@ class TestQueueDepthCollector:
|
||||
|
||||
assert first is second # Same object, from cache
|
||||
|
||||
def test_factory_called_each_scrape(self) -> None:
|
||||
"""Verify the Redis factory is called on each fresh collect, not cached."""
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
factory = MagicMock(return_value=MagicMock())
|
||||
collector.set_redis_factory(factory)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
return_value=0,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
|
||||
return_value=set(),
|
||||
),
|
||||
):
|
||||
collector.collect()
|
||||
collector.collect()
|
||||
|
||||
assert factory.call_count == 2
|
||||
|
||||
def test_error_returns_stale_cache(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
MagicMock()
|
||||
collector.set_celery_app(MagicMock())
|
||||
|
||||
# First call succeeds
|
||||
with (
|
||||
|
||||
@@ -1,96 +1,22 @@
|
||||
"""Tests for indexing pipeline setup (Redis factory caching)."""
|
||||
"""Tests for indexing pipeline setup."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from onyx.server.metrics.indexing_pipeline_setup import _make_broker_redis_factory
|
||||
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import RedisHealthCollector
|
||||
|
||||
|
||||
def _make_mock_app(client: MagicMock) -> MagicMock:
|
||||
"""Create a mock Celery app whose broker_connection().channel().client
|
||||
returns the given client."""
|
||||
mock_app = MagicMock()
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.channel.return_value.client = client
|
||||
class TestCollectorCeleryAppSetup:
|
||||
def test_queue_depth_collector_uses_celery_app(self) -> None:
|
||||
"""QueueDepthCollector.set_celery_app stores the app for broker access."""
|
||||
collector = QueueDepthCollector()
|
||||
mock_app = MagicMock()
|
||||
collector.set_celery_app(mock_app)
|
||||
assert collector._celery_app is mock_app
|
||||
|
||||
mock_app.broker_connection.return_value = mock_conn
|
||||
|
||||
return mock_app
|
||||
|
||||
|
||||
class TestMakeBrokerRedisFactory:
|
||||
def test_caches_redis_client_across_calls(self) -> None:
|
||||
"""Factory should reuse the same client on subsequent calls."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.ping.return_value = True
|
||||
mock_app = _make_mock_app(mock_client)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
|
||||
client1 = factory()
|
||||
client2 = factory()
|
||||
|
||||
assert client1 is client2
|
||||
# broker_connection should only be called once
|
||||
assert mock_app.broker_connection.call_count == 1
|
||||
|
||||
def test_reconnects_when_ping_fails(self) -> None:
|
||||
"""Factory should create a new client if ping fails (stale connection)."""
|
||||
mock_client_stale = MagicMock()
|
||||
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
|
||||
|
||||
mock_client_fresh = MagicMock()
|
||||
mock_client_fresh.ping.return_value = True
|
||||
|
||||
mock_app = _make_mock_app(mock_client_stale)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
|
||||
# First call — creates and caches
|
||||
client1 = factory()
|
||||
assert client1 is mock_client_stale
|
||||
assert mock_app.broker_connection.call_count == 1
|
||||
|
||||
# Switch to fresh client for next connection
|
||||
mock_conn_fresh = MagicMock()
|
||||
mock_conn_fresh.channel.return_value.client = mock_client_fresh
|
||||
mock_app.broker_connection.return_value = mock_conn_fresh
|
||||
|
||||
# Second call — ping fails on stale, reconnects
|
||||
client2 = factory()
|
||||
assert client2 is mock_client_fresh
|
||||
assert mock_app.broker_connection.call_count == 2
|
||||
|
||||
def test_reconnect_closes_stale_client(self) -> None:
|
||||
"""When ping fails, the old client should be closed before reconnecting."""
|
||||
mock_client_stale = MagicMock()
|
||||
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
|
||||
|
||||
mock_client_fresh = MagicMock()
|
||||
mock_client_fresh.ping.return_value = True
|
||||
|
||||
mock_app = _make_mock_app(mock_client_stale)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
|
||||
# First call — creates and caches
|
||||
factory()
|
||||
|
||||
# Switch to fresh client
|
||||
mock_conn_fresh = MagicMock()
|
||||
mock_conn_fresh.channel.return_value.client = mock_client_fresh
|
||||
mock_app.broker_connection.return_value = mock_conn_fresh
|
||||
|
||||
# Second call — ping fails, should close stale client
|
||||
factory()
|
||||
mock_client_stale.close.assert_called_once()
|
||||
|
||||
def test_first_call_creates_connection(self) -> None:
|
||||
"""First call should always create a new connection."""
|
||||
mock_client = MagicMock()
|
||||
mock_app = _make_mock_app(mock_client)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
client = factory()
|
||||
|
||||
assert client is mock_client
|
||||
mock_app.broker_connection.assert_called_once()
|
||||
def test_redis_health_collector_uses_celery_app(self) -> None:
|
||||
"""RedisHealthCollector.set_celery_app stores the app for broker access."""
|
||||
collector = RedisHealthCollector()
|
||||
mock_app = MagicMock()
|
||||
collector.set_celery_app(mock_app)
|
||||
assert collector._celery_app is mock_app
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import { useState, useMemo, useEffect } from "react";
|
||||
import useSWR from "swr";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Select } from "@/refresh-components/cards";
|
||||
import { useCreateModal } from "@/refresh-components/contexts/ModalContext";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
@@ -24,8 +23,9 @@ import { ProviderIcon } from "@/app/admin/configuration/llm/ProviderIcon";
|
||||
import Message from "@/refresh-components/messages/Message";
|
||||
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import { Button } from "@opal/components";
|
||||
import { Button, Text } from "@opal/components";
|
||||
import { SvgSlash, SvgUnplug } from "@opal/icons";
|
||||
import { markdown } from "@opal/utils";
|
||||
|
||||
const NO_DEFAULT_VALUE = "__none__";
|
||||
|
||||
@@ -201,10 +201,10 @@ export default function ImageGenerationContent() {
|
||||
<div className="flex flex-col gap-6">
|
||||
{/* Section Header */}
|
||||
<div className="flex flex-col gap-0.5">
|
||||
<Text mainContentEmphasis text05>
|
||||
<Text font="main-content-emphasis" color="text-05">
|
||||
Image Generation Model
|
||||
</Text>
|
||||
<Text secondaryBody text03>
|
||||
<Text font="secondary-body" color="text-03">
|
||||
Select a model to generate images in chat.
|
||||
</Text>
|
||||
</div>
|
||||
@@ -223,7 +223,7 @@ export default function ImageGenerationContent() {
|
||||
{/* Provider Groups */}
|
||||
{IMAGE_PROVIDER_GROUPS.map((group) => (
|
||||
<div key={group.name} className="flex flex-col gap-2">
|
||||
<Text secondaryBody text03>
|
||||
<Text font="secondary-body" color="text-03">
|
||||
{group.name}
|
||||
</Text>
|
||||
<div className="flex flex-col gap-2">
|
||||
@@ -277,12 +277,13 @@ export default function ImageGenerationContent() {
|
||||
{needsReplacement ? (
|
||||
hasReplacements ? (
|
||||
<Section alignItems="start">
|
||||
<Text as="p" text03>
|
||||
<b>{disconnectProvider.title}</b> is currently the default
|
||||
image generation model. Session history will be preserved.
|
||||
<Text as="p" color="text-03">
|
||||
{markdown(
|
||||
`**${disconnectProvider.title}** is currently the default image generation model. Session history will be preserved.`
|
||||
)}
|
||||
</Text>
|
||||
<Section alignItems="start" gap={0.25}>
|
||||
<Text as="p" text04>
|
||||
<Text as="p" color="text-04">
|
||||
Set New Default
|
||||
</Text>
|
||||
<InputSelect
|
||||
@@ -329,22 +330,24 @@ export default function ImageGenerationContent() {
|
||||
</Section>
|
||||
) : (
|
||||
<>
|
||||
<Text as="p" text03>
|
||||
<b>{disconnectProvider.title}</b> is currently the default
|
||||
image generation model.
|
||||
<Text as="p" color="text-03">
|
||||
{markdown(
|
||||
`**${disconnectProvider.title}** is currently the default image generation model.`
|
||||
)}
|
||||
</Text>
|
||||
<Text as="p" text03>
|
||||
<Text as="p" color="text-03">
|
||||
Connect another provider to continue using image generation.
|
||||
</Text>
|
||||
</>
|
||||
)
|
||||
) : (
|
||||
<>
|
||||
<Text as="p" text03>
|
||||
<b>{disconnectProvider.title}</b> models will no longer be used
|
||||
to generate images.
|
||||
<Text as="p" color="text-03">
|
||||
{markdown(
|
||||
`**${disconnectProvider.title}** models will no longer be used to generate images.`
|
||||
)}
|
||||
</Text>
|
||||
<Text as="p" text03>
|
||||
<Text as="p" color="text-03">
|
||||
Session history will be preserved.
|
||||
</Text>
|
||||
</>
|
||||
|
||||
@@ -15,7 +15,7 @@ import { Callout } from "@/components/ui/callout";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { SvgGlobe, SvgOnyxLogo, SvgSlash, SvgUnplug } from "@opal/icons";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
import { Button } from "@opal/components";
|
||||
import { ADMIN_ROUTES } from "@/lib/admin-routes";
|
||||
import { WebProviderSetupModal } from "@/app/admin/configuration/web-search/WebProviderSetupModal";
|
||||
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
|
||||
@@ -151,7 +151,7 @@ function WebSearchDisconnectModal({
|
||||
description="This will remove the stored credentials for this provider."
|
||||
onClose={onClose}
|
||||
submit={
|
||||
<OpalButton
|
||||
<Button
|
||||
variant="danger"
|
||||
onClick={onDisconnect}
|
||||
disabled={
|
||||
@@ -159,7 +159,7 @@ function WebSearchDisconnectModal({
|
||||
}
|
||||
>
|
||||
Disconnect
|
||||
</OpalButton>
|
||||
</Button>
|
||||
}
|
||||
>
|
||||
{needsReplacement ? (
|
||||
|
||||
@@ -6,6 +6,8 @@ import { INTERNAL_URL, IS_DEV } from "@/lib/constants";
|
||||
const TARGET_SAMPLE_RATE = 24000;
|
||||
const CHUNK_INTERVAL_MS = 250;
|
||||
const DUPLICATE_FINAL_TRANSCRIPT_WINDOW_MS = 1500;
|
||||
// When VAD-based auto-stop is disabled, force-stop after this much silence as a fallback
|
||||
const SILENCE_FALLBACK_TIMEOUT_MS = 10000;
|
||||
|
||||
interface TranscriptMessage {
|
||||
type: "transcript" | "error";
|
||||
@@ -58,6 +60,8 @@ class VoiceRecorderSession {
|
||||
private finalTranscriptDelivered = false;
|
||||
private lastDeliveredFinalText: string | null = null;
|
||||
private lastDeliveredFinalAtMs = 0;
|
||||
// Fallback timer: force-stop after extended silence when VAD auto-stop is disabled
|
||||
private silenceFallbackTimer: NodeJS.Timeout | null = null;
|
||||
|
||||
// Callbacks to update React state
|
||||
private onTranscriptChange: (text: string) => void;
|
||||
@@ -174,6 +178,8 @@ class VoiceRecorderSession {
|
||||
async stop(): Promise<string | null> {
|
||||
if (!this.isActive) return this.transcript || null;
|
||||
|
||||
this.resetSilenceFallbackTimer();
|
||||
|
||||
// Stop audio capture
|
||||
if (this.sendInterval) {
|
||||
clearInterval(this.sendInterval);
|
||||
@@ -219,6 +225,7 @@ class VoiceRecorderSession {
|
||||
}
|
||||
|
||||
cleanup(): void {
|
||||
this.resetSilenceFallbackTimer();
|
||||
if (this.sendInterval) clearInterval(this.sendInterval);
|
||||
if (this.scriptNode) this.scriptNode.disconnect();
|
||||
if (this.sourceNode) this.sourceNode.disconnect();
|
||||
@@ -274,6 +281,23 @@ class VoiceRecorderSession {
|
||||
});
|
||||
}
|
||||
|
||||
private resetSilenceFallbackTimer(): void {
|
||||
if (this.silenceFallbackTimer) {
|
||||
clearTimeout(this.silenceFallbackTimer);
|
||||
this.silenceFallbackTimer = null;
|
||||
}
|
||||
}
|
||||
|
||||
private startSilenceFallbackTimer(): void {
|
||||
this.resetSilenceFallbackTimer();
|
||||
this.silenceFallbackTimer = setTimeout(() => {
|
||||
// 10s of silence with no new speech — force-stop as a safety fallback
|
||||
if (this.isActive && this.onVADStop) {
|
||||
this.onVADStop();
|
||||
}
|
||||
}, SILENCE_FALLBACK_TIMEOUT_MS);
|
||||
}
|
||||
|
||||
private handleMessage = (event: MessageEvent): void => {
|
||||
try {
|
||||
const data: TranscriptMessage = JSON.parse(event.data);
|
||||
@@ -281,47 +305,53 @@ class VoiceRecorderSession {
|
||||
if (data.type === "transcript") {
|
||||
if (data.text) {
|
||||
this.transcript = data.text;
|
||||
this.onTranscriptChange(data.text);
|
||||
// Only push live updates to React while actively recording.
|
||||
// After stop(), the final transcript is returned via stopResolver
|
||||
// instead — this prevents stale text from reappearing in the
|
||||
// input box when the user clears it and starts a new recording.
|
||||
if (this.isActive) {
|
||||
this.onTranscriptChange(data.text);
|
||||
}
|
||||
}
|
||||
|
||||
if (data.is_final && data.text) {
|
||||
// VAD detected silence - trigger callback (only once per utterance)
|
||||
const now = Date.now();
|
||||
const isLikelyDuplicateFinal =
|
||||
this.autoStopOnSilence &&
|
||||
this.lastDeliveredFinalText === data.text &&
|
||||
now - this.lastDeliveredFinalAtMs <
|
||||
DUPLICATE_FINAL_TRANSCRIPT_WINDOW_MS;
|
||||
|
||||
if (
|
||||
this.onFinalTranscript &&
|
||||
!this.finalTranscriptDelivered &&
|
||||
!isLikelyDuplicateFinal
|
||||
) {
|
||||
this.finalTranscriptDelivered = true;
|
||||
this.lastDeliveredFinalText = data.text;
|
||||
this.lastDeliveredFinalAtMs = now;
|
||||
this.onFinalTranscript(data.text);
|
||||
// Resolve stop promise if waiting — must run even after stop()
|
||||
// so the caller receives the final transcript.
|
||||
if (this.stopResolver) {
|
||||
this.stopResolver(data.text);
|
||||
this.stopResolver = null;
|
||||
}
|
||||
|
||||
// Auto-stop recording if enabled
|
||||
// Skip VAD logic if session is no longer active
|
||||
if (!this.isActive) return;
|
||||
|
||||
if (this.autoStopOnSilence) {
|
||||
// Trigger stop callback to update React state
|
||||
// VAD detected silence — auto-stop and trigger callback
|
||||
const now = Date.now();
|
||||
const isLikelyDuplicateFinal =
|
||||
this.lastDeliveredFinalText === data.text &&
|
||||
now - this.lastDeliveredFinalAtMs <
|
||||
DUPLICATE_FINAL_TRANSCRIPT_WINDOW_MS;
|
||||
|
||||
if (
|
||||
this.onFinalTranscript &&
|
||||
!this.finalTranscriptDelivered &&
|
||||
!isLikelyDuplicateFinal
|
||||
) {
|
||||
this.finalTranscriptDelivered = true;
|
||||
this.lastDeliveredFinalText = data.text;
|
||||
this.lastDeliveredFinalAtMs = now;
|
||||
this.onFinalTranscript(data.text);
|
||||
}
|
||||
|
||||
if (this.onVADStop) {
|
||||
this.onVADStop();
|
||||
}
|
||||
} else {
|
||||
// If not auto-stopping, reset for next utterance
|
||||
this.transcript = "";
|
||||
this.finalTranscriptDelivered = false;
|
||||
this.onTranscriptChange("");
|
||||
this.resetBackendTranscript();
|
||||
}
|
||||
|
||||
// Resolve stop promise if waiting
|
||||
if (this.stopResolver) {
|
||||
this.stopResolver(data.text);
|
||||
this.stopResolver = null;
|
||||
// Auto-stop disabled (push-to-talk): ignore VAD, keep recording.
|
||||
// Start/reset a 10s fallback timer — if no new speech arrives,
|
||||
// force-stop to avoid recording silence indefinitely.
|
||||
this.startSilenceFallbackTimer();
|
||||
}
|
||||
}
|
||||
} else if (data.type === "error") {
|
||||
|
||||
@@ -142,6 +142,7 @@ function PopoverContent({
|
||||
collisionPadding={8}
|
||||
className={cn(
|
||||
"bg-background-neutral-00 p-1 z-popover rounded-12 border shadow-md data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2",
|
||||
"flex flex-col",
|
||||
"max-h-[var(--radix-popover-content-available-height)]",
|
||||
"overflow-hidden",
|
||||
widthClasses[width]
|
||||
@@ -226,7 +227,7 @@ export function PopoverMenu({
|
||||
});
|
||||
|
||||
return (
|
||||
<Section alignItems="stretch">
|
||||
<Section alignItems="stretch" height="auto" className="flex-1 min-h-0">
|
||||
<ShadowDiv
|
||||
scrollContainerRef={scrollContainerRef}
|
||||
className="flex flex-col gap-1 max-h-[20rem] w-full"
|
||||
|
||||
@@ -105,7 +105,7 @@ export default function ShadowDiv({
|
||||
}, [containerRef, checkScroll]);
|
||||
|
||||
return (
|
||||
<div className="relative min-h-0">
|
||||
<div className="relative min-h-0 flex flex-col">
|
||||
<div
|
||||
ref={containerRef}
|
||||
className={cn("overflow-y-auto", className)}
|
||||
|
||||
@@ -984,8 +984,8 @@ function ChatPreferencesSettings() {
|
||||
/>
|
||||
<Card>
|
||||
<InputLayouts.Horizontal
|
||||
title="Auto-Send"
|
||||
description="Automatically send voice input when recording stops."
|
||||
title="Auto-Send on Pause"
|
||||
description="Automatically send voice input when you stop speaking."
|
||||
>
|
||||
<Switch
|
||||
checked={user?.preferences.voice_auto_send ?? false}
|
||||
|
||||
@@ -7,7 +7,6 @@ import {
|
||||
IconProps,
|
||||
OpenAIIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Select } from "@/refresh-components/cards";
|
||||
import Message from "@/refresh-components/messages/Message";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
@@ -26,7 +25,8 @@ import { toast } from "@/hooks/useToast";
|
||||
import { Callout } from "@/components/ui/callout";
|
||||
import { Content } from "@opal/layouts";
|
||||
import { SvgMicrophone, SvgSlash, SvgUnplug } from "@opal/icons";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
import { Button, Text } from "@opal/components";
|
||||
import { markdown } from "@opal/utils";
|
||||
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { ADMIN_ROUTES } from "@/lib/admin-routes";
|
||||
@@ -205,7 +205,7 @@ function VoiceDisconnectModal({
|
||||
description="Voice models"
|
||||
onClose={onClose}
|
||||
submit={
|
||||
<OpalButton
|
||||
<Button
|
||||
variant="danger"
|
||||
onClick={onDisconnect}
|
||||
disabled={
|
||||
@@ -213,19 +213,19 @@ function VoiceDisconnectModal({
|
||||
}
|
||||
>
|
||||
Disconnect
|
||||
</OpalButton>
|
||||
</Button>
|
||||
}
|
||||
>
|
||||
{needsReplacement ? (
|
||||
hasReplacements ? (
|
||||
<Section alignItems="start">
|
||||
<Text as="p" text03>
|
||||
<b>{disconnectTarget.providerLabel}</b> models will no longer be
|
||||
used for speech-to-text or text-to-speech, and it will no longer
|
||||
be your default. Session history will be preserved.
|
||||
<Text as="p" color="text-03">
|
||||
{markdown(
|
||||
`**${disconnectTarget.providerLabel}** models will no longer be used for speech-to-text or text-to-speech, and it will no longer be your default. Session history will be preserved.`
|
||||
)}
|
||||
</Text>
|
||||
<Section alignItems="start" gap={0.25}>
|
||||
<Text as="p" text04>
|
||||
<Text as="p" color="text-04">
|
||||
Set New Default
|
||||
</Text>
|
||||
<InputSelect
|
||||
@@ -256,23 +256,24 @@ function VoiceDisconnectModal({
|
||||
</Section>
|
||||
) : (
|
||||
<>
|
||||
<Text as="p" text03>
|
||||
<b>{disconnectTarget.providerLabel}</b> models will no longer be
|
||||
used for speech-to-text or text-to-speech, and it will no longer
|
||||
be your default.
|
||||
<Text as="p" color="text-03">
|
||||
{markdown(
|
||||
`**${disconnectTarget.providerLabel}** models will no longer be used for speech-to-text or text-to-speech, and it will no longer be your default.`
|
||||
)}
|
||||
</Text>
|
||||
<Text as="p" text03>
|
||||
<Text as="p" color="text-03">
|
||||
Connect another provider to continue using voice.
|
||||
</Text>
|
||||
</>
|
||||
)
|
||||
) : (
|
||||
<>
|
||||
<Text as="p" text03>
|
||||
<b>{disconnectTarget.providerLabel}</b> models will no longer be
|
||||
available for voice.
|
||||
<Text as="p" color="text-03">
|
||||
{markdown(
|
||||
`**${disconnectTarget.providerLabel}** models will no longer be available for voice.`
|
||||
)}
|
||||
</Text>
|
||||
<Text as="p" text03>
|
||||
<Text as="p" color="text-03">
|
||||
Session history will be preserved.
|
||||
</Text>
|
||||
</>
|
||||
@@ -536,7 +537,7 @@ export default function VoiceConfigurationPage() {
|
||||
<Callout type="danger" title="Failed to load voice settings">
|
||||
{message}
|
||||
{detail && (
|
||||
<Text as="p" mainContentBody text03>
|
||||
<Text as="p" font="main-content-body" color="text-03">
|
||||
{detail}
|
||||
</Text>
|
||||
)}
|
||||
@@ -626,7 +627,7 @@ export default function VoiceConfigurationPage() {
|
||||
|
||||
{TTS_PROVIDER_GROUPS.map((group) => (
|
||||
<div key={group.providerType} className="flex flex-col gap-2">
|
||||
<Text secondaryBody text03>
|
||||
<Text font="secondary-body" color="text-03">
|
||||
{group.providerLabel}
|
||||
</Text>
|
||||
<div className="flex flex-col gap-2">
|
||||
|
||||
@@ -122,7 +122,10 @@ function MicrophoneButton({
|
||||
startRecording,
|
||||
stopRecording,
|
||||
setMuted,
|
||||
} = useVoiceRecorder({ onFinalTranscript: handleFinalTranscript });
|
||||
} = useVoiceRecorder({
|
||||
onFinalTranscript: handleFinalTranscript,
|
||||
autoStopOnSilence: autoSend,
|
||||
});
|
||||
|
||||
// Expose stopRecording to parent
|
||||
useEffect(() => {
|
||||
|
||||
@@ -4,6 +4,14 @@ import { expectScreenshot } from "@tests/e2e/utils/visualRegression";
|
||||
|
||||
test.use({ storageState: "admin_auth.json" });
|
||||
|
||||
/** Maps each settings slug to the header title shown on that page. */
|
||||
const SLUG_TO_HEADER: Record<string, string> = {
|
||||
general: "Profile",
|
||||
"chat-preferences": "Chats",
|
||||
"accounts-access": "Accounts",
|
||||
connectors: "Connectors",
|
||||
};
|
||||
|
||||
for (const theme of THEMES) {
|
||||
test.describe(`Settings pages (${theme} mode)`, () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
@@ -11,21 +19,33 @@ for (const theme of THEMES) {
|
||||
});
|
||||
|
||||
test("should screenshot each settings tab", async ({ page }) => {
|
||||
await page.goto("/app/settings");
|
||||
await page.waitForLoadState("networkidle");
|
||||
await page.goto("/app/settings/general");
|
||||
await page
|
||||
.getByTestId("settings-left-tab-navigation")
|
||||
.waitFor({ state: "visible" });
|
||||
|
||||
const nav = page.getByTestId("settings-left-tab-navigation");
|
||||
const tabs = nav.locator("a");
|
||||
await expect(tabs.first()).toBeVisible({ timeout: 10_000 });
|
||||
const count = await tabs.count();
|
||||
|
||||
expect(count).toBeGreaterThan(0);
|
||||
for (let i = 0; i < count; i++) {
|
||||
const tab = tabs.nth(i);
|
||||
const href = await tab.getAttribute("href");
|
||||
const slug = href ? href.replace("/app/settings/", "") : `tab-${i}`;
|
||||
|
||||
await tab.click();
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
const expectedHeader = SLUG_TO_HEADER[slug];
|
||||
if (expectedHeader) {
|
||||
await expect(
|
||||
page
|
||||
.locator(".opal-content-md-header")
|
||||
.filter({ hasText: expectedHeader })
|
||||
).toBeVisible({ timeout: 10_000 });
|
||||
} else {
|
||||
await page.waitForLoadState("networkidle");
|
||||
}
|
||||
|
||||
await expectScreenshot(page, {
|
||||
name: `settings-${theme}-${slug}`,
|
||||
|
||||
Reference in New Issue
Block a user