mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-09 00:42:47 +00:00
Compare commits
22 Commits
refactor/r
...
v3.1.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b4579a1365 | ||
|
|
893c094aed | ||
|
|
f8a55712d2 | ||
|
|
591afd4fb1 | ||
|
|
9328070dc0 | ||
|
|
6163521126 | ||
|
|
d42c5616b0 | ||
|
|
aeb4fdd6c1 | ||
|
|
c673959714 | ||
|
|
cb36562802 | ||
|
|
efc424bf3e | ||
|
|
e0baaf85e5 | ||
|
|
a0ffd47e2c | ||
|
|
d0396a1337 | ||
|
|
b9e84c42a8 | ||
|
|
0a1df52c2f | ||
|
|
306b0d452f | ||
|
|
5fdb34ba8e | ||
|
|
2d066631e3 | ||
|
|
5c84f6c61b | ||
|
|
899179d4b6 | ||
|
|
80d6bafc74 |
@@ -0,0 +1,35 @@
|
||||
"""remove voice_provider deleted column
|
||||
|
||||
Revision ID: 1d78c0ca7853
|
||||
Revises: a3f8b2c1d4e5
|
||||
Create Date: 2026-03-26 11:30:53.883127
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1d78c0ca7853"
|
||||
down_revision = "a3f8b2c1d4e5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Hard-delete any soft-deleted rows before dropping the column
|
||||
op.execute("DELETE FROM voice_provider WHERE deleted = true")
|
||||
op.drop_column("voice_provider", "deleted")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"voice_provider",
|
||||
sa.Column(
|
||||
"deleted",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
)
|
||||
@@ -28,6 +28,7 @@ from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ElementExternalAccess
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
@@ -187,7 +188,6 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
|
||||
@@ -227,6 +227,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
validate_permission_sync_fences(
|
||||
tenant_id, r, r_replica, r_celery, lock_beat
|
||||
)
|
||||
|
||||
@@ -29,6 +29,7 @@ from ee.onyx.external_permissions.sync_params import (
|
||||
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
from onyx.background.error_logging import emit_background_error
|
||||
@@ -162,7 +163,6 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
|
||||
@@ -221,6 +221,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
validate_external_group_sync_fences(
|
||||
tenant_id, self.app, r, r_replica, r_celery, lock_beat
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# These are helper objects for tracking the keys we need to write in redis
|
||||
import json
|
||||
import threading
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -7,7 +8,59 @@ from celery import Celery
|
||||
from redis import Redis
|
||||
|
||||
from onyx.background.celery.configs.base import CELERY_SEPARATOR
|
||||
from onyx.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
|
||||
|
||||
|
||||
_broker_client: Redis | None = None
|
||||
_broker_url: str | None = None
|
||||
_broker_client_lock = threading.Lock()
|
||||
|
||||
|
||||
def celery_get_broker_client(app: Celery) -> Redis:
|
||||
"""Return a shared Redis client connected to the Celery broker DB.
|
||||
|
||||
Uses a module-level singleton so all tasks on a worker share one
|
||||
connection instead of creating a new one per call. The client
|
||||
connects directly to the broker Redis DB (parsed from the broker URL).
|
||||
|
||||
Thread-safe via lock — safe for use in Celery thread-pool workers.
|
||||
|
||||
Usage:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
length = celery_get_queue_length(queue, r_celery)
|
||||
"""
|
||||
global _broker_client, _broker_url
|
||||
with _broker_client_lock:
|
||||
url = app.conf.broker_url
|
||||
if _broker_client is not None and _broker_url == url:
|
||||
try:
|
||||
_broker_client.ping()
|
||||
return _broker_client
|
||||
except Exception:
|
||||
try:
|
||||
_broker_client.close()
|
||||
except Exception:
|
||||
pass
|
||||
_broker_client = None
|
||||
elif _broker_client is not None:
|
||||
try:
|
||||
_broker_client.close()
|
||||
except Exception:
|
||||
pass
|
||||
_broker_client = None
|
||||
|
||||
_broker_url = url
|
||||
_broker_client = Redis.from_url(
|
||||
url,
|
||||
decode_responses=False,
|
||||
health_check_interval=REDIS_HEALTH_CHECK_INTERVAL,
|
||||
socket_keepalive=True,
|
||||
socket_keepalive_options=REDIS_SOCKET_KEEPALIVE_OPTIONS,
|
||||
retry_on_timeout=True,
|
||||
)
|
||||
return _broker_client
|
||||
|
||||
|
||||
def celery_get_unacked_length(r: Redis) -> int:
|
||||
|
||||
@@ -14,6 +14,7 @@ from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
@@ -132,7 +133,6 @@ def revoke_tasks_blocking_deletion(
|
||||
def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
@@ -149,6 +149,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | N
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES):
|
||||
# clear fences that don't have associated celery tasks in progress
|
||||
try:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
validate_connector_deletion_fences(
|
||||
tenant_id, r, r_replica, r_celery, lock_beat
|
||||
)
|
||||
|
||||
@@ -22,6 +22,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
@@ -449,7 +450,7 @@ def check_indexing_completion(
|
||||
):
|
||||
# Check if the task exists in the celery queue
|
||||
# This handles the case where Redis dies after task creation but before task execution
|
||||
redis_celery = task.app.broker_connection().channel().client # type: ignore
|
||||
redis_celery = celery_get_broker_client(task.app)
|
||||
task_exists = celery_find_task(
|
||||
attempt.celery_task_id,
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from itertools import islice
|
||||
from typing import Any
|
||||
@@ -19,6 +18,7 @@ from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
@@ -698,31 +698,27 @@ def monitor_background_processes(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Get Redis client for Celery broker
|
||||
redis_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
redis_std = get_redis_client()
|
||||
|
||||
# Define metric collection functions and their dependencies
|
||||
metric_functions: list[Callable[[], list[Metric]]] = [
|
||||
lambda: _collect_queue_metrics(redis_celery),
|
||||
lambda: _collect_connector_metrics(db_session, redis_std),
|
||||
lambda: _collect_sync_metrics(db_session, redis_std),
|
||||
]
|
||||
# Collect queue metrics with broker connection
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
queue_metrics = _collect_queue_metrics(r_celery)
|
||||
|
||||
# Collect and log each metric
|
||||
# Collect remaining metrics (no broker connection needed)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for metric_fn in metric_functions:
|
||||
metrics = metric_fn()
|
||||
for metric in metrics:
|
||||
# double check to make sure we aren't double-emitting metrics
|
||||
if metric.key is None or not _has_metric_been_emitted(
|
||||
redis_std, metric.key
|
||||
):
|
||||
metric.log()
|
||||
metric.emit(tenant_id)
|
||||
all_metrics: list[Metric] = queue_metrics
|
||||
all_metrics.extend(_collect_connector_metrics(db_session, redis_std))
|
||||
all_metrics.extend(_collect_sync_metrics(db_session, redis_std))
|
||||
|
||||
if metric.key is not None:
|
||||
_mark_metric_as_emitted(redis_std, metric.key)
|
||||
for metric in all_metrics:
|
||||
if metric.key is None or not _has_metric_been_emitted(
|
||||
redis_std, metric.key
|
||||
):
|
||||
metric.log()
|
||||
metric.emit(tenant_id)
|
||||
|
||||
if metric.key is not None:
|
||||
_mark_metric_as_emitted(redis_std, metric.key)
|
||||
|
||||
task_logger.info("Successfully collected background metrics")
|
||||
except SoftTimeLimitExceeded:
|
||||
@@ -890,7 +886,7 @@ def monitor_celery_queues_helper(
|
||||
) -> None:
|
||||
"""A task to monitor all celery queue lengths."""
|
||||
|
||||
r_celery = task.app.broker_connection().channel().client # type: ignore
|
||||
r_celery = celery_get_broker_client(task.app)
|
||||
n_celery = celery_get_queue_length(OnyxCeleryQueues.PRIMARY, r_celery)
|
||||
n_docfetching = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
|
||||
@@ -1080,7 +1076,7 @@ def cloud_monitor_celery_pidbox(
|
||||
num_deleted = 0
|
||||
|
||||
MAX_PIDBOX_IDLE = 24 * 3600 # 1 day in seconds
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
for key in r_celery.scan_iter("*.reply.celery.pidbox"):
|
||||
key_bytes = cast(bytes, key)
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
|
||||
@@ -17,6 +17,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
@@ -203,7 +204,6 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
|
||||
@@ -261,6 +261,7 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
validate_pruning_fences(tenant_id, r, r_replica, r_celery, lock_beat)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while validating pruning fences")
|
||||
|
||||
@@ -16,6 +16,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.access import build_access_for_user_files
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
@@ -105,7 +106,7 @@ def _user_file_delete_queued_key(user_file_id: str | UUID) -> str:
|
||||
|
||||
|
||||
def get_user_file_project_sync_queue_depth(celery_app: Celery) -> int:
|
||||
redis_celery: Redis = celery_app.broker_connection().channel().client # type: ignore
|
||||
redis_celery = celery_get_broker_client(celery_app)
|
||||
return celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, redis_celery
|
||||
)
|
||||
@@ -238,7 +239,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
skipped_guard = 0
|
||||
try:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
queue_len = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
|
||||
)
|
||||
@@ -591,7 +592,7 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
# NOTE: must use the broker's Redis client (not redis_client) because
|
||||
# Celery queues live on a separate Redis DB with CELERY_SEPARATOR keys.
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
queue_len = celery_get_queue_length(OnyxCeleryQueues.USER_FILE_DELETE, r_celery)
|
||||
if queue_len > USER_FILE_DELETE_MAX_QUEUE_DEPTH:
|
||||
task_logger.warning(
|
||||
|
||||
@@ -12,6 +12,11 @@ SLACK_USER_TOKEN_PREFIX = "xoxp-"
|
||||
SLACK_BOT_TOKEN_PREFIX = "xoxb-"
|
||||
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
|
||||
|
||||
# The mask_string() function in encryption.py uses "•" (U+2022 BULLET) to mask secrets.
|
||||
MASK_CREDENTIAL_CHAR = "\u2022"
|
||||
# Pattern produced by mask_string for strings >= 14 chars: "abcd...wxyz" (exactly 11 chars)
|
||||
MASK_CREDENTIAL_LONG_RE = re.compile(r"^.{4}\.{3}.{4}$")
|
||||
|
||||
SOURCE_TYPE = "source_type"
|
||||
# stored in the `metadata` of a chunk. Used to signify that this chunk should
|
||||
# not be used for QA. For example, Google Drive file types which can't be parsed
|
||||
|
||||
@@ -10,6 +10,7 @@ from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from jira import JIRA
|
||||
from jira.exceptions import JIRAError
|
||||
from jira.resources import Issue
|
||||
@@ -239,29 +240,53 @@ def enhanced_search_ids(
|
||||
)
|
||||
|
||||
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO: move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
def _bulk_fetch_request(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Raw POST to the bulkfetch endpoint. Returns the list of raw issue dicts."""
|
||||
bulk_fetch_path = jira_client._get_url("issue/bulkfetch")
|
||||
|
||||
# Prepare the payload according to Jira API v3 specification
|
||||
payload: dict[str, Any] = {"issueIdsOrKeys": issue_ids}
|
||||
|
||||
# Only restrict fields if specified, might want to explicitly do this in the future
|
||||
# to avoid reading unnecessary data
|
||||
payload["fields"] = fields.split(",") if fields else ["*all"]
|
||||
|
||||
resp = jira_client._session.post(bulk_fetch_path, json=payload)
|
||||
return resp.json()["issues"]
|
||||
|
||||
|
||||
def bulk_fetch_issues(
|
||||
jira_client: JIRA, issue_ids: list[str], fields: str | None = None
|
||||
) -> list[Issue]:
|
||||
# TODO(evan): move away from this jira library if they continue to not support
|
||||
# the endpoints we need. Using private fields is not ideal, but
|
||||
# is likely fine for now since we pin the library version
|
||||
|
||||
try:
|
||||
response = jira_client._session.post(bulk_fetch_path, json=payload).json()
|
||||
raw_issues = _bulk_fetch_request(jira_client, issue_ids, fields)
|
||||
except requests.exceptions.JSONDecodeError:
|
||||
if len(issue_ids) <= 1:
|
||||
logger.exception(
|
||||
f"Jira bulk-fetch response for issue(s) {issue_ids} could not "
|
||||
f"be decoded as JSON (response too large or truncated)."
|
||||
)
|
||||
raise
|
||||
|
||||
mid = len(issue_ids) // 2
|
||||
logger.warning(
|
||||
f"Jira bulk-fetch JSON decode failed for batch of {len(issue_ids)} issues. "
|
||||
f"Splitting into sub-batches of {mid} and {len(issue_ids) - mid}."
|
||||
)
|
||||
left = bulk_fetch_issues(jira_client, issue_ids[:mid], fields)
|
||||
right = bulk_fetch_issues(jira_client, issue_ids[mid:], fields)
|
||||
return left + right
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching issues: {e}")
|
||||
raise e
|
||||
raise
|
||||
|
||||
return [
|
||||
Issue(jira_client._options, jira_client._session, raw=issue)
|
||||
for issue in response["issues"]
|
||||
for issue in raw_issues
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.api_key import ApiKeyDescriptor
|
||||
@@ -55,7 +54,6 @@ async def fetch_user_for_api_key(
|
||||
select(User)
|
||||
.join(ApiKey, ApiKey.user_id == User.id)
|
||||
.where(ApiKey.hashed_api_key == hashed_api_key)
|
||||
.options(selectinload(User.memories))
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from sqlalchemy import func
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
@@ -98,11 +97,6 @@ async def get_user_count(only_admin_users: bool = False) -> int:
|
||||
|
||||
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
|
||||
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase[UP, ID]):
|
||||
async def _get_user(self, statement: Select) -> UP | None:
|
||||
statement = statement.options(selectinload(User.memories))
|
||||
results = await self.session.execute(statement)
|
||||
return results.unique().scalar_one_or_none()
|
||||
|
||||
async def create(
|
||||
self,
|
||||
create_dict: Dict[str, Any],
|
||||
|
||||
@@ -8,7 +8,6 @@ from uuid import UUID
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import exists
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import nullsfirst
|
||||
from sqlalchemy import or_
|
||||
@@ -132,32 +131,47 @@ def get_chat_sessions_by_user(
|
||||
if before is not None:
|
||||
stmt = stmt.where(ChatSession.time_updated < before)
|
||||
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
if project_id is not None:
|
||||
stmt = stmt.where(ChatSession.project_id == project_id)
|
||||
elif only_non_project_chats:
|
||||
stmt = stmt.where(ChatSession.project_id.is_(None))
|
||||
|
||||
if not include_failed_chats:
|
||||
non_system_message_exists_subq = (
|
||||
exists()
|
||||
.where(ChatMessage.chat_session_id == ChatSession.id)
|
||||
.where(ChatMessage.message_type != MessageType.SYSTEM)
|
||||
.correlate(ChatSession)
|
||||
)
|
||||
|
||||
# Leeway for newly created chats that don't have messages yet
|
||||
time = datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||
recently_created = ChatSession.time_created >= time
|
||||
|
||||
stmt = stmt.where(or_(non_system_message_exists_subq, recently_created))
|
||||
# When filtering out failed chats, we apply the limit in Python after
|
||||
# filtering rather than in SQL, since the post-filter may remove rows.
|
||||
if limit and include_failed_chats:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
chat_sessions = result.scalars().all()
|
||||
chat_sessions = list(result.scalars().all())
|
||||
|
||||
return list(chat_sessions)
|
||||
if not include_failed_chats and chat_sessions:
|
||||
# Filter out "failed" sessions (those with only SYSTEM messages)
|
||||
# using a separate efficient query instead of a correlated EXISTS
|
||||
# subquery, which causes full sequential scans of chat_message.
|
||||
leeway = datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||
session_ids = [cs.id for cs in chat_sessions if cs.time_created < leeway]
|
||||
|
||||
if session_ids:
|
||||
valid_session_ids_stmt = (
|
||||
select(ChatMessage.chat_session_id)
|
||||
.where(ChatMessage.chat_session_id.in_(session_ids))
|
||||
.where(ChatMessage.message_type != MessageType.SYSTEM)
|
||||
.distinct()
|
||||
)
|
||||
valid_session_ids = set(
|
||||
db_session.execute(valid_session_ids_stmt).scalars().all()
|
||||
)
|
||||
|
||||
chat_sessions = [
|
||||
cs
|
||||
for cs in chat_sessions
|
||||
if cs.time_created >= leeway or cs.id in valid_session_ids
|
||||
]
|
||||
|
||||
if limit:
|
||||
chat_sessions = chat_sessions[:limit]
|
||||
|
||||
return chat_sessions
|
||||
|
||||
|
||||
def delete_orphaned_search_docs(db_session: Session) -> None:
|
||||
|
||||
@@ -8,6 +8,8 @@ from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import FederatedConnectorSource
|
||||
from onyx.configs.constants import MASK_CREDENTIAL_CHAR
|
||||
from onyx.configs.constants import MASK_CREDENTIAL_LONG_RE
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import FederatedConnector
|
||||
@@ -45,6 +47,23 @@ def fetch_all_federated_connectors_parallel() -> list[FederatedConnector]:
|
||||
return fetch_all_federated_connectors(db_session)
|
||||
|
||||
|
||||
def _reject_masked_credentials(credentials: dict[str, Any]) -> None:
|
||||
"""Raise if any credential string value contains mask placeholder characters.
|
||||
|
||||
mask_string() has two output formats:
|
||||
- Short strings (< 14 chars): "••••••••••••" (U+2022 BULLET)
|
||||
- Long strings (>= 14 chars): "abcd...wxyz" (first4 + "..." + last4)
|
||||
Both must be rejected.
|
||||
"""
|
||||
for key, val in credentials.items():
|
||||
if isinstance(val, str) and (
|
||||
MASK_CREDENTIAL_CHAR in val or MASK_CREDENTIAL_LONG_RE.match(val)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Credential field '{key}' contains masked placeholder characters. Please provide the actual credential value."
|
||||
)
|
||||
|
||||
|
||||
def validate_federated_connector_credentials(
|
||||
source: FederatedConnectorSource,
|
||||
credentials: dict[str, Any],
|
||||
@@ -66,6 +85,8 @@ def create_federated_connector(
|
||||
config: dict[str, Any] | None = None,
|
||||
) -> FederatedConnector:
|
||||
"""Create a new federated connector with credential and config validation."""
|
||||
_reject_masked_credentials(credentials)
|
||||
|
||||
# Validate credentials before creating
|
||||
if not validate_federated_connector_credentials(source, credentials):
|
||||
raise ValueError(
|
||||
@@ -277,6 +298,8 @@ def update_federated_connector(
|
||||
)
|
||||
|
||||
if credentials is not None:
|
||||
_reject_masked_credentials(credentials)
|
||||
|
||||
# Validate credentials before updating
|
||||
if not validate_federated_connector_credentials(
|
||||
federated_connector.source, credentials
|
||||
|
||||
@@ -3135,8 +3135,6 @@ class VoiceProvider(Base):
|
||||
is_default_stt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
is_default_tts: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
@@ -8,7 +8,6 @@ from uuid import UUID
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.pat import build_displayable_pat
|
||||
@@ -47,7 +46,6 @@ async def fetch_user_for_pat(
|
||||
(PersonalAccessToken.expires_at.is_(None))
|
||||
| (PersonalAccessToken.expires_at > now)
|
||||
)
|
||||
.options(selectinload(User.memories))
|
||||
)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
@@ -229,7 +229,9 @@ def get_memories_for_user(
|
||||
user_id: UUID,
|
||||
db_session: Session,
|
||||
) -> Sequence[Memory]:
|
||||
return db_session.scalars(select(Memory).where(Memory.user_id == user_id)).all()
|
||||
return db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.desc())
|
||||
).all()
|
||||
|
||||
|
||||
def update_user_pinned_assistants(
|
||||
|
||||
@@ -17,39 +17,30 @@ MAX_VOICE_PLAYBACK_SPEED = 2.0
|
||||
def fetch_voice_providers(db_session: Session) -> list[VoiceProvider]:
|
||||
"""Fetch all voice providers."""
|
||||
return list(
|
||||
db_session.scalars(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
.order_by(VoiceProvider.name)
|
||||
).all()
|
||||
db_session.scalars(select(VoiceProvider).order_by(VoiceProvider.name)).all()
|
||||
)
|
||||
|
||||
|
||||
def fetch_voice_provider_by_id(
|
||||
db_session: Session, provider_id: int, include_deleted: bool = False
|
||||
db_session: Session, provider_id: int
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by ID."""
|
||||
stmt = select(VoiceProvider).where(VoiceProvider.id == provider_id)
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(VoiceProvider.deleted.is_(False))
|
||||
return db_session.scalar(stmt)
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider).where(VoiceProvider.id == provider_id)
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_stt_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default STT provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.is_default_stt.is_(True))
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
select(VoiceProvider).where(VoiceProvider.is_default_stt.is_(True))
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_tts_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default TTS provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.is_default_tts.is_(True))
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
select(VoiceProvider).where(VoiceProvider.is_default_tts.is_(True))
|
||||
)
|
||||
|
||||
|
||||
@@ -58,9 +49,7 @@ def fetch_voice_provider_by_type(
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by type."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.provider_type == provider_type)
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
select(VoiceProvider).where(VoiceProvider.provider_type == provider_type)
|
||||
)
|
||||
|
||||
|
||||
@@ -119,10 +108,10 @@ def upsert_voice_provider(
|
||||
|
||||
|
||||
def delete_voice_provider(db_session: Session, provider_id: int) -> None:
|
||||
"""Soft-delete a voice provider by ID."""
|
||||
"""Delete a voice provider by ID."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider:
|
||||
provider.deleted = True
|
||||
db_session.delete(provider)
|
||||
db_session.flush()
|
||||
|
||||
|
||||
|
||||
@@ -49,9 +49,21 @@ KNOWN_OPENPYXL_BUGS = [
|
||||
|
||||
def get_markitdown_converter() -> "MarkItDown":
|
||||
global _MARKITDOWN_CONVERTER
|
||||
from markitdown import MarkItDown
|
||||
|
||||
if _MARKITDOWN_CONVERTER is None:
|
||||
from markitdown import MarkItDown
|
||||
|
||||
# Patch this function to effectively no-op because we were seeing this
|
||||
# module take an inordinate amount of time to convert charts to markdown,
|
||||
# making some powerpoint files with many or complicated charts nearly
|
||||
# unindexable.
|
||||
from markitdown.converters._pptx_converter import PptxConverter
|
||||
|
||||
setattr(
|
||||
PptxConverter,
|
||||
"_convert_chart_to_markdown",
|
||||
lambda self, chart: "\n\n[chart omitted]\n\n", # noqa: ARG005
|
||||
)
|
||||
_MARKITDOWN_CONVERTER = MarkItDown(enable_plugins=False)
|
||||
return _MARKITDOWN_CONVERTER
|
||||
|
||||
@@ -202,18 +214,26 @@ def read_pdf_file(
|
||||
try:
|
||||
pdf_reader = PdfReader(file)
|
||||
|
||||
if pdf_reader.is_encrypted and pdf_pass is not None:
|
||||
if pdf_reader.is_encrypted:
|
||||
# Try the explicit password first, then fall back to an empty
|
||||
# string. Owner-password-only PDFs (permission restrictions but
|
||||
# no open password) decrypt successfully with "".
|
||||
# See https://github.com/onyx-dot-app/onyx/issues/9754
|
||||
passwords = [p for p in [pdf_pass, ""] if p is not None]
|
||||
decrypt_success = False
|
||||
try:
|
||||
decrypt_success = pdf_reader.decrypt(pdf_pass) != 0
|
||||
except Exception:
|
||||
logger.error("Unable to decrypt pdf")
|
||||
for pw in passwords:
|
||||
try:
|
||||
if pdf_reader.decrypt(pw) != 0:
|
||||
decrypt_success = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not decrypt_success:
|
||||
logger.error(
|
||||
"Encrypted PDF could not be decrypted, returning empty text."
|
||||
)
|
||||
return "", metadata, []
|
||||
elif pdf_reader.is_encrypted:
|
||||
logger.warning("No Password for an encrypted PDF, returning empty text.")
|
||||
return "", metadata, []
|
||||
|
||||
# Basic PDF metadata
|
||||
if pdf_reader.metadata is not None:
|
||||
|
||||
@@ -33,8 +33,20 @@ def is_pdf_protected(file: IO[Any]) -> bool:
|
||||
|
||||
with preserve_position(file):
|
||||
reader = PdfReader(file)
|
||||
if not reader.is_encrypted:
|
||||
return False
|
||||
|
||||
return bool(reader.is_encrypted)
|
||||
# PDFs with only an owner password (permission restrictions like
|
||||
# print/copy disabled) use an empty user password — any viewer can open
|
||||
# them without prompting. decrypt("") returns 0 only when a real user
|
||||
# password is required. See https://github.com/onyx-dot-app/onyx/issues/9754
|
||||
try:
|
||||
return reader.decrypt("") == 0
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to evaluate PDF encryption; treating as password protected"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def is_docx_protected(file: IO[Any]) -> bool:
|
||||
|
||||
@@ -185,6 +185,21 @@ def _messages_contain_tool_content(messages: list[dict[str, Any]]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _prompt_contains_tool_call_history(prompt: LanguageModelInput) -> bool:
|
||||
"""Check if the prompt contains any assistant messages with tool_calls.
|
||||
|
||||
When Anthropic's extended thinking is enabled, the API requires every
|
||||
assistant message to start with a thinking block before any tool_use
|
||||
blocks. Since we don't preserve thinking_blocks (they carry
|
||||
cryptographic signatures that can't be reconstructed), we must skip
|
||||
the thinking param whenever history contains prior tool-calling turns.
|
||||
"""
|
||||
from onyx.llm.models import AssistantMessage
|
||||
|
||||
msgs = prompt if isinstance(prompt, list) else [prompt]
|
||||
return any(isinstance(msg, AssistantMessage) and msg.tool_calls for msg in msgs)
|
||||
|
||||
|
||||
def _is_vertex_model_rejecting_output_config(model_name: str) -> bool:
|
||||
normalized_model_name = model_name.lower()
|
||||
return any(
|
||||
@@ -466,7 +481,20 @@ class LitellmLLM(LLM):
|
||||
reasoning_effort
|
||||
)
|
||||
|
||||
if budget_tokens is not None:
|
||||
# Anthropic requires every assistant message with tool_use
|
||||
# blocks to start with a thinking block that carries a
|
||||
# cryptographic signature. We don't preserve those blocks
|
||||
# across turns, so skip thinking when the history already
|
||||
# contains tool-calling assistant messages. LiteLLM's
|
||||
# modify_params workaround doesn't cover all providers
|
||||
# (notably Bedrock).
|
||||
can_enable_thinking = (
|
||||
budget_tokens is not None
|
||||
and not _prompt_contains_tool_call_history(prompt)
|
||||
)
|
||||
|
||||
if can_enable_thinking:
|
||||
assert budget_tokens is not None # mypy
|
||||
if max_tokens is not None:
|
||||
# Anthropic has a weird rule where max token has to be at least as much as budget tokens if set
|
||||
# and the minimum budget tokens is 1024
|
||||
|
||||
@@ -6,6 +6,7 @@ from onyx.configs.app_configs import MCP_SERVER_ENABLED
|
||||
from onyx.configs.app_configs import MCP_SERVER_HOST
|
||||
from onyx.configs.app_configs import MCP_SERVER_PORT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -16,6 +17,7 @@ def main() -> None:
|
||||
logger.info("MCP server is disabled (MCP_SERVER_ENABLED=false)")
|
||||
return
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
logger.info(f"Starting MCP server on {MCP_SERVER_HOST}:{MCP_SERVER_PORT}")
|
||||
|
||||
from onyx.mcp_server.api import mcp_app
|
||||
|
||||
@@ -44,11 +44,12 @@ def _check_ssrf_safety(endpoint_url: str) -> None:
|
||||
"""Raise OnyxError if endpoint_url could be used for SSRF.
|
||||
|
||||
Delegates to validate_outbound_http_url with https_only=True.
|
||||
Uses BAD_GATEWAY so the frontend maps the error to the Endpoint URL field.
|
||||
"""
|
||||
try:
|
||||
validate_outbound_http_url(endpoint_url, https_only=True)
|
||||
except (SSRFException, ValueError) as e:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
|
||||
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, str(e))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -141,19 +142,11 @@ def _validate_endpoint(
|
||||
)
|
||||
return HookValidateResponse(status=HookValidateStatus.passed)
|
||||
except httpx.TimeoutException as exc:
|
||||
# ConnectTimeout: TCP handshake never completed → cannot_connect.
|
||||
# ReadTimeout / WriteTimeout: TCP was established, server just responded slowly → timeout.
|
||||
if isinstance(exc, httpx.ConnectTimeout):
|
||||
logger.warning(
|
||||
"Hook endpoint validation: connect timeout for %s",
|
||||
endpoint_url,
|
||||
exc_info=exc,
|
||||
)
|
||||
return HookValidateResponse(
|
||||
status=HookValidateStatus.cannot_connect, error_message=str(exc)
|
||||
)
|
||||
# Any timeout (connect, read, or write) means the configured timeout_seconds
|
||||
# is too low for this endpoint. Report as timeout so the UI directs the user
|
||||
# to increase the timeout setting.
|
||||
logger.warning(
|
||||
"Hook endpoint validation: read/write timeout for %s",
|
||||
"Hook endpoint validation: timeout for %s",
|
||||
endpoint_url,
|
||||
exc_info=exc,
|
||||
)
|
||||
|
||||
@@ -1524,6 +1524,7 @@ def get_bifrost_available_models(
|
||||
display_name=model_name,
|
||||
max_input_tokens=model.get("context_length"),
|
||||
supports_image_input=infer_vision_support(model_id),
|
||||
supports_reasoning=is_reasoning_model(model_id, model_name),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -463,3 +463,4 @@ class BifrostFinalModelResponse(BaseModel):
|
||||
display_name: str # Human-readable name from Bifrost API
|
||||
max_input_tokens: int | None
|
||||
supports_image_input: bool
|
||||
supports_reasoning: bool
|
||||
|
||||
@@ -147,6 +147,7 @@ class UserInfo(BaseModel):
|
||||
is_anonymous_user: bool | None = None,
|
||||
tenant_info: TenantInfo | None = None,
|
||||
assistant_specific_configs: UserSpecificAssistantPreferences | None = None,
|
||||
memories: list[MemoryItem] | None = None,
|
||||
) -> "UserInfo":
|
||||
return cls(
|
||||
id=str(user.id),
|
||||
@@ -191,10 +192,7 @@ class UserInfo(BaseModel):
|
||||
role=user.personal_role or "",
|
||||
use_memories=user.use_memories,
|
||||
enable_memory_tool=user.enable_memory_tool,
|
||||
memories=[
|
||||
MemoryItem(id=memory.id, content=memory.memory_text)
|
||||
for memory in (user.memories or [])
|
||||
],
|
||||
memories=memories or [],
|
||||
user_preferences=user.user_preferences or "",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -57,6 +57,7 @@ from onyx.db.user_preferences import activate_user
|
||||
from onyx.db.user_preferences import deactivate_user
|
||||
from onyx.db.user_preferences import get_all_user_assistant_specific_configs
|
||||
from onyx.db.user_preferences import get_latest_access_token_for_user
|
||||
from onyx.db.user_preferences import get_memories_for_user
|
||||
from onyx.db.user_preferences import update_assistant_preferences
|
||||
from onyx.db.user_preferences import update_user_assistant_visibility
|
||||
from onyx.db.user_preferences import update_user_auto_scroll
|
||||
@@ -823,6 +824,11 @@ def verify_user_logged_in(
|
||||
[],
|
||||
),
|
||||
)
|
||||
memories = [
|
||||
MemoryItem(id=memory.id, content=memory.memory_text)
|
||||
for memory in get_memories_for_user(user.id, db_session)
|
||||
]
|
||||
|
||||
user_info = UserInfo.from_model(
|
||||
user,
|
||||
current_token_created_at=token_created_at,
|
||||
@@ -833,6 +839,7 @@ def verify_user_logged_in(
|
||||
new_tenant=new_tenant,
|
||||
invitation=tenant_invitation,
|
||||
),
|
||||
memories=memories,
|
||||
)
|
||||
|
||||
return user_info
|
||||
@@ -930,7 +937,8 @@ def update_user_personalization_api(
|
||||
else user.enable_memory_tool
|
||||
)
|
||||
existing_memories = [
|
||||
MemoryItem(id=memory.id, content=memory.memory_text) for memory in user.memories
|
||||
MemoryItem(id=memory.id, content=memory.memory_text)
|
||||
for memory in get_memories_for_user(user.id, db_session)
|
||||
]
|
||||
new_memories = (
|
||||
request.memories if request.memories is not None else existing_memories
|
||||
|
||||
@@ -12,7 +12,6 @@ stale, which is fine for monitoring dashboards.
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
@@ -104,25 +103,23 @@ class _CachedCollector(Collector):
|
||||
|
||||
|
||||
class QueueDepthCollector(_CachedCollector):
|
||||
"""Reads Celery queue lengths from the broker Redis on each scrape.
|
||||
|
||||
Uses a Redis client factory (callable) rather than a stored client
|
||||
reference so the connection is always fresh from Celery's pool.
|
||||
"""
|
||||
"""Reads Celery queue lengths from the broker Redis on each scrape."""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._get_redis: Callable[[], Redis] | None = None
|
||||
self._celery_app: Any | None = None
|
||||
|
||||
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
|
||||
"""Set a callable that returns a broker Redis client on demand."""
|
||||
self._get_redis = factory
|
||||
def set_celery_app(self, app: Any) -> None:
|
||||
"""Set the Celery app for broker Redis access."""
|
||||
self._celery_app = app
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if self._get_redis is None:
|
||||
if self._celery_app is None:
|
||||
return []
|
||||
|
||||
redis_client = self._get_redis()
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
|
||||
redis_client = celery_get_broker_client(self._celery_app)
|
||||
|
||||
depth = GaugeMetricFamily(
|
||||
"onyx_queue_depth",
|
||||
@@ -404,17 +401,19 @@ class RedisHealthCollector(_CachedCollector):
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._get_redis: Callable[[], Redis] | None = None
|
||||
self._celery_app: Any | None = None
|
||||
|
||||
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
|
||||
"""Set a callable that returns a broker Redis client on demand."""
|
||||
self._get_redis = factory
|
||||
def set_celery_app(self, app: Any) -> None:
|
||||
"""Set the Celery app for broker Redis access."""
|
||||
self._celery_app = app
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if self._get_redis is None:
|
||||
if self._celery_app is None:
|
||||
return []
|
||||
|
||||
redis_client = self._get_redis()
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
|
||||
redis_client = celery_get_broker_client(self._celery_app)
|
||||
|
||||
memory_used = GaugeMetricFamily(
|
||||
"onyx_redis_memory_used_bytes",
|
||||
|
||||
@@ -3,12 +3,8 @@
|
||||
Called once by the monitoring celery worker after Redis and DB are ready.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from prometheus_client.registry import REGISTRY
|
||||
from redis import Redis
|
||||
|
||||
from onyx.server.metrics.indexing_pipeline import ConnectorHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
|
||||
@@ -21,7 +17,7 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
# Module-level singletons — these are lightweight objects (no connections or DB
|
||||
# state) until configure() / set_redis_factory() is called. Keeping them at
|
||||
# state) until configure() / set_celery_app() is called. Keeping them at
|
||||
# module level ensures they survive the lifetime of the worker process and are
|
||||
# only registered with the Prometheus registry once.
|
||||
_queue_collector = QueueDepthCollector()
|
||||
@@ -32,72 +28,15 @@ _worker_health_collector = WorkerHealthCollector()
|
||||
_heartbeat_monitor: WorkerHeartbeatMonitor | None = None
|
||||
|
||||
|
||||
def _make_broker_redis_factory(celery_app: Celery) -> Callable[[], Redis]:
|
||||
"""Create a factory that returns a cached broker Redis client.
|
||||
|
||||
Reuses a single connection across scrapes to avoid leaking connections.
|
||||
Reconnects automatically if the cached connection becomes stale.
|
||||
"""
|
||||
_cached_client: list[Redis | None] = [None]
|
||||
# Keep a reference to the Kombu Connection so we can close it on
|
||||
# reconnect (the raw Redis client outlives the Kombu wrapper).
|
||||
_cached_kombu_conn: list[Any] = [None]
|
||||
|
||||
def _close_client(client: Redis) -> None:
|
||||
"""Best-effort close of a Redis client."""
|
||||
try:
|
||||
client.close()
|
||||
except Exception:
|
||||
logger.debug("Failed to close stale Redis client", exc_info=True)
|
||||
|
||||
def _close_kombu_conn() -> None:
|
||||
"""Best-effort close of the cached Kombu Connection."""
|
||||
conn = _cached_kombu_conn[0]
|
||||
if conn is not None:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
logger.debug("Failed to close Kombu connection", exc_info=True)
|
||||
_cached_kombu_conn[0] = None
|
||||
|
||||
def _get_broker_redis() -> Redis:
|
||||
client = _cached_client[0]
|
||||
if client is not None:
|
||||
try:
|
||||
client.ping()
|
||||
return client
|
||||
except Exception:
|
||||
logger.debug("Cached Redis client stale, reconnecting")
|
||||
_close_client(client)
|
||||
_cached_client[0] = None
|
||||
_close_kombu_conn()
|
||||
|
||||
# Get a fresh Redis client from the broker connection.
|
||||
# We hold this client long-term (cached above) rather than using a
|
||||
# context manager, because we need it to persist across scrapes.
|
||||
# The caching logic above ensures we only ever hold one connection,
|
||||
# and we close it explicitly on reconnect.
|
||||
conn = celery_app.broker_connection()
|
||||
# kombu's Channel exposes .client at runtime (the underlying Redis
|
||||
# client) but the type stubs don't declare it.
|
||||
new_client: Redis = conn.channel().client # type: ignore[attr-defined]
|
||||
_cached_client[0] = new_client
|
||||
_cached_kombu_conn[0] = conn
|
||||
return new_client
|
||||
|
||||
return _get_broker_redis
|
||||
|
||||
|
||||
def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
|
||||
"""Register all indexing pipeline collectors with the default registry.
|
||||
|
||||
Args:
|
||||
celery_app: The Celery application instance. Used to obtain a fresh
|
||||
celery_app: The Celery application instance. Used to obtain a
|
||||
broker Redis client on each scrape for queue depth metrics.
|
||||
"""
|
||||
redis_factory = _make_broker_redis_factory(celery_app)
|
||||
_queue_collector.set_redis_factory(redis_factory)
|
||||
_redis_health_collector.set_redis_factory(redis_factory)
|
||||
_queue_collector.set_celery_app(celery_app)
|
||||
_redis_health_collector.set_celery_app(celery_app)
|
||||
|
||||
# Start the heartbeat monitor daemon thread — uses a single persistent
|
||||
# connection to receive worker-heartbeat events.
|
||||
|
||||
@@ -129,6 +129,10 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
|
||||
return_value=mock_app,
|
||||
),
|
||||
patch(_PATCH_QUEUE_DEPTH, return_value=0),
|
||||
patch(
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@@ -88,10 +88,22 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
|
||||
the actual task instance. We patch ``app`` on that instance's class
|
||||
(a unique Celery-generated Task subclass) so the mock is scoped to this
|
||||
task only.
|
||||
|
||||
Also patches ``celery_get_broker_client`` so the mock app doesn't need
|
||||
a real broker URL.
|
||||
"""
|
||||
task_instance = task.run.__self__
|
||||
with patch.object(
|
||||
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
|
||||
with (
|
||||
patch.object(
|
||||
type(task_instance),
|
||||
"app",
|
||||
new_callable=PropertyMock,
|
||||
return_value=mock_app,
|
||||
),
|
||||
patch(
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@@ -90,8 +90,17 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
|
||||
task only.
|
||||
"""
|
||||
task_instance = task.run.__self__
|
||||
with patch.object(
|
||||
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
|
||||
with (
|
||||
patch.object(
|
||||
type(task_instance),
|
||||
"app",
|
||||
new_callable=PropertyMock,
|
||||
return_value=mock_app,
|
||||
),
|
||||
patch(
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@@ -103,6 +103,11 @@ _EXPECTED_CONFLUENCE_GROUPS = [
|
||||
user_emails={"oauth@onyx.app"},
|
||||
gives_anyone_access=False,
|
||||
),
|
||||
ExternalUserGroupSet(
|
||||
id="no yuhong allowed",
|
||||
user_emails={"hagen@danswer.ai", "pablo@onyx.app", "chris@onyx.app"},
|
||||
gives_anyone_access=False,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import MASK_CREDENTIAL_CHAR
|
||||
from onyx.db.federated import _reject_masked_credentials
|
||||
|
||||
|
||||
class TestRejectMaskedCredentials:
|
||||
"""Verify that masked credential values are never accepted for DB writes.
|
||||
|
||||
mask_string() has two output formats:
|
||||
- Short strings (< 14 chars): "••••••••••••" (U+2022 BULLET)
|
||||
- Long strings (>= 14 chars): "abcd...wxyz" (first4 + "..." + last4)
|
||||
_reject_masked_credentials must catch both.
|
||||
"""
|
||||
|
||||
def test_rejects_fully_masked_value(self) -> None:
|
||||
masked = MASK_CREDENTIAL_CHAR * 12 # "••••••••••••"
|
||||
with pytest.raises(ValueError, match="masked placeholder"):
|
||||
_reject_masked_credentials({"client_id": masked})
|
||||
|
||||
def test_rejects_long_string_masked_value(self) -> None:
|
||||
"""mask_string returns 'first4...last4' for long strings — the real
|
||||
format used for OAuth credentials like client_id and client_secret."""
|
||||
with pytest.raises(ValueError, match="masked placeholder"):
|
||||
_reject_masked_credentials({"client_id": "1234...7890"})
|
||||
|
||||
def test_rejects_when_any_field_is_masked(self) -> None:
|
||||
"""Even if client_id is real, a masked client_secret must be caught."""
|
||||
with pytest.raises(ValueError, match="client_secret"):
|
||||
_reject_masked_credentials(
|
||||
{
|
||||
"client_id": "1234567890.1234567890",
|
||||
"client_secret": MASK_CREDENTIAL_CHAR * 12,
|
||||
}
|
||||
)
|
||||
|
||||
def test_accepts_real_credentials(self) -> None:
|
||||
# Should not raise
|
||||
_reject_masked_credentials(
|
||||
{
|
||||
"client_id": "1234567890.1234567890",
|
||||
"client_secret": "test_client_secret_value",
|
||||
}
|
||||
)
|
||||
|
||||
def test_accepts_empty_dict(self) -> None:
|
||||
# Should not raise — empty credentials are handled elsewhere
|
||||
_reject_masked_credentials({})
|
||||
|
||||
def test_ignores_non_string_values(self) -> None:
|
||||
# Non-string values (None, bool, int) should pass through
|
||||
_reject_masked_credentials(
|
||||
{
|
||||
"client_id": "real_value",
|
||||
"redirect_uri": None,
|
||||
"some_flag": True,
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,87 @@
|
||||
"""Tests for celery_get_broker_client singleton."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.background.celery import celery_redis
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_singleton() -> Iterator[None]:
|
||||
"""Reset the module-level singleton between tests."""
|
||||
celery_redis._broker_client = None
|
||||
celery_redis._broker_url = None
|
||||
yield
|
||||
celery_redis._broker_client = None
|
||||
celery_redis._broker_url = None
|
||||
|
||||
|
||||
def _make_mock_app(broker_url: str = "redis://localhost:6379/15") -> MagicMock:
|
||||
app = MagicMock()
|
||||
app.conf.broker_url = broker_url
|
||||
return app
|
||||
|
||||
|
||||
class TestCeleryGetBrokerClient:
|
||||
@patch("onyx.background.celery.celery_redis.Redis")
|
||||
def test_creates_client_on_first_call(self, mock_redis_cls: MagicMock) -> None:
|
||||
mock_client = MagicMock()
|
||||
mock_redis_cls.from_url.return_value = mock_client
|
||||
|
||||
app = _make_mock_app()
|
||||
result = celery_redis.celery_get_broker_client(app)
|
||||
|
||||
assert result is mock_client
|
||||
call_args = mock_redis_cls.from_url.call_args
|
||||
assert call_args[0][0] == "redis://localhost:6379/15"
|
||||
assert call_args[1]["decode_responses"] is False
|
||||
assert call_args[1]["socket_keepalive"] is True
|
||||
assert call_args[1]["retry_on_timeout"] is True
|
||||
|
||||
@patch("onyx.background.celery.celery_redis.Redis")
|
||||
def test_reuses_cached_client(self, mock_redis_cls: MagicMock) -> None:
|
||||
mock_client = MagicMock()
|
||||
mock_client.ping.return_value = True
|
||||
mock_redis_cls.from_url.return_value = mock_client
|
||||
|
||||
app = _make_mock_app()
|
||||
client1 = celery_redis.celery_get_broker_client(app)
|
||||
client2 = celery_redis.celery_get_broker_client(app)
|
||||
|
||||
assert client1 is client2
|
||||
# from_url called only once
|
||||
assert mock_redis_cls.from_url.call_count == 1
|
||||
|
||||
@patch("onyx.background.celery.celery_redis.Redis")
|
||||
def test_reconnects_on_ping_failure(self, mock_redis_cls: MagicMock) -> None:
|
||||
stale_client = MagicMock()
|
||||
stale_client.ping.side_effect = ConnectionError("disconnected")
|
||||
|
||||
fresh_client = MagicMock()
|
||||
fresh_client.ping.return_value = True
|
||||
|
||||
mock_redis_cls.from_url.side_effect = [stale_client, fresh_client]
|
||||
|
||||
app = _make_mock_app()
|
||||
|
||||
# First call creates stale_client
|
||||
client1 = celery_redis.celery_get_broker_client(app)
|
||||
assert client1 is stale_client
|
||||
|
||||
# Second call: ping fails, creates fresh_client
|
||||
client2 = celery_redis.celery_get_broker_client(app)
|
||||
assert client2 is fresh_client
|
||||
assert mock_redis_cls.from_url.call_count == 2
|
||||
|
||||
@patch("onyx.background.celery.celery_redis.Redis")
|
||||
def test_uses_broker_url_from_app_config(self, mock_redis_cls: MagicMock) -> None:
|
||||
mock_redis_cls.from_url.return_value = MagicMock()
|
||||
|
||||
app = _make_mock_app("redis://custom-host:6380/3")
|
||||
celery_redis.celery_get_broker_client(app)
|
||||
|
||||
call_args = mock_redis_cls.from_url.call_args
|
||||
assert call_args[0][0] == "redis://custom-host:6380/3"
|
||||
147
backend/tests/unit/onyx/connectors/jira/test_jira_bulk_fetch.py
Normal file
147
backend/tests/unit/onyx/connectors/jira/test_jira_bulk_fetch.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from jira import JIRA
|
||||
from jira.resources import Issue
|
||||
|
||||
from onyx.connectors.jira.connector import bulk_fetch_issues
|
||||
|
||||
|
||||
def _make_raw_issue(issue_id: str) -> dict[str, Any]:
|
||||
return {
|
||||
"id": issue_id,
|
||||
"key": f"TEST-{issue_id}",
|
||||
"fields": {"summary": f"Issue {issue_id}"},
|
||||
}
|
||||
|
||||
|
||||
def _mock_jira_client() -> MagicMock:
|
||||
mock = MagicMock(spec=JIRA)
|
||||
mock._options = {"server": "https://jira.example.com"}
|
||||
mock._session = MagicMock()
|
||||
mock._get_url = MagicMock(
|
||||
return_value="https://jira.example.com/rest/api/3/issue/bulkfetch"
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
def test_bulk_fetch_success() -> None:
|
||||
"""Happy path: all issues fetched in one request."""
|
||||
client = _mock_jira_client()
|
||||
raw = [_make_raw_issue("1"), _make_raw_issue("2"), _make_raw_issue("3")]
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": raw}
|
||||
client._session.post.return_value = resp
|
||||
|
||||
result = bulk_fetch_issues(client, ["1", "2", "3"])
|
||||
assert len(result) == 3
|
||||
assert all(isinstance(r, Issue) for r in result)
|
||||
client._session.post.assert_called_once()
|
||||
|
||||
|
||||
def test_bulk_fetch_splits_on_json_error() -> None:
|
||||
"""When the full batch fails with JSONDecodeError, sub-batches succeed."""
|
||||
client = _mock_jira_client()
|
||||
|
||||
call_count = 0
|
||||
|
||||
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
ids = json["issueIdsOrKeys"]
|
||||
if len(ids) > 2:
|
||||
resp = MagicMock()
|
||||
resp.json.side_effect = requests.exceptions.JSONDecodeError(
|
||||
"Expecting ',' delimiter", "doc", 2294125
|
||||
)
|
||||
return resp
|
||||
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
|
||||
return resp
|
||||
|
||||
client._session.post.side_effect = _post_side_effect
|
||||
|
||||
result = bulk_fetch_issues(client, ["1", "2", "3", "4"])
|
||||
assert len(result) == 4
|
||||
returned_ids = {r.raw["id"] for r in result}
|
||||
assert returned_ids == {"1", "2", "3", "4"}
|
||||
assert call_count > 1
|
||||
|
||||
|
||||
def test_bulk_fetch_raises_on_single_unfetchable_issue() -> None:
|
||||
"""A single issue that always fails JSON decode raises after splitting."""
|
||||
client = _mock_jira_client()
|
||||
|
||||
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
|
||||
ids = json["issueIdsOrKeys"]
|
||||
if "bad" in ids:
|
||||
resp = MagicMock()
|
||||
resp.json.side_effect = requests.exceptions.JSONDecodeError(
|
||||
"Expecting ',' delimiter", "doc", 100
|
||||
)
|
||||
return resp
|
||||
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
|
||||
return resp
|
||||
|
||||
client._session.post.side_effect = _post_side_effect
|
||||
|
||||
with pytest.raises(requests.exceptions.JSONDecodeError):
|
||||
bulk_fetch_issues(client, ["1", "bad", "2"])
|
||||
|
||||
|
||||
def test_bulk_fetch_non_json_error_propagates() -> None:
|
||||
"""Non-JSONDecodeError exceptions still propagate."""
|
||||
client = _mock_jira_client()
|
||||
|
||||
resp = MagicMock()
|
||||
resp.json.side_effect = ValueError("something else broke")
|
||||
client._session.post.return_value = resp
|
||||
|
||||
try:
|
||||
bulk_fetch_issues(client, ["1"])
|
||||
assert False, "Expected ValueError to propagate"
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
def test_bulk_fetch_with_fields() -> None:
|
||||
"""Fields parameter is forwarded correctly."""
|
||||
client = _mock_jira_client()
|
||||
raw = [_make_raw_issue("1")]
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": raw}
|
||||
client._session.post.return_value = resp
|
||||
|
||||
bulk_fetch_issues(client, ["1"], fields="summary,description")
|
||||
|
||||
call_payload = client._session.post.call_args[1]["json"]
|
||||
assert call_payload["fields"] == ["summary", "description"]
|
||||
|
||||
|
||||
def test_bulk_fetch_recursive_splitting_raises_on_bad_issue() -> None:
|
||||
"""With a 6-issue batch where one is bad, recursion isolates it and raises."""
|
||||
client = _mock_jira_client()
|
||||
bad_id = "BAD"
|
||||
|
||||
def _post_side_effect(url: str, json: dict[str, Any]) -> MagicMock: # noqa: ARG001
|
||||
ids = json["issueIdsOrKeys"]
|
||||
if bad_id in ids:
|
||||
resp = MagicMock()
|
||||
resp.json.side_effect = requests.exceptions.JSONDecodeError(
|
||||
"truncated", "doc", 999
|
||||
)
|
||||
return resp
|
||||
|
||||
resp = MagicMock()
|
||||
resp.json.return_value = {"issues": [_make_raw_issue(i) for i in ids]}
|
||||
return resp
|
||||
|
||||
client._session.post.side_effect = _post_side_effect
|
||||
|
||||
with pytest.raises(requests.exceptions.JSONDecodeError):
|
||||
bulk_fetch_issues(client, ["1", "2", bad_id, "3", "4", "5"])
|
||||
225
backend/tests/unit/onyx/db/test_chat_sessions.py
Normal file
225
backend/tests/unit/onyx/db/test_chat_sessions.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Tests for get_chat_sessions_by_user filtering behavior.
|
||||
|
||||
Verifies that failed chat sessions (those with only SYSTEM messages) are
|
||||
correctly filtered out while preserving recently created sessions, matching
|
||||
the behavior specified in PR #7233.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.models import ChatSession
|
||||
|
||||
|
||||
def _make_session(
|
||||
user_id: UUID,
|
||||
time_created: datetime | None = None,
|
||||
time_updated: datetime | None = None,
|
||||
description: str = "",
|
||||
) -> MagicMock:
|
||||
"""Create a mock ChatSession with the given attributes."""
|
||||
session = MagicMock(spec=ChatSession)
|
||||
session.id = uuid4()
|
||||
session.user_id = user_id
|
||||
session.time_created = time_created or datetime.now(timezone.utc)
|
||||
session.time_updated = time_updated or session.time_created
|
||||
session.description = description
|
||||
session.deleted = False
|
||||
session.onyxbot_flow = False
|
||||
session.project_id = None
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_id() -> UUID:
|
||||
return uuid4()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def old_time() -> datetime:
|
||||
"""A timestamp well outside the 5-minute leeway window."""
|
||||
return datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def recent_time() -> datetime:
|
||||
"""A timestamp within the 5-minute leeway window."""
|
||||
return datetime.now(timezone.utc) - timedelta(minutes=2)
|
||||
|
||||
|
||||
class TestGetChatSessionsByUser:
|
||||
"""Tests for the failed chat filtering logic in get_chat_sessions_by_user."""
|
||||
|
||||
def test_filters_out_failed_sessions(
|
||||
self, user_id: UUID, old_time: datetime
|
||||
) -> None:
|
||||
"""Sessions with only SYSTEM messages should be excluded."""
|
||||
valid_session = _make_session(user_id, time_created=old_time)
|
||||
failed_session = _make_session(user_id, time_created=old_time)
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
# First execute: returns all sessions
|
||||
# Second execute: returns only the valid session's ID (has non-system msgs)
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.scalars.return_value.all.return_value = [
|
||||
valid_session,
|
||||
failed_session,
|
||||
]
|
||||
|
||||
mock_result_2 = MagicMock()
|
||||
mock_result_2.scalars.return_value.all.return_value = [valid_session.id]
|
||||
|
||||
db_session.execute.side_effect = [mock_result_1, mock_result_2]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].id == valid_session.id
|
||||
|
||||
def test_keeps_recent_sessions_without_messages(
|
||||
self, user_id: UUID, recent_time: datetime
|
||||
) -> None:
|
||||
"""Recently created sessions should be kept even without messages."""
|
||||
recent_session = _make_session(user_id, time_created=recent_time)
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.scalars.return_value.all.return_value = [recent_session]
|
||||
|
||||
db_session.execute.side_effect = [mock_result_1]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].id == recent_session.id
|
||||
# Should only have been called once — no second query needed
|
||||
# because the recent session is within the leeway window
|
||||
assert db_session.execute.call_count == 1
|
||||
|
||||
def test_include_failed_chats_skips_filtering(
|
||||
self, user_id: UUID, old_time: datetime
|
||||
) -> None:
|
||||
"""When include_failed_chats=True, no filtering should occur."""
|
||||
session_a = _make_session(user_id, time_created=old_time)
|
||||
session_b = _make_session(user_id, time_created=old_time)
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [session_a, session_b]
|
||||
|
||||
db_session.execute.side_effect = [mock_result]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=True,
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
# Only one DB call — no second query for message validation
|
||||
assert db_session.execute.call_count == 1
|
||||
|
||||
def test_limit_applied_after_filtering(
|
||||
self, user_id: UUID, old_time: datetime
|
||||
) -> None:
|
||||
"""Limit should be applied after filtering, not before."""
|
||||
sessions = [_make_session(user_id, time_created=old_time) for _ in range(5)]
|
||||
valid_ids = [s.id for s in sessions[:3]]
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.scalars.return_value.all.return_value = sessions
|
||||
|
||||
mock_result_2 = MagicMock()
|
||||
mock_result_2.scalars.return_value.all.return_value = valid_ids
|
||||
|
||||
db_session.execute.side_effect = [mock_result_1, mock_result_2]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
limit=2,
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
# Should be the first 2 valid sessions (order preserved)
|
||||
assert result[0].id == sessions[0].id
|
||||
assert result[1].id == sessions[1].id
|
||||
|
||||
def test_mixed_recent_and_old_sessions(
|
||||
self, user_id: UUID, old_time: datetime, recent_time: datetime
|
||||
) -> None:
|
||||
"""Mix of recent and old sessions should filter correctly."""
|
||||
old_valid = _make_session(user_id, time_created=old_time)
|
||||
old_failed = _make_session(user_id, time_created=old_time)
|
||||
recent_no_msgs = _make_session(user_id, time_created=recent_time)
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.scalars.return_value.all.return_value = [
|
||||
old_valid,
|
||||
old_failed,
|
||||
recent_no_msgs,
|
||||
]
|
||||
|
||||
mock_result_2 = MagicMock()
|
||||
mock_result_2.scalars.return_value.all.return_value = [old_valid.id]
|
||||
|
||||
db_session.execute.side_effect = [mock_result_1, mock_result_2]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
)
|
||||
|
||||
result_ids = {cs.id for cs in result}
|
||||
assert old_valid.id in result_ids
|
||||
assert recent_no_msgs.id in result_ids
|
||||
assert old_failed.id not in result_ids
|
||||
|
||||
def test_empty_result(self, user_id: UUID) -> None:
|
||||
"""No sessions should return empty list without errors."""
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
|
||||
db_session.execute.side_effect = [mock_result]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
)
|
||||
|
||||
assert result == []
|
||||
assert db_session.execute.call_count == 1
|
||||
@@ -272,13 +272,13 @@ class TestUpsertVoiceProvider:
|
||||
class TestDeleteVoiceProvider:
|
||||
"""Tests for delete_voice_provider."""
|
||||
|
||||
def test_soft_deletes_provider_when_found(self, mock_db_session: MagicMock) -> None:
|
||||
def test_hard_deletes_provider_when_found(self, mock_db_session: MagicMock) -> None:
|
||||
provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
|
||||
delete_voice_provider(mock_db_session, 1)
|
||||
|
||||
assert provider.deleted is True
|
||||
mock_db_session.delete.assert_called_once_with(provider)
|
||||
mock_db_session.flush.assert_called_once()
|
||||
|
||||
def test_does_nothing_when_provider_not_found(
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
%PDF-1.3
|
||||
%<25><><EFBFBD><EFBFBD>
|
||||
1 0 obj
|
||||
<<
|
||||
/Producer <1083d595b1>
|
||||
>>
|
||||
endobj
|
||||
2 0 obj
|
||||
<<
|
||||
/Type /Pages
|
||||
/Count 1
|
||||
/Kids [ 4 0 R ]
|
||||
>>
|
||||
endobj
|
||||
3 0 obj
|
||||
<<
|
||||
/Type /Catalog
|
||||
/Pages 2 0 R
|
||||
>>
|
||||
endobj
|
||||
4 0 obj
|
||||
<<
|
||||
/Type /Page
|
||||
/Resources <<
|
||||
/Font <<
|
||||
/F1 <<
|
||||
/Type /Font
|
||||
/Subtype /Type1
|
||||
/BaseFont /Helvetica
|
||||
>>
|
||||
>>
|
||||
>>
|
||||
/MediaBox [ 0.0 0.0 200 200 ]
|
||||
/Contents 5 0 R
|
||||
/Parent 2 0 R
|
||||
>>
|
||||
endobj
|
||||
5 0 obj
|
||||
<<
|
||||
/Length 42
|
||||
>>
|
||||
stream
|
||||
,N<><6~<7E>)<29><><EFBFBD><EFBFBD><EFBFBD>u<EFBFBD><0C><><EFBFBD>Zc'<27><>>8g<38><67><EFBFBD>n<EFBFBD><6E><EFBFBD><EFBFBD><EFBFBD>9"
|
||||
endstream
|
||||
endobj
|
||||
6 0 obj
|
||||
<<
|
||||
/V 2
|
||||
/R 3
|
||||
/Length 128
|
||||
/P 4294967292
|
||||
/Filter /Standard
|
||||
/O <6a340a292629053da84a6d8b19a5d505953b8b3fdac3d2d389fde0e354528d44>
|
||||
/U <d6f0dc91c7b9de264a8d708515468e6528bf4e5e4e758a4164004e56fffa0108>
|
||||
>>
|
||||
endobj
|
||||
xref
|
||||
0 7
|
||||
0000000000 65535 f
|
||||
0000000015 00000 n
|
||||
0000000059 00000 n
|
||||
0000000118 00000 n
|
||||
0000000167 00000 n
|
||||
0000000348 00000 n
|
||||
0000000440 00000 n
|
||||
trailer
|
||||
<<
|
||||
/Size 7
|
||||
/Root 3 0 R
|
||||
/Info 1 0 R
|
||||
/ID [ <6364336635356135633239323638353039306635656133623165313637366430> <6364336635356135633239323638353039306635656133623165313637366430> ]
|
||||
/Encrypt 6 0 R
|
||||
>>
|
||||
startxref
|
||||
655
|
||||
%%EOF
|
||||
@@ -54,6 +54,12 @@ class TestReadPdfFile:
|
||||
text, _, _ = read_pdf_file(_load("encrypted.pdf"), pdf_pass="wrong")
|
||||
assert text == ""
|
||||
|
||||
def test_owner_password_only_pdf_extracts_text(self) -> None:
|
||||
"""A PDF encrypted with only an owner password (no user password)
|
||||
should still yield its text content. Regression for #9754."""
|
||||
text, _, _ = read_pdf_file(_load("owner_protected.pdf"))
|
||||
assert "Hello World" in text
|
||||
|
||||
def test_empty_pdf(self) -> None:
|
||||
text, _, _ = read_pdf_file(_load("empty.pdf"))
|
||||
assert text.strip() == ""
|
||||
@@ -117,6 +123,12 @@ class TestIsPdfProtected:
|
||||
def test_protected_pdf(self) -> None:
|
||||
assert is_pdf_protected(_load("encrypted.pdf")) is True
|
||||
|
||||
def test_owner_password_only_is_not_protected(self) -> None:
|
||||
"""A PDF with only an owner password (permission restrictions) but no
|
||||
user password should NOT be considered protected — any viewer can open
|
||||
it without prompting for a password."""
|
||||
assert is_pdf_protected(_load("owner_protected.pdf")) is False
|
||||
|
||||
def test_preserves_file_position(self) -> None:
|
||||
pdf = _load("simple.pdf")
|
||||
pdf.seek(42)
|
||||
|
||||
79
backend/tests/unit/onyx/file_processing/test_pptx_to_text.py
Normal file
79
backend/tests/unit/onyx/file_processing/test_pptx_to_text.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import io
|
||||
|
||||
from pptx import Presentation # type: ignore[import-untyped]
|
||||
from pptx.chart.data import CategoryChartData # type: ignore[import-untyped]
|
||||
from pptx.enum.chart import XL_CHART_TYPE # type: ignore[import-untyped]
|
||||
from pptx.util import Inches # type: ignore[import-untyped]
|
||||
|
||||
from onyx.file_processing.extract_file_text import pptx_to_text
|
||||
|
||||
|
||||
def _make_pptx_with_chart() -> io.BytesIO:
|
||||
"""Create an in-memory pptx with one text slide and one chart slide."""
|
||||
prs = Presentation()
|
||||
|
||||
# Slide 1: text only
|
||||
slide1 = prs.slides.add_slide(prs.slide_layouts[1])
|
||||
slide1.shapes.title.text = "Introduction"
|
||||
slide1.placeholders[1].text = "This is the first slide."
|
||||
|
||||
# Slide 2: chart
|
||||
slide2 = prs.slides.add_slide(prs.slide_layouts[5]) # Blank layout
|
||||
chart_data = CategoryChartData()
|
||||
chart_data.categories = ["Q1", "Q2", "Q3"]
|
||||
chart_data.add_series("Revenue", (100, 200, 300))
|
||||
slide2.shapes.add_chart(
|
||||
XL_CHART_TYPE.COLUMN_CLUSTERED,
|
||||
Inches(1),
|
||||
Inches(1),
|
||||
Inches(6),
|
||||
Inches(4),
|
||||
chart_data,
|
||||
)
|
||||
|
||||
buf = io.BytesIO()
|
||||
prs.save(buf)
|
||||
buf.seek(0)
|
||||
return buf
|
||||
|
||||
|
||||
def _make_pptx_without_chart() -> io.BytesIO:
|
||||
"""Create an in-memory pptx with a single text-only slide."""
|
||||
prs = Presentation()
|
||||
slide = prs.slides.add_slide(prs.slide_layouts[1])
|
||||
slide.shapes.title.text = "Hello World"
|
||||
slide.placeholders[1].text = "Some content here."
|
||||
|
||||
buf = io.BytesIO()
|
||||
prs.save(buf)
|
||||
buf.seek(0)
|
||||
return buf
|
||||
|
||||
|
||||
class TestPptxToText:
|
||||
def test_chart_is_omitted(self) -> None:
|
||||
# Precondition
|
||||
pptx_file = _make_pptx_with_chart()
|
||||
|
||||
# Under test
|
||||
result = pptx_to_text(pptx_file)
|
||||
|
||||
# Postcondition
|
||||
assert "Introduction" in result
|
||||
assert "first slide" in result
|
||||
assert "[chart omitted]" in result
|
||||
# The actual chart data should NOT appear in the output.
|
||||
assert "Revenue" not in result
|
||||
assert "Q1" not in result
|
||||
|
||||
def test_text_only_pptx(self) -> None:
|
||||
# Precondition
|
||||
pptx_file = _make_pptx_without_chart()
|
||||
|
||||
# Under test
|
||||
result = pptx_to_text(pptx_file)
|
||||
|
||||
# Postcondition
|
||||
assert "Hello World" in result
|
||||
assert "Some content" in result
|
||||
assert "[chart omitted]" not in result
|
||||
@@ -11,6 +11,7 @@ from litellm.types.utils import ChatCompletionDeltaToolCall
|
||||
from litellm.types.utils import Delta
|
||||
from litellm.types.utils import Function as LiteLLMFunction
|
||||
|
||||
import onyx.llm.models
|
||||
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
@@ -1479,6 +1480,147 @@ def test_bifrost_normalizes_api_base_in_model_kwargs() -> None:
|
||||
assert llm._model_kwargs["api_base"] == "https://bifrost.example.com/v1"
|
||||
|
||||
|
||||
def test_prompt_contains_tool_call_history_true() -> None:
|
||||
from onyx.llm.multi_llm import _prompt_contains_tool_call_history
|
||||
|
||||
messages: LanguageModelInput = [
|
||||
UserMessage(content="What's the weather?"),
|
||||
AssistantMessage(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id="tc_1",
|
||||
function=FunctionCall(name="get_weather", arguments="{}"),
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
assert _prompt_contains_tool_call_history(messages) is True
|
||||
|
||||
|
||||
def test_prompt_contains_tool_call_history_false_no_tools() -> None:
|
||||
from onyx.llm.multi_llm import _prompt_contains_tool_call_history
|
||||
|
||||
messages: LanguageModelInput = [
|
||||
UserMessage(content="Hello"),
|
||||
AssistantMessage(content="Hi there!"),
|
||||
]
|
||||
assert _prompt_contains_tool_call_history(messages) is False
|
||||
|
||||
|
||||
def test_prompt_contains_tool_call_history_false_user_only() -> None:
|
||||
from onyx.llm.multi_llm import _prompt_contains_tool_call_history
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hello")]
|
||||
assert _prompt_contains_tool_call_history(messages) is False
|
||||
|
||||
|
||||
def test_bedrock_claude_drops_thinking_when_thinking_blocks_missing() -> None:
|
||||
"""When thinking is enabled but assistant messages with tool_calls lack
|
||||
thinking_blocks, the thinking param must be dropped to avoid the Bedrock
|
||||
BadRequestError about missing thinking blocks."""
|
||||
llm = LitellmLLM(
|
||||
api_key=None,
|
||||
timeout=30,
|
||||
model_provider=LlmProviderNames.BEDROCK,
|
||||
model_name="anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
max_input_tokens=200000,
|
||||
)
|
||||
|
||||
messages: LanguageModelInput = [
|
||||
UserMessage(content="What's the weather?"),
|
||||
AssistantMessage(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id="tc_1",
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "Paris"}',
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
onyx.llm.models.ToolMessage(
|
||||
content="22°C sunny",
|
||||
tool_call_id="tc_1",
|
||||
),
|
||||
]
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with (
|
||||
patch("litellm.completion") as mock_completion,
|
||||
patch("onyx.llm.multi_llm.model_is_reasoning_model", return_value=True),
|
||||
):
|
||||
mock_completion.return_value = []
|
||||
|
||||
list(llm.stream(messages, tools=tools, reasoning_effort=ReasoningEffort.HIGH))
|
||||
|
||||
kwargs = mock_completion.call_args.kwargs
|
||||
assert "thinking" not in kwargs, (
|
||||
"thinking param should be dropped when thinking_blocks are missing "
|
||||
"from assistant messages with tool_calls"
|
||||
)
|
||||
|
||||
|
||||
def test_bedrock_claude_keeps_thinking_when_no_tool_history() -> None:
|
||||
"""When thinking is enabled and there are no historical assistant messages
|
||||
with tool_calls, the thinking param should be preserved."""
|
||||
llm = LitellmLLM(
|
||||
api_key=None,
|
||||
timeout=30,
|
||||
model_provider=LlmProviderNames.BEDROCK,
|
||||
model_name="anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
max_input_tokens=200000,
|
||||
)
|
||||
|
||||
messages: LanguageModelInput = [
|
||||
UserMessage(content="What's the weather?"),
|
||||
]
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with (
|
||||
patch("litellm.completion") as mock_completion,
|
||||
patch("onyx.llm.multi_llm.model_is_reasoning_model", return_value=True),
|
||||
):
|
||||
mock_completion.return_value = []
|
||||
|
||||
list(llm.stream(messages, tools=tools, reasoning_effort=ReasoningEffort.HIGH))
|
||||
|
||||
kwargs = mock_completion.call_args.kwargs
|
||||
assert "thinking" in kwargs, (
|
||||
"thinking param should be preserved when no assistant messages "
|
||||
"with tool_calls exist in history"
|
||||
)
|
||||
assert kwargs["thinking"]["type"] == "enabled"
|
||||
|
||||
|
||||
def test_bifrost_claude_includes_allowed_openai_params() -> None:
|
||||
llm = LitellmLLM(
|
||||
api_key="test_key",
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
Covers:
|
||||
- _check_ssrf_safety: scheme enforcement and private-IP blocklist
|
||||
- _validate_endpoint: httpx exception → HookValidateStatus mapping
|
||||
ConnectTimeout → cannot_connect (TCP handshake never completed)
|
||||
ConnectTimeout → timeout (any timeout directs user to increase timeout_seconds)
|
||||
ConnectError → cannot_connect (DNS / TLS failure)
|
||||
ReadTimeout et al. → timeout (TCP connected, server slow)
|
||||
Any other exc → cannot_connect
|
||||
@@ -61,7 +61,7 @@ class TestCheckSsrfSafety:
|
||||
def test_non_https_scheme_rejected(self, url: str) -> None:
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
self._call(url)
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
|
||||
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
|
||||
assert "https" in (exc_info.value.detail or "").lower()
|
||||
|
||||
# --- private IP blocklist ---
|
||||
@@ -87,7 +87,7 @@ class TestCheckSsrfSafety:
|
||||
):
|
||||
mock_dns.return_value = [(None, None, None, None, (ip, 0))]
|
||||
self._call("https://internal.example.com/hook")
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
|
||||
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
|
||||
assert ip in (exc_info.value.detail or "")
|
||||
|
||||
def test_public_ip_is_allowed(self) -> None:
|
||||
@@ -106,7 +106,7 @@ class TestCheckSsrfSafety:
|
||||
pytest.raises(OnyxError) as exc_info,
|
||||
):
|
||||
self._call("https://no-such-host.example.com/hook")
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INVALID_INPUT
|
||||
assert exc_info.value.error_code == OnyxErrorCode.BAD_GATEWAY
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -158,13 +158,11 @@ class TestValidateEndpoint:
|
||||
assert self._call().status == HookValidateStatus.passed
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
def test_connect_timeout_returns_cannot_connect(
|
||||
self, mock_client_cls: MagicMock
|
||||
) -> None:
|
||||
def test_connect_timeout_returns_timeout(self, mock_client_cls: MagicMock) -> None:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.side_effect = (
|
||||
httpx.ConnectTimeout("timed out")
|
||||
)
|
||||
assert self._call().status == HookValidateStatus.cannot_connect
|
||||
assert self._call().status == HookValidateStatus.timeout
|
||||
|
||||
@patch("onyx.server.features.hooks.api.httpx.Client")
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Tests for indexing pipeline Prometheus collectors."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@@ -13,6 +14,16 @@ from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
|
||||
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_broker_client() -> Iterator[None]:
|
||||
"""Patch celery_get_broker_client for all collector tests."""
|
||||
with patch(
|
||||
"onyx.background.celery.celery_redis.celery_get_broker_client",
|
||||
return_value=MagicMock(),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class TestQueueDepthCollector:
|
||||
def test_returns_empty_when_factory_not_set(self) -> None:
|
||||
collector = QueueDepthCollector()
|
||||
@@ -24,8 +35,7 @@ class TestQueueDepthCollector:
|
||||
|
||||
def test_collects_queue_depths(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
collector.set_celery_app(MagicMock())
|
||||
|
||||
with (
|
||||
patch(
|
||||
@@ -60,8 +70,8 @@ class TestQueueDepthCollector:
|
||||
|
||||
def test_handles_redis_error_gracefully(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
MagicMock()
|
||||
collector.set_celery_app(MagicMock())
|
||||
|
||||
with patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
@@ -74,8 +84,8 @@ class TestQueueDepthCollector:
|
||||
|
||||
def test_caching_returns_stale_within_ttl(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=60)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
MagicMock()
|
||||
collector.set_celery_app(MagicMock())
|
||||
|
||||
with (
|
||||
patch(
|
||||
@@ -98,31 +108,10 @@ class TestQueueDepthCollector:
|
||||
|
||||
assert first is second # Same object, from cache
|
||||
|
||||
def test_factory_called_each_scrape(self) -> None:
|
||||
"""Verify the Redis factory is called on each fresh collect, not cached."""
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
factory = MagicMock(return_value=MagicMock())
|
||||
collector.set_redis_factory(factory)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
return_value=0,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
|
||||
return_value=set(),
|
||||
),
|
||||
):
|
||||
collector.collect()
|
||||
collector.collect()
|
||||
|
||||
assert factory.call_count == 2
|
||||
|
||||
def test_error_returns_stale_cache(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
MagicMock()
|
||||
collector.set_celery_app(MagicMock())
|
||||
|
||||
# First call succeeds
|
||||
with (
|
||||
|
||||
@@ -1,96 +1,22 @@
|
||||
"""Tests for indexing pipeline setup (Redis factory caching)."""
|
||||
"""Tests for indexing pipeline setup."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from onyx.server.metrics.indexing_pipeline_setup import _make_broker_redis_factory
|
||||
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import RedisHealthCollector
|
||||
|
||||
|
||||
def _make_mock_app(client: MagicMock) -> MagicMock:
|
||||
"""Create a mock Celery app whose broker_connection().channel().client
|
||||
returns the given client."""
|
||||
mock_app = MagicMock()
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.channel.return_value.client = client
|
||||
class TestCollectorCeleryAppSetup:
|
||||
def test_queue_depth_collector_uses_celery_app(self) -> None:
|
||||
"""QueueDepthCollector.set_celery_app stores the app for broker access."""
|
||||
collector = QueueDepthCollector()
|
||||
mock_app = MagicMock()
|
||||
collector.set_celery_app(mock_app)
|
||||
assert collector._celery_app is mock_app
|
||||
|
||||
mock_app.broker_connection.return_value = mock_conn
|
||||
|
||||
return mock_app
|
||||
|
||||
|
||||
class TestMakeBrokerRedisFactory:
|
||||
def test_caches_redis_client_across_calls(self) -> None:
|
||||
"""Factory should reuse the same client on subsequent calls."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.ping.return_value = True
|
||||
mock_app = _make_mock_app(mock_client)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
|
||||
client1 = factory()
|
||||
client2 = factory()
|
||||
|
||||
assert client1 is client2
|
||||
# broker_connection should only be called once
|
||||
assert mock_app.broker_connection.call_count == 1
|
||||
|
||||
def test_reconnects_when_ping_fails(self) -> None:
|
||||
"""Factory should create a new client if ping fails (stale connection)."""
|
||||
mock_client_stale = MagicMock()
|
||||
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
|
||||
|
||||
mock_client_fresh = MagicMock()
|
||||
mock_client_fresh.ping.return_value = True
|
||||
|
||||
mock_app = _make_mock_app(mock_client_stale)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
|
||||
# First call — creates and caches
|
||||
client1 = factory()
|
||||
assert client1 is mock_client_stale
|
||||
assert mock_app.broker_connection.call_count == 1
|
||||
|
||||
# Switch to fresh client for next connection
|
||||
mock_conn_fresh = MagicMock()
|
||||
mock_conn_fresh.channel.return_value.client = mock_client_fresh
|
||||
mock_app.broker_connection.return_value = mock_conn_fresh
|
||||
|
||||
# Second call — ping fails on stale, reconnects
|
||||
client2 = factory()
|
||||
assert client2 is mock_client_fresh
|
||||
assert mock_app.broker_connection.call_count == 2
|
||||
|
||||
def test_reconnect_closes_stale_client(self) -> None:
|
||||
"""When ping fails, the old client should be closed before reconnecting."""
|
||||
mock_client_stale = MagicMock()
|
||||
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
|
||||
|
||||
mock_client_fresh = MagicMock()
|
||||
mock_client_fresh.ping.return_value = True
|
||||
|
||||
mock_app = _make_mock_app(mock_client_stale)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
|
||||
# First call — creates and caches
|
||||
factory()
|
||||
|
||||
# Switch to fresh client
|
||||
mock_conn_fresh = MagicMock()
|
||||
mock_conn_fresh.channel.return_value.client = mock_client_fresh
|
||||
mock_app.broker_connection.return_value = mock_conn_fresh
|
||||
|
||||
# Second call — ping fails, should close stale client
|
||||
factory()
|
||||
mock_client_stale.close.assert_called_once()
|
||||
|
||||
def test_first_call_creates_connection(self) -> None:
|
||||
"""First call should always create a new connection."""
|
||||
mock_client = MagicMock()
|
||||
mock_app = _make_mock_app(mock_client)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
client = factory()
|
||||
|
||||
assert client is mock_client
|
||||
mock_app.broker_connection.assert_called_once()
|
||||
def test_redis_health_collector_uses_celery_app(self) -> None:
|
||||
"""RedisHealthCollector.set_celery_app stores the app for broker access."""
|
||||
collector = RedisHealthCollector()
|
||||
mock_app = MagicMock()
|
||||
collector.set_celery_app(mock_app)
|
||||
assert collector._celery_app is mock_app
|
||||
|
||||
@@ -70,6 +70,10 @@ backend = [
|
||||
"lazy_imports==1.0.1",
|
||||
"lxml==5.3.0",
|
||||
"Mako==1.2.4",
|
||||
# NOTE: Do not update without understanding the patching behavior in
|
||||
# get_markitdown_converter in
|
||||
# backend/onyx/file_processing/extract_file_text.py and what impacts
|
||||
# updating might have on this behavior.
|
||||
"markitdown[pdf, docx, pptx, xlsx, xls]==0.1.2",
|
||||
"mcp[cli]==1.26.0",
|
||||
"msal==1.34.0",
|
||||
|
||||
22
web/lib/opal/src/icons/bifrost.tsx
Normal file
22
web/lib/opal/src/icons/bifrost.tsx
Normal file
@@ -0,0 +1,22 @@
|
||||
import { cn } from "@opal/utils";
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgBifrost = ({ size, className, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 37 46"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className={cn(className, "text-[#33C19E] dark:text-white")}
|
||||
{...props}
|
||||
>
|
||||
<title>Bifrost</title>
|
||||
<path
|
||||
d="M27.6219 46H0V36.8H27.6219V46ZM36.8268 36.8H27.6219V27.6H36.8268V36.8ZM18.4146 27.6H9.2073V18.4H18.4146V27.6ZM36.8268 18.4H27.6219V9.2H36.8268V18.4ZM27.6219 9.2H0V0H27.6219V9.2Z"
|
||||
fill="currentColor"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
|
||||
export default SvgBifrost;
|
||||
@@ -24,6 +24,7 @@ export { default as SvgAzure } from "@opal/icons/azure";
|
||||
export { default as SvgBarChart } from "@opal/icons/bar-chart";
|
||||
export { default as SvgBarChartSmall } from "@opal/icons/bar-chart-small";
|
||||
export { default as SvgBell } from "@opal/icons/bell";
|
||||
export { default as SvgBifrost } from "@opal/icons/bifrost";
|
||||
export { default as SvgBlocks } from "@opal/icons/blocks";
|
||||
export { default as SvgBookOpen } from "@opal/icons/book-open";
|
||||
export { default as SvgBookmark } from "@opal/icons/bookmark";
|
||||
|
||||
@@ -30,8 +30,11 @@ import CreateButton from "@/refresh-components/buttons/CreateButton";
|
||||
import { Button } from "@opal/components";
|
||||
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { SvgEdit, SvgInfo, SvgKey, SvgRefreshCw } from "@opal/icons";
|
||||
import { SvgEdit, SvgKey, SvgRefreshCw } from "@opal/icons";
|
||||
import Message from "@/refresh-components/messages/Message";
|
||||
import { useCloudSubscription } from "@/hooks/useCloudSubscription";
|
||||
import { useBillingInformation } from "@/hooks/useBillingInformation";
|
||||
import { BillingStatus, hasActiveSubscription } from "@/lib/billing/interfaces";
|
||||
import { ADMIN_ROUTES } from "@/lib/admin-routes";
|
||||
|
||||
const route = ADMIN_ROUTES.API_KEYS;
|
||||
@@ -44,6 +47,11 @@ function Main() {
|
||||
} = useSWR<APIKey[]>("/api/admin/api-key", errorHandlingFetcher);
|
||||
|
||||
const canCreateKeys = useCloudSubscription();
|
||||
const { data: billingData } = useBillingInformation();
|
||||
const isTrialing =
|
||||
billingData !== undefined &&
|
||||
hasActiveSubscription(billingData) &&
|
||||
billingData.status === BillingStatus.TRIALING;
|
||||
|
||||
const [fullApiKey, setFullApiKey] = useState<string | null>(null);
|
||||
const [keyIsGenerating, setKeyIsGenerating] = useState(false);
|
||||
@@ -75,6 +83,16 @@ function Main() {
|
||||
|
||||
const introSection = (
|
||||
<div className="flex flex-col items-start gap-4">
|
||||
{isTrialing && (
|
||||
<Message
|
||||
static
|
||||
warning
|
||||
close={false}
|
||||
className="w-full"
|
||||
text="Upgrade to a paid plan to create API keys."
|
||||
description="Trial accounts do not include API key access — purchase a paid subscription to unlock this feature."
|
||||
/>
|
||||
)}
|
||||
<Text as="p">
|
||||
API Keys allow you to access Onyx APIs programmatically.
|
||||
{canCreateKeys
|
||||
@@ -85,23 +103,9 @@ function Main() {
|
||||
<CreateButton onClick={() => setShowCreateUpdateForm(true)}>
|
||||
Create API Key
|
||||
</CreateButton>
|
||||
) : (
|
||||
<div className="flex flex-col gap-2 rounded-lg bg-background-tint-02 p-4">
|
||||
<div className="flex items-center gap-1.5">
|
||||
<Text as="p" text04>
|
||||
Upgrade to a paid plan to create API keys.
|
||||
</Text>
|
||||
<Button
|
||||
variant="none"
|
||||
prominence="tertiary"
|
||||
size="2xs"
|
||||
icon={SvgInfo}
|
||||
tooltip="API keys enable programmatic access to Onyx for service accounts and integrations. Trial accounts do not include API key access — purchase a paid subscription to unlock this feature."
|
||||
/>
|
||||
</div>
|
||||
<Button href="/admin/billing">Upgrade to Paid Plan</Button>
|
||||
</div>
|
||||
)}
|
||||
) : isTrialing ? (
|
||||
<Button href="/admin/billing">Upgrade to Paid Plan</Button>
|
||||
) : null}
|
||||
</div>
|
||||
);
|
||||
|
||||
|
||||
387
web/src/app/admin/billing/page.test.tsx
Normal file
387
web/src/app/admin/billing/page.test.tsx
Normal file
@@ -0,0 +1,387 @@
|
||||
/**
|
||||
* Tests for BillingPage handleBillingReturn retry logic.
|
||||
*
|
||||
* The retry logic retries claimLicense up to 3 times with 2s backoff
|
||||
* when returning from a Stripe checkout session. This prevents the user
|
||||
* from getting stranded when the Stripe webhook fires concurrently with
|
||||
* the browser redirect and the license isn't ready yet.
|
||||
*/
|
||||
import React from "react";
|
||||
import { render, screen, waitFor } from "@tests/setup/test-utils";
|
||||
import { act } from "@testing-library/react";
|
||||
|
||||
// ---- Stable mock objects (must be named with mock* prefix for jest hoisting) ----
|
||||
// useRouter and useSearchParams must return the SAME reference each call, otherwise
|
||||
// React's useEffect sees them as changed and re-runs the effect on every render.
|
||||
const mockRouter = {
|
||||
replace: jest.fn() as jest.Mock,
|
||||
refresh: jest.fn() as jest.Mock,
|
||||
};
|
||||
const mockSearchParams = {
|
||||
get: jest.fn() as jest.Mock,
|
||||
};
|
||||
const mockClaimLicense = jest.fn() as jest.Mock;
|
||||
const mockRefreshBilling = jest.fn() as jest.Mock;
|
||||
const mockRefreshLicense = jest.fn() as jest.Mock;
|
||||
|
||||
// ---- Mocks ----
|
||||
|
||||
jest.mock("next/navigation", () => ({
|
||||
useRouter: () => mockRouter,
|
||||
useSearchParams: () => mockSearchParams,
|
||||
}));
|
||||
|
||||
jest.mock("@/layouts/settings-layouts", () => ({
|
||||
Root: ({ children }: { children: React.ReactNode }) => (
|
||||
<div data-testid="settings-root">{children}</div>
|
||||
),
|
||||
Header: () => <div data-testid="settings-header" />,
|
||||
Body: ({ children }: { children: React.ReactNode }) => (
|
||||
<div data-testid="settings-body">{children}</div>
|
||||
),
|
||||
}));
|
||||
|
||||
jest.mock("@/layouts/general-layouts", () => ({
|
||||
Section: ({ children }: { children: React.ReactNode }) => (
|
||||
<div>{children}</div>
|
||||
),
|
||||
}));
|
||||
|
||||
jest.mock("@opal/icons", () => ({
|
||||
SvgArrowUpCircle: () => <svg />,
|
||||
SvgWallet: () => <svg />,
|
||||
}));
|
||||
|
||||
jest.mock("./PlansView", () => ({
|
||||
__esModule: true,
|
||||
default: () => <div data-testid="plans-view" />,
|
||||
}));
|
||||
jest.mock("./CheckoutView", () => ({
|
||||
__esModule: true,
|
||||
default: () => <div data-testid="checkout-view" />,
|
||||
}));
|
||||
jest.mock("./BillingDetailsView", () => ({
|
||||
__esModule: true,
|
||||
default: () => <div data-testid="billing-details-view" />,
|
||||
}));
|
||||
jest.mock("./LicenseActivationCard", () => ({
|
||||
__esModule: true,
|
||||
default: () => <div data-testid="license-activation-card" />,
|
||||
}));
|
||||
|
||||
jest.mock("@/refresh-components/messages/Message", () => ({
|
||||
__esModule: true,
|
||||
default: ({
|
||||
text,
|
||||
description,
|
||||
onClose,
|
||||
}: {
|
||||
text: string;
|
||||
description?: string;
|
||||
onClose?: () => void;
|
||||
}) => (
|
||||
<div data-testid="activating-banner">
|
||||
<span data-testid="activating-banner-text">{text}</span>
|
||||
{description && (
|
||||
<span data-testid="activating-banner-description">{description}</span>
|
||||
)}
|
||||
{onClose && (
|
||||
<button data-testid="activating-banner-close" onClick={onClose}>
|
||||
Close
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
),
|
||||
}));
|
||||
|
||||
jest.mock("@/lib/billing", () => ({
|
||||
useBillingInformation: jest.fn(),
|
||||
useLicense: jest.fn(),
|
||||
hasActiveSubscription: jest.fn().mockReturnValue(false),
|
||||
claimLicense: (...args: unknown[]) => mockClaimLicense(...args),
|
||||
}));
|
||||
|
||||
jest.mock("@/lib/constants", () => ({
|
||||
NEXT_PUBLIC_CLOUD_ENABLED: false,
|
||||
}));
|
||||
|
||||
// ---- Import after mocks ----
|
||||
import BillingPage from "./page";
|
||||
import { useBillingInformation, useLicense } from "@/lib/billing";
|
||||
|
||||
// ---- Test helpers ----
|
||||
|
||||
function setupHooks() {
|
||||
(useBillingInformation as jest.Mock).mockReturnValue({
|
||||
data: null,
|
||||
isLoading: false,
|
||||
error: null,
|
||||
refresh: mockRefreshBilling,
|
||||
});
|
||||
(useLicense as jest.Mock).mockReturnValue({
|
||||
data: null,
|
||||
isLoading: false,
|
||||
refresh: mockRefreshLicense,
|
||||
});
|
||||
}
|
||||
|
||||
// ---- Tests ----
|
||||
|
||||
describe("BillingPage — handleBillingReturn retry logic", () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
jest.useFakeTimers();
|
||||
setupHooks();
|
||||
// Default: no billing-return params
|
||||
mockSearchParams.get.mockReturnValue(null);
|
||||
// Clear any activating state from prior tests
|
||||
sessionStorage.clear();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.useRealTimers();
|
||||
jest.restoreAllMocks();
|
||||
});
|
||||
|
||||
test("calls claimLicense once and refreshes on first-attempt success", async () => {
|
||||
mockSearchParams.get.mockImplementation((key: string) =>
|
||||
key === "session_id" ? "cs_test_123" : null
|
||||
);
|
||||
mockClaimLicense.mockResolvedValueOnce({ success: true });
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockClaimLicense).toHaveBeenCalledTimes(1);
|
||||
expect(mockClaimLicense).toHaveBeenCalledWith("cs_test_123");
|
||||
});
|
||||
expect(mockRouter.refresh).toHaveBeenCalled();
|
||||
expect(mockRefreshBilling).toHaveBeenCalled();
|
||||
// URL cleaned up after checkout return
|
||||
expect(mockRouter.replace).toHaveBeenCalledWith("/admin/billing", {
|
||||
scroll: false,
|
||||
});
|
||||
});
|
||||
|
||||
test("retries after first failure and succeeds on second attempt", async () => {
|
||||
mockSearchParams.get.mockImplementation((key: string) =>
|
||||
key === "session_id" ? "cs_retry_test" : null
|
||||
);
|
||||
mockClaimLicense
|
||||
.mockRejectedValueOnce(new Error("License not ready yet"))
|
||||
.mockResolvedValueOnce({ success: true });
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockClaimLicense).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
// On eventual success, router and billing should be refreshed
|
||||
expect(mockRouter.refresh).toHaveBeenCalled();
|
||||
expect(mockRefreshBilling).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test("retries all 3 times then navigates to details even on total failure", async () => {
|
||||
mockSearchParams.get.mockImplementation((key: string) =>
|
||||
key === "session_id" ? "cs_all_fail" : null
|
||||
);
|
||||
// All 3 attempts fail
|
||||
mockClaimLicense.mockRejectedValue(new Error("Webhook not processed yet"));
|
||||
|
||||
const consoleSpy = jest
|
||||
.spyOn(console, "error")
|
||||
.mockImplementation(() => {});
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockClaimLicense).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
// User stays on plans view with the activating banner
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("plans-view")).toBeInTheDocument();
|
||||
});
|
||||
// refreshBilling still fires so billing state is up to date
|
||||
expect(mockRefreshBilling).toHaveBeenCalled();
|
||||
// Failure is logged
|
||||
expect(consoleSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining("Failed to sync license after billing return"),
|
||||
expect.any(Error)
|
||||
);
|
||||
|
||||
consoleSpy.mockRestore();
|
||||
});
|
||||
|
||||
test("calls claimLicense without session_id on portal_return", async () => {
|
||||
mockSearchParams.get.mockImplementation((key: string) =>
|
||||
key === "portal_return" ? "true" : null
|
||||
);
|
||||
mockClaimLicense.mockResolvedValueOnce({ success: true });
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockClaimLicense).toHaveBeenCalledTimes(1);
|
||||
// No session_id for portal returns — called with undefined
|
||||
expect(mockClaimLicense).toHaveBeenCalledWith(undefined);
|
||||
});
|
||||
expect(mockRefreshBilling).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test("does not call claimLicense when no billing-return params present", async () => {
|
||||
mockSearchParams.get.mockReturnValue(null);
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
expect(mockClaimLicense).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test("shows activating banner and sets sessionStorage on 3x retry failure", async () => {
|
||||
mockSearchParams.get.mockImplementation((key: string) =>
|
||||
key === "session_id" ? "cs_all_fail" : null
|
||||
);
|
||||
mockClaimLicense.mockRejectedValue(new Error("Webhook not processed yet"));
|
||||
|
||||
const consoleSpy = jest
|
||||
.spyOn(console, "error")
|
||||
.mockImplementation(() => {});
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
|
||||
});
|
||||
expect(screen.getByTestId("activating-banner-text")).toHaveTextContent(
|
||||
"Your license is still activating"
|
||||
);
|
||||
expect(
|
||||
sessionStorage.getItem("billing_license_activating_until")
|
||||
).not.toBeNull();
|
||||
|
||||
consoleSpy.mockRestore();
|
||||
});
|
||||
|
||||
test("banner not rendered when no activating state", async () => {
|
||||
mockSearchParams.get.mockReturnValue(null);
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("banner shown on mount when sessionStorage key is set and not expired", async () => {
|
||||
sessionStorage.setItem(
|
||||
"billing_license_activating_until",
|
||||
String(Date.now() + 120_000)
|
||||
);
|
||||
mockSearchParams.get.mockReturnValue(null);
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
// Flush React effects — banner is visible from lazy state init, no timer advancement needed
|
||||
await act(async () => {});
|
||||
|
||||
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("banner not shown on mount when sessionStorage key is expired", async () => {
|
||||
sessionStorage.setItem(
|
||||
"billing_license_activating_until",
|
||||
String(Date.now() - 1000)
|
||||
);
|
||||
mockSearchParams.get.mockReturnValue(null);
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
await act(async () => {
|
||||
await jest.runAllTimersAsync();
|
||||
});
|
||||
|
||||
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
|
||||
expect(
|
||||
sessionStorage.getItem("billing_license_activating_until")
|
||||
).toBeNull();
|
||||
});
|
||||
|
||||
test("poll calls claimLicense after 15s and clears banner on success", async () => {
|
||||
sessionStorage.setItem(
|
||||
"billing_license_activating_until",
|
||||
String(Date.now() + 120_000)
|
||||
);
|
||||
mockSearchParams.get.mockReturnValue(null);
|
||||
// Poll attempt succeeds
|
||||
mockClaimLicense.mockResolvedValueOnce({ success: true });
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
// Flush effects — banner visible from lazy state init
|
||||
await act(async () => {});
|
||||
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
|
||||
|
||||
// Advance past one poll interval (15s)
|
||||
await act(async () => {
|
||||
await jest.advanceTimersByTimeAsync(15_000);
|
||||
});
|
||||
|
||||
expect(mockClaimLicense).toHaveBeenCalledWith(undefined);
|
||||
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
|
||||
expect(
|
||||
sessionStorage.getItem("billing_license_activating_until")
|
||||
).toBeNull();
|
||||
expect(mockRefreshBilling).toHaveBeenCalled();
|
||||
expect(mockRefreshLicense).toHaveBeenCalled();
|
||||
expect(mockRouter.refresh).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test("close button removes banner and clears sessionStorage", async () => {
|
||||
sessionStorage.setItem(
|
||||
"billing_license_activating_until",
|
||||
String(Date.now() + 120_000)
|
||||
);
|
||||
mockSearchParams.get.mockReturnValue(null);
|
||||
|
||||
render(<BillingPage />);
|
||||
|
||||
// Flush effects — banner visible from lazy state init
|
||||
await act(async () => {});
|
||||
expect(screen.getByTestId("activating-banner")).toBeInTheDocument();
|
||||
|
||||
const closeButton = screen.getByTestId("activating-banner-close");
|
||||
await act(async () => {
|
||||
closeButton.click();
|
||||
});
|
||||
|
||||
expect(screen.queryByTestId("activating-banner")).not.toBeInTheDocument();
|
||||
expect(
|
||||
sessionStorage.getItem("billing_license_activating_until")
|
||||
).toBeNull();
|
||||
});
|
||||
});
|
||||
@@ -17,6 +17,7 @@ import {
|
||||
} from "@/lib/billing";
|
||||
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import Message from "@/refresh-components/messages/Message";
|
||||
|
||||
import PlansView from "./PlansView";
|
||||
import CheckoutView from "./CheckoutView";
|
||||
@@ -24,6 +25,9 @@ import BillingDetailsView from "./BillingDetailsView";
|
||||
import LicenseActivationCard from "./LicenseActivationCard";
|
||||
import "./billing.css";
|
||||
|
||||
// sessionStorage key: value is a unix-ms expiry timestamp
|
||||
const BILLING_ACTIVATING_KEY = "billing_license_activating_until";
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Types
|
||||
// ----------------------------------------------------------------------------
|
||||
@@ -105,6 +109,7 @@ export default function BillingPage() {
|
||||
const [transitionType, setTransitionType] = useState<
|
||||
"expand" | "collapse" | "fade"
|
||||
>("fade");
|
||||
const [isActivating, setIsActivating] = useState<boolean>(false);
|
||||
|
||||
const {
|
||||
data: billingData,
|
||||
@@ -155,6 +160,17 @@ export default function BillingPage() {
|
||||
view,
|
||||
]);
|
||||
|
||||
// Read activating state from sessionStorage after mount (avoids SSR hydration mismatch)
|
||||
useEffect(() => {
|
||||
const raw = sessionStorage.getItem(BILLING_ACTIVATING_KEY);
|
||||
if (!raw) return;
|
||||
if (Number(raw) > Date.now()) {
|
||||
setIsActivating(true);
|
||||
} else {
|
||||
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
|
||||
}
|
||||
}, []);
|
||||
|
||||
// Show license activation card when there's a Stripe error
|
||||
useEffect(() => {
|
||||
if (hasStripeError && !showLicenseActivationInput) {
|
||||
@@ -172,24 +188,96 @@ export default function BillingPage() {
|
||||
|
||||
router.replace("/admin/billing", { scroll: false });
|
||||
|
||||
let cancelled = false;
|
||||
|
||||
const handleBillingReturn = async () => {
|
||||
if (!NEXT_PUBLIC_CLOUD_ENABLED) {
|
||||
try {
|
||||
// After checkout, exchange session_id for license; after portal, re-sync license
|
||||
await claimLicense(sessionId ?? undefined);
|
||||
refreshLicense();
|
||||
// Refresh the page to update settings (including ee_features_enabled)
|
||||
router.refresh();
|
||||
// Navigate to billing details now that the license is active
|
||||
changeView("details");
|
||||
} catch (error) {
|
||||
console.error("Failed to sync license after billing return:", error);
|
||||
// Retry up to 3 times with 2s backoff. The license may not be available
|
||||
// immediately if the Stripe webhook hasn't finished processing yet
|
||||
// (redirect and webhook fire nearly simultaneously).
|
||||
let lastError: Error | null = null;
|
||||
for (let attempt = 0; attempt < 3; attempt++) {
|
||||
if (cancelled) return;
|
||||
try {
|
||||
// After checkout, exchange session_id for license; after portal, re-sync license
|
||||
await claimLicense(sessionId ?? undefined);
|
||||
if (cancelled) return;
|
||||
refreshLicense();
|
||||
// Refresh the page to update settings (including ee_features_enabled)
|
||||
router.refresh();
|
||||
// Navigate to billing details now that the license is active
|
||||
changeView("details");
|
||||
lastError = null;
|
||||
break;
|
||||
} catch (err) {
|
||||
lastError = err instanceof Error ? err : new Error("Unknown error");
|
||||
if (attempt < 2) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 2000));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (cancelled) return;
|
||||
if (lastError) {
|
||||
console.error(
|
||||
"Failed to sync license after billing return:",
|
||||
lastError
|
||||
);
|
||||
// Show an activating banner on the plans view and keep retrying in the background.
|
||||
sessionStorage.setItem(
|
||||
BILLING_ACTIVATING_KEY,
|
||||
String(Date.now() + 120_000)
|
||||
);
|
||||
setIsActivating(true);
|
||||
changeView("plans");
|
||||
}
|
||||
}
|
||||
refreshBilling();
|
||||
if (!cancelled) refreshBilling();
|
||||
};
|
||||
handleBillingReturn();
|
||||
}, [searchParams, router, refreshBilling, refreshLicense]);
|
||||
|
||||
return () => {
|
||||
cancelled = true;
|
||||
};
|
||||
// changeView intentionally omitted: it only calls stable state setters and the
|
||||
// effect runs at most once (when session_id/portal_return params are present).
|
||||
}, [searchParams, router, refreshBilling, refreshLicense]); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
// Poll every 15s while activating, up to 2 minutes, to detect when the license arrives.
|
||||
useEffect(() => {
|
||||
if (!isActivating) return;
|
||||
|
||||
let requestInFlight = false;
|
||||
|
||||
const intervalId = setInterval(async () => {
|
||||
if (requestInFlight) return;
|
||||
const raw = sessionStorage.getItem(BILLING_ACTIVATING_KEY);
|
||||
if (!raw || Number(raw) <= Date.now()) {
|
||||
// Expired — stop immediately without waiting for React cleanup
|
||||
clearInterval(intervalId);
|
||||
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
|
||||
setIsActivating(false);
|
||||
return;
|
||||
}
|
||||
requestInFlight = true;
|
||||
try {
|
||||
await claimLicense(undefined);
|
||||
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
|
||||
setIsActivating(false);
|
||||
refreshLicense();
|
||||
refreshBilling();
|
||||
router.refresh();
|
||||
changeView("details");
|
||||
} catch (err) {
|
||||
// License not ready yet — keep polling. Log so unexpected failures
|
||||
// (network errors, 500s) are distinguishable from expected 404s.
|
||||
console.debug("License activation poll: will retry", err);
|
||||
} finally {
|
||||
requestInFlight = false;
|
||||
}
|
||||
}, 15_000);
|
||||
|
||||
return () => clearInterval(intervalId);
|
||||
}, [isActivating]); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
const handleRefresh = async () => {
|
||||
await Promise.all([
|
||||
@@ -386,6 +474,22 @@ export default function BillingPage() {
|
||||
/>
|
||||
<SettingsLayouts.Body>
|
||||
<div className="flex flex-col items-center gap-6">
|
||||
{isActivating && (
|
||||
<Message
|
||||
static
|
||||
warning
|
||||
large
|
||||
text="Your license is still activating"
|
||||
description="Your license is being processed. You'll be taken to billing details automatically once confirmed."
|
||||
icon
|
||||
close
|
||||
onClose={() => {
|
||||
sessionStorage.removeItem(BILLING_ACTIVATING_KEY);
|
||||
setIsActivating(false);
|
||||
}}
|
||||
className="w-full"
|
||||
/>
|
||||
)}
|
||||
{renderContent()}
|
||||
{renderFooter()}
|
||||
</div>
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useMemo } from "react";
|
||||
import { useState, useMemo, useEffect } from "react";
|
||||
import useSWR from "swr";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Select } from "@/refresh-components/cards";
|
||||
import { useCreateModal } from "@/refresh-components/contexts/ModalContext";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { LLMProviderResponse, LLMProviderView } from "@/interfaces/llm";
|
||||
import {
|
||||
@@ -17,9 +17,17 @@ import {
|
||||
ImageGenerationConfigView,
|
||||
setDefaultImageGenerationConfig,
|
||||
unsetDefaultImageGenerationConfig,
|
||||
deleteImageGenerationConfig,
|
||||
} from "@/lib/configuration/imageConfigurationService";
|
||||
import { ProviderIcon } from "@/app/admin/configuration/llm/ProviderIcon";
|
||||
import Message from "@/refresh-components/messages/Message";
|
||||
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import { Button, Text } from "@opal/components";
|
||||
import { SvgSlash, SvgUnplug } from "@opal/icons";
|
||||
import { markdown } from "@opal/utils";
|
||||
|
||||
const NO_DEFAULT_VALUE = "__none__";
|
||||
|
||||
export default function ImageGenerationContent() {
|
||||
const {
|
||||
@@ -47,6 +55,11 @@ export default function ImageGenerationContent() {
|
||||
);
|
||||
const [editConfig, setEditConfig] =
|
||||
useState<ImageGenerationConfigView | null>(null);
|
||||
const [disconnectProvider, setDisconnectProvider] =
|
||||
useState<ImageProvider | null>(null);
|
||||
const [replacementProviderId, setReplacementProviderId] = useState<
|
||||
string | null
|
||||
>(null);
|
||||
|
||||
const connectedProviderIds = useMemo(() => {
|
||||
return new Set(configs.map((c) => c.image_provider_id));
|
||||
@@ -115,6 +128,29 @@ export default function ImageGenerationContent() {
|
||||
modal.toggle(true);
|
||||
};
|
||||
|
||||
const handleDisconnect = async () => {
|
||||
if (!disconnectProvider) return;
|
||||
try {
|
||||
// If a replacement was selected (not "No Default"), activate it first
|
||||
if (replacementProviderId && replacementProviderId !== NO_DEFAULT_VALUE) {
|
||||
await setDefaultImageGenerationConfig(replacementProviderId);
|
||||
}
|
||||
|
||||
await deleteImageGenerationConfig(disconnectProvider.image_provider_id);
|
||||
toast.success(`${disconnectProvider.title} disconnected`);
|
||||
refetchConfigs();
|
||||
refetchProviders();
|
||||
} catch (error) {
|
||||
console.error("Failed to disconnect image generation provider:", error);
|
||||
toast.error(
|
||||
error instanceof Error ? error.message : "Failed to disconnect"
|
||||
);
|
||||
} finally {
|
||||
setDisconnectProvider(null);
|
||||
setReplacementProviderId(null);
|
||||
}
|
||||
};
|
||||
|
||||
const handleModalSuccess = () => {
|
||||
toast.success("Provider configured successfully");
|
||||
setEditConfig(null);
|
||||
@@ -130,15 +166,45 @@ export default function ImageGenerationContent() {
|
||||
);
|
||||
}
|
||||
|
||||
// Compute replacement options when disconnecting an active provider
|
||||
const isDisconnectingDefault =
|
||||
disconnectProvider &&
|
||||
defaultConfig?.image_provider_id === disconnectProvider.image_provider_id;
|
||||
|
||||
// Group connected replacement models by provider (excluding the model being disconnected)
|
||||
const replacementGroups = useMemo(() => {
|
||||
if (!disconnectProvider) return [];
|
||||
return IMAGE_PROVIDER_GROUPS.map((group) => ({
|
||||
...group,
|
||||
providers: group.providers.filter(
|
||||
(p) =>
|
||||
p.image_provider_id !== disconnectProvider.image_provider_id &&
|
||||
connectedProviderIds.has(p.image_provider_id)
|
||||
),
|
||||
})).filter((g) => g.providers.length > 0);
|
||||
}, [disconnectProvider, connectedProviderIds]);
|
||||
|
||||
const needsReplacement = !!isDisconnectingDefault;
|
||||
const hasReplacements = replacementGroups.length > 0;
|
||||
|
||||
// Auto-select first replacement when modal opens
|
||||
useEffect(() => {
|
||||
if (needsReplacement && !replacementProviderId && hasReplacements) {
|
||||
const firstGroup = replacementGroups[0];
|
||||
const firstModel = firstGroup?.providers[0];
|
||||
if (firstModel) setReplacementProviderId(firstModel.image_provider_id);
|
||||
}
|
||||
}, [disconnectProvider]); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="flex flex-col gap-6">
|
||||
{/* Section Header */}
|
||||
<div className="flex flex-col gap-0.5">
|
||||
<Text mainContentEmphasis text05>
|
||||
<Text font="main-content-emphasis" color="text-05">
|
||||
Image Generation Model
|
||||
</Text>
|
||||
<Text secondaryBody text03>
|
||||
<Text font="secondary-body" color="text-03">
|
||||
Select a model to generate images in chat.
|
||||
</Text>
|
||||
</div>
|
||||
@@ -157,7 +223,7 @@ export default function ImageGenerationContent() {
|
||||
{/* Provider Groups */}
|
||||
{IMAGE_PROVIDER_GROUPS.map((group) => (
|
||||
<div key={group.name} className="flex flex-col gap-2">
|
||||
<Text secondaryBody text03>
|
||||
<Text font="secondary-body" color="text-03">
|
||||
{group.name}
|
||||
</Text>
|
||||
<div className="flex flex-col gap-2">
|
||||
@@ -175,6 +241,11 @@ export default function ImageGenerationContent() {
|
||||
onSelect={() => handleSelect(provider)}
|
||||
onDeselect={() => handleDeselect(provider)}
|
||||
onEdit={() => handleEdit(provider)}
|
||||
onDisconnect={
|
||||
getStatus(provider) !== "disconnected"
|
||||
? () => setDisconnectProvider(provider)
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
@@ -182,6 +253,108 @@ export default function ImageGenerationContent() {
|
||||
))}
|
||||
</div>
|
||||
|
||||
{disconnectProvider && (
|
||||
<ConfirmationModalLayout
|
||||
icon={SvgUnplug}
|
||||
title={`Disconnect ${disconnectProvider.title}`}
|
||||
description="This will remove the stored credentials for this provider."
|
||||
onClose={() => {
|
||||
setDisconnectProvider(null);
|
||||
setReplacementProviderId(null);
|
||||
}}
|
||||
submit={
|
||||
<Button
|
||||
variant="danger"
|
||||
onClick={() => void handleDisconnect()}
|
||||
disabled={
|
||||
needsReplacement && hasReplacements && !replacementProviderId
|
||||
}
|
||||
>
|
||||
Disconnect
|
||||
</Button>
|
||||
}
|
||||
>
|
||||
{needsReplacement ? (
|
||||
hasReplacements ? (
|
||||
<Section alignItems="start">
|
||||
<Text as="p" color="text-03">
|
||||
{markdown(
|
||||
`**${disconnectProvider.title}** is currently the default image generation model. Session history will be preserved.`
|
||||
)}
|
||||
</Text>
|
||||
<Section alignItems="start" gap={0.25}>
|
||||
<Text as="p" color="text-04">
|
||||
Set New Default
|
||||
</Text>
|
||||
<InputSelect
|
||||
value={replacementProviderId ?? undefined}
|
||||
onValueChange={(v) => setReplacementProviderId(v)}
|
||||
>
|
||||
<InputSelect.Trigger placeholder="Select a replacement model" />
|
||||
<InputSelect.Content>
|
||||
{replacementGroups.map((group) => (
|
||||
<InputSelect.Group key={group.name}>
|
||||
<InputSelect.Label>{group.name}</InputSelect.Label>
|
||||
{group.providers.map((p) => (
|
||||
<InputSelect.Item
|
||||
key={p.image_provider_id}
|
||||
value={p.image_provider_id}
|
||||
icon={() => (
|
||||
<ProviderIcon
|
||||
provider={p.provider_name}
|
||||
size={16}
|
||||
/>
|
||||
)}
|
||||
>
|
||||
{p.title}
|
||||
</InputSelect.Item>
|
||||
))}
|
||||
</InputSelect.Group>
|
||||
))}
|
||||
<InputSelect.Separator />
|
||||
<InputSelect.Item
|
||||
value={NO_DEFAULT_VALUE}
|
||||
icon={SvgSlash}
|
||||
>
|
||||
<span>
|
||||
<b>No Default</b>
|
||||
<span className="text-text-03">
|
||||
{" "}
|
||||
(Disable Image Generation)
|
||||
</span>
|
||||
</span>
|
||||
</InputSelect.Item>
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</Section>
|
||||
</Section>
|
||||
) : (
|
||||
<>
|
||||
<Text as="p" color="text-03">
|
||||
{markdown(
|
||||
`**${disconnectProvider.title}** is currently the default image generation model.`
|
||||
)}
|
||||
</Text>
|
||||
<Text as="p" color="text-03">
|
||||
Connect another provider to continue using image generation.
|
||||
</Text>
|
||||
</>
|
||||
)
|
||||
) : (
|
||||
<>
|
||||
<Text as="p" color="text-03">
|
||||
{markdown(
|
||||
`**${disconnectProvider.title}** models will no longer be used to generate images.`
|
||||
)}
|
||||
</Text>
|
||||
<Text as="p" color="text-03">
|
||||
Session history will be preserved.
|
||||
</Text>
|
||||
</>
|
||||
)}
|
||||
</ConfirmationModalLayout>
|
||||
)}
|
||||
|
||||
{activeProvider && (
|
||||
<modal.Provider>
|
||||
<ImageGenerationConnectionModal
|
||||
|
||||
@@ -23,6 +23,7 @@ import {
|
||||
BedrockModelResponse,
|
||||
LMStudioModelResponse,
|
||||
LiteLLMProxyModelResponse,
|
||||
BifrostModelResponse,
|
||||
ModelConfiguration,
|
||||
LLMProviderName,
|
||||
BedrockFetchParams,
|
||||
@@ -30,8 +31,9 @@ import {
|
||||
LMStudioFetchParams,
|
||||
OpenRouterFetchParams,
|
||||
LiteLLMProxyFetchParams,
|
||||
BifrostFetchParams,
|
||||
} from "@/interfaces/llm";
|
||||
import { SvgAws, SvgOpenrouter } from "@opal/icons";
|
||||
import { SvgAws, SvgBifrost, SvgOpenrouter } from "@opal/icons";
|
||||
|
||||
// Aggregator providers that host models from multiple vendors
|
||||
export const AGGREGATOR_PROVIDERS = new Set([
|
||||
@@ -41,6 +43,7 @@ export const AGGREGATOR_PROVIDERS = new Set([
|
||||
"ollama_chat",
|
||||
"lm_studio",
|
||||
"litellm_proxy",
|
||||
"bifrost",
|
||||
"vertex_ai",
|
||||
]);
|
||||
|
||||
@@ -78,6 +81,7 @@ export const getProviderIcon = (
|
||||
bedrock_converse: SvgAws,
|
||||
openrouter: SvgOpenrouter,
|
||||
litellm_proxy: LiteLLMIcon,
|
||||
bifrost: SvgBifrost,
|
||||
vertex_ai: GeminiIcon,
|
||||
};
|
||||
|
||||
@@ -263,8 +267,11 @@ export const fetchOpenRouterModels = async (
|
||||
try {
|
||||
const errorData = await response.json();
|
||||
errorMessage = errorData.detail || errorData.message || errorMessage;
|
||||
} catch {
|
||||
// ignore JSON parsing errors
|
||||
} catch (jsonError) {
|
||||
console.warn(
|
||||
"Failed to parse OpenRouter model fetch error response",
|
||||
jsonError
|
||||
);
|
||||
}
|
||||
return { models: [], error: errorMessage };
|
||||
}
|
||||
@@ -319,8 +326,11 @@ export const fetchLMStudioModels = async (
|
||||
try {
|
||||
const errorData = await response.json();
|
||||
errorMessage = errorData.detail || errorData.message || errorMessage;
|
||||
} catch {
|
||||
// ignore JSON parsing errors
|
||||
} catch (jsonError) {
|
||||
console.warn(
|
||||
"Failed to parse LM Studio model fetch error response",
|
||||
jsonError
|
||||
);
|
||||
}
|
||||
return { models: [], error: errorMessage };
|
||||
}
|
||||
@@ -343,6 +353,64 @@ export const fetchLMStudioModels = async (
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches Bifrost models directly without any form state dependencies.
|
||||
* Uses snake_case params to match API structure.
|
||||
*/
|
||||
export const fetchBifrostModels = async (
|
||||
params: BifrostFetchParams
|
||||
): Promise<{ models: ModelConfiguration[]; error?: string }> => {
|
||||
const apiBase = params.api_base;
|
||||
if (!apiBase) {
|
||||
return { models: [], error: "API Base is required" };
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch("/api/admin/llm/bifrost/available-models", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
api_base: apiBase,
|
||||
api_key: params.api_key,
|
||||
provider_name: params.provider_name,
|
||||
}),
|
||||
signal: params.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
let errorMessage = "Failed to fetch models";
|
||||
try {
|
||||
const errorData = await response.json();
|
||||
errorMessage = errorData.detail || errorData.message || errorMessage;
|
||||
} catch (jsonError) {
|
||||
console.warn(
|
||||
"Failed to parse Bifrost model fetch error response",
|
||||
jsonError
|
||||
);
|
||||
}
|
||||
return { models: [], error: errorMessage };
|
||||
}
|
||||
|
||||
const data: BifrostModelResponse[] = await response.json();
|
||||
const models: ModelConfiguration[] = data.map((modelData) => ({
|
||||
name: modelData.name,
|
||||
display_name: modelData.display_name,
|
||||
is_visible: true,
|
||||
max_input_tokens: modelData.max_input_tokens,
|
||||
supports_image_input: modelData.supports_image_input,
|
||||
supports_reasoning: modelData.supports_reasoning,
|
||||
}));
|
||||
|
||||
return { models };
|
||||
} catch (error) {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : "Unknown error";
|
||||
return { models: [], error: errorMessage };
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches LiteLLM Proxy models directly without any form state dependencies.
|
||||
* Uses snake_case params to match API structure.
|
||||
@@ -456,6 +524,13 @@ export const fetchModels = async (
|
||||
provider_name: formValues.name,
|
||||
signal,
|
||||
});
|
||||
case LLMProviderName.BIFROST:
|
||||
return fetchBifrostModels({
|
||||
api_base: formValues.api_base,
|
||||
api_key: formValues.api_key,
|
||||
provider_name: formValues.name,
|
||||
signal,
|
||||
});
|
||||
default:
|
||||
return { models: [], error: `Unknown provider: ${providerName}` };
|
||||
}
|
||||
@@ -469,6 +544,7 @@ export function canProviderFetchModels(providerName?: string) {
|
||||
case LLMProviderName.LM_STUDIO:
|
||||
case LLMProviderName.OPENROUTER:
|
||||
case LLMProviderName.LITELLM_PROXY:
|
||||
case LLMProviderName.BIFROST:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
||||
@@ -1,32 +1,25 @@
|
||||
"use client";
|
||||
|
||||
import Image from "next/image";
|
||||
import { useMemo, useState, useReducer } from "react";
|
||||
import { useEffect, useMemo, useState, useReducer } from "react";
|
||||
import { InfoIcon } from "@/components/icons/icons";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Select } from "@/refresh-components/cards";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { Content } from "@opal/layouts";
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher, FetchError } from "@/lib/fetcher";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { Callout } from "@/components/ui/callout";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { cn } from "@/lib/utils";
|
||||
import {
|
||||
SvgArrowExchange,
|
||||
SvgArrowRightCircle,
|
||||
SvgCheckSquare,
|
||||
SvgEdit,
|
||||
SvgGlobe,
|
||||
SvgOnyxLogo,
|
||||
SvgX,
|
||||
} from "@opal/icons";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { SvgGlobe, SvgOnyxLogo, SvgSlash, SvgUnplug } from "@opal/icons";
|
||||
import { Button } from "@opal/components";
|
||||
import { ADMIN_ROUTES } from "@/lib/admin-routes";
|
||||
import { WebProviderSetupModal } from "@/app/admin/configuration/web-search/WebProviderSetupModal";
|
||||
|
||||
const route = ADMIN_ROUTES.WEB_SEARCH;
|
||||
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import {
|
||||
SEARCH_PROVIDERS_URL,
|
||||
SEARCH_PROVIDER_DETAILS,
|
||||
@@ -58,6 +51,10 @@ import {
|
||||
} from "@/app/admin/configuration/web-search/WebProviderModalReducer";
|
||||
import { connectProviderFlow } from "@/app/admin/configuration/web-search/connectProviderFlow";
|
||||
|
||||
const NO_DEFAULT_VALUE = "__none__";
|
||||
|
||||
const route = ADMIN_ROUTES.WEB_SEARCH;
|
||||
|
||||
interface WebSearchProviderView {
|
||||
id: number;
|
||||
name: string;
|
||||
@@ -76,27 +73,151 @@ interface WebContentProviderView {
|
||||
has_api_key: boolean;
|
||||
}
|
||||
|
||||
interface HoverIconButtonProps extends React.ComponentProps<typeof Button> {
|
||||
isHovered: boolean;
|
||||
onMouseEnter: () => void;
|
||||
onMouseLeave: () => void;
|
||||
children: React.ReactNode;
|
||||
interface DisconnectTargetState {
|
||||
id: number;
|
||||
label: string;
|
||||
category: "search" | "content";
|
||||
providerType: string;
|
||||
}
|
||||
|
||||
function HoverIconButton({
|
||||
isHovered,
|
||||
onMouseEnter,
|
||||
onMouseLeave,
|
||||
children,
|
||||
...buttonProps
|
||||
}: HoverIconButtonProps) {
|
||||
function WebSearchDisconnectModal({
|
||||
disconnectTarget,
|
||||
searchProviders,
|
||||
contentProviders,
|
||||
replacementProviderId,
|
||||
onReplacementChange,
|
||||
onClose,
|
||||
onDisconnect,
|
||||
}: {
|
||||
disconnectTarget: DisconnectTargetState;
|
||||
searchProviders: WebSearchProviderView[];
|
||||
contentProviders: WebContentProviderView[];
|
||||
replacementProviderId: string | null;
|
||||
onReplacementChange: (id: string | null) => void;
|
||||
onClose: () => void;
|
||||
onDisconnect: () => void;
|
||||
}) {
|
||||
const isSearch = disconnectTarget.category === "search";
|
||||
|
||||
// Determine if the target is currently the active/selected provider
|
||||
const isActive = isSearch
|
||||
? searchProviders.find((p) => p.id === disconnectTarget.id)?.is_active ??
|
||||
false
|
||||
: contentProviders.find((p) => p.id === disconnectTarget.id)?.is_active ??
|
||||
false;
|
||||
|
||||
// Find other configured providers as replacements
|
||||
const replacementOptions = isSearch
|
||||
? searchProviders.filter(
|
||||
(p) => p.id !== disconnectTarget.id && p.id > 0 && p.has_api_key
|
||||
)
|
||||
: contentProviders.filter(
|
||||
(p) =>
|
||||
p.id !== disconnectTarget.id &&
|
||||
p.provider_type !== "onyx_web_crawler" &&
|
||||
p.id > 0 &&
|
||||
p.has_api_key
|
||||
);
|
||||
|
||||
const needsReplacement = isActive;
|
||||
const hasReplacements = replacementOptions.length > 0;
|
||||
|
||||
const getLabel = (p: { name: string; provider_type: string }) => {
|
||||
if (isSearch) {
|
||||
const details =
|
||||
SEARCH_PROVIDER_DETAILS[p.provider_type as WebSearchProviderType];
|
||||
return details?.label ?? p.name ?? p.provider_type;
|
||||
}
|
||||
const details = CONTENT_PROVIDER_DETAILS[p.provider_type];
|
||||
return details?.label ?? p.name ?? p.provider_type;
|
||||
};
|
||||
|
||||
const categoryLabel = isSearch ? "search engine" : "web crawler";
|
||||
const featureLabel = isSearch ? "web search" : "web crawling";
|
||||
const disableLabel = isSearch ? "Disable Web Search" : "Disable Web Crawling";
|
||||
|
||||
// Auto-select first replacement when modal opens
|
||||
useEffect(() => {
|
||||
if (needsReplacement && hasReplacements && !replacementProviderId) {
|
||||
const first = replacementOptions[0];
|
||||
if (first) onReplacementChange(String(first.id));
|
||||
}
|
||||
}, []); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
return (
|
||||
<div onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
|
||||
{/* TODO(@raunakab): migrate to opal Button once HoverIconButtonProps typing is resolved */}
|
||||
<Button {...buttonProps} rightIcon={isHovered ? SvgX : SvgCheckSquare}>
|
||||
{children}
|
||||
</Button>
|
||||
</div>
|
||||
<ConfirmationModalLayout
|
||||
icon={SvgUnplug}
|
||||
title={`Disconnect ${disconnectTarget.label}`}
|
||||
description="This will remove the stored credentials for this provider."
|
||||
onClose={onClose}
|
||||
submit={
|
||||
<Button
|
||||
variant="danger"
|
||||
onClick={onDisconnect}
|
||||
disabled={
|
||||
needsReplacement && hasReplacements && !replacementProviderId
|
||||
}
|
||||
>
|
||||
Disconnect
|
||||
</Button>
|
||||
}
|
||||
>
|
||||
{needsReplacement ? (
|
||||
hasReplacements ? (
|
||||
<Section alignItems="start">
|
||||
<Text as="p" text03>
|
||||
<b>{disconnectTarget.label}</b> is currently the active{" "}
|
||||
{categoryLabel}. Search history will be preserved.
|
||||
</Text>
|
||||
<Section alignItems="start" gap={0.25}>
|
||||
<Text as="p" secondaryBody text03>
|
||||
Set New Default
|
||||
</Text>
|
||||
<InputSelect
|
||||
value={replacementProviderId ?? undefined}
|
||||
onValueChange={(v) => onReplacementChange(v)}
|
||||
>
|
||||
<InputSelect.Trigger placeholder="Select a replacement provider" />
|
||||
<InputSelect.Content>
|
||||
{replacementOptions.map((p) => (
|
||||
<InputSelect.Item key={p.id} value={String(p.id)}>
|
||||
{getLabel(p)}
|
||||
</InputSelect.Item>
|
||||
))}
|
||||
<InputSelect.Separator />
|
||||
<InputSelect.Item value={NO_DEFAULT_VALUE} icon={SvgSlash}>
|
||||
<span>
|
||||
<b>No Default</b>
|
||||
<span className="text-text-03"> ({disableLabel})</span>
|
||||
</span>
|
||||
</InputSelect.Item>
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</Section>
|
||||
</Section>
|
||||
) : (
|
||||
<>
|
||||
<Text as="p" text03>
|
||||
<b>{disconnectTarget.label}</b> is currently the active{" "}
|
||||
{categoryLabel}.
|
||||
</Text>
|
||||
<Text as="p" text03>
|
||||
Connect another provider to continue using {featureLabel}.
|
||||
</Text>
|
||||
</>
|
||||
)
|
||||
) : (
|
||||
<>
|
||||
<Text as="p" text03>
|
||||
{isSearch ? "Web search" : "Web crawling"} will no longer be routed
|
||||
through <b>{disconnectTarget.label}</b>.
|
||||
</Text>
|
||||
<Text as="p" text03>
|
||||
Search history will be preserved.
|
||||
</Text>
|
||||
</>
|
||||
)}
|
||||
</ConfirmationModalLayout>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -105,6 +226,11 @@ export default function Page() {
|
||||
WebProviderModalReducer,
|
||||
initialWebProviderModalState
|
||||
);
|
||||
const [disconnectTarget, setDisconnectTarget] =
|
||||
useState<DisconnectTargetState | null>(null);
|
||||
const [replacementProviderId, setReplacementProviderId] = useState<
|
||||
string | null
|
||||
>(null);
|
||||
const [contentModal, dispatchContentModal] = useReducer(
|
||||
WebProviderModalReducer,
|
||||
initialWebProviderModalState
|
||||
@@ -113,8 +239,6 @@ export default function Page() {
|
||||
const [contentActivationError, setContentActivationError] = useState<
|
||||
string | null
|
||||
>(null);
|
||||
const [hoveredButtonKey, setHoveredButtonKey] = useState<string | null>(null);
|
||||
|
||||
const {
|
||||
data: searchProvidersData,
|
||||
error: searchProvidersError,
|
||||
@@ -833,6 +957,67 @@ export default function Page() {
|
||||
});
|
||||
};
|
||||
|
||||
const handleDisconnectProvider = async () => {
|
||||
if (!disconnectTarget) return;
|
||||
const { id, category } = disconnectTarget;
|
||||
|
||||
try {
|
||||
// If a replacement was selected (not "No Default"), activate it first
|
||||
if (replacementProviderId && replacementProviderId !== NO_DEFAULT_VALUE) {
|
||||
const repId = Number(replacementProviderId);
|
||||
const activateEndpoint =
|
||||
category === "search"
|
||||
? `/api/admin/web-search/search-providers/${repId}/activate`
|
||||
: `/api/admin/web-search/content-providers/${repId}/activate`;
|
||||
const activateResp = await fetch(activateEndpoint, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
});
|
||||
if (!activateResp.ok) {
|
||||
const errorBody = await activateResp.json().catch(() => ({}));
|
||||
throw new Error(
|
||||
typeof errorBody?.detail === "string"
|
||||
? errorBody.detail
|
||||
: "Failed to activate replacement provider."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const response = await fetch(
|
||||
`/api/admin/web-search/${category}-providers/${id}`,
|
||||
{ method: "DELETE" }
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
const errorBody = await response.json().catch((parseErr) => {
|
||||
console.error("Failed to parse disconnect error response:", parseErr);
|
||||
return {};
|
||||
});
|
||||
throw new Error(
|
||||
typeof errorBody?.detail === "string"
|
||||
? errorBody.detail
|
||||
: "Failed to disconnect provider."
|
||||
);
|
||||
}
|
||||
|
||||
toast.success(`${disconnectTarget.label} disconnected`);
|
||||
await mutateSearchProviders();
|
||||
await mutateContentProviders();
|
||||
} catch (error) {
|
||||
console.error("Failed to disconnect web search provider:", error);
|
||||
const message =
|
||||
error instanceof Error ? error.message : "Unexpected error occurred.";
|
||||
if (category === "search") {
|
||||
setActivationError(message);
|
||||
} else {
|
||||
setContentActivationError(message);
|
||||
}
|
||||
} finally {
|
||||
setDisconnectTarget(null);
|
||||
setReplacementProviderId(null);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<SettingsLayouts.Root>
|
||||
@@ -894,149 +1079,79 @@ export default function Page() {
|
||||
provider
|
||||
);
|
||||
const isActive = provider?.is_active ?? false;
|
||||
const isHighlighted = isActive;
|
||||
const providerId = provider?.id;
|
||||
const canOpenModal =
|
||||
isBuiltInSearchProviderType(providerType);
|
||||
|
||||
const buttonState = (() => {
|
||||
if (!provider || !isConfigured) {
|
||||
return {
|
||||
label: "Connect",
|
||||
disabled: false,
|
||||
icon: "arrow" as const,
|
||||
onClick: canOpenModal
|
||||
const status: "disconnected" | "connected" | "selected" =
|
||||
!isConfigured
|
||||
? "disconnected"
|
||||
: isActive
|
||||
? "selected"
|
||||
: "connected";
|
||||
|
||||
return (
|
||||
<Select
|
||||
key={`${key}-${providerType}`}
|
||||
icon={() =>
|
||||
logoSrc ? (
|
||||
<Image
|
||||
src={logoSrc}
|
||||
alt={`${label} logo`}
|
||||
width={16}
|
||||
height={16}
|
||||
/>
|
||||
) : (
|
||||
<SvgGlobe size={16} />
|
||||
)
|
||||
}
|
||||
title={label}
|
||||
description={subtitle}
|
||||
status={status}
|
||||
onConnect={
|
||||
canOpenModal
|
||||
? () => {
|
||||
openSearchModal(providerType, provider);
|
||||
setActivationError(null);
|
||||
}
|
||||
: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
if (isActive) {
|
||||
return {
|
||||
label: "Current Default",
|
||||
disabled: false,
|
||||
icon: "check" as const,
|
||||
onClick: providerId
|
||||
: undefined
|
||||
}
|
||||
onSelect={
|
||||
providerId
|
||||
? () => {
|
||||
void handleActivateSearchProvider(providerId);
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
onDeselect={
|
||||
providerId
|
||||
? () => {
|
||||
void handleDeactivateSearchProvider(providerId);
|
||||
}
|
||||
: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
label: "Set as Default",
|
||||
disabled: false,
|
||||
icon: "arrow-circle" as const,
|
||||
onClick: providerId
|
||||
? () => {
|
||||
void handleActivateSearchProvider(providerId);
|
||||
}
|
||||
: undefined,
|
||||
};
|
||||
})();
|
||||
|
||||
const buttonKey = `search-${key}-${providerType}`;
|
||||
const isButtonHovered = hoveredButtonKey === buttonKey;
|
||||
const isCardClickable =
|
||||
buttonState.icon === "arrow" &&
|
||||
typeof buttonState.onClick === "function" &&
|
||||
!buttonState.disabled;
|
||||
|
||||
const handleCardClick = () => {
|
||||
if (isCardClickable) {
|
||||
buttonState.onClick?.();
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
key={`${key}-${providerType}`}
|
||||
onClick={isCardClickable ? handleCardClick : undefined}
|
||||
className={cn(
|
||||
"flex items-start justify-between gap-3 rounded-16 border p-1 bg-background-neutral-00",
|
||||
isHighlighted
|
||||
? "border-action-link-05"
|
||||
: "border-border-01",
|
||||
isCardClickable &&
|
||||
"cursor-pointer hover:bg-background-tint-01 transition-colors"
|
||||
)}
|
||||
>
|
||||
<div className="flex flex-1 items-start gap-1 px-2 py-1">
|
||||
{renderLogo({
|
||||
logoSrc,
|
||||
alt: `${label} logo`,
|
||||
size: 16,
|
||||
isHighlighted,
|
||||
})}
|
||||
<Content
|
||||
title={label}
|
||||
description={subtitle}
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-center justify-end gap-2">
|
||||
{isConfigured && (
|
||||
<OpalButton
|
||||
icon={SvgEdit}
|
||||
tooltip="Edit"
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
onClick={() => {
|
||||
if (!canOpenModal) return;
|
||||
: undefined
|
||||
}
|
||||
onEdit={
|
||||
isConfigured && canOpenModal
|
||||
? () => {
|
||||
openSearchModal(
|
||||
providerType as WebSearchProviderType,
|
||||
provider
|
||||
);
|
||||
}}
|
||||
aria-label={`Edit ${label}`}
|
||||
/>
|
||||
)}
|
||||
{buttonState.icon === "check" ? (
|
||||
<HoverIconButton
|
||||
isHovered={isButtonHovered}
|
||||
onMouseEnter={() => setHoveredButtonKey(buttonKey)}
|
||||
onMouseLeave={() => setHoveredButtonKey(null)}
|
||||
action={true}
|
||||
tertiary
|
||||
disabled={buttonState.disabled}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
buttonState.onClick?.();
|
||||
}}
|
||||
>
|
||||
{buttonState.label}
|
||||
</HoverIconButton>
|
||||
) : (
|
||||
<Disabled
|
||||
disabled={
|
||||
buttonState.disabled || !buttonState.onClick
|
||||
}
|
||||
>
|
||||
<OpalButton
|
||||
prominence="tertiary"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
buttonState.onClick?.();
|
||||
}}
|
||||
rightIcon={
|
||||
buttonState.icon === "arrow"
|
||||
? SvgArrowExchange
|
||||
: buttonState.icon === "arrow-circle"
|
||||
? SvgArrowRightCircle
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
{buttonState.label}
|
||||
</OpalButton>
|
||||
</Disabled>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
: undefined
|
||||
}
|
||||
onDisconnect={
|
||||
isConfigured && provider && provider.id > 0
|
||||
? () =>
|
||||
setDisconnectTarget({
|
||||
id: provider.id,
|
||||
label,
|
||||
category: "search",
|
||||
providerType,
|
||||
})
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
);
|
||||
}
|
||||
)}
|
||||
@@ -1076,161 +1191,81 @@ export default function Page() {
|
||||
const isCurrentCrawler =
|
||||
provider.provider_type === currentContentProviderType;
|
||||
|
||||
const buttonState = (() => {
|
||||
if (!isConfigured) {
|
||||
return {
|
||||
label: "Connect",
|
||||
icon: "arrow" as const,
|
||||
disabled: false,
|
||||
onClick: () => {
|
||||
openContentModal(provider.provider_type, provider);
|
||||
setContentActivationError(null);
|
||||
},
|
||||
};
|
||||
}
|
||||
const status: "disconnected" | "connected" | "selected" =
|
||||
!isConfigured
|
||||
? "disconnected"
|
||||
: isCurrentCrawler
|
||||
? "selected"
|
||||
: "connected";
|
||||
|
||||
if (isCurrentCrawler) {
|
||||
return {
|
||||
label: "Current Crawler",
|
||||
icon: "check" as const,
|
||||
disabled: false,
|
||||
onClick: () => {
|
||||
void handleDeactivateContentProvider(
|
||||
providerId,
|
||||
provider.provider_type
|
||||
);
|
||||
},
|
||||
};
|
||||
}
|
||||
const canActivate =
|
||||
providerId > 0 ||
|
||||
provider.provider_type === "onyx_web_crawler" ||
|
||||
isConfigured;
|
||||
|
||||
const canActivate =
|
||||
providerId > 0 ||
|
||||
provider.provider_type === "onyx_web_crawler" ||
|
||||
isConfigured;
|
||||
|
||||
return {
|
||||
label: "Set as Default",
|
||||
icon: "arrow-circle" as const,
|
||||
disabled: !canActivate,
|
||||
onClick: canActivate
|
||||
? () => {
|
||||
void handleActivateContentProvider(provider);
|
||||
}
|
||||
: undefined,
|
||||
};
|
||||
})();
|
||||
|
||||
const contentButtonKey = `content-${provider.provider_type}-${provider.id}`;
|
||||
const isContentButtonHovered =
|
||||
hoveredButtonKey === contentButtonKey;
|
||||
const isContentCardClickable =
|
||||
buttonState.icon === "arrow" &&
|
||||
typeof buttonState.onClick === "function" &&
|
||||
!buttonState.disabled;
|
||||
|
||||
const handleContentCardClick = () => {
|
||||
if (isContentCardClickable) {
|
||||
buttonState.onClick?.();
|
||||
}
|
||||
};
|
||||
const contentLogoSrc =
|
||||
CONTENT_PROVIDER_DETAILS[provider.provider_type]?.logoSrc;
|
||||
|
||||
return (
|
||||
<div
|
||||
<Select
|
||||
key={`${provider.provider_type}-${provider.id}`}
|
||||
onClick={
|
||||
isContentCardClickable
|
||||
? handleContentCardClick
|
||||
icon={() =>
|
||||
contentLogoSrc ? (
|
||||
<Image
|
||||
src={contentLogoSrc}
|
||||
alt={`${label} logo`}
|
||||
width={16}
|
||||
height={16}
|
||||
/>
|
||||
) : provider.provider_type === "onyx_web_crawler" ? (
|
||||
<SvgOnyxLogo size={16} />
|
||||
) : (
|
||||
<SvgGlobe size={16} />
|
||||
)
|
||||
}
|
||||
title={label}
|
||||
description={subtitle}
|
||||
status={status}
|
||||
selectedLabel="Current Crawler"
|
||||
onConnect={() => {
|
||||
openContentModal(provider.provider_type, provider);
|
||||
setContentActivationError(null);
|
||||
}}
|
||||
onSelect={
|
||||
canActivate
|
||||
? () => {
|
||||
void handleActivateContentProvider(provider);
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
className={cn(
|
||||
"flex items-start justify-between gap-3 rounded-16 border p-1 bg-background-neutral-00",
|
||||
isCurrentCrawler
|
||||
? "border-action-link-05"
|
||||
: "border-border-01",
|
||||
isContentCardClickable &&
|
||||
"cursor-pointer hover:bg-background-tint-01 transition-colors"
|
||||
)}
|
||||
>
|
||||
<div className="flex flex-1 items-start gap-1 px-2 py-1">
|
||||
{renderLogo({
|
||||
logoSrc:
|
||||
CONTENT_PROVIDER_DETAILS[provider.provider_type]
|
||||
?.logoSrc,
|
||||
alt: `${label} logo`,
|
||||
fallback:
|
||||
provider.provider_type === "onyx_web_crawler" ? (
|
||||
<SvgOnyxLogo size={16} />
|
||||
) : undefined,
|
||||
size: 16,
|
||||
isHighlighted: isCurrentCrawler,
|
||||
})}
|
||||
<Content
|
||||
title={label}
|
||||
description={subtitle}
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-center justify-end gap-2">
|
||||
{provider.provider_type !== "onyx_web_crawler" &&
|
||||
isConfigured && (
|
||||
<OpalButton
|
||||
icon={SvgEdit}
|
||||
tooltip="Edit"
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
onClick={() => {
|
||||
openContentModal(
|
||||
provider.provider_type,
|
||||
provider
|
||||
);
|
||||
}}
|
||||
aria-label={`Edit ${label}`}
|
||||
/>
|
||||
)}
|
||||
{buttonState.icon === "check" ? (
|
||||
<HoverIconButton
|
||||
isHovered={isContentButtonHovered}
|
||||
onMouseEnter={() =>
|
||||
setHoveredButtonKey(contentButtonKey)
|
||||
onDeselect={() => {
|
||||
void handleDeactivateContentProvider(
|
||||
providerId,
|
||||
provider.provider_type
|
||||
);
|
||||
}}
|
||||
onEdit={
|
||||
provider.provider_type !== "onyx_web_crawler" &&
|
||||
isConfigured
|
||||
? () => {
|
||||
openContentModal(provider.provider_type, provider);
|
||||
}
|
||||
onMouseLeave={() => setHoveredButtonKey(null)}
|
||||
action={true}
|
||||
tertiary
|
||||
disabled={buttonState.disabled}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
buttonState.onClick?.();
|
||||
}}
|
||||
>
|
||||
{buttonState.label}
|
||||
</HoverIconButton>
|
||||
) : (
|
||||
<Disabled
|
||||
disabled={
|
||||
buttonState.disabled || !buttonState.onClick
|
||||
}
|
||||
>
|
||||
<OpalButton
|
||||
prominence="tertiary"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
buttonState.onClick?.();
|
||||
}}
|
||||
rightIcon={
|
||||
buttonState.icon === "arrow"
|
||||
? SvgArrowExchange
|
||||
: buttonState.icon === "arrow-circle"
|
||||
? SvgArrowRightCircle
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
{buttonState.label}
|
||||
</OpalButton>
|
||||
</Disabled>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
: undefined
|
||||
}
|
||||
onDisconnect={
|
||||
provider.provider_type !== "onyx_web_crawler" &&
|
||||
isConfigured &&
|
||||
provider.id > 0
|
||||
? () =>
|
||||
setDisconnectTarget({
|
||||
id: provider.id,
|
||||
label,
|
||||
category: "content",
|
||||
providerType: provider.provider_type,
|
||||
})
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
@@ -1238,6 +1273,21 @@ export default function Page() {
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
|
||||
{disconnectTarget && (
|
||||
<WebSearchDisconnectModal
|
||||
disconnectTarget={disconnectTarget}
|
||||
searchProviders={searchProviders}
|
||||
contentProviders={combinedContentProviders}
|
||||
replacementProviderId={replacementProviderId}
|
||||
onReplacementChange={setReplacementProviderId}
|
||||
onClose={() => {
|
||||
setDisconnectTarget(null);
|
||||
setReplacementProviderId(null);
|
||||
}}
|
||||
onDisconnect={() => void handleDisconnectProvider()}
|
||||
/>
|
||||
)}
|
||||
|
||||
<WebProviderSetupModal
|
||||
isOpen={selectedProviderType !== null}
|
||||
onClose={() => {
|
||||
|
||||
@@ -19,6 +19,10 @@
|
||||
background-color: var(--background-neutral-00);
|
||||
border: 1px solid var(--status-error-05);
|
||||
}
|
||||
.input-error:focus:not(:active),
|
||||
.input-error:focus-within:not(:active) {
|
||||
box-shadow: inset 0px 0px 0px 2px var(--background-tint-04);
|
||||
}
|
||||
|
||||
.input-disabled {
|
||||
background-color: var(--background-neutral-03);
|
||||
|
||||
@@ -133,7 +133,7 @@ async function createFederatedConnector(
|
||||
|
||||
async function updateFederatedConnector(
|
||||
id: number,
|
||||
credentials: CredentialForm,
|
||||
credentials: CredentialForm | null,
|
||||
config?: ConfigForm
|
||||
): Promise<{ success: boolean; message: string }> {
|
||||
try {
|
||||
@@ -143,7 +143,7 @@ async function updateFederatedConnector(
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
credentials,
|
||||
credentials: credentials ?? undefined,
|
||||
config: config || {},
|
||||
}),
|
||||
});
|
||||
@@ -201,7 +201,9 @@ export function FederatedConnectorForm({
|
||||
const isEditMode = connectorId !== undefined;
|
||||
|
||||
const [formState, setFormState] = useState<FormState>({
|
||||
credentials: preloadedConnectorData?.credentials || {},
|
||||
// In edit mode, don't populate credentials with masked values from the API.
|
||||
// Masked values (e.g. "••••••••••••") would be saved back and corrupt the real credentials.
|
||||
credentials: isEditMode ? {} : preloadedConnectorData?.credentials || {},
|
||||
config: preloadedConnectorData?.config || {},
|
||||
schema: preloadedCredentialSchema?.credentials || null,
|
||||
configurationSchema: null,
|
||||
@@ -209,6 +211,7 @@ export function FederatedConnectorForm({
|
||||
configurationSchemaError: null,
|
||||
connectorError: null,
|
||||
});
|
||||
const [credentialsModified, setCredentialsModified] = useState(false);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [submitMessage, setSubmitMessage] = useState<string | null>(null);
|
||||
const [submitSuccess, setSubmitSuccess] = useState<boolean | null>(null);
|
||||
@@ -333,6 +336,7 @@ export function FederatedConnectorForm({
|
||||
}
|
||||
|
||||
const handleCredentialChange = (key: string, value: string) => {
|
||||
setCredentialsModified(true);
|
||||
setFormState((prev) => ({
|
||||
...prev,
|
||||
credentials: {
|
||||
@@ -354,6 +358,11 @@ export function FederatedConnectorForm({
|
||||
|
||||
const handleValidateCredentials = async () => {
|
||||
if (!formState.schema) return;
|
||||
if (isEditMode && !credentialsModified) {
|
||||
setSubmitMessage("Enter new credential values before validating.");
|
||||
setSubmitSuccess(false);
|
||||
return;
|
||||
}
|
||||
|
||||
setIsValidating(true);
|
||||
setSubmitMessage(null);
|
||||
@@ -411,8 +420,10 @@ export function FederatedConnectorForm({
|
||||
setSubmitSuccess(null);
|
||||
|
||||
try {
|
||||
// Validate required fields
|
||||
if (formState.schema) {
|
||||
const shouldValidateCredentials = !isEditMode || credentialsModified;
|
||||
|
||||
// Validate required fields (skip for credentials in edit mode when unchanged)
|
||||
if (formState.schema && shouldValidateCredentials) {
|
||||
const missingRequired = Object.entries(formState.schema)
|
||||
.filter(
|
||||
([key, field]) => field.required && !formState.credentials[key]
|
||||
@@ -442,16 +453,20 @@ export function FederatedConnectorForm({
|
||||
}
|
||||
setConfigValidationErrors({});
|
||||
|
||||
// Validate credentials before creating/updating
|
||||
const validation = await validateCredentials(
|
||||
connector,
|
||||
formState.credentials
|
||||
);
|
||||
if (!validation.success) {
|
||||
setSubmitMessage(`Credential validation failed: ${validation.message}`);
|
||||
setSubmitSuccess(false);
|
||||
setIsSubmitting(false);
|
||||
return;
|
||||
// Validate credentials before creating/updating (skip in edit mode when unchanged)
|
||||
if (shouldValidateCredentials) {
|
||||
const validation = await validateCredentials(
|
||||
connector,
|
||||
formState.credentials
|
||||
);
|
||||
if (!validation.success) {
|
||||
setSubmitMessage(
|
||||
`Credential validation failed: ${validation.message}`
|
||||
);
|
||||
setSubmitSuccess(false);
|
||||
setIsSubmitting(false);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Create or update the connector
|
||||
@@ -459,7 +474,7 @@ export function FederatedConnectorForm({
|
||||
isEditMode && connectorId
|
||||
? await updateFederatedConnector(
|
||||
connectorId,
|
||||
formState.credentials,
|
||||
credentialsModified ? formState.credentials : null,
|
||||
formState.config
|
||||
)
|
||||
: await createFederatedConnector(
|
||||
@@ -538,14 +553,16 @@ export function FederatedConnectorForm({
|
||||
id={fieldKey}
|
||||
type={fieldSpec.secret ? "password" : "text"}
|
||||
placeholder={
|
||||
fieldSpec.example
|
||||
? String(fieldSpec.example)
|
||||
: fieldSpec.description
|
||||
isEditMode && !credentialsModified
|
||||
? "•••••••• (leave blank to keep current value)"
|
||||
: fieldSpec.example
|
||||
? String(fieldSpec.example)
|
||||
: fieldSpec.description
|
||||
}
|
||||
value={formState.credentials[fieldKey] || ""}
|
||||
onChange={(e) => handleCredentialChange(fieldKey, e.target.value)}
|
||||
className="w-96"
|
||||
required={fieldSpec.required}
|
||||
required={!isEditMode && fieldSpec.required}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
* and support various display sizes.
|
||||
*/
|
||||
import React from "react";
|
||||
import { SvgBifrost } from "@opal/icons";
|
||||
import { render } from "@tests/setup/test-utils";
|
||||
import { GithubIcon, GitbookIcon, ConfluenceIcon } from "./icons";
|
||||
|
||||
@@ -51,4 +52,15 @@ describe("Logo Icons", () => {
|
||||
render(<GithubIcon size={100} className="custom-class" />);
|
||||
}).not.toThrow();
|
||||
});
|
||||
|
||||
test("renders the Bifrost icon with theme-aware colors", () => {
|
||||
const { container } = render(
|
||||
<SvgBifrost size={32} className="custom text-red-500 dark:text-black" />
|
||||
);
|
||||
const icon = container.querySelector("svg");
|
||||
|
||||
expect(icon).toBeInTheDocument();
|
||||
expect(icon).toHaveClass("custom", "text-[#33C19E]", "dark:text-white");
|
||||
expect(icon).not.toHaveClass("text-red-500", "dark:text-black");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -6,6 +6,8 @@ import { INTERNAL_URL, IS_DEV } from "@/lib/constants";
|
||||
const TARGET_SAMPLE_RATE = 24000;
|
||||
const CHUNK_INTERVAL_MS = 250;
|
||||
const DUPLICATE_FINAL_TRANSCRIPT_WINDOW_MS = 1500;
|
||||
// When VAD-based auto-stop is disabled, force-stop after this much silence as a fallback
|
||||
const SILENCE_FALLBACK_TIMEOUT_MS = 10000;
|
||||
|
||||
interface TranscriptMessage {
|
||||
type: "transcript" | "error";
|
||||
@@ -58,6 +60,8 @@ class VoiceRecorderSession {
|
||||
private finalTranscriptDelivered = false;
|
||||
private lastDeliveredFinalText: string | null = null;
|
||||
private lastDeliveredFinalAtMs = 0;
|
||||
// Fallback timer: force-stop after extended silence when VAD auto-stop is disabled
|
||||
private silenceFallbackTimer: NodeJS.Timeout | null = null;
|
||||
|
||||
// Callbacks to update React state
|
||||
private onTranscriptChange: (text: string) => void;
|
||||
@@ -174,6 +178,8 @@ class VoiceRecorderSession {
|
||||
async stop(): Promise<string | null> {
|
||||
if (!this.isActive) return this.transcript || null;
|
||||
|
||||
this.resetSilenceFallbackTimer();
|
||||
|
||||
// Stop audio capture
|
||||
if (this.sendInterval) {
|
||||
clearInterval(this.sendInterval);
|
||||
@@ -219,6 +225,7 @@ class VoiceRecorderSession {
|
||||
}
|
||||
|
||||
cleanup(): void {
|
||||
this.resetSilenceFallbackTimer();
|
||||
if (this.sendInterval) clearInterval(this.sendInterval);
|
||||
if (this.scriptNode) this.scriptNode.disconnect();
|
||||
if (this.sourceNode) this.sourceNode.disconnect();
|
||||
@@ -274,6 +281,23 @@ class VoiceRecorderSession {
|
||||
});
|
||||
}
|
||||
|
||||
private resetSilenceFallbackTimer(): void {
|
||||
if (this.silenceFallbackTimer) {
|
||||
clearTimeout(this.silenceFallbackTimer);
|
||||
this.silenceFallbackTimer = null;
|
||||
}
|
||||
}
|
||||
|
||||
private startSilenceFallbackTimer(): void {
|
||||
this.resetSilenceFallbackTimer();
|
||||
this.silenceFallbackTimer = setTimeout(() => {
|
||||
// 10s of silence with no new speech — force-stop as a safety fallback
|
||||
if (this.isActive && this.onVADStop) {
|
||||
this.onVADStop();
|
||||
}
|
||||
}, SILENCE_FALLBACK_TIMEOUT_MS);
|
||||
}
|
||||
|
||||
private handleMessage = (event: MessageEvent): void => {
|
||||
try {
|
||||
const data: TranscriptMessage = JSON.parse(event.data);
|
||||
@@ -281,47 +305,53 @@ class VoiceRecorderSession {
|
||||
if (data.type === "transcript") {
|
||||
if (data.text) {
|
||||
this.transcript = data.text;
|
||||
this.onTranscriptChange(data.text);
|
||||
// Only push live updates to React while actively recording.
|
||||
// After stop(), the final transcript is returned via stopResolver
|
||||
// instead — this prevents stale text from reappearing in the
|
||||
// input box when the user clears it and starts a new recording.
|
||||
if (this.isActive) {
|
||||
this.onTranscriptChange(data.text);
|
||||
}
|
||||
}
|
||||
|
||||
if (data.is_final && data.text) {
|
||||
// VAD detected silence - trigger callback (only once per utterance)
|
||||
const now = Date.now();
|
||||
const isLikelyDuplicateFinal =
|
||||
this.autoStopOnSilence &&
|
||||
this.lastDeliveredFinalText === data.text &&
|
||||
now - this.lastDeliveredFinalAtMs <
|
||||
DUPLICATE_FINAL_TRANSCRIPT_WINDOW_MS;
|
||||
|
||||
if (
|
||||
this.onFinalTranscript &&
|
||||
!this.finalTranscriptDelivered &&
|
||||
!isLikelyDuplicateFinal
|
||||
) {
|
||||
this.finalTranscriptDelivered = true;
|
||||
this.lastDeliveredFinalText = data.text;
|
||||
this.lastDeliveredFinalAtMs = now;
|
||||
this.onFinalTranscript(data.text);
|
||||
// Resolve stop promise if waiting — must run even after stop()
|
||||
// so the caller receives the final transcript.
|
||||
if (this.stopResolver) {
|
||||
this.stopResolver(data.text);
|
||||
this.stopResolver = null;
|
||||
}
|
||||
|
||||
// Auto-stop recording if enabled
|
||||
// Skip VAD logic if session is no longer active
|
||||
if (!this.isActive) return;
|
||||
|
||||
if (this.autoStopOnSilence) {
|
||||
// Trigger stop callback to update React state
|
||||
// VAD detected silence — auto-stop and trigger callback
|
||||
const now = Date.now();
|
||||
const isLikelyDuplicateFinal =
|
||||
this.lastDeliveredFinalText === data.text &&
|
||||
now - this.lastDeliveredFinalAtMs <
|
||||
DUPLICATE_FINAL_TRANSCRIPT_WINDOW_MS;
|
||||
|
||||
if (
|
||||
this.onFinalTranscript &&
|
||||
!this.finalTranscriptDelivered &&
|
||||
!isLikelyDuplicateFinal
|
||||
) {
|
||||
this.finalTranscriptDelivered = true;
|
||||
this.lastDeliveredFinalText = data.text;
|
||||
this.lastDeliveredFinalAtMs = now;
|
||||
this.onFinalTranscript(data.text);
|
||||
}
|
||||
|
||||
if (this.onVADStop) {
|
||||
this.onVADStop();
|
||||
}
|
||||
} else {
|
||||
// If not auto-stopping, reset for next utterance
|
||||
this.transcript = "";
|
||||
this.finalTranscriptDelivered = false;
|
||||
this.onTranscriptChange("");
|
||||
this.resetBackendTranscript();
|
||||
}
|
||||
|
||||
// Resolve stop promise if waiting
|
||||
if (this.stopResolver) {
|
||||
this.stopResolver(data.text);
|
||||
this.stopResolver = null;
|
||||
// Auto-stop disabled (push-to-talk): ignore VAD, keep recording.
|
||||
// Start/reset a 10s fallback timer — if no new speech arrives,
|
||||
// force-stop to avoid recording silence indefinitely.
|
||||
this.startSilenceFallbackTimer();
|
||||
}
|
||||
}
|
||||
} else if (data.type === "error") {
|
||||
|
||||
@@ -13,6 +13,7 @@ export enum LLMProviderName {
|
||||
VERTEX_AI = "vertex_ai",
|
||||
BEDROCK = "bedrock",
|
||||
LITELLM_PROXY = "litellm_proxy",
|
||||
BIFROST = "bifrost",
|
||||
CUSTOM = "custom",
|
||||
}
|
||||
|
||||
@@ -165,6 +166,21 @@ export interface LiteLLMProxyModelResponse {
|
||||
model_name: string;
|
||||
}
|
||||
|
||||
export interface BifrostFetchParams {
|
||||
api_base?: string;
|
||||
api_key?: string;
|
||||
provider_name?: string;
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
|
||||
export interface BifrostModelResponse {
|
||||
name: string;
|
||||
display_name: string;
|
||||
max_input_tokens: number | null;
|
||||
supports_image_input: boolean;
|
||||
supports_reasoning: boolean;
|
||||
}
|
||||
|
||||
export interface VertexAIFetchParams {
|
||||
model_configurations?: ModelConfiguration[];
|
||||
}
|
||||
@@ -182,5 +198,6 @@ export type FetchModelsParams =
|
||||
| OllamaFetchParams
|
||||
| OpenRouterFetchParams
|
||||
| LiteLLMProxyFetchParams
|
||||
| BifrostFetchParams
|
||||
| VertexAIFetchParams
|
||||
| LMStudioFetchParams;
|
||||
|
||||
@@ -53,6 +53,12 @@ export async function fetchVoicesByType(
|
||||
return fetch(`/api/admin/voice/voices?provider_type=${providerType}`);
|
||||
}
|
||||
|
||||
export async function deleteVoiceProvider(
|
||||
providerId: number
|
||||
): Promise<Response> {
|
||||
return fetch(`${VOICE_PROVIDERS_URL}/${providerId}`, { method: "DELETE" });
|
||||
}
|
||||
|
||||
export async function fetchLLMProviders(): Promise<Response> {
|
||||
return fetch("/api/admin/llm/provider");
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import {
|
||||
SvgBifrost,
|
||||
SvgCpu,
|
||||
SvgOpenai,
|
||||
SvgClaude,
|
||||
@@ -26,6 +27,7 @@ const PROVIDER_ICONS: Record<string, IconFunctionComponent> = {
|
||||
[LLMProviderName.OLLAMA_CHAT]: SvgOllama,
|
||||
[LLMProviderName.OPENROUTER]: SvgOpenrouter,
|
||||
[LLMProviderName.LM_STUDIO]: SvgLmStudio,
|
||||
[LLMProviderName.BIFROST]: SvgBifrost,
|
||||
|
||||
// fallback
|
||||
[LLMProviderName.CUSTOM]: SvgServer,
|
||||
@@ -42,6 +44,7 @@ const PROVIDER_PRODUCT_NAMES: Record<string, string> = {
|
||||
[LLMProviderName.OLLAMA_CHAT]: "Ollama",
|
||||
[LLMProviderName.OPENROUTER]: "OpenRouter",
|
||||
[LLMProviderName.LM_STUDIO]: "LM Studio",
|
||||
[LLMProviderName.BIFROST]: "Bifrost",
|
||||
|
||||
// fallback
|
||||
[LLMProviderName.CUSTOM]: "Custom Models",
|
||||
@@ -58,6 +61,7 @@ const PROVIDER_DISPLAY_NAMES: Record<string, string> = {
|
||||
[LLMProviderName.OLLAMA_CHAT]: "Ollama",
|
||||
[LLMProviderName.OPENROUTER]: "OpenRouter",
|
||||
[LLMProviderName.LM_STUDIO]: "LM Studio",
|
||||
[LLMProviderName.BIFROST]: "Bifrost",
|
||||
|
||||
// fallback
|
||||
[LLMProviderName.CUSTOM]: "Other providers or self-hosted",
|
||||
|
||||
@@ -457,6 +457,7 @@ const ModalHeader = React.forwardRef<HTMLDivElement, ModalHeaderProps>(
|
||||
<div
|
||||
tabIndex={-1}
|
||||
ref={closeButtonRef as React.RefObject<HTMLDivElement>}
|
||||
className="outline-none"
|
||||
>
|
||||
<DialogPrimitive.Close asChild>
|
||||
<Button
|
||||
|
||||
@@ -142,6 +142,7 @@ function PopoverContent({
|
||||
collisionPadding={8}
|
||||
className={cn(
|
||||
"bg-background-neutral-00 p-1 z-popover rounded-12 border shadow-md data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2",
|
||||
"flex flex-col",
|
||||
"max-h-[var(--radix-popover-content-available-height)]",
|
||||
"overflow-hidden",
|
||||
widthClasses[width]
|
||||
@@ -226,7 +227,7 @@ export function PopoverMenu({
|
||||
});
|
||||
|
||||
return (
|
||||
<Section alignItems="stretch">
|
||||
<Section alignItems="stretch" height="auto" className="flex-1 min-h-0">
|
||||
<ShadowDiv
|
||||
scrollContainerRef={scrollContainerRef}
|
||||
className="flex flex-col gap-1 max-h-[20rem] w-full"
|
||||
|
||||
@@ -105,7 +105,7 @@ export default function ShadowDiv({
|
||||
}, [containerRef, checkScroll]);
|
||||
|
||||
return (
|
||||
<div className="relative min-h-0">
|
||||
<div className="relative min-h-0 flex flex-col">
|
||||
<div
|
||||
ref={containerRef}
|
||||
className={cn("overflow-y-auto", className)}
|
||||
|
||||
@@ -12,6 +12,7 @@ import {
|
||||
SvgArrowRightCircle,
|
||||
SvgCheckSquare,
|
||||
SvgSettings,
|
||||
SvgUnplug,
|
||||
} from "@opal/icons";
|
||||
|
||||
const containerClasses = {
|
||||
@@ -35,6 +36,7 @@ export interface SelectProps
|
||||
onSelect?: () => void;
|
||||
onDeselect?: () => void;
|
||||
onEdit?: () => void;
|
||||
onDisconnect?: () => void;
|
||||
|
||||
// Labels (customizable)
|
||||
connectLabel?: string;
|
||||
@@ -59,6 +61,7 @@ export default function Select({
|
||||
onSelect,
|
||||
onDeselect,
|
||||
onEdit,
|
||||
onDisconnect,
|
||||
connectLabel = "Connect",
|
||||
selectLabel = "Set as Default",
|
||||
selectedLabel = "Current Default",
|
||||
@@ -68,7 +71,7 @@ export default function Select({
|
||||
disabled,
|
||||
...rest
|
||||
}: SelectProps) {
|
||||
const sizeClass = medium ? "h-[3.75rem]" : "h-[4.25rem]";
|
||||
const sizeClass = medium ? "h-[3.75rem]" : "min-h-[3.75rem] max-h-[5.25rem]";
|
||||
const containerClass = containerClasses[status];
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
|
||||
@@ -121,7 +124,7 @@ export default function Select({
|
||||
</div>
|
||||
|
||||
{/* Right section - Actions */}
|
||||
<div className="flex items-center justify-end gap-1">
|
||||
<div className="flex flex-col h-full items-end justify-between gap-1">
|
||||
{/* Disconnected: Show Connect button */}
|
||||
{isDisconnected && (
|
||||
<Disabled disabled={disabled || !onConnect}>
|
||||
@@ -149,18 +152,32 @@ export default function Select({
|
||||
{selectLabel}
|
||||
</SelectButton>
|
||||
</Disabled>
|
||||
{onEdit && (
|
||||
<Disabled disabled={disabled}>
|
||||
<Button
|
||||
icon={SvgSettings}
|
||||
tooltip="Edit"
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
onClick={noProp(onEdit)}
|
||||
aria-label={`Edit ${title}`}
|
||||
/>
|
||||
</Disabled>
|
||||
)}
|
||||
<div className="flex px-1 gap-1">
|
||||
{onDisconnect && (
|
||||
<Disabled disabled={disabled}>
|
||||
<Button
|
||||
icon={SvgUnplug}
|
||||
tooltip="Disconnect"
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
onClick={noProp(onDisconnect)}
|
||||
aria-label={`Disconnect ${title}`}
|
||||
/>
|
||||
</Disabled>
|
||||
)}
|
||||
{onEdit && (
|
||||
<Disabled disabled={disabled}>
|
||||
<Button
|
||||
icon={SvgSettings}
|
||||
tooltip="Edit"
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
onClick={noProp(onEdit)}
|
||||
aria-label={`Edit ${title}`}
|
||||
/>
|
||||
</Disabled>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
||||
@@ -177,18 +194,32 @@ export default function Select({
|
||||
{selectedLabel}
|
||||
</SelectButton>
|
||||
</Disabled>
|
||||
{onEdit && (
|
||||
<Disabled disabled={disabled}>
|
||||
<Button
|
||||
icon={SvgSettings}
|
||||
tooltip="Edit"
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
onClick={noProp(onEdit)}
|
||||
aria-label={`Edit ${title}`}
|
||||
/>
|
||||
</Disabled>
|
||||
)}
|
||||
<div className="flex px-1 gap-1">
|
||||
{onDisconnect && (
|
||||
<Disabled disabled={disabled}>
|
||||
<Button
|
||||
icon={SvgUnplug}
|
||||
tooltip="Disconnect"
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
onClick={noProp(onDisconnect)}
|
||||
aria-label={`Disconnect ${title}`}
|
||||
/>
|
||||
</Disabled>
|
||||
)}
|
||||
{onEdit && (
|
||||
<Disabled disabled={disabled}>
|
||||
<Button
|
||||
icon={SvgSettings}
|
||||
tooltip="Edit"
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
onClick={noProp(onEdit)}
|
||||
aria-label={`Edit ${title}`}
|
||||
/>
|
||||
</Disabled>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -12,7 +12,7 @@ export interface FieldContextType {
|
||||
|
||||
export type FormFieldRootProps = React.HTMLAttributes<HTMLDivElement> & {
|
||||
name?: string;
|
||||
state: FormFieldState;
|
||||
state?: FormFieldState;
|
||||
required?: boolean;
|
||||
id?: string;
|
||||
};
|
||||
|
||||
@@ -984,8 +984,8 @@ function ChatPreferencesSettings() {
|
||||
/>
|
||||
<Card>
|
||||
<InputLayouts.Horizontal
|
||||
title="Auto-Send"
|
||||
description="Automatically send voice input when recording stops."
|
||||
title="Auto-Send on Pause"
|
||||
description="Automatically send voice input when you stop speaking."
|
||||
>
|
||||
<Switch
|
||||
checked={user?.preferences.voice_auto_send ?? false}
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { useMemo, useState } from "react";
|
||||
import { useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import useSWR from "swr";
|
||||
import { Table, Button } from "@opal/components";
|
||||
import { IllustrationContent } from "@opal/layouts";
|
||||
import { SvgUsers } from "@opal/icons";
|
||||
@@ -14,16 +13,14 @@ import Text from "@/refresh-components/texts/Text";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import useAdminUsers from "@/hooks/useAdminUsers";
|
||||
import type { ApiKeyDescriptor, MemberRow } from "./interfaces";
|
||||
import useGroupMemberCandidates from "./useGroupMemberCandidates";
|
||||
import {
|
||||
createGroup,
|
||||
updateAgentGroupSharing,
|
||||
updateDocSetGroupSharing,
|
||||
saveTokenLimits,
|
||||
} from "./svc";
|
||||
import { apiKeyToMemberRow, memberTableColumns, PAGE_SIZE } from "./shared";
|
||||
import { memberTableColumns, PAGE_SIZE } from "./shared";
|
||||
import SharedGroupResources from "@/refresh-pages/admin/GroupsPage/SharedGroupResources";
|
||||
import TokenLimitSection from "./TokenLimitSection";
|
||||
import type { TokenLimit } from "./TokenLimitSection";
|
||||
@@ -41,22 +38,7 @@ function CreateGroupPage() {
|
||||
{ tokenBudget: null, periodHours: null },
|
||||
]);
|
||||
|
||||
const { users, isLoading: usersLoading, error: usersError } = useAdminUsers();
|
||||
|
||||
const {
|
||||
data: apiKeys,
|
||||
isLoading: apiKeysLoading,
|
||||
error: apiKeysError,
|
||||
} = useSWR<ApiKeyDescriptor[]>("/api/admin/api-key", errorHandlingFetcher);
|
||||
|
||||
const isLoading = usersLoading || apiKeysLoading;
|
||||
const error = usersError ?? apiKeysError;
|
||||
|
||||
const allRows: MemberRow[] = useMemo(() => {
|
||||
const activeUsers = users.filter((u) => u.is_active);
|
||||
const serviceAccountRows = (apiKeys ?? []).map(apiKeyToMemberRow);
|
||||
return [...activeUsers, ...serviceAccountRows];
|
||||
}, [users, apiKeys]);
|
||||
const { rows: allRows, isLoading, error } = useGroupMemberCandidates();
|
||||
|
||||
async function handleCreate() {
|
||||
const trimmed = groupName.trim();
|
||||
@@ -133,11 +115,11 @@ function CreateGroupPage() {
|
||||
{/* Members table */}
|
||||
{isLoading && <SimpleLoader />}
|
||||
|
||||
{error && (
|
||||
{error ? (
|
||||
<Text as="p" secondaryBody text03>
|
||||
Failed to load users.
|
||||
</Text>
|
||||
)}
|
||||
) : null}
|
||||
|
||||
{!isLoading && !error && (
|
||||
<Section
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import useSWR, { useSWRConfig } from "swr";
|
||||
import useGroupMemberCandidates from "./useGroupMemberCandidates";
|
||||
import { Table, Button } from "@opal/components";
|
||||
import { IllustrationContent } from "@opal/layouts";
|
||||
import { SvgUsers, SvgTrash, SvgMinusCircle, SvgPlusCircle } from "@opal/icons";
|
||||
@@ -19,20 +20,9 @@ import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationMo
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import useAdminUsers from "@/hooks/useAdminUsers";
|
||||
import type { UserGroup } from "@/lib/types";
|
||||
import type {
|
||||
ApiKeyDescriptor,
|
||||
MemberRow,
|
||||
TokenRateLimitDisplay,
|
||||
} from "./interfaces";
|
||||
import {
|
||||
apiKeyToMemberRow,
|
||||
baseColumns,
|
||||
memberTableColumns,
|
||||
tc,
|
||||
PAGE_SIZE,
|
||||
} from "./shared";
|
||||
import type { MemberRow, TokenRateLimitDisplay } from "./interfaces";
|
||||
import { baseColumns, memberTableColumns, tc, PAGE_SIZE } from "./shared";
|
||||
import {
|
||||
USER_GROUP_URL,
|
||||
renameGroup,
|
||||
@@ -104,18 +94,15 @@ function EditGroupPage({ groupId }: EditGroupPageProps) {
|
||||
const initialAgentIdsRef = useRef<number[]>([]);
|
||||
const initialDocSetIdsRef = useRef<number[]>([]);
|
||||
|
||||
// Users and API keys
|
||||
const { users, isLoading: usersLoading, error: usersError } = useAdminUsers();
|
||||
|
||||
// Users + service accounts (curator-accessible — see hook docs).
|
||||
const {
|
||||
data: apiKeys,
|
||||
isLoading: apiKeysLoading,
|
||||
error: apiKeysError,
|
||||
} = useSWR<ApiKeyDescriptor[]>("/api/admin/api-key", errorHandlingFetcher);
|
||||
rows: allRows,
|
||||
isLoading: candidatesLoading,
|
||||
error: candidatesError,
|
||||
} = useGroupMemberCandidates();
|
||||
|
||||
const isLoading =
|
||||
groupLoading || usersLoading || apiKeysLoading || tokenLimitsLoading;
|
||||
const error = groupError ?? usersError ?? apiKeysError;
|
||||
const isLoading = groupLoading || candidatesLoading || tokenLimitsLoading;
|
||||
const error = groupError ?? candidatesError;
|
||||
|
||||
// Pre-populate form when group data loads
|
||||
useEffect(() => {
|
||||
@@ -145,12 +132,6 @@ function EditGroupPage({ groupId }: EditGroupPageProps) {
|
||||
}
|
||||
}, [tokenRateLimits]);
|
||||
|
||||
const allRows = useMemo(() => {
|
||||
const activeUsers = users.filter((u) => u.is_active);
|
||||
const serviceAccountRows = (apiKeys ?? []).map(apiKeyToMemberRow);
|
||||
return [...activeUsers, ...serviceAccountRows];
|
||||
}, [users, apiKeys]);
|
||||
|
||||
const memberRows = useMemo(() => {
|
||||
const selected = new Set(selectedUserIds);
|
||||
return allRows.filter((r) => selected.has(r.id ?? r.email));
|
||||
|
||||
@@ -0,0 +1,161 @@
|
||||
"use client";
|
||||
|
||||
import { useMemo } from "react";
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
|
||||
// Curator-accessible listing of all users (and service-account entries via
|
||||
// `?include_api_keys=true`). The admin-only `/manage/users/accepted/all` and
|
||||
// `/manage/users/invited` endpoints 403 for global curators, which used to
|
||||
// break the Edit Group page entirely — see useGroupMemberCandidates docs.
|
||||
const GROUP_MEMBER_CANDIDATES_URL = "/api/manage/users?include_api_keys=true";
|
||||
const ADMIN_API_KEYS_URL = "/api/admin/api-key";
|
||||
import { UserStatus, type UserRole } from "@/lib/types";
|
||||
|
||||
// Mirrors `DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN` on the backend; service-account
|
||||
// users are identified by this email suffix because release/v3.1 does not yet
|
||||
// expose `account_type` on the FullUserSnapshot returned from `/manage/users`.
|
||||
const API_KEY_EMAIL_SUFFIX = "@onyxapikey.ai";
|
||||
|
||||
function isApiKeyEmail(email: string): boolean {
|
||||
return email.endsWith(API_KEY_EMAIL_SUFFIX);
|
||||
}
|
||||
import type {
|
||||
UserGroupInfo,
|
||||
UserRow,
|
||||
} from "@/refresh-pages/admin/UsersPage/interfaces";
|
||||
import type { ApiKeyDescriptor, MemberRow } from "./interfaces";
|
||||
|
||||
// Backend response shape for `/api/manage/users?include_api_keys=true`. The
|
||||
// existing `AllUsersResponse` in `lib/types.ts` types `accepted` as `User[]`,
|
||||
// which is missing fields the table needs (`personal_name`, `account_type`,
|
||||
// `groups`, etc.), so we declare an accurate local type here.
|
||||
interface FullUserSnapshot {
|
||||
id: string;
|
||||
email: string;
|
||||
role: UserRole;
|
||||
is_active: boolean;
|
||||
password_configured: boolean;
|
||||
personal_name: string | null;
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
groups: UserGroupInfo[];
|
||||
is_scim_synced: boolean;
|
||||
}
|
||||
|
||||
interface ManageUsersResponse {
|
||||
accepted: FullUserSnapshot[];
|
||||
invited: { email: string }[];
|
||||
slack_users: FullUserSnapshot[];
|
||||
accepted_pages: number;
|
||||
invited_pages: number;
|
||||
slack_users_pages: number;
|
||||
}
|
||||
|
||||
function snapshotToMemberRow(snapshot: FullUserSnapshot): MemberRow {
|
||||
return {
|
||||
id: snapshot.id,
|
||||
email: snapshot.email,
|
||||
role: snapshot.role,
|
||||
status: snapshot.is_active ? UserStatus.ACTIVE : UserStatus.INACTIVE,
|
||||
is_active: snapshot.is_active,
|
||||
is_scim_synced: snapshot.is_scim_synced,
|
||||
personal_name: snapshot.personal_name,
|
||||
created_at: snapshot.created_at,
|
||||
updated_at: snapshot.updated_at,
|
||||
groups: snapshot.groups,
|
||||
};
|
||||
}
|
||||
|
||||
function serviceAccountToMemberRow(
|
||||
snapshot: FullUserSnapshot,
|
||||
apiKey: ApiKeyDescriptor | undefined
|
||||
): MemberRow {
|
||||
return {
|
||||
id: snapshot.id,
|
||||
email: "Service Account",
|
||||
role: apiKey?.api_key_role ?? snapshot.role,
|
||||
status: UserStatus.ACTIVE,
|
||||
is_active: true,
|
||||
is_scim_synced: false,
|
||||
personal_name:
|
||||
apiKey?.api_key_name ?? snapshot.personal_name ?? "Unnamed Key",
|
||||
created_at: null,
|
||||
updated_at: null,
|
||||
groups: [],
|
||||
api_key_display: apiKey?.api_key_display,
|
||||
};
|
||||
}
|
||||
|
||||
interface UseGroupMemberCandidatesResult {
|
||||
/** Active users + service-account rows, in the order the table expects. */
|
||||
rows: MemberRow[];
|
||||
/** Subset of `rows` representing real (non-service-account) users. */
|
||||
userRows: MemberRow[];
|
||||
isLoading: boolean;
|
||||
error: unknown;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the candidate list for the group create/edit member pickers.
|
||||
*
|
||||
* Hits `/api/manage/users?include_api_keys=true`, which is gated by
|
||||
* `current_curator_or_admin_user` on the backend, so this works for both
|
||||
* admins and global curators (the admin-only `/accepted/all` and `/invited`
|
||||
* endpoints used to be called here, which 403'd for global curators and broke
|
||||
* the Edit Group page entirely).
|
||||
*
|
||||
* For admins, we additionally fetch `/admin/api-key` to enrich service-account
|
||||
* rows with the masked api-key display string. That call is admin-only and is
|
||||
* skipped for curators; its failure is non-fatal.
|
||||
*/
|
||||
export default function useGroupMemberCandidates(): UseGroupMemberCandidatesResult {
|
||||
const { isAdmin } = useUser();
|
||||
|
||||
const {
|
||||
data: usersData,
|
||||
isLoading: usersLoading,
|
||||
error: usersError,
|
||||
} = useSWR<ManageUsersResponse>(
|
||||
GROUP_MEMBER_CANDIDATES_URL,
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
||||
const { data: apiKeys, isLoading: apiKeysLoading } = useSWR<
|
||||
ApiKeyDescriptor[]
|
||||
>(isAdmin ? ADMIN_API_KEYS_URL : null, errorHandlingFetcher);
|
||||
|
||||
const apiKeysByUserId = useMemo(() => {
|
||||
const map = new Map<string, ApiKeyDescriptor>();
|
||||
for (const key of apiKeys ?? []) map.set(key.user_id, key);
|
||||
return map;
|
||||
}, [apiKeys]);
|
||||
|
||||
const { rows, userRows } = useMemo(() => {
|
||||
const accepted = usersData?.accepted ?? [];
|
||||
const userRowsLocal: MemberRow[] = [];
|
||||
const serviceAccountRows: MemberRow[] = [];
|
||||
for (const snapshot of accepted) {
|
||||
if (!snapshot.is_active) continue;
|
||||
if (isApiKeyEmail(snapshot.email)) {
|
||||
serviceAccountRows.push(
|
||||
serviceAccountToMemberRow(snapshot, apiKeysByUserId.get(snapshot.id))
|
||||
);
|
||||
} else {
|
||||
userRowsLocal.push(snapshotToMemberRow(snapshot));
|
||||
}
|
||||
}
|
||||
return {
|
||||
rows: [...userRowsLocal, ...serviceAccountRows],
|
||||
userRows: userRowsLocal,
|
||||
};
|
||||
}, [usersData, apiKeysByUserId]);
|
||||
|
||||
return {
|
||||
rows,
|
||||
userRows,
|
||||
isLoading: usersLoading || (isAdmin && apiKeysLoading),
|
||||
error: usersError,
|
||||
};
|
||||
}
|
||||
559
web/src/refresh-pages/admin/HooksPage/HookFormModal.tsx
Normal file
559
web/src/refresh-pages/admin/HooksPage/HookFormModal.tsx
Normal file
@@ -0,0 +1,559 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Button, Text } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import {
|
||||
SvgCheckCircle,
|
||||
SvgHookNodes,
|
||||
SvgLoader,
|
||||
SvgRevert,
|
||||
} from "@opal/icons";
|
||||
import Modal, { BasicModalFooter } from "@/refresh-components/Modal";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import PasswordInputTypeIn from "@/refresh-components/inputs/PasswordInputTypeIn";
|
||||
import { FormField } from "@/refresh-components/form/FormField";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { ContentAction } from "@opal/layouts";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import {
|
||||
createHook,
|
||||
updateHook,
|
||||
HookAuthError,
|
||||
HookTimeoutError,
|
||||
HookConnectError,
|
||||
} from "@/refresh-pages/admin/HooksPage/svc";
|
||||
import type {
|
||||
HookFailStrategy,
|
||||
HookFormState,
|
||||
HookPointMeta,
|
||||
HookResponse,
|
||||
HookUpdateRequest,
|
||||
} from "@/refresh-pages/admin/HooksPage/interfaces";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface HookFormModalProps {
|
||||
open: boolean;
|
||||
onOpenChange: (open: boolean) => void;
|
||||
/** When provided, the modal is in edit mode for this hook. */
|
||||
hook?: HookResponse;
|
||||
/** When provided (create mode), the hook point is pre-selected and locked. */
|
||||
spec?: HookPointMeta;
|
||||
onSuccess: (hook: HookResponse) => void;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function buildInitialState(
|
||||
hook: HookResponse | undefined,
|
||||
spec: HookPointMeta | undefined
|
||||
): HookFormState {
|
||||
if (hook) {
|
||||
return {
|
||||
name: hook.name,
|
||||
endpoint_url: hook.endpoint_url ?? "",
|
||||
api_key: "",
|
||||
fail_strategy: hook.fail_strategy,
|
||||
timeout_seconds: String(hook.timeout_seconds),
|
||||
};
|
||||
}
|
||||
return {
|
||||
name: "",
|
||||
endpoint_url: "",
|
||||
api_key: "",
|
||||
fail_strategy: spec?.default_fail_strategy ?? "hard",
|
||||
timeout_seconds: spec ? String(spec.default_timeout_seconds) : "30",
|
||||
};
|
||||
}
|
||||
|
||||
const SOFT_DESCRIPTION =
|
||||
"If the endpoint returns an error, Onyx logs it and continues the pipeline as normal, ignoring the hook result.";
|
||||
|
||||
const MAX_TIMEOUT_SECONDS = 600;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Component
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export default function HookFormModal({
|
||||
open,
|
||||
onOpenChange,
|
||||
hook,
|
||||
spec,
|
||||
onSuccess,
|
||||
}: HookFormModalProps) {
|
||||
const isEdit = !!hook;
|
||||
const [form, setForm] = useState<HookFormState>(() =>
|
||||
buildInitialState(hook, spec)
|
||||
);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [isConnected, setIsConnected] = useState(false);
|
||||
// Tracks whether the user explicitly cleared the API key field in edit mode.
|
||||
// - false + empty field → key unchanged (omitted from PATCH)
|
||||
// - true + empty field → key cleared (api_key: null sent to backend)
|
||||
// - false + non-empty → new key provided (new value sent to backend)
|
||||
const [apiKeyCleared, setApiKeyCleared] = useState(false);
|
||||
const [touched, setTouched] = useState({
|
||||
name: false,
|
||||
endpoint_url: false,
|
||||
api_key: false,
|
||||
});
|
||||
const [apiKeyServerError, setApiKeyServerError] = useState(false);
|
||||
const [endpointServerError, setEndpointServerError] = useState<string | null>(
|
||||
null
|
||||
);
|
||||
const [timeoutServerError, setTimeoutServerError] = useState(false);
|
||||
|
||||
function touch(key: keyof typeof touched) {
|
||||
setTouched((prev) => ({ ...prev, [key]: true }));
|
||||
}
|
||||
|
||||
function handleOpenChange(next: boolean) {
|
||||
if (!next) {
|
||||
if (isSubmitting) return;
|
||||
setTimeout(() => {
|
||||
setForm(buildInitialState(hook, spec));
|
||||
setIsConnected(false);
|
||||
setApiKeyCleared(false);
|
||||
setTouched({ name: false, endpoint_url: false, api_key: false });
|
||||
setApiKeyServerError(false);
|
||||
setEndpointServerError(null);
|
||||
setTimeoutServerError(false);
|
||||
}, 200);
|
||||
}
|
||||
onOpenChange(next);
|
||||
}
|
||||
|
||||
function set<K extends keyof HookFormState>(key: K, value: HookFormState[K]) {
|
||||
setForm((prev) => ({ ...prev, [key]: value }));
|
||||
}
|
||||
|
||||
const timeoutNum = parseFloat(form.timeout_seconds);
|
||||
const isTimeoutValid =
|
||||
!isNaN(timeoutNum) && timeoutNum > 0 && timeoutNum <= MAX_TIMEOUT_SECONDS;
|
||||
const isValid =
|
||||
form.name.trim().length > 0 &&
|
||||
form.endpoint_url.trim().length > 0 &&
|
||||
isTimeoutValid &&
|
||||
(isEdit || form.api_key.trim().length > 0);
|
||||
|
||||
const nameError = touched.name && !form.name.trim();
|
||||
const endpointEmptyError = touched.endpoint_url && !form.endpoint_url.trim();
|
||||
const endpointFieldError = endpointEmptyError
|
||||
? "Endpoint URL cannot be empty."
|
||||
: endpointServerError ?? undefined;
|
||||
const apiKeyEmptyError = !isEdit && touched.api_key && !form.api_key.trim();
|
||||
const apiKeyFieldError = apiKeyEmptyError
|
||||
? "API key cannot be empty."
|
||||
: apiKeyServerError
|
||||
? "Invalid API key."
|
||||
: undefined;
|
||||
|
||||
function handleTimeoutBlur() {
|
||||
if (!isTimeoutValid) {
|
||||
const fallback = hook?.timeout_seconds ?? spec?.default_timeout_seconds;
|
||||
if (fallback !== undefined) {
|
||||
set("timeout_seconds", String(fallback));
|
||||
if (timeoutServerError) setTimeoutServerError(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const hasChanges =
|
||||
isEdit && hook
|
||||
? form.name !== hook.name ||
|
||||
form.endpoint_url !== (hook.endpoint_url ?? "") ||
|
||||
form.fail_strategy !== hook.fail_strategy ||
|
||||
timeoutNum !== hook.timeout_seconds ||
|
||||
form.api_key.trim().length > 0 ||
|
||||
apiKeyCleared
|
||||
: true;
|
||||
|
||||
async function handleSubmit() {
|
||||
if (!isValid) return;
|
||||
|
||||
setIsSubmitting(true);
|
||||
try {
|
||||
let result: HookResponse;
|
||||
if (isEdit && hook) {
|
||||
const req: HookUpdateRequest = {};
|
||||
if (form.name !== hook.name) req.name = form.name;
|
||||
if (form.endpoint_url !== (hook.endpoint_url ?? ""))
|
||||
req.endpoint_url = form.endpoint_url;
|
||||
if (form.fail_strategy !== hook.fail_strategy)
|
||||
req.fail_strategy = form.fail_strategy;
|
||||
if (timeoutNum !== hook.timeout_seconds)
|
||||
req.timeout_seconds = timeoutNum;
|
||||
if (form.api_key.trim().length > 0) {
|
||||
req.api_key = form.api_key;
|
||||
} else if (apiKeyCleared) {
|
||||
req.api_key = null;
|
||||
}
|
||||
if (Object.keys(req).length === 0) {
|
||||
setIsSubmitting(false);
|
||||
handleOpenChange(false);
|
||||
return;
|
||||
}
|
||||
result = await updateHook(hook.id, req);
|
||||
} else {
|
||||
if (!spec) {
|
||||
toast.error("No hook point specified.");
|
||||
setIsSubmitting(false);
|
||||
return;
|
||||
}
|
||||
result = await createHook({
|
||||
name: form.name,
|
||||
hook_point: spec.hook_point,
|
||||
endpoint_url: form.endpoint_url,
|
||||
...(form.api_key ? { api_key: form.api_key } : {}),
|
||||
fail_strategy: form.fail_strategy,
|
||||
timeout_seconds: timeoutNum,
|
||||
});
|
||||
}
|
||||
toast.success(isEdit ? "Hook updated." : "Hook created.");
|
||||
onSuccess(result);
|
||||
if (!isEdit) {
|
||||
setIsConnected(true);
|
||||
await new Promise((resolve) => setTimeout(resolve, 500));
|
||||
}
|
||||
setIsSubmitting(false);
|
||||
handleOpenChange(false);
|
||||
} catch (err) {
|
||||
if (err instanceof HookAuthError) {
|
||||
setApiKeyServerError(true);
|
||||
} else if (err instanceof HookTimeoutError) {
|
||||
setTimeoutServerError(true);
|
||||
} else if (err instanceof HookConnectError) {
|
||||
setEndpointServerError(err.message || "Could not connect to endpoint.");
|
||||
} else {
|
||||
toast.error(
|
||||
err instanceof Error ? err.message : "Something went wrong."
|
||||
);
|
||||
}
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}
|
||||
|
||||
const hookPointDisplayName =
|
||||
spec?.display_name ?? spec?.hook_point ?? hook?.hook_point ?? "";
|
||||
const hookPointDescription = spec?.description;
|
||||
const docsUrl = spec?.docs_url;
|
||||
|
||||
const failStrategyDescription =
|
||||
form.fail_strategy === "soft"
|
||||
? SOFT_DESCRIPTION
|
||||
: spec?.fail_hard_description;
|
||||
|
||||
return (
|
||||
<Modal open={open} onOpenChange={handleOpenChange}>
|
||||
<Modal.Content width="md" height="fit">
|
||||
<Modal.Header
|
||||
icon={SvgHookNodes}
|
||||
title={isEdit ? "Manage Hook Extension" : "Set Up Hook Extension"}
|
||||
description={
|
||||
isEdit
|
||||
? undefined
|
||||
: "Connect an external API endpoint to extend the hook point."
|
||||
}
|
||||
onClose={() => handleOpenChange(false)}
|
||||
/>
|
||||
|
||||
<Modal.Body>
|
||||
{/* Hook point section header */}
|
||||
<ContentAction
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
paddingVariant="fit"
|
||||
title={hookPointDisplayName}
|
||||
description={hookPointDescription}
|
||||
rightChildren={
|
||||
<Section
|
||||
flexDirection="column"
|
||||
alignItems="end"
|
||||
width="fit"
|
||||
height="fit"
|
||||
gap={0.25}
|
||||
>
|
||||
<div className="flex items-center gap-1">
|
||||
<SvgHookNodes
|
||||
style={{ width: "1rem", height: "1rem" }}
|
||||
className="text-text-03 shrink-0"
|
||||
/>
|
||||
<Text font="secondary-body" color="text-03">
|
||||
Hook Point
|
||||
</Text>
|
||||
</div>
|
||||
{docsUrl && (
|
||||
<a
|
||||
href={docsUrl}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="underline"
|
||||
>
|
||||
<Text font="secondary-body" color="text-03">
|
||||
Documentation
|
||||
</Text>
|
||||
</a>
|
||||
)}
|
||||
</Section>
|
||||
}
|
||||
/>
|
||||
|
||||
<FormField className="w-full" state={nameError ? "error" : "idle"}>
|
||||
<FormField.Label>Display Name</FormField.Label>
|
||||
<FormField.Control>
|
||||
<div className="[&_input::placeholder]:!font-main-ui-muted w-full">
|
||||
<InputTypeIn
|
||||
value={form.name}
|
||||
onChange={(e) => set("name", e.target.value)}
|
||||
onBlur={() => touch("name")}
|
||||
placeholder="Name your extension at this hook point"
|
||||
variant={
|
||||
isSubmitting ? "disabled" : nameError ? "error" : undefined
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</FormField.Control>
|
||||
<FormField.Message
|
||||
messages={{ error: "Display name cannot be empty." }}
|
||||
/>
|
||||
</FormField>
|
||||
|
||||
<FormField className="w-full">
|
||||
<FormField.Label>Fail Strategy</FormField.Label>
|
||||
<FormField.Control>
|
||||
<InputSelect
|
||||
value={form.fail_strategy}
|
||||
onValueChange={(v) =>
|
||||
set("fail_strategy", v as HookFailStrategy)
|
||||
}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
<InputSelect.Trigger placeholder="Select strategy" />
|
||||
<InputSelect.Content>
|
||||
<InputSelect.Item value="soft">
|
||||
Log Error and Continue
|
||||
{spec?.default_fail_strategy === "soft" && (
|
||||
<>
|
||||
{" "}
|
||||
<Text color="text-03">(Default)</Text>
|
||||
</>
|
||||
)}
|
||||
</InputSelect.Item>
|
||||
<InputSelect.Item value="hard">
|
||||
Block Pipeline on Failure
|
||||
{spec?.default_fail_strategy === "hard" && (
|
||||
<>
|
||||
{" "}
|
||||
<Text color="text-03">(Default)</Text>
|
||||
</>
|
||||
)}
|
||||
</InputSelect.Item>
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</FormField.Control>
|
||||
<FormField.Description>
|
||||
{failStrategyDescription}
|
||||
</FormField.Description>
|
||||
</FormField>
|
||||
|
||||
<FormField
|
||||
className="w-full"
|
||||
state={timeoutServerError ? "error" : "idle"}
|
||||
>
|
||||
<FormField.Label>
|
||||
Timeout{" "}
|
||||
<Text font="main-ui-action" color="text-03">
|
||||
(seconds)
|
||||
</Text>
|
||||
</FormField.Label>
|
||||
<FormField.Control>
|
||||
<div className="[&_input]:!font-main-ui-mono [&_input::placeholder]:!font-main-ui-mono [&_input]:![appearance:textfield] [&_input::-webkit-outer-spin-button]:!appearance-none [&_input::-webkit-inner-spin-button]:!appearance-none w-full">
|
||||
<InputTypeIn
|
||||
type="number"
|
||||
value={form.timeout_seconds}
|
||||
onChange={(e) => {
|
||||
set("timeout_seconds", e.target.value);
|
||||
if (timeoutServerError) setTimeoutServerError(false);
|
||||
}}
|
||||
onBlur={handleTimeoutBlur}
|
||||
placeholder={
|
||||
spec ? String(spec.default_timeout_seconds) : undefined
|
||||
}
|
||||
variant={
|
||||
isSubmitting
|
||||
? "disabled"
|
||||
: timeoutServerError
|
||||
? "error"
|
||||
: undefined
|
||||
}
|
||||
showClearButton={false}
|
||||
rightSection={
|
||||
spec?.default_timeout_seconds !== undefined &&
|
||||
form.timeout_seconds !==
|
||||
String(spec.default_timeout_seconds) ? (
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
size="xs"
|
||||
icon={SvgRevert}
|
||||
tooltip="Revert to Default"
|
||||
onClick={() =>
|
||||
set(
|
||||
"timeout_seconds",
|
||||
String(spec.default_timeout_seconds)
|
||||
)
|
||||
}
|
||||
disabled={isSubmitting}
|
||||
/>
|
||||
) : undefined
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</FormField.Control>
|
||||
{!timeoutServerError && (
|
||||
<FormField.Description>
|
||||
Maximum time Onyx will wait for the endpoint to respond before
|
||||
applying the fail strategy. Must be greater than 0 and at most{" "}
|
||||
{MAX_TIMEOUT_SECONDS} seconds.
|
||||
</FormField.Description>
|
||||
)}
|
||||
<FormField.Message
|
||||
messages={{
|
||||
error: "Connection timed out. Try increasing the timeout.",
|
||||
}}
|
||||
/>
|
||||
</FormField>
|
||||
|
||||
<FormField
|
||||
className="w-full"
|
||||
state={endpointFieldError ? "error" : "idle"}
|
||||
>
|
||||
<FormField.Label>External API Endpoint URL</FormField.Label>
|
||||
<FormField.Control>
|
||||
<div className="[&_input::placeholder]:!font-main-ui-muted w-full">
|
||||
<InputTypeIn
|
||||
value={form.endpoint_url}
|
||||
onChange={(e) => {
|
||||
set("endpoint_url", e.target.value);
|
||||
if (endpointServerError) setEndpointServerError(null);
|
||||
}}
|
||||
onBlur={() => touch("endpoint_url")}
|
||||
placeholder="https://your-api-endpoint.com"
|
||||
variant={
|
||||
isSubmitting
|
||||
? "disabled"
|
||||
: endpointFieldError
|
||||
? "error"
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</FormField.Control>
|
||||
{!endpointFieldError && (
|
||||
<FormField.Description>
|
||||
Only connect to servers you trust. You are responsible for
|
||||
actions taken and data shared with this connection.
|
||||
</FormField.Description>
|
||||
)}
|
||||
<FormField.Message messages={{ error: endpointFieldError }} />
|
||||
</FormField>
|
||||
|
||||
<FormField
|
||||
className="w-full"
|
||||
state={apiKeyFieldError ? "error" : "idle"}
|
||||
>
|
||||
<FormField.Label>API Key</FormField.Label>
|
||||
<FormField.Control>
|
||||
<PasswordInputTypeIn
|
||||
value={form.api_key}
|
||||
onChange={(e) => {
|
||||
set("api_key", e.target.value);
|
||||
if (apiKeyServerError) setApiKeyServerError(false);
|
||||
if (isEdit) {
|
||||
setApiKeyCleared(
|
||||
e.target.value === "" && !!hook?.api_key_masked
|
||||
);
|
||||
}
|
||||
}}
|
||||
onBlur={() => touch("api_key")}
|
||||
placeholder={
|
||||
isEdit
|
||||
? hook?.api_key_masked ?? "Leave blank to keep current key"
|
||||
: undefined
|
||||
}
|
||||
disabled={isSubmitting}
|
||||
error={!!apiKeyFieldError}
|
||||
/>
|
||||
</FormField.Control>
|
||||
{!apiKeyFieldError && (
|
||||
<FormField.Description>
|
||||
Onyx will use this key to authenticate with your API endpoint.
|
||||
</FormField.Description>
|
||||
)}
|
||||
<FormField.Message messages={{ error: apiKeyFieldError }} />
|
||||
</FormField>
|
||||
|
||||
{!isEdit && (isSubmitting || isConnected) && (
|
||||
<Section
|
||||
flexDirection="row"
|
||||
alignItems="center"
|
||||
justifyContent="start"
|
||||
height="fit"
|
||||
gap={1}
|
||||
className="px-0.5"
|
||||
>
|
||||
<div className="p-0.5 shrink-0">
|
||||
{isConnected ? (
|
||||
<SvgCheckCircle
|
||||
size={16}
|
||||
className="text-status-success-05"
|
||||
/>
|
||||
) : (
|
||||
<SvgLoader size={16} className="animate-spin text-text-03" />
|
||||
)}
|
||||
</div>
|
||||
<Text font="secondary-body" color="text-03">
|
||||
{isConnected ? "Connection valid." : "Verifying connection…"}
|
||||
</Text>
|
||||
</Section>
|
||||
)}
|
||||
</Modal.Body>
|
||||
|
||||
<Modal.Footer>
|
||||
<BasicModalFooter
|
||||
cancel={
|
||||
<Disabled disabled={isSubmitting}>
|
||||
<Button
|
||||
prominence="secondary"
|
||||
onClick={() => handleOpenChange(false)}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
</Disabled>
|
||||
}
|
||||
submit={
|
||||
<Disabled disabled={isSubmitting || !isValid || !hasChanges}>
|
||||
<Button
|
||||
onClick={handleSubmit}
|
||||
icon={
|
||||
isSubmitting && !isEdit
|
||||
? () => <SvgLoader size={16} className="animate-spin" />
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
{isEdit ? "Save Changes" : "Connect"}
|
||||
</Button>
|
||||
</Disabled>
|
||||
}
|
||||
/>
|
||||
</Modal.Footer>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
@@ -29,6 +29,14 @@ export interface HookResponse {
|
||||
updated_at: string;
|
||||
}
|
||||
|
||||
export interface HookFormState {
|
||||
name: string;
|
||||
endpoint_url: string;
|
||||
api_key: string;
|
||||
fail_strategy: HookFailStrategy;
|
||||
timeout_seconds: string;
|
||||
}
|
||||
|
||||
export interface HookCreateRequest {
|
||||
name: string;
|
||||
hook_point: HookPoint;
|
||||
|
||||
@@ -5,15 +5,27 @@ import {
|
||||
HookValidateResponse,
|
||||
} from "@/refresh-pages/admin/HooksPage/interfaces";
|
||||
|
||||
async function parseErrorDetail(
|
||||
res: Response,
|
||||
fallback: string
|
||||
): Promise<string> {
|
||||
export class HookAuthError extends Error {}
|
||||
export class HookTimeoutError extends Error {}
|
||||
export class HookConnectError extends Error {}
|
||||
|
||||
async function parseError(res: Response, fallback: string): Promise<Error> {
|
||||
try {
|
||||
const body = await res.json();
|
||||
return body?.detail ?? fallback;
|
||||
if (body?.error_code === "CREDENTIAL_INVALID") {
|
||||
return new HookAuthError(body?.detail ?? "Invalid API key.");
|
||||
}
|
||||
if (body?.error_code === "GATEWAY_TIMEOUT") {
|
||||
return new HookTimeoutError(body?.detail ?? "Connection timed out.");
|
||||
}
|
||||
if (body?.error_code === "BAD_GATEWAY") {
|
||||
return new HookConnectError(
|
||||
body?.detail ?? "Could not connect to endpoint."
|
||||
);
|
||||
}
|
||||
return new Error(body?.detail ?? fallback);
|
||||
} catch {
|
||||
return fallback;
|
||||
return new Error(fallback);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,7 +38,7 @@ export async function createHook(
|
||||
body: JSON.stringify(req),
|
||||
});
|
||||
if (!res.ok) {
|
||||
throw new Error(await parseErrorDetail(res, "Failed to create hook"));
|
||||
throw await parseError(res, "Failed to create hook");
|
||||
}
|
||||
return res.json();
|
||||
}
|
||||
@@ -41,7 +53,7 @@ export async function updateHook(
|
||||
body: JSON.stringify(req),
|
||||
});
|
||||
if (!res.ok) {
|
||||
throw new Error(await parseErrorDetail(res, "Failed to update hook"));
|
||||
throw await parseError(res, "Failed to update hook");
|
||||
}
|
||||
return res.json();
|
||||
}
|
||||
@@ -49,7 +61,7 @@ export async function updateHook(
|
||||
export async function deleteHook(id: number): Promise<void> {
|
||||
const res = await fetch(`/api/admin/hooks/${id}`, { method: "DELETE" });
|
||||
if (!res.ok) {
|
||||
throw new Error(await parseErrorDetail(res, "Failed to delete hook"));
|
||||
throw await parseError(res, "Failed to delete hook");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,7 +70,7 @@ export async function activateHook(id: number): Promise<HookResponse> {
|
||||
method: "POST",
|
||||
});
|
||||
if (!res.ok) {
|
||||
throw new Error(await parseErrorDetail(res, "Failed to activate hook"));
|
||||
throw await parseError(res, "Failed to activate hook");
|
||||
}
|
||||
return res.json();
|
||||
}
|
||||
@@ -68,7 +80,7 @@ export async function deactivateHook(id: number): Promise<HookResponse> {
|
||||
method: "POST",
|
||||
});
|
||||
if (!res.ok) {
|
||||
throw new Error(await parseErrorDetail(res, "Failed to deactivate hook"));
|
||||
throw await parseError(res, "Failed to deactivate hook");
|
||||
}
|
||||
return res.json();
|
||||
}
|
||||
@@ -78,7 +90,7 @@ export async function validateHook(id: number): Promise<HookValidateResponse> {
|
||||
method: "POST",
|
||||
});
|
||||
if (!res.ok) {
|
||||
throw new Error(await parseErrorDetail(res, "Failed to validate hook"));
|
||||
throw await parseError(res, "Failed to validate hook");
|
||||
}
|
||||
return res.json();
|
||||
}
|
||||
|
||||
@@ -45,6 +45,7 @@ import OpenRouterModal from "@/sections/modals/llmConfig/OpenRouterModal";
|
||||
import CustomModal from "@/sections/modals/llmConfig/CustomModal";
|
||||
import LMStudioForm from "@/sections/modals/llmConfig/LMStudioForm";
|
||||
import LiteLLMProxyModal from "@/sections/modals/llmConfig/LiteLLMProxyModal";
|
||||
import BifrostModal from "@/sections/modals/llmConfig/BifrostModal";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
|
||||
const route = ADMIN_ROUTES.LLM_MODELS;
|
||||
@@ -65,6 +66,7 @@ const PROVIDER_DISPLAY_ORDER: string[] = [
|
||||
"ollama_chat",
|
||||
"openrouter",
|
||||
"lm_studio",
|
||||
"bifrost",
|
||||
];
|
||||
|
||||
const PROVIDER_MODAL_MAP: Record<
|
||||
@@ -138,6 +140,13 @@ const PROVIDER_MODAL_MAP: Record<
|
||||
onOpenChange={onOpenChange}
|
||||
/>
|
||||
),
|
||||
bifrost: (d, open, onOpenChange) => (
|
||||
<BifrostModal
|
||||
shouldMarkAsDefault={d}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
/>
|
||||
),
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
"use client";
|
||||
|
||||
import { useMemo, useState } from "react";
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
import {
|
||||
AzureIcon,
|
||||
ElevenLabsIcon,
|
||||
IconProps,
|
||||
OpenAIIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Select } from "@/refresh-components/cards";
|
||||
import Message from "@/refresh-components/messages/Message";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
@@ -19,12 +18,19 @@ import {
|
||||
import {
|
||||
activateVoiceProvider,
|
||||
deactivateVoiceProvider,
|
||||
deleteVoiceProvider,
|
||||
} from "@/lib/admin/voice/svc";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { Callout } from "@/components/ui/callout";
|
||||
import { Content } from "@opal/layouts";
|
||||
import { SvgMicrophone } from "@opal/icons";
|
||||
import { SvgMicrophone, SvgSlash, SvgUnplug } from "@opal/icons";
|
||||
import { Button, Text } from "@opal/components";
|
||||
import { markdown } from "@opal/utils";
|
||||
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { ADMIN_ROUTES } from "@/lib/admin-routes";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import VoiceProviderSetupModal from "@/app/admin/configuration/voice/VoiceProviderSetupModal";
|
||||
|
||||
interface ModelDetails {
|
||||
@@ -129,10 +135,153 @@ function getProviderIcon(
|
||||
|
||||
type ProviderMode = "stt" | "tts";
|
||||
|
||||
function getProviderLabel(providerType: string): string {
|
||||
switch (providerType) {
|
||||
case "openai":
|
||||
return "OpenAI";
|
||||
case "azure":
|
||||
return "Azure";
|
||||
case "elevenlabs":
|
||||
return "ElevenLabs";
|
||||
default:
|
||||
return providerType;
|
||||
}
|
||||
}
|
||||
|
||||
const NO_DEFAULT_VALUE = "__none__";
|
||||
|
||||
const route = ADMIN_ROUTES.VOICE;
|
||||
const pageDescription =
|
||||
"Configure speech-to-text and text-to-speech providers for voice input and spoken responses.";
|
||||
|
||||
interface VoiceDisconnectModalProps {
|
||||
disconnectTarget: {
|
||||
providerId: number;
|
||||
providerLabel: string;
|
||||
providerType: string;
|
||||
};
|
||||
providers: VoiceProviderView[];
|
||||
replacementProviderId: string | null;
|
||||
onReplacementChange: (id: string | null) => void;
|
||||
onClose: () => void;
|
||||
onDisconnect: () => void;
|
||||
}
|
||||
|
||||
function VoiceDisconnectModal({
|
||||
disconnectTarget,
|
||||
providers,
|
||||
replacementProviderId,
|
||||
onReplacementChange,
|
||||
onClose,
|
||||
onDisconnect,
|
||||
}: VoiceDisconnectModalProps) {
|
||||
const targetProvider = providers.find(
|
||||
(p) => p.id === disconnectTarget.providerId
|
||||
);
|
||||
const isActive =
|
||||
(targetProvider?.is_default_stt ?? false) ||
|
||||
(targetProvider?.is_default_tts ?? false);
|
||||
|
||||
// Find other configured providers that could serve as replacements
|
||||
const replacementOptions = providers.filter(
|
||||
(p) => p.id !== disconnectTarget.providerId && p.has_api_key
|
||||
);
|
||||
|
||||
const needsReplacement = isActive;
|
||||
const hasReplacements = replacementOptions.length > 0;
|
||||
|
||||
// Auto-select first replacement when modal opens
|
||||
useEffect(() => {
|
||||
if (needsReplacement && hasReplacements && !replacementProviderId) {
|
||||
const first = replacementOptions[0];
|
||||
if (first) onReplacementChange(String(first.id));
|
||||
}
|
||||
}, []); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
return (
|
||||
<ConfirmationModalLayout
|
||||
icon={SvgUnplug}
|
||||
title={`Disconnect ${disconnectTarget.providerLabel}`}
|
||||
description="Voice models"
|
||||
onClose={onClose}
|
||||
submit={
|
||||
<Button
|
||||
variant="danger"
|
||||
onClick={onDisconnect}
|
||||
disabled={
|
||||
needsReplacement && hasReplacements && !replacementProviderId
|
||||
}
|
||||
>
|
||||
Disconnect
|
||||
</Button>
|
||||
}
|
||||
>
|
||||
{needsReplacement ? (
|
||||
hasReplacements ? (
|
||||
<Section alignItems="start">
|
||||
<Text as="p" color="text-03">
|
||||
{markdown(
|
||||
`**${disconnectTarget.providerLabel}** models will no longer be used for speech-to-text or text-to-speech, and it will no longer be your default. Session history will be preserved.`
|
||||
)}
|
||||
</Text>
|
||||
<Section alignItems="start" gap={0.25}>
|
||||
<Text as="p" color="text-04">
|
||||
Set New Default
|
||||
</Text>
|
||||
<InputSelect
|
||||
value={replacementProviderId ?? undefined}
|
||||
onValueChange={(v) => onReplacementChange(v)}
|
||||
>
|
||||
<InputSelect.Trigger placeholder="Select a replacement provider" />
|
||||
<InputSelect.Content>
|
||||
{replacementOptions.map((p) => (
|
||||
<InputSelect.Item
|
||||
key={p.id}
|
||||
value={String(p.id)}
|
||||
icon={getProviderIcon(p.provider_type)}
|
||||
>
|
||||
{getProviderLabel(p.provider_type)}
|
||||
</InputSelect.Item>
|
||||
))}
|
||||
<InputSelect.Separator />
|
||||
<InputSelect.Item value={NO_DEFAULT_VALUE} icon={SvgSlash}>
|
||||
<span>
|
||||
<b>No Default</b>
|
||||
<span className="text-text-03"> (Disable Voice)</span>
|
||||
</span>
|
||||
</InputSelect.Item>
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</Section>
|
||||
</Section>
|
||||
) : (
|
||||
<>
|
||||
<Text as="p" color="text-03">
|
||||
{markdown(
|
||||
`**${disconnectTarget.providerLabel}** models will no longer be used for speech-to-text or text-to-speech, and it will no longer be your default.`
|
||||
)}
|
||||
</Text>
|
||||
<Text as="p" color="text-03">
|
||||
Connect another provider to continue using voice.
|
||||
</Text>
|
||||
</>
|
||||
)
|
||||
) : (
|
||||
<>
|
||||
<Text as="p" color="text-03">
|
||||
{markdown(
|
||||
`**${disconnectTarget.providerLabel}** models will no longer be available for voice.`
|
||||
)}
|
||||
</Text>
|
||||
<Text as="p" color="text-03">
|
||||
Session history will be preserved.
|
||||
</Text>
|
||||
</>
|
||||
)}
|
||||
</ConfirmationModalLayout>
|
||||
);
|
||||
}
|
||||
|
||||
export default function VoiceConfigurationPage() {
|
||||
const [modalOpen, setModalOpen] = useState(false);
|
||||
const [selectedProvider, setSelectedProvider] = useState<string | null>(null);
|
||||
@@ -146,6 +295,14 @@ export default function VoiceConfigurationPage() {
|
||||
const [ttsActivationError, setTTSActivationError] = useState<string | null>(
|
||||
null
|
||||
);
|
||||
const [disconnectTarget, setDisconnectTarget] = useState<{
|
||||
providerId: number;
|
||||
providerLabel: string;
|
||||
providerType: string;
|
||||
} | null>(null);
|
||||
const [replacementProviderId, setReplacementProviderId] = useState<
|
||||
string | null
|
||||
>(null);
|
||||
|
||||
const { providers, error, isLoading, refresh: mutate } = useVoiceProviders();
|
||||
|
||||
@@ -237,6 +394,65 @@ export default function VoiceConfigurationPage() {
|
||||
handleModalClose();
|
||||
};
|
||||
|
||||
const handleDisconnect = async () => {
|
||||
if (!disconnectTarget) return;
|
||||
try {
|
||||
const targetProvider = providers.find(
|
||||
(p) => p.id === disconnectTarget.providerId
|
||||
);
|
||||
|
||||
// If a replacement was selected (not "No Default"), activate it for each
|
||||
// mode the disconnected provider was default for
|
||||
if (replacementProviderId && replacementProviderId !== NO_DEFAULT_VALUE) {
|
||||
const repId = Number(replacementProviderId);
|
||||
|
||||
if (targetProvider?.is_default_stt) {
|
||||
const resp = await activateVoiceProvider(repId, "stt");
|
||||
if (!resp.ok) {
|
||||
const errorBody = await resp.json().catch(() => ({}));
|
||||
throw new Error(
|
||||
typeof errorBody?.detail === "string"
|
||||
? errorBody.detail
|
||||
: "Failed to activate replacement STT provider."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (targetProvider?.is_default_tts) {
|
||||
const resp = await activateVoiceProvider(repId, "tts");
|
||||
if (!resp.ok) {
|
||||
const errorBody = await resp.json().catch(() => ({}));
|
||||
throw new Error(
|
||||
typeof errorBody?.detail === "string"
|
||||
? errorBody.detail
|
||||
: "Failed to activate replacement TTS provider."
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const response = await deleteVoiceProvider(disconnectTarget.providerId);
|
||||
if (!response.ok) {
|
||||
const errorBody = await response.json().catch(() => ({}));
|
||||
throw new Error(
|
||||
typeof errorBody?.detail === "string"
|
||||
? errorBody.detail
|
||||
: "Failed to disconnect provider."
|
||||
);
|
||||
}
|
||||
await mutate();
|
||||
toast.success(`${disconnectTarget.providerLabel} disconnected`);
|
||||
} catch (err) {
|
||||
console.error("Failed to disconnect voice provider:", err);
|
||||
toast.error(
|
||||
err instanceof Error ? err.message : "Unexpected error occurred."
|
||||
);
|
||||
} finally {
|
||||
setDisconnectTarget(null);
|
||||
setReplacementProviderId(null);
|
||||
}
|
||||
};
|
||||
|
||||
const isProviderConfigured = (provider?: VoiceProviderView): boolean => {
|
||||
return !!provider?.has_api_key;
|
||||
};
|
||||
@@ -289,6 +505,16 @@ export default function VoiceConfigurationPage() {
|
||||
onEdit={() => {
|
||||
if (provider) handleEdit(provider, mode, model.id);
|
||||
}}
|
||||
onDisconnect={
|
||||
status !== "disconnected" && provider
|
||||
? () =>
|
||||
setDisconnectTarget({
|
||||
providerId: provider.id,
|
||||
providerLabel: getProviderLabel(model.providerType),
|
||||
providerType: model.providerType,
|
||||
})
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
);
|
||||
};
|
||||
@@ -311,7 +537,7 @@ export default function VoiceConfigurationPage() {
|
||||
<Callout type="danger" title="Failed to load voice settings">
|
||||
{message}
|
||||
{detail && (
|
||||
<Text as="p" mainContentBody text03>
|
||||
<Text as="p" font="main-content-body" color="text-03">
|
||||
{detail}
|
||||
</Text>
|
||||
)}
|
||||
@@ -401,7 +627,7 @@ export default function VoiceConfigurationPage() {
|
||||
|
||||
{TTS_PROVIDER_GROUPS.map((group) => (
|
||||
<div key={group.providerType} className="flex flex-col gap-2">
|
||||
<Text secondaryBody text03>
|
||||
<Text font="secondary-body" color="text-03">
|
||||
{group.providerLabel}
|
||||
</Text>
|
||||
<div className="flex flex-col gap-2">
|
||||
@@ -412,6 +638,20 @@ export default function VoiceConfigurationPage() {
|
||||
</div>
|
||||
</SettingsLayouts.Body>
|
||||
|
||||
{disconnectTarget && (
|
||||
<VoiceDisconnectModal
|
||||
disconnectTarget={disconnectTarget}
|
||||
providers={providers}
|
||||
replacementProviderId={replacementProviderId}
|
||||
onReplacementChange={setReplacementProviderId}
|
||||
onClose={() => {
|
||||
setDisconnectTarget(null);
|
||||
setReplacementProviderId(null);
|
||||
}}
|
||||
onDisconnect={() => void handleDisconnect()}
|
||||
/>
|
||||
)}
|
||||
|
||||
{modalOpen && selectedProvider && (
|
||||
<VoiceProviderSetupModal
|
||||
providerType={selectedProvider}
|
||||
|
||||
@@ -122,7 +122,10 @@ function MicrophoneButton({
|
||||
startRecording,
|
||||
stopRecording,
|
||||
setMuted,
|
||||
} = useVoiceRecorder({ onFinalTranscript: handleFinalTranscript });
|
||||
} = useVoiceRecorder({
|
||||
onFinalTranscript: handleFinalTranscript,
|
||||
autoStopOnSilence: autoSend,
|
||||
});
|
||||
|
||||
// Expose stopRecording to parent
|
||||
useEffect(() => {
|
||||
|
||||
278
web/src/sections/modals/llmConfig/BifrostModal.tsx
Normal file
278
web/src/sections/modals/llmConfig/BifrostModal.tsx
Normal file
@@ -0,0 +1,278 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import { markdown } from "@opal/utils";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { Formik, FormikProps } from "formik";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
} from "@/interfaces/llm";
|
||||
import { fetchBifrostModels } from "@/app/admin/configuration/llm/utils";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import {
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
BaseLLMFormValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import {
|
||||
ModelsField,
|
||||
DisplayNameField,
|
||||
ModelsAccessField,
|
||||
FieldSeparator,
|
||||
FieldWrapper,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
|
||||
const BIFROST_PROVIDER_NAME = LLMProviderName.BIFROST;
|
||||
const DEFAULT_API_BASE = "";
|
||||
|
||||
interface BifrostModalValues extends BaseLLMFormValues {
|
||||
api_key: string;
|
||||
api_base: string;
|
||||
}
|
||||
|
||||
interface BifrostModalInternalsProps {
|
||||
formikProps: FormikProps<BifrostModalValues>;
|
||||
existingLlmProvider: LLMProviderView | undefined;
|
||||
fetchedModels: ModelConfiguration[];
|
||||
setFetchedModels: (models: ModelConfiguration[]) => void;
|
||||
modelConfigurations: ModelConfiguration[];
|
||||
isTesting: boolean;
|
||||
onClose: () => void;
|
||||
isOnboarding: boolean;
|
||||
}
|
||||
|
||||
function BifrostModalInternals({
|
||||
formikProps,
|
||||
existingLlmProvider,
|
||||
fetchedModels,
|
||||
setFetchedModels,
|
||||
modelConfigurations,
|
||||
isTesting,
|
||||
onClose,
|
||||
isOnboarding,
|
||||
}: BifrostModalInternalsProps) {
|
||||
const currentModels =
|
||||
fetchedModels.length > 0
|
||||
? fetchedModels
|
||||
: existingLlmProvider?.model_configurations || modelConfigurations;
|
||||
|
||||
const isFetchDisabled = !formikProps.values.api_base;
|
||||
|
||||
const handleFetchModels = async () => {
|
||||
const { models, error } = await fetchBifrostModels({
|
||||
api_base: formikProps.values.api_base,
|
||||
api_key: formikProps.values.api_key || undefined,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
setFetchedModels(models);
|
||||
};
|
||||
|
||||
// Auto-fetch models on initial load when editing an existing provider
|
||||
useEffect(() => {
|
||||
if (existingLlmProvider && !isFetchDisabled) {
|
||||
handleFetchModels().catch((err) => {
|
||||
console.error("Failed to fetch Bifrost models:", err);
|
||||
toast.error(
|
||||
err instanceof Error ? err.message : "Failed to fetch models"
|
||||
);
|
||||
});
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={LLMProviderName.BIFROST}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
>
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="api_base"
|
||||
title="API Base URL"
|
||||
subDescription="Paste your Bifrost gateway endpoint URL (including API version)."
|
||||
>
|
||||
<InputTypeInField
|
||||
name="api_base"
|
||||
placeholder="https://your-bifrost-gateway.com/v1"
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="api_key"
|
||||
title="API Key"
|
||||
optional={true}
|
||||
subDescription={markdown(
|
||||
"Paste your API key from [Bifrost](https://docs.getbifrost.ai/overview) to access your models."
|
||||
)}
|
||||
>
|
||||
<PasswordInputTypeInField name="api_key" placeholder="API Key" />
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<FieldSeparator />
|
||||
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. anthropic/claude-sonnet-4-6" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={currentModels}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={null}
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
)}
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
</>
|
||||
)}
|
||||
</LLMConfigurationModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
export default function BifrostModal({
|
||||
variant = "llm-configuration",
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
|
||||
BIFROST_PROVIDER_NAME
|
||||
);
|
||||
|
||||
if (open === false) return null;
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const initialValues: BifrostModalValues = isOnboarding
|
||||
? ({
|
||||
...buildOnboardingInitialValues(),
|
||||
name: BIFROST_PROVIDER_NAME,
|
||||
provider: BIFROST_PROVIDER_NAME,
|
||||
api_key: "",
|
||||
api_base: DEFAULT_API_BASE,
|
||||
default_model_name: "",
|
||||
} as BifrostModalValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
});
|
||||
|
||||
return (
|
||||
<Formik
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
fetchedModels.length > 0 ? fetchedModels : [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: BIFROST_PROVIDER_NAME,
|
||||
payload: {
|
||||
...values,
|
||||
model_configurations: modelConfigsToUse,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: BIFROST_PROVIDER_NAME,
|
||||
values,
|
||||
initialValues,
|
||||
modelConfigurations:
|
||||
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
{(formikProps) => (
|
||||
<BifrostModalInternals
|
||||
formikProps={formikProps}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
fetchedModels={fetchedModels}
|
||||
setFetchedModels={setFetchedModels}
|
||||
modelConfigurations={modelConfigurations}
|
||||
isTesting={isTesting}
|
||||
onClose={onClose}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
)}
|
||||
</Formik>
|
||||
);
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import CustomModal from "@/sections/modals/llmConfig/CustomModal";
|
||||
import BedrockModal from "@/sections/modals/llmConfig/BedrockModal";
|
||||
import LMStudioForm from "@/sections/modals/llmConfig/LMStudioForm";
|
||||
import LiteLLMProxyModal from "@/sections/modals/llmConfig/LiteLLMProxyModal";
|
||||
import BifrostModal from "@/sections/modals/llmConfig/BifrostModal";
|
||||
|
||||
function detectIfRealOpenAIProvider(provider: LLMProviderView) {
|
||||
return (
|
||||
@@ -56,6 +57,8 @@ export function getModalForExistingProvider(
|
||||
return <LMStudioForm {...props} />;
|
||||
case LLMProviderName.LITELLM_PROXY:
|
||||
return <LiteLLMProxyModal {...props} />;
|
||||
case LLMProviderName.BIFROST:
|
||||
return <BifrostModal {...props} />;
|
||||
default:
|
||||
return <CustomModal {...props} />;
|
||||
}
|
||||
|
||||
246
web/tests/e2e/admin/image-generation/disconnect-provider.spec.ts
Normal file
246
web/tests/e2e/admin/image-generation/disconnect-provider.spec.ts
Normal file
@@ -0,0 +1,246 @@
|
||||
import { test, expect, Page, Locator } from "@playwright/test";
|
||||
import { loginAs } from "@tests/e2e/utils/auth";
|
||||
import { expectElementScreenshot } from "@tests/e2e/utils/visualRegression";
|
||||
|
||||
const IMAGE_GENERATION_URL = "/admin/configuration/image-generation";
|
||||
|
||||
const FAKE_CONNECTED_CONFIG = {
|
||||
image_provider_id: "openai_dalle_3",
|
||||
model_configuration_id: 100,
|
||||
model_name: "dall-e-3",
|
||||
llm_provider_id: 100,
|
||||
llm_provider_name: "openai-dalle3",
|
||||
is_default: false,
|
||||
};
|
||||
|
||||
const FAKE_DEFAULT_CONFIG = {
|
||||
image_provider_id: "openai_gpt_image_1",
|
||||
model_configuration_id: 101,
|
||||
model_name: "gpt-image-1",
|
||||
llm_provider_id: 101,
|
||||
llm_provider_name: "openai-gpt-image-1",
|
||||
is_default: true,
|
||||
};
|
||||
|
||||
function getProviderCard(page: Page, providerId: string): Locator {
|
||||
return page.getByLabel(`image-gen-provider-${providerId}`, { exact: true });
|
||||
}
|
||||
|
||||
function mainContainer(page: Page): Locator {
|
||||
return page.locator("[data-main-container]");
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets up route mocks so the page sees configured providers
|
||||
* without needing real API keys.
|
||||
*/
|
||||
async function mockImageGenApis(
|
||||
page: Page,
|
||||
configs: (typeof FAKE_CONNECTED_CONFIG)[]
|
||||
) {
|
||||
await page.route("**/api/admin/image-generation/config", async (route) => {
|
||||
if (route.request().method() === "GET") {
|
||||
await route.fulfill({ status: 200, json: configs });
|
||||
} else {
|
||||
await route.continue();
|
||||
}
|
||||
});
|
||||
|
||||
await page.route(
|
||||
"**/api/admin/llm/provider?include_image_gen=true",
|
||||
async (route) => {
|
||||
await route.fulfill({ status: 200, json: { providers: [] } });
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
test.describe("Image Generation Provider Disconnect", () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await page.context().clearCookies();
|
||||
await loginAs(page, "admin");
|
||||
});
|
||||
|
||||
test("should disconnect a connected (non-default) provider", async ({
|
||||
page,
|
||||
}) => {
|
||||
const configs = [{ ...FAKE_CONNECTED_CONFIG }, { ...FAKE_DEFAULT_CONFIG }];
|
||||
await mockImageGenApis(page, configs);
|
||||
|
||||
await page.goto(IMAGE_GENERATION_URL);
|
||||
await page.waitForSelector("text=Image Generation Model", {
|
||||
timeout: 20000,
|
||||
});
|
||||
|
||||
const card = getProviderCard(page, "openai_dalle_3");
|
||||
await card.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
await expectElementScreenshot(mainContainer(page), {
|
||||
name: "image-gen-disconnect-non-default-before",
|
||||
});
|
||||
|
||||
// Verify disconnect button exists and is enabled
|
||||
const disconnectButton = card.getByRole("button", {
|
||||
name: "Disconnect DALL-E 3",
|
||||
});
|
||||
await expect(disconnectButton).toBeVisible();
|
||||
await expect(disconnectButton).toBeEnabled();
|
||||
|
||||
// Mock the DELETE to succeed and update the config list
|
||||
await page.route(
|
||||
"**/api/admin/image-generation/config/openai_dalle_3",
|
||||
async (route) => {
|
||||
if (route.request().method() === "DELETE") {
|
||||
// Update the GET mock to return only the default config
|
||||
await page.unroute("**/api/admin/image-generation/config");
|
||||
await page.route(
|
||||
"**/api/admin/image-generation/config",
|
||||
async (route) => {
|
||||
if (route.request().method() === "GET") {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
json: [{ ...FAKE_DEFAULT_CONFIG }],
|
||||
});
|
||||
} else {
|
||||
await route.continue();
|
||||
}
|
||||
}
|
||||
);
|
||||
await route.fulfill({ status: 200, json: {} });
|
||||
} else {
|
||||
await route.continue();
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
// Click disconnect
|
||||
await disconnectButton.click();
|
||||
|
||||
// Verify confirmation modal appears
|
||||
const confirmDialog = page.getByRole("dialog");
|
||||
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
|
||||
await expect(confirmDialog).toContainText("Disconnect DALL-E 3");
|
||||
|
||||
await expectElementScreenshot(confirmDialog, {
|
||||
name: "image-gen-disconnect-non-default-modal",
|
||||
});
|
||||
|
||||
// Click Disconnect in the confirmation modal
|
||||
const confirmButton = confirmDialog.getByRole("button", {
|
||||
name: "Disconnect",
|
||||
});
|
||||
await confirmButton.click();
|
||||
|
||||
// Verify the card reverts to disconnected state (shows "Connect" button)
|
||||
await expect(card.getByRole("button", { name: "Connect" })).toBeVisible({
|
||||
timeout: 10000,
|
||||
});
|
||||
|
||||
await expectElementScreenshot(mainContainer(page), {
|
||||
name: "image-gen-disconnect-non-default-after",
|
||||
});
|
||||
});
|
||||
|
||||
test("should show replacement dropdown when disconnecting default provider with alternatives", async ({
|
||||
page,
|
||||
}) => {
|
||||
const configs = [{ ...FAKE_CONNECTED_CONFIG }, { ...FAKE_DEFAULT_CONFIG }];
|
||||
await mockImageGenApis(page, configs);
|
||||
|
||||
await page.goto(IMAGE_GENERATION_URL);
|
||||
await page.waitForSelector("text=Image Generation Model", {
|
||||
timeout: 20000,
|
||||
});
|
||||
|
||||
const defaultCard = getProviderCard(page, "openai_gpt_image_1");
|
||||
await defaultCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
// The disconnect button should be visible and enabled
|
||||
const disconnectButton = defaultCard.getByRole("button", {
|
||||
name: "Disconnect GPT Image 1",
|
||||
});
|
||||
await expect(disconnectButton).toBeVisible();
|
||||
await expect(disconnectButton).toBeEnabled();
|
||||
|
||||
await disconnectButton.click();
|
||||
|
||||
const confirmDialog = page.getByRole("dialog");
|
||||
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
|
||||
|
||||
// Should show replacement dropdown since there's an alternative
|
||||
await expect(
|
||||
confirmDialog.getByText("Session history will be preserved")
|
||||
).toBeVisible();
|
||||
|
||||
// Disconnect button should be enabled because first replacement is auto-selected
|
||||
const confirmButton = confirmDialog.getByRole("button", {
|
||||
name: "Disconnect",
|
||||
});
|
||||
await expect(confirmButton).toBeEnabled();
|
||||
|
||||
await expectElementScreenshot(confirmDialog, {
|
||||
name: "image-gen-disconnect-default-with-alt-modal",
|
||||
});
|
||||
});
|
||||
|
||||
test("should show connect message when disconnecting default provider with no alternatives", async ({
|
||||
page,
|
||||
}) => {
|
||||
// Only the default config — no other providers configured
|
||||
await mockImageGenApis(page, [{ ...FAKE_DEFAULT_CONFIG }]);
|
||||
|
||||
await page.goto(IMAGE_GENERATION_URL);
|
||||
await page.waitForSelector("text=Image Generation Model", {
|
||||
timeout: 20000,
|
||||
});
|
||||
|
||||
const defaultCard = getProviderCard(page, "openai_gpt_image_1");
|
||||
await defaultCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
const disconnectButton = defaultCard.getByRole("button", {
|
||||
name: "Disconnect GPT Image 1",
|
||||
});
|
||||
await disconnectButton.click();
|
||||
|
||||
const confirmDialog = page.getByRole("dialog");
|
||||
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
|
||||
|
||||
// Should show message about connecting another provider
|
||||
await expect(
|
||||
confirmDialog.getByText("Connect another provider")
|
||||
).toBeVisible();
|
||||
|
||||
// Disconnect button should be enabled
|
||||
const confirmButton = confirmDialog.getByRole("button", {
|
||||
name: "Disconnect",
|
||||
});
|
||||
await expect(confirmButton).toBeEnabled();
|
||||
|
||||
await expectElementScreenshot(confirmDialog, {
|
||||
name: "image-gen-disconnect-no-alt-modal",
|
||||
});
|
||||
});
|
||||
|
||||
test("should not show disconnect button for unconfigured providers", async ({
|
||||
page,
|
||||
}) => {
|
||||
await mockImageGenApis(page, [{ ...FAKE_DEFAULT_CONFIG }]);
|
||||
|
||||
await page.goto(IMAGE_GENERATION_URL);
|
||||
await page.waitForSelector("text=Image Generation Model", {
|
||||
timeout: 20000,
|
||||
});
|
||||
|
||||
// DALL-E 3 is not configured — should not have a disconnect button
|
||||
const card = getProviderCard(page, "openai_dalle_3");
|
||||
await card.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
const disconnectButton = card.getByRole("button", {
|
||||
name: "Disconnect DALL-E 3",
|
||||
});
|
||||
await expect(disconnectButton).not.toBeVisible();
|
||||
|
||||
await expectElementScreenshot(mainContainer(page), {
|
||||
name: "image-gen-disconnect-unconfigured",
|
||||
});
|
||||
});
|
||||
});
|
||||
317
web/tests/e2e/admin/voice/disconnect-provider.spec.ts
Normal file
317
web/tests/e2e/admin/voice/disconnect-provider.spec.ts
Normal file
@@ -0,0 +1,317 @@
|
||||
import { test, expect, Page, Locator } from "@playwright/test";
|
||||
import { loginAs } from "@tests/e2e/utils/auth";
|
||||
import { expectElementScreenshot } from "@tests/e2e/utils/visualRegression";
|
||||
|
||||
const VOICE_URL = "/admin/configuration/voice";
|
||||
|
||||
const FAKE_PROVIDERS = {
|
||||
openai_active_stt: {
|
||||
id: 1,
|
||||
name: "openai",
|
||||
provider_type: "openai",
|
||||
is_default_stt: true,
|
||||
is_default_tts: false,
|
||||
stt_model: "whisper",
|
||||
tts_model: null,
|
||||
default_voice: null,
|
||||
has_api_key: true,
|
||||
target_uri: null,
|
||||
},
|
||||
openai_active_both: {
|
||||
id: 1,
|
||||
name: "openai",
|
||||
provider_type: "openai",
|
||||
is_default_stt: true,
|
||||
is_default_tts: true,
|
||||
stt_model: "whisper",
|
||||
tts_model: "tts-1",
|
||||
default_voice: "alloy",
|
||||
has_api_key: true,
|
||||
target_uri: null,
|
||||
},
|
||||
openai_connected: {
|
||||
id: 1,
|
||||
name: "openai",
|
||||
provider_type: "openai",
|
||||
is_default_stt: false,
|
||||
is_default_tts: false,
|
||||
stt_model: null,
|
||||
tts_model: null,
|
||||
default_voice: null,
|
||||
has_api_key: true,
|
||||
target_uri: null,
|
||||
},
|
||||
elevenlabs_connected: {
|
||||
id: 2,
|
||||
name: "elevenlabs",
|
||||
provider_type: "elevenlabs",
|
||||
is_default_stt: false,
|
||||
is_default_tts: false,
|
||||
stt_model: null,
|
||||
tts_model: null,
|
||||
default_voice: null,
|
||||
has_api_key: true,
|
||||
target_uri: null,
|
||||
},
|
||||
};
|
||||
|
||||
function findModelCard(page: Page, ariaLabel: string): Locator {
|
||||
return page.getByLabel(ariaLabel, { exact: true });
|
||||
}
|
||||
|
||||
function mainContainer(page: Page): Locator {
|
||||
return page.locator("[data-main-container]");
|
||||
}
|
||||
|
||||
async function mockVoiceApis(
|
||||
page: Page,
|
||||
providers: (typeof FAKE_PROVIDERS)[keyof typeof FAKE_PROVIDERS][]
|
||||
) {
|
||||
await page.route("**/api/admin/voice/providers", async (route) => {
|
||||
if (route.request().method() === "GET") {
|
||||
await route.fulfill({ status: 200, json: providers });
|
||||
} else {
|
||||
await route.continue();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
test.describe("Voice Provider Disconnect", () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await page.context().clearCookies();
|
||||
await loginAs(page, "admin");
|
||||
});
|
||||
|
||||
test("should disconnect a non-active provider and affect both STT and TTS cards", async ({
|
||||
page,
|
||||
}) => {
|
||||
const providers = [
|
||||
{ ...FAKE_PROVIDERS.openai_connected },
|
||||
{ ...FAKE_PROVIDERS.elevenlabs_connected },
|
||||
];
|
||||
await mockVoiceApis(page, providers);
|
||||
|
||||
await page.goto(VOICE_URL);
|
||||
await page.waitForSelector("text=Speech to Text", { timeout: 20000 });
|
||||
|
||||
const whisperCard = findModelCard(page, "voice-stt-whisper");
|
||||
await whisperCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
await expectElementScreenshot(mainContainer(page), {
|
||||
name: "voice-disconnect-non-active-before",
|
||||
});
|
||||
|
||||
const disconnectButton = whisperCard.getByRole("button", {
|
||||
name: "Disconnect Whisper",
|
||||
});
|
||||
await expect(disconnectButton).toBeVisible();
|
||||
await expect(disconnectButton).toBeEnabled();
|
||||
|
||||
// Mock DELETE to succeed and remove OpenAI from provider list
|
||||
await page.route("**/api/admin/voice/providers/1", async (route) => {
|
||||
if (route.request().method() === "DELETE") {
|
||||
await page.unroute("**/api/admin/voice/providers");
|
||||
await page.route("**/api/admin/voice/providers", async (route) => {
|
||||
if (route.request().method() === "GET") {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
json: [{ ...FAKE_PROVIDERS.elevenlabs_connected }],
|
||||
});
|
||||
} else {
|
||||
await route.continue();
|
||||
}
|
||||
});
|
||||
await route.fulfill({ status: 200, json: {} });
|
||||
} else {
|
||||
await route.continue();
|
||||
}
|
||||
});
|
||||
|
||||
await disconnectButton.click();
|
||||
|
||||
// Modal shows provider name, not model name
|
||||
const confirmDialog = page.getByRole("dialog");
|
||||
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
|
||||
await expect(confirmDialog).toContainText("Disconnect OpenAI");
|
||||
|
||||
await expectElementScreenshot(confirmDialog, {
|
||||
name: "voice-disconnect-non-active-modal",
|
||||
});
|
||||
|
||||
const confirmButton = confirmDialog.getByRole("button", {
|
||||
name: "Disconnect",
|
||||
});
|
||||
await confirmButton.click();
|
||||
|
||||
// Both STT and TTS cards for OpenAI revert to disconnected
|
||||
await expect(
|
||||
whisperCard.getByRole("button", { name: "Connect" })
|
||||
).toBeVisible({ timeout: 10000 });
|
||||
|
||||
const tts1Card = findModelCard(page, "voice-tts-tts-1");
|
||||
await expect(tts1Card.getByRole("button", { name: "Connect" })).toBeVisible(
|
||||
{ timeout: 10000 }
|
||||
);
|
||||
|
||||
await expectElementScreenshot(mainContainer(page), {
|
||||
name: "voice-disconnect-non-active-after",
|
||||
});
|
||||
});
|
||||
|
||||
test("should show replacement dropdown when disconnecting active provider with alternatives", async ({
|
||||
page,
|
||||
}) => {
|
||||
// OpenAI is active for STT, ElevenLabs is also configured
|
||||
const providers = [
|
||||
{ ...FAKE_PROVIDERS.openai_active_stt },
|
||||
{ ...FAKE_PROVIDERS.elevenlabs_connected },
|
||||
];
|
||||
await mockVoiceApis(page, providers);
|
||||
|
||||
await page.goto(VOICE_URL);
|
||||
await page.waitForSelector("text=Speech to Text", { timeout: 20000 });
|
||||
|
||||
const whisperCard = findModelCard(page, "voice-stt-whisper");
|
||||
await whisperCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
await expectElementScreenshot(mainContainer(page), {
|
||||
name: "voice-disconnect-active-with-alt-before",
|
||||
});
|
||||
|
||||
const disconnectButton = whisperCard.getByRole("button", {
|
||||
name: "Disconnect Whisper",
|
||||
});
|
||||
await disconnectButton.click();
|
||||
|
||||
const confirmDialog = page.getByRole("dialog");
|
||||
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
|
||||
await expect(confirmDialog).toContainText("Disconnect OpenAI");
|
||||
|
||||
// Should show replacement text and dropdown
|
||||
await expect(
|
||||
confirmDialog.getByText("Session history will be preserved")
|
||||
).toBeVisible();
|
||||
|
||||
// Disconnect button should be enabled because first replacement is auto-selected
|
||||
const confirmButton = confirmDialog.getByRole("button", {
|
||||
name: "Disconnect",
|
||||
});
|
||||
await expect(confirmButton).toBeEnabled();
|
||||
|
||||
await expectElementScreenshot(confirmDialog, {
|
||||
name: "voice-disconnect-active-with-alt-modal",
|
||||
});
|
||||
});
|
||||
|
||||
test("should show replacement when provider is default for both STT and TTS", async ({
|
||||
page,
|
||||
}) => {
|
||||
// OpenAI is default for both modes, ElevenLabs also configured
|
||||
const providers = [
|
||||
{ ...FAKE_PROVIDERS.openai_active_both },
|
||||
{ ...FAKE_PROVIDERS.elevenlabs_connected },
|
||||
];
|
||||
await mockVoiceApis(page, providers);
|
||||
|
||||
await page.goto(VOICE_URL);
|
||||
await page.waitForSelector("text=Speech to Text", { timeout: 20000 });
|
||||
|
||||
const whisperCard = findModelCard(page, "voice-stt-whisper");
|
||||
await whisperCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
await expectElementScreenshot(mainContainer(page), {
|
||||
name: "voice-disconnect-both-modes-before",
|
||||
});
|
||||
|
||||
const disconnectButton = whisperCard.getByRole("button", {
|
||||
name: "Disconnect Whisper",
|
||||
});
|
||||
await disconnectButton.click();
|
||||
|
||||
const confirmDialog = page.getByRole("dialog");
|
||||
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
|
||||
await expect(confirmDialog).toContainText("Disconnect OpenAI");
|
||||
|
||||
// Should mention both modes
|
||||
await expect(
|
||||
confirmDialog.getByText("speech-to-text or text-to-speech")
|
||||
).toBeVisible();
|
||||
|
||||
// Should show replacement dropdown
|
||||
await expect(
|
||||
confirmDialog.getByText("Session history will be preserved")
|
||||
).toBeVisible();
|
||||
|
||||
const confirmButton = confirmDialog.getByRole("button", {
|
||||
name: "Disconnect",
|
||||
});
|
||||
await expect(confirmButton).toBeEnabled();
|
||||
|
||||
await expectElementScreenshot(confirmDialog, {
|
||||
name: "voice-disconnect-both-modes-modal",
|
||||
});
|
||||
});
|
||||
|
||||
test("should show connect message when disconnecting active provider with no alternatives", async ({
|
||||
page,
|
||||
}) => {
|
||||
// Only OpenAI configured, active for STT — no other providers
|
||||
const providers = [{ ...FAKE_PROVIDERS.openai_active_stt }];
|
||||
await mockVoiceApis(page, providers);
|
||||
|
||||
await page.goto(VOICE_URL);
|
||||
await page.waitForSelector("text=Speech to Text", { timeout: 20000 });
|
||||
|
||||
const whisperCard = findModelCard(page, "voice-stt-whisper");
|
||||
await whisperCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
await expectElementScreenshot(mainContainer(page), {
|
||||
name: "voice-disconnect-no-alt-before",
|
||||
});
|
||||
|
||||
const disconnectButton = whisperCard.getByRole("button", {
|
||||
name: "Disconnect Whisper",
|
||||
});
|
||||
await disconnectButton.click();
|
||||
|
||||
const confirmDialog = page.getByRole("dialog");
|
||||
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
|
||||
await expect(confirmDialog).toContainText("Disconnect OpenAI");
|
||||
|
||||
// Should show message about connecting another provider
|
||||
await expect(
|
||||
confirmDialog.getByText("Connect another provider")
|
||||
).toBeVisible();
|
||||
|
||||
// Disconnect button should be enabled
|
||||
const confirmButton = confirmDialog.getByRole("button", {
|
||||
name: "Disconnect",
|
||||
});
|
||||
await expect(confirmButton).toBeEnabled();
|
||||
|
||||
await expectElementScreenshot(confirmDialog, {
|
||||
name: "voice-disconnect-no-alt-modal",
|
||||
});
|
||||
});
|
||||
|
||||
test("should not show disconnect button for unconfigured provider", async ({
|
||||
page,
|
||||
}) => {
|
||||
await mockVoiceApis(page, []);
|
||||
|
||||
await page.goto(VOICE_URL);
|
||||
await page.waitForSelector("text=Speech to Text", { timeout: 20000 });
|
||||
|
||||
const whisperCard = findModelCard(page, "voice-stt-whisper");
|
||||
await whisperCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
const disconnectButton = whisperCard.getByRole("button", {
|
||||
name: "Disconnect Whisper",
|
||||
});
|
||||
await expect(disconnectButton).not.toBeVisible();
|
||||
|
||||
await expectElementScreenshot(mainContainer(page), {
|
||||
name: "voice-disconnect-unconfigured",
|
||||
});
|
||||
});
|
||||
});
|
||||
394
web/tests/e2e/admin/web-search/disconnect-provider.spec.ts
Normal file
394
web/tests/e2e/admin/web-search/disconnect-provider.spec.ts
Normal file
@@ -0,0 +1,394 @@
|
||||
import { test, expect, Page, Locator } from "@playwright/test";
|
||||
import { loginAs } from "@tests/e2e/utils/auth";
|
||||
import { expectElementScreenshot } from "@tests/e2e/utils/visualRegression";
|
||||
|
||||
const WEB_SEARCH_URL = "/admin/configuration/web-search";
|
||||
|
||||
const FAKE_SEARCH_PROVIDERS = {
|
||||
exa: {
|
||||
id: 1,
|
||||
name: "Exa",
|
||||
provider_type: "exa",
|
||||
is_active: true,
|
||||
config: null,
|
||||
has_api_key: true,
|
||||
},
|
||||
brave: {
|
||||
id: 2,
|
||||
name: "Brave",
|
||||
provider_type: "brave",
|
||||
is_active: false,
|
||||
config: null,
|
||||
has_api_key: true,
|
||||
},
|
||||
};
|
||||
|
||||
const FAKE_CONTENT_PROVIDERS = {
|
||||
firecrawl: {
|
||||
id: 10,
|
||||
name: "Firecrawl",
|
||||
provider_type: "firecrawl",
|
||||
is_active: true,
|
||||
config: { base_url: "https://api.firecrawl.dev/v2/scrape" },
|
||||
has_api_key: true,
|
||||
},
|
||||
exa: {
|
||||
id: 11,
|
||||
name: "Exa",
|
||||
provider_type: "exa",
|
||||
is_active: false,
|
||||
config: null,
|
||||
has_api_key: true,
|
||||
},
|
||||
};
|
||||
|
||||
function findProviderCard(page: Page, providerLabel: string): Locator {
|
||||
return page
|
||||
.locator("div.rounded-16")
|
||||
.filter({ hasText: providerLabel })
|
||||
.first();
|
||||
}
|
||||
|
||||
function mainContainer(page: Page): Locator {
|
||||
return page.locator("[data-main-container]");
|
||||
}
|
||||
|
||||
async function mockWebSearchApis(
|
||||
page: Page,
|
||||
searchProviders: (typeof FAKE_SEARCH_PROVIDERS)[keyof typeof FAKE_SEARCH_PROVIDERS][],
|
||||
contentProviders: (typeof FAKE_CONTENT_PROVIDERS)[keyof typeof FAKE_CONTENT_PROVIDERS][]
|
||||
) {
|
||||
await page.route(
|
||||
"**/api/admin/web-search/search-providers",
|
||||
async (route) => {
|
||||
if (route.request().method() === "GET") {
|
||||
await route.fulfill({ status: 200, json: searchProviders });
|
||||
} else {
|
||||
await route.continue();
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
await page.route(
|
||||
"**/api/admin/web-search/content-providers",
|
||||
async (route) => {
|
||||
if (route.request().method() === "GET") {
|
||||
await route.fulfill({ status: 200, json: contentProviders });
|
||||
} else {
|
||||
await route.continue();
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
test.describe("Web Search Provider Disconnect", () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await page.context().clearCookies();
|
||||
await loginAs(page, "admin");
|
||||
});
|
||||
|
||||
test.describe("Search Engine Providers", () => {
|
||||
test("should disconnect a connected (non-active) search provider", async ({
|
||||
page,
|
||||
}) => {
|
||||
const searchProviders = [
|
||||
{ ...FAKE_SEARCH_PROVIDERS.exa },
|
||||
{ ...FAKE_SEARCH_PROVIDERS.brave },
|
||||
];
|
||||
await mockWebSearchApis(page, searchProviders, []);
|
||||
|
||||
await page.goto(WEB_SEARCH_URL);
|
||||
await page.waitForSelector("text=Search Engine", { timeout: 20000 });
|
||||
|
||||
const braveCard = findProviderCard(page, "Brave");
|
||||
await braveCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
await expectElementScreenshot(mainContainer(page), {
|
||||
name: "web-search-disconnect-non-active-before",
|
||||
});
|
||||
|
||||
const disconnectButton = braveCard.getByRole("button", {
|
||||
name: "Disconnect Brave",
|
||||
});
|
||||
await expect(disconnectButton).toBeVisible();
|
||||
await expect(disconnectButton).toBeEnabled();
|
||||
|
||||
// Mock the DELETE to succeed
|
||||
await page.route(
|
||||
"**/api/admin/web-search/search-providers/2",
|
||||
async (route) => {
|
||||
if (route.request().method() === "DELETE") {
|
||||
await page.unroute("**/api/admin/web-search/search-providers");
|
||||
await page.route(
|
||||
"**/api/admin/web-search/search-providers",
|
||||
async (route) => {
|
||||
if (route.request().method() === "GET") {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
json: [{ ...FAKE_SEARCH_PROVIDERS.exa }],
|
||||
});
|
||||
} else {
|
||||
await route.continue();
|
||||
}
|
||||
}
|
||||
);
|
||||
await route.fulfill({ status: 200, json: {} });
|
||||
} else {
|
||||
await route.continue();
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
await disconnectButton.click();
|
||||
|
||||
const confirmDialog = page.getByRole("dialog");
|
||||
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
|
||||
await expect(confirmDialog).toContainText("Disconnect Brave");
|
||||
|
||||
await expectElementScreenshot(confirmDialog, {
|
||||
name: "web-search-disconnect-non-active-modal",
|
||||
});
|
||||
|
||||
const confirmButton = confirmDialog.getByRole("button", {
|
||||
name: "Disconnect",
|
||||
});
|
||||
await confirmButton.click();
|
||||
|
||||
await expect(
|
||||
braveCard.getByRole("button", { name: "Connect" })
|
||||
).toBeVisible({ timeout: 10000 });
|
||||
|
||||
await expectElementScreenshot(mainContainer(page), {
|
||||
name: "web-search-disconnect-non-active-after",
|
||||
});
|
||||
});
|
||||
|
||||
test("should show replacement dropdown when disconnecting active search provider with alternatives", async ({
|
||||
page,
|
||||
}) => {
|
||||
// Exa is active, Brave is also configured
|
||||
const searchProviders = [
|
||||
{ ...FAKE_SEARCH_PROVIDERS.exa },
|
||||
{ ...FAKE_SEARCH_PROVIDERS.brave },
|
||||
];
|
||||
await mockWebSearchApis(page, searchProviders, []);
|
||||
|
||||
await page.goto(WEB_SEARCH_URL);
|
||||
await page.waitForSelector("text=Search Engine", { timeout: 20000 });
|
||||
|
||||
const exaCard = findProviderCard(page, "Exa");
|
||||
await exaCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
const disconnectButton = exaCard.getByRole("button", {
|
||||
name: "Disconnect Exa",
|
||||
});
|
||||
await expect(disconnectButton).toBeVisible();
|
||||
await expect(disconnectButton).toBeEnabled();
|
||||
|
||||
await disconnectButton.click();
|
||||
|
||||
const confirmDialog = page.getByRole("dialog");
|
||||
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
|
||||
await expect(confirmDialog).toContainText("Disconnect Exa");
|
||||
|
||||
// Should show replacement dropdown
|
||||
await expect(
|
||||
confirmDialog.getByText("Search history will be preserved")
|
||||
).toBeVisible();
|
||||
|
||||
// Disconnect button should be enabled because first replacement is auto-selected
|
||||
const confirmButton = confirmDialog.getByRole("button", {
|
||||
name: "Disconnect",
|
||||
});
|
||||
await expect(confirmButton).toBeEnabled();
|
||||
|
||||
await expectElementScreenshot(confirmDialog, {
|
||||
name: "web-search-disconnect-active-with-alt-modal",
|
||||
});
|
||||
});
|
||||
|
||||
test("should show connect message when disconnecting active search provider with no alternatives", async ({
|
||||
page,
|
||||
}) => {
|
||||
// Only Exa configured and active
|
||||
await mockWebSearchApis(page, [{ ...FAKE_SEARCH_PROVIDERS.exa }], []);
|
||||
|
||||
await page.goto(WEB_SEARCH_URL);
|
||||
await page.waitForSelector("text=Search Engine", { timeout: 20000 });
|
||||
|
||||
const exaCard = findProviderCard(page, "Exa");
|
||||
await exaCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
const disconnectButton = exaCard.getByRole("button", {
|
||||
name: "Disconnect Exa",
|
||||
});
|
||||
await disconnectButton.click();
|
||||
|
||||
const confirmDialog = page.getByRole("dialog");
|
||||
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
|
||||
|
||||
// Should show message about connecting another provider
|
||||
await expect(
|
||||
confirmDialog.getByText("Connect another provider")
|
||||
).toBeVisible();
|
||||
|
||||
// Disconnect button should be enabled
|
||||
const confirmButton = confirmDialog.getByRole("button", {
|
||||
name: "Disconnect",
|
||||
});
|
||||
await expect(confirmButton).toBeEnabled();
|
||||
|
||||
await expectElementScreenshot(confirmDialog, {
|
||||
name: "web-search-disconnect-no-alt-modal",
|
||||
});
|
||||
});
|
||||
|
||||
test("should not show disconnect button for unconfigured search provider", async ({
|
||||
page,
|
||||
}) => {
|
||||
await mockWebSearchApis(page, [{ ...FAKE_SEARCH_PROVIDERS.exa }], []);
|
||||
|
||||
await page.goto(WEB_SEARCH_URL);
|
||||
await page.waitForSelector("text=Search Engine", { timeout: 20000 });
|
||||
|
||||
const braveCard = findProviderCard(page, "Brave");
|
||||
await braveCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
const disconnectButton = braveCard.getByRole("button", {
|
||||
name: "Disconnect Brave",
|
||||
});
|
||||
await expect(disconnectButton).not.toBeVisible();
|
||||
|
||||
await expectElementScreenshot(mainContainer(page), {
|
||||
name: "web-search-disconnect-unconfigured",
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
test.describe("Web Crawler (Content) Providers", () => {
|
||||
test("should disconnect a connected (non-active) content provider", async ({
|
||||
page,
|
||||
}) => {
|
||||
// Firecrawl connected but not active, Exa is active
|
||||
const contentProviders = [
|
||||
{ ...FAKE_CONTENT_PROVIDERS.firecrawl, is_active: false },
|
||||
{ ...FAKE_CONTENT_PROVIDERS.exa, is_active: true },
|
||||
];
|
||||
await mockWebSearchApis(page, [], contentProviders);
|
||||
|
||||
await page.goto(WEB_SEARCH_URL);
|
||||
await page.waitForSelector("text=Web Crawler", { timeout: 20000 });
|
||||
|
||||
const firecrawlCard = findProviderCard(page, "Firecrawl");
|
||||
await firecrawlCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
const disconnectButton = firecrawlCard.getByRole("button", {
|
||||
name: "Disconnect Firecrawl",
|
||||
});
|
||||
await expect(disconnectButton).toBeVisible();
|
||||
await expect(disconnectButton).toBeEnabled();
|
||||
|
||||
// Mock the DELETE to succeed
|
||||
await page.route(
|
||||
"**/api/admin/web-search/content-providers/10",
|
||||
async (route) => {
|
||||
if (route.request().method() === "DELETE") {
|
||||
await page.unroute("**/api/admin/web-search/content-providers");
|
||||
await page.route(
|
||||
"**/api/admin/web-search/content-providers",
|
||||
async (route) => {
|
||||
if (route.request().method() === "GET") {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
json: [{ ...FAKE_CONTENT_PROVIDERS.exa, is_active: true }],
|
||||
});
|
||||
} else {
|
||||
await route.continue();
|
||||
}
|
||||
}
|
||||
);
|
||||
await route.fulfill({ status: 200, json: {} });
|
||||
} else {
|
||||
await route.continue();
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
await disconnectButton.click();
|
||||
|
||||
const confirmDialog = page.getByRole("dialog");
|
||||
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
|
||||
await expect(confirmDialog).toContainText("Disconnect Firecrawl");
|
||||
|
||||
await expectElementScreenshot(confirmDialog, {
|
||||
name: "web-search-disconnect-content-non-active-modal",
|
||||
});
|
||||
|
||||
const confirmButton = confirmDialog.getByRole("button", {
|
||||
name: "Disconnect",
|
||||
});
|
||||
await confirmButton.click();
|
||||
|
||||
await expect(
|
||||
firecrawlCard.getByRole("button", { name: "Connect" })
|
||||
).toBeVisible({ timeout: 10000 });
|
||||
});
|
||||
|
||||
test("should show replacement dropdown when disconnecting active content provider with alternatives", async ({
|
||||
page,
|
||||
}) => {
|
||||
// Firecrawl is active, Exa is also configured
|
||||
const contentProviders = [
|
||||
{ ...FAKE_CONTENT_PROVIDERS.firecrawl },
|
||||
{ ...FAKE_CONTENT_PROVIDERS.exa },
|
||||
];
|
||||
await mockWebSearchApis(page, [], contentProviders);
|
||||
|
||||
await page.goto(WEB_SEARCH_URL);
|
||||
await page.waitForSelector("text=Web Crawler", { timeout: 20000 });
|
||||
|
||||
const firecrawlCard = findProviderCard(page, "Firecrawl");
|
||||
await firecrawlCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
const disconnectButton = firecrawlCard.getByRole("button", {
|
||||
name: "Disconnect Firecrawl",
|
||||
});
|
||||
await disconnectButton.click();
|
||||
|
||||
const confirmDialog = page.getByRole("dialog");
|
||||
await expect(confirmDialog).toBeVisible({ timeout: 5000 });
|
||||
|
||||
// Should show replacement dropdown
|
||||
await expect(
|
||||
confirmDialog.getByText("Search history will be preserved")
|
||||
).toBeVisible();
|
||||
|
||||
// Disconnect should be enabled because first replacement is auto-selected
|
||||
const confirmButton = confirmDialog.getByRole("button", {
|
||||
name: "Disconnect",
|
||||
});
|
||||
await expect(confirmButton).toBeEnabled();
|
||||
|
||||
await expectElementScreenshot(confirmDialog, {
|
||||
name: "web-search-disconnect-content-active-with-alt-modal",
|
||||
});
|
||||
});
|
||||
|
||||
test("should not show disconnect for Onyx Web Crawler (built-in)", async ({
|
||||
page,
|
||||
}) => {
|
||||
await mockWebSearchApis(page, [], []);
|
||||
|
||||
await page.goto(WEB_SEARCH_URL);
|
||||
await page.waitForSelector("text=Web Crawler", { timeout: 20000 });
|
||||
|
||||
const onyxCard = findProviderCard(page, "Onyx Web Crawler");
|
||||
await onyxCard.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
const disconnectButton = onyxCard.getByRole("button", {
|
||||
name: "Disconnect Onyx Web Crawler",
|
||||
});
|
||||
await expect(disconnectButton).not.toBeVisible();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -4,6 +4,14 @@ import { expectScreenshot } from "@tests/e2e/utils/visualRegression";
|
||||
|
||||
test.use({ storageState: "admin_auth.json" });
|
||||
|
||||
/** Maps each settings slug to the header title shown on that page. */
|
||||
const SLUG_TO_HEADER: Record<string, string> = {
|
||||
general: "Profile",
|
||||
"chat-preferences": "Chats",
|
||||
"accounts-access": "Accounts",
|
||||
connectors: "Connectors",
|
||||
};
|
||||
|
||||
for (const theme of THEMES) {
|
||||
test.describe(`Settings pages (${theme} mode)`, () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
@@ -11,21 +19,33 @@ for (const theme of THEMES) {
|
||||
});
|
||||
|
||||
test("should screenshot each settings tab", async ({ page }) => {
|
||||
await page.goto("/app/settings");
|
||||
await page.waitForLoadState("networkidle");
|
||||
await page.goto("/app/settings/general");
|
||||
await page
|
||||
.getByTestId("settings-left-tab-navigation")
|
||||
.waitFor({ state: "visible" });
|
||||
|
||||
const nav = page.getByTestId("settings-left-tab-navigation");
|
||||
const tabs = nav.locator("a");
|
||||
await expect(tabs.first()).toBeVisible({ timeout: 10_000 });
|
||||
const count = await tabs.count();
|
||||
|
||||
expect(count).toBeGreaterThan(0);
|
||||
for (let i = 0; i < count; i++) {
|
||||
const tab = tabs.nth(i);
|
||||
const href = await tab.getAttribute("href");
|
||||
const slug = href ? href.replace("/app/settings/", "") : `tab-${i}`;
|
||||
|
||||
await tab.click();
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
const expectedHeader = SLUG_TO_HEADER[slug];
|
||||
if (expectedHeader) {
|
||||
await expect(
|
||||
page
|
||||
.locator(".opal-content-md-header")
|
||||
.filter({ hasText: expectedHeader })
|
||||
).toBeVisible({ timeout: 10_000 });
|
||||
} else {
|
||||
await page.waitForLoadState("networkidle");
|
||||
}
|
||||
|
||||
await expectScreenshot(page, {
|
||||
name: `settings-${theme}-${slug}`,
|
||||
|
||||
Reference in New Issue
Block a user