mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-09 00:42:47 +00:00
Compare commits
18 Commits
feat/resol
...
v3.1.3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6d2bd97412 | ||
|
|
3d48b6a63e | ||
|
|
2a7b7c9187 | ||
|
|
c348d1855d | ||
|
|
b4579a1365 | ||
|
|
893c094aed | ||
|
|
f8a55712d2 | ||
|
|
591afd4fb1 | ||
|
|
9328070dc0 | ||
|
|
6163521126 | ||
|
|
d42c5616b0 | ||
|
|
aeb4fdd6c1 | ||
|
|
c673959714 | ||
|
|
cb36562802 | ||
|
|
efc424bf3e | ||
|
|
e0baaf85e5 | ||
|
|
a0ffd47e2c | ||
|
|
d0396a1337 |
@@ -28,6 +28,7 @@ from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ElementExternalAccess
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
@@ -187,7 +188,6 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
|
||||
@@ -227,6 +227,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
validate_permission_sync_fences(
|
||||
tenant_id, r, r_replica, r_celery, lock_beat
|
||||
)
|
||||
|
||||
@@ -29,6 +29,7 @@ from ee.onyx.external_permissions.sync_params import (
|
||||
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
from onyx.background.error_logging import emit_background_error
|
||||
@@ -162,7 +163,6 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
|
||||
@@ -221,6 +221,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
validate_external_group_sync_fences(
|
||||
tenant_id, self.app, r, r_replica, r_celery, lock_beat
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# These are helper objects for tracking the keys we need to write in redis
|
||||
import json
|
||||
import threading
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -7,7 +8,59 @@ from celery import Celery
|
||||
from redis import Redis
|
||||
|
||||
from onyx.background.celery.configs.base import CELERY_SEPARATOR
|
||||
from onyx.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
|
||||
|
||||
|
||||
_broker_client: Redis | None = None
|
||||
_broker_url: str | None = None
|
||||
_broker_client_lock = threading.Lock()
|
||||
|
||||
|
||||
def celery_get_broker_client(app: Celery) -> Redis:
|
||||
"""Return a shared Redis client connected to the Celery broker DB.
|
||||
|
||||
Uses a module-level singleton so all tasks on a worker share one
|
||||
connection instead of creating a new one per call. The client
|
||||
connects directly to the broker Redis DB (parsed from the broker URL).
|
||||
|
||||
Thread-safe via lock — safe for use in Celery thread-pool workers.
|
||||
|
||||
Usage:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
length = celery_get_queue_length(queue, r_celery)
|
||||
"""
|
||||
global _broker_client, _broker_url
|
||||
with _broker_client_lock:
|
||||
url = app.conf.broker_url
|
||||
if _broker_client is not None and _broker_url == url:
|
||||
try:
|
||||
_broker_client.ping()
|
||||
return _broker_client
|
||||
except Exception:
|
||||
try:
|
||||
_broker_client.close()
|
||||
except Exception:
|
||||
pass
|
||||
_broker_client = None
|
||||
elif _broker_client is not None:
|
||||
try:
|
||||
_broker_client.close()
|
||||
except Exception:
|
||||
pass
|
||||
_broker_client = None
|
||||
|
||||
_broker_url = url
|
||||
_broker_client = Redis.from_url(
|
||||
url,
|
||||
decode_responses=False,
|
||||
health_check_interval=REDIS_HEALTH_CHECK_INTERVAL,
|
||||
socket_keepalive=True,
|
||||
socket_keepalive_options=REDIS_SOCKET_KEEPALIVE_OPTIONS,
|
||||
retry_on_timeout=True,
|
||||
)
|
||||
return _broker_client
|
||||
|
||||
|
||||
def celery_get_unacked_length(r: Redis) -> int:
|
||||
|
||||
@@ -14,6 +14,7 @@ from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
@@ -132,7 +133,6 @@ def revoke_tasks_blocking_deletion(
|
||||
def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
@@ -149,6 +149,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | N
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES):
|
||||
# clear fences that don't have associated celery tasks in progress
|
||||
try:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
validate_connector_deletion_fences(
|
||||
tenant_id, r, r_replica, r_celery, lock_beat
|
||||
)
|
||||
|
||||
@@ -22,6 +22,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
@@ -449,7 +450,7 @@ def check_indexing_completion(
|
||||
):
|
||||
# Check if the task exists in the celery queue
|
||||
# This handles the case where Redis dies after task creation but before task execution
|
||||
redis_celery = task.app.broker_connection().channel().client # type: ignore
|
||||
redis_celery = celery_get_broker_client(task.app)
|
||||
task_exists = celery_find_task(
|
||||
attempt.celery_task_id,
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from itertools import islice
|
||||
from typing import Any
|
||||
@@ -19,6 +18,7 @@ from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
@@ -698,31 +698,27 @@ def monitor_background_processes(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Get Redis client for Celery broker
|
||||
redis_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
redis_std = get_redis_client()
|
||||
|
||||
# Define metric collection functions and their dependencies
|
||||
metric_functions: list[Callable[[], list[Metric]]] = [
|
||||
lambda: _collect_queue_metrics(redis_celery),
|
||||
lambda: _collect_connector_metrics(db_session, redis_std),
|
||||
lambda: _collect_sync_metrics(db_session, redis_std),
|
||||
]
|
||||
# Collect queue metrics with broker connection
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
queue_metrics = _collect_queue_metrics(r_celery)
|
||||
|
||||
# Collect and log each metric
|
||||
# Collect remaining metrics (no broker connection needed)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for metric_fn in metric_functions:
|
||||
metrics = metric_fn()
|
||||
for metric in metrics:
|
||||
# double check to make sure we aren't double-emitting metrics
|
||||
if metric.key is None or not _has_metric_been_emitted(
|
||||
redis_std, metric.key
|
||||
):
|
||||
metric.log()
|
||||
metric.emit(tenant_id)
|
||||
all_metrics: list[Metric] = queue_metrics
|
||||
all_metrics.extend(_collect_connector_metrics(db_session, redis_std))
|
||||
all_metrics.extend(_collect_sync_metrics(db_session, redis_std))
|
||||
|
||||
if metric.key is not None:
|
||||
_mark_metric_as_emitted(redis_std, metric.key)
|
||||
for metric in all_metrics:
|
||||
if metric.key is None or not _has_metric_been_emitted(
|
||||
redis_std, metric.key
|
||||
):
|
||||
metric.log()
|
||||
metric.emit(tenant_id)
|
||||
|
||||
if metric.key is not None:
|
||||
_mark_metric_as_emitted(redis_std, metric.key)
|
||||
|
||||
task_logger.info("Successfully collected background metrics")
|
||||
except SoftTimeLimitExceeded:
|
||||
@@ -890,7 +886,7 @@ def monitor_celery_queues_helper(
|
||||
) -> None:
|
||||
"""A task to monitor all celery queue lengths."""
|
||||
|
||||
r_celery = task.app.broker_connection().channel().client # type: ignore
|
||||
r_celery = celery_get_broker_client(task.app)
|
||||
n_celery = celery_get_queue_length(OnyxCeleryQueues.PRIMARY, r_celery)
|
||||
n_docfetching = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
|
||||
@@ -1080,7 +1076,7 @@ def cloud_monitor_celery_pidbox(
|
||||
num_deleted = 0
|
||||
|
||||
MAX_PIDBOX_IDLE = 24 * 3600 # 1 day in seconds
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
for key in r_celery.scan_iter("*.reply.celery.pidbox"):
|
||||
key_bytes = cast(bytes, key)
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
|
||||
@@ -17,6 +17,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
@@ -203,7 +204,6 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
|
||||
@@ -261,6 +261,7 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
validate_pruning_fences(tenant_id, r, r_replica, r_celery, lock_beat)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while validating pruning fences")
|
||||
|
||||
@@ -16,6 +16,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.access import build_access_for_user_files
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
@@ -105,7 +106,7 @@ def _user_file_delete_queued_key(user_file_id: str | UUID) -> str:
|
||||
|
||||
|
||||
def get_user_file_project_sync_queue_depth(celery_app: Celery) -> int:
|
||||
redis_celery: Redis = celery_app.broker_connection().channel().client # type: ignore
|
||||
redis_celery = celery_get_broker_client(celery_app)
|
||||
return celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, redis_celery
|
||||
)
|
||||
@@ -238,7 +239,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
skipped_guard = 0
|
||||
try:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
queue_len = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
|
||||
)
|
||||
@@ -591,7 +592,7 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
# NOTE: must use the broker's Redis client (not redis_client) because
|
||||
# Celery queues live on a separate Redis DB with CELERY_SEPARATOR keys.
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
r_celery = celery_get_broker_client(self.app)
|
||||
queue_len = celery_get_queue_length(OnyxCeleryQueues.USER_FILE_DELETE, r_celery)
|
||||
if queue_len > USER_FILE_DELETE_MAX_QUEUE_DEPTH:
|
||||
task_logger.warning(
|
||||
|
||||
@@ -12,6 +12,11 @@ SLACK_USER_TOKEN_PREFIX = "xoxp-"
|
||||
SLACK_BOT_TOKEN_PREFIX = "xoxb-"
|
||||
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
|
||||
|
||||
# The mask_string() function in encryption.py uses "•" (U+2022 BULLET) to mask secrets.
|
||||
MASK_CREDENTIAL_CHAR = "\u2022"
|
||||
# Pattern produced by mask_string for strings >= 14 chars: "abcd...wxyz" (exactly 11 chars)
|
||||
MASK_CREDENTIAL_LONG_RE = re.compile(r"^.{4}\.{3}.{4}$")
|
||||
|
||||
SOURCE_TYPE = "source_type"
|
||||
# stored in the `metadata` of a chunk. Used to signify that this chunk should
|
||||
# not be used for QA. For example, Google Drive file types which can't be parsed
|
||||
|
||||
@@ -4,7 +4,6 @@ from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.api_key import ApiKeyDescriptor
|
||||
@@ -55,7 +54,6 @@ async def fetch_user_for_api_key(
|
||||
select(User)
|
||||
.join(ApiKey, ApiKey.user_id == User.id)
|
||||
.where(ApiKey.hashed_api_key == hashed_api_key)
|
||||
.options(selectinload(User.memories))
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from sqlalchemy import func
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
@@ -98,11 +97,6 @@ async def get_user_count(only_admin_users: bool = False) -> int:
|
||||
|
||||
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
|
||||
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase[UP, ID]):
|
||||
async def _get_user(self, statement: Select) -> UP | None:
|
||||
statement = statement.options(selectinload(User.memories))
|
||||
results = await self.session.execute(statement)
|
||||
return results.unique().scalar_one_or_none()
|
||||
|
||||
async def create(
|
||||
self,
|
||||
create_dict: Dict[str, Any],
|
||||
|
||||
@@ -8,7 +8,6 @@ from uuid import UUID
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import exists
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import nullsfirst
|
||||
from sqlalchemy import or_
|
||||
@@ -132,32 +131,47 @@ def get_chat_sessions_by_user(
|
||||
if before is not None:
|
||||
stmt = stmt.where(ChatSession.time_updated < before)
|
||||
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
if project_id is not None:
|
||||
stmt = stmt.where(ChatSession.project_id == project_id)
|
||||
elif only_non_project_chats:
|
||||
stmt = stmt.where(ChatSession.project_id.is_(None))
|
||||
|
||||
if not include_failed_chats:
|
||||
non_system_message_exists_subq = (
|
||||
exists()
|
||||
.where(ChatMessage.chat_session_id == ChatSession.id)
|
||||
.where(ChatMessage.message_type != MessageType.SYSTEM)
|
||||
.correlate(ChatSession)
|
||||
)
|
||||
|
||||
# Leeway for newly created chats that don't have messages yet
|
||||
time = datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||
recently_created = ChatSession.time_created >= time
|
||||
|
||||
stmt = stmt.where(or_(non_system_message_exists_subq, recently_created))
|
||||
# When filtering out failed chats, we apply the limit in Python after
|
||||
# filtering rather than in SQL, since the post-filter may remove rows.
|
||||
if limit and include_failed_chats:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
chat_sessions = result.scalars().all()
|
||||
chat_sessions = list(result.scalars().all())
|
||||
|
||||
return list(chat_sessions)
|
||||
if not include_failed_chats and chat_sessions:
|
||||
# Filter out "failed" sessions (those with only SYSTEM messages)
|
||||
# using a separate efficient query instead of a correlated EXISTS
|
||||
# subquery, which causes full sequential scans of chat_message.
|
||||
leeway = datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||
session_ids = [cs.id for cs in chat_sessions if cs.time_created < leeway]
|
||||
|
||||
if session_ids:
|
||||
valid_session_ids_stmt = (
|
||||
select(ChatMessage.chat_session_id)
|
||||
.where(ChatMessage.chat_session_id.in_(session_ids))
|
||||
.where(ChatMessage.message_type != MessageType.SYSTEM)
|
||||
.distinct()
|
||||
)
|
||||
valid_session_ids = set(
|
||||
db_session.execute(valid_session_ids_stmt).scalars().all()
|
||||
)
|
||||
|
||||
chat_sessions = [
|
||||
cs
|
||||
for cs in chat_sessions
|
||||
if cs.time_created >= leeway or cs.id in valid_session_ids
|
||||
]
|
||||
|
||||
if limit:
|
||||
chat_sessions = chat_sessions[:limit]
|
||||
|
||||
return chat_sessions
|
||||
|
||||
|
||||
def delete_orphaned_search_docs(db_session: Session) -> None:
|
||||
|
||||
@@ -8,6 +8,8 @@ from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import FederatedConnectorSource
|
||||
from onyx.configs.constants import MASK_CREDENTIAL_CHAR
|
||||
from onyx.configs.constants import MASK_CREDENTIAL_LONG_RE
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import FederatedConnector
|
||||
@@ -45,6 +47,23 @@ def fetch_all_federated_connectors_parallel() -> list[FederatedConnector]:
|
||||
return fetch_all_federated_connectors(db_session)
|
||||
|
||||
|
||||
def _reject_masked_credentials(credentials: dict[str, Any]) -> None:
|
||||
"""Raise if any credential string value contains mask placeholder characters.
|
||||
|
||||
mask_string() has two output formats:
|
||||
- Short strings (< 14 chars): "••••••••••••" (U+2022 BULLET)
|
||||
- Long strings (>= 14 chars): "abcd...wxyz" (first4 + "..." + last4)
|
||||
Both must be rejected.
|
||||
"""
|
||||
for key, val in credentials.items():
|
||||
if isinstance(val, str) and (
|
||||
MASK_CREDENTIAL_CHAR in val or MASK_CREDENTIAL_LONG_RE.match(val)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Credential field '{key}' contains masked placeholder characters. Please provide the actual credential value."
|
||||
)
|
||||
|
||||
|
||||
def validate_federated_connector_credentials(
|
||||
source: FederatedConnectorSource,
|
||||
credentials: dict[str, Any],
|
||||
@@ -66,6 +85,8 @@ def create_federated_connector(
|
||||
config: dict[str, Any] | None = None,
|
||||
) -> FederatedConnector:
|
||||
"""Create a new federated connector with credential and config validation."""
|
||||
_reject_masked_credentials(credentials)
|
||||
|
||||
# Validate credentials before creating
|
||||
if not validate_federated_connector_credentials(source, credentials):
|
||||
raise ValueError(
|
||||
@@ -277,6 +298,8 @@ def update_federated_connector(
|
||||
)
|
||||
|
||||
if credentials is not None:
|
||||
_reject_masked_credentials(credentials)
|
||||
|
||||
# Validate credentials before updating
|
||||
if not validate_federated_connector_credentials(
|
||||
federated_connector.source, credentials
|
||||
|
||||
@@ -8,7 +8,6 @@ from uuid import UUID
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.pat import build_displayable_pat
|
||||
@@ -47,7 +46,6 @@ async def fetch_user_for_pat(
|
||||
(PersonalAccessToken.expires_at.is_(None))
|
||||
| (PersonalAccessToken.expires_at > now)
|
||||
)
|
||||
.options(selectinload(User.memories))
|
||||
)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
@@ -229,7 +229,9 @@ def get_memories_for_user(
|
||||
user_id: UUID,
|
||||
db_session: Session,
|
||||
) -> Sequence[Memory]:
|
||||
return db_session.scalars(select(Memory).where(Memory.user_id == user_id)).all()
|
||||
return db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.desc())
|
||||
).all()
|
||||
|
||||
|
||||
def update_user_pinned_assistants(
|
||||
|
||||
@@ -49,9 +49,21 @@ KNOWN_OPENPYXL_BUGS = [
|
||||
|
||||
def get_markitdown_converter() -> "MarkItDown":
|
||||
global _MARKITDOWN_CONVERTER
|
||||
from markitdown import MarkItDown
|
||||
|
||||
if _MARKITDOWN_CONVERTER is None:
|
||||
from markitdown import MarkItDown
|
||||
|
||||
# Patch this function to effectively no-op because we were seeing this
|
||||
# module take an inordinate amount of time to convert charts to markdown,
|
||||
# making some powerpoint files with many or complicated charts nearly
|
||||
# unindexable.
|
||||
from markitdown.converters._pptx_converter import PptxConverter
|
||||
|
||||
setattr(
|
||||
PptxConverter,
|
||||
"_convert_chart_to_markdown",
|
||||
lambda self, chart: "\n\n[chart omitted]\n\n", # noqa: ARG005
|
||||
)
|
||||
_MARKITDOWN_CONVERTER = MarkItDown(enable_plugins=False)
|
||||
return _MARKITDOWN_CONVERTER
|
||||
|
||||
@@ -202,18 +214,26 @@ def read_pdf_file(
|
||||
try:
|
||||
pdf_reader = PdfReader(file)
|
||||
|
||||
if pdf_reader.is_encrypted and pdf_pass is not None:
|
||||
if pdf_reader.is_encrypted:
|
||||
# Try the explicit password first, then fall back to an empty
|
||||
# string. Owner-password-only PDFs (permission restrictions but
|
||||
# no open password) decrypt successfully with "".
|
||||
# See https://github.com/onyx-dot-app/onyx/issues/9754
|
||||
passwords = [p for p in [pdf_pass, ""] if p is not None]
|
||||
decrypt_success = False
|
||||
try:
|
||||
decrypt_success = pdf_reader.decrypt(pdf_pass) != 0
|
||||
except Exception:
|
||||
logger.error("Unable to decrypt pdf")
|
||||
for pw in passwords:
|
||||
try:
|
||||
if pdf_reader.decrypt(pw) != 0:
|
||||
decrypt_success = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not decrypt_success:
|
||||
logger.error(
|
||||
"Encrypted PDF could not be decrypted, returning empty text."
|
||||
)
|
||||
return "", metadata, []
|
||||
elif pdf_reader.is_encrypted:
|
||||
logger.warning("No Password for an encrypted PDF, returning empty text.")
|
||||
return "", metadata, []
|
||||
|
||||
# Basic PDF metadata
|
||||
if pdf_reader.metadata is not None:
|
||||
|
||||
@@ -33,8 +33,20 @@ def is_pdf_protected(file: IO[Any]) -> bool:
|
||||
|
||||
with preserve_position(file):
|
||||
reader = PdfReader(file)
|
||||
if not reader.is_encrypted:
|
||||
return False
|
||||
|
||||
return bool(reader.is_encrypted)
|
||||
# PDFs with only an owner password (permission restrictions like
|
||||
# print/copy disabled) use an empty user password — any viewer can open
|
||||
# them without prompting. decrypt("") returns 0 only when a real user
|
||||
# password is required. See https://github.com/onyx-dot-app/onyx/issues/9754
|
||||
try:
|
||||
return reader.decrypt("") == 0
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to evaluate PDF encryption; treating as password protected"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def is_docx_protected(file: IO[Any]) -> bool:
|
||||
|
||||
@@ -26,6 +26,7 @@ class LlmProviderNames(str, Enum):
|
||||
MISTRAL = "mistral"
|
||||
LITELLM_PROXY = "litellm_proxy"
|
||||
BIFROST = "bifrost"
|
||||
OPENAI_COMPATIBLE = "openai_compatible"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Needed so things like:
|
||||
@@ -46,6 +47,7 @@ WELL_KNOWN_PROVIDER_NAMES = [
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.LITELLM_PROXY,
|
||||
LlmProviderNames.BIFROST,
|
||||
LlmProviderNames.OPENAI_COMPATIBLE,
|
||||
]
|
||||
|
||||
|
||||
@@ -64,6 +66,7 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
|
||||
LlmProviderNames.LM_STUDIO: "LM Studio",
|
||||
LlmProviderNames.LITELLM_PROXY: "LiteLLM Proxy",
|
||||
LlmProviderNames.BIFROST: "Bifrost",
|
||||
LlmProviderNames.OPENAI_COMPATIBLE: "OpenAI Compatible",
|
||||
"groq": "Groq",
|
||||
"anyscale": "Anyscale",
|
||||
"deepseek": "DeepSeek",
|
||||
@@ -116,6 +119,7 @@ AGGREGATOR_PROVIDERS: set[str] = {
|
||||
LlmProviderNames.AZURE,
|
||||
LlmProviderNames.LITELLM_PROXY,
|
||||
LlmProviderNames.BIFROST,
|
||||
LlmProviderNames.OPENAI_COMPATIBLE,
|
||||
}
|
||||
|
||||
# Model family name mappings for display name generation
|
||||
|
||||
@@ -185,6 +185,21 @@ def _messages_contain_tool_content(messages: list[dict[str, Any]]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _prompt_contains_tool_call_history(prompt: LanguageModelInput) -> bool:
|
||||
"""Check if the prompt contains any assistant messages with tool_calls.
|
||||
|
||||
When Anthropic's extended thinking is enabled, the API requires every
|
||||
assistant message to start with a thinking block before any tool_use
|
||||
blocks. Since we don't preserve thinking_blocks (they carry
|
||||
cryptographic signatures that can't be reconstructed), we must skip
|
||||
the thinking param whenever history contains prior tool-calling turns.
|
||||
"""
|
||||
from onyx.llm.models import AssistantMessage
|
||||
|
||||
msgs = prompt if isinstance(prompt, list) else [prompt]
|
||||
return any(isinstance(msg, AssistantMessage) and msg.tool_calls for msg in msgs)
|
||||
|
||||
|
||||
def _is_vertex_model_rejecting_output_config(model_name: str) -> bool:
|
||||
normalized_model_name = model_name.lower()
|
||||
return any(
|
||||
@@ -290,12 +305,19 @@ class LitellmLLM(LLM):
|
||||
):
|
||||
model_kwargs[VERTEX_LOCATION_KWARG] = "global"
|
||||
|
||||
# Bifrost: OpenAI-compatible proxy that expects model names in
|
||||
# provider/model format (e.g. "anthropic/claude-sonnet-4-6").
|
||||
# We route through LiteLLM's openai provider with the Bifrost base URL,
|
||||
# and ensure /v1 is appended.
|
||||
if model_provider == LlmProviderNames.BIFROST:
|
||||
# Bifrost and OpenAI-compatible: OpenAI-compatible proxies that send
|
||||
# model names directly to the endpoint. We route through LiteLLM's
|
||||
# openai provider with the server's base URL, and ensure /v1 is appended.
|
||||
if model_provider in (
|
||||
LlmProviderNames.BIFROST,
|
||||
LlmProviderNames.OPENAI_COMPATIBLE,
|
||||
):
|
||||
self._custom_llm_provider = "openai"
|
||||
# LiteLLM's OpenAI client requires an api_key to be set.
|
||||
# Many OpenAI-compatible servers don't need auth, so supply a
|
||||
# placeholder to prevent LiteLLM from raising AuthenticationError.
|
||||
if not self._api_key:
|
||||
model_kwargs.setdefault("api_key", "not-needed")
|
||||
if self._api_base is not None:
|
||||
base = self._api_base.rstrip("/")
|
||||
self._api_base = base if base.endswith("/v1") else f"{base}/v1"
|
||||
@@ -412,17 +434,20 @@ class LitellmLLM(LLM):
|
||||
optional_kwargs: dict[str, Any] = {}
|
||||
|
||||
# Model name
|
||||
is_bifrost = self._model_provider == LlmProviderNames.BIFROST
|
||||
is_openai_compatible_proxy = self._model_provider in (
|
||||
LlmProviderNames.BIFROST,
|
||||
LlmProviderNames.OPENAI_COMPATIBLE,
|
||||
)
|
||||
model_provider = (
|
||||
f"{self.config.model_provider}/responses"
|
||||
if is_openai_model # Uses litellm's completions -> responses bridge
|
||||
else self.config.model_provider
|
||||
)
|
||||
if is_bifrost:
|
||||
# Bifrost expects model names in provider/model format
|
||||
# (e.g. "anthropic/claude-sonnet-4-6") sent directly to its
|
||||
# OpenAI-compatible endpoint. We use custom_llm_provider="openai"
|
||||
# so LiteLLM doesn't try to route based on the provider prefix.
|
||||
if is_openai_compatible_proxy:
|
||||
# OpenAI-compatible proxies (Bifrost, generic OpenAI-compatible
|
||||
# servers) expect model names sent directly to their endpoint.
|
||||
# We use custom_llm_provider="openai" so LiteLLM doesn't try
|
||||
# to route based on the provider prefix.
|
||||
model = self.config.deployment_name or self.config.model_name
|
||||
else:
|
||||
model = f"{model_provider}/{self.config.deployment_name or self.config.model_name}"
|
||||
@@ -466,7 +491,20 @@ class LitellmLLM(LLM):
|
||||
reasoning_effort
|
||||
)
|
||||
|
||||
if budget_tokens is not None:
|
||||
# Anthropic requires every assistant message with tool_use
|
||||
# blocks to start with a thinking block that carries a
|
||||
# cryptographic signature. We don't preserve those blocks
|
||||
# across turns, so skip thinking when the history already
|
||||
# contains tool-calling assistant messages. LiteLLM's
|
||||
# modify_params workaround doesn't cover all providers
|
||||
# (notably Bedrock).
|
||||
can_enable_thinking = (
|
||||
budget_tokens is not None
|
||||
and not _prompt_contains_tool_call_history(prompt)
|
||||
)
|
||||
|
||||
if can_enable_thinking:
|
||||
assert budget_tokens is not None # mypy
|
||||
if max_tokens is not None:
|
||||
# Anthropic has a weird rule where max token has to be at least as much as budget tokens if set
|
||||
# and the minimum budget tokens is 1024
|
||||
@@ -500,7 +538,10 @@ class LitellmLLM(LLM):
|
||||
if structured_response_format:
|
||||
optional_kwargs["response_format"] = structured_response_format
|
||||
|
||||
if not (is_claude_model or is_ollama or is_mistral) or is_bifrost:
|
||||
if (
|
||||
not (is_claude_model or is_ollama or is_mistral)
|
||||
or is_openai_compatible_proxy
|
||||
):
|
||||
# Litellm bug: tool_choice is dropped silently if not specified here for OpenAI
|
||||
# However, this param breaks Anthropic and Mistral models,
|
||||
# so it must be conditionally included unless the request is
|
||||
|
||||
@@ -15,6 +15,8 @@ LITELLM_PROXY_PROVIDER_NAME = "litellm_proxy"
|
||||
|
||||
BIFROST_PROVIDER_NAME = "bifrost"
|
||||
|
||||
OPENAI_COMPATIBLE_PROVIDER_NAME = "openai_compatible"
|
||||
|
||||
# Providers that use optional Bearer auth from custom_config
|
||||
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING: dict[str, str] = {
|
||||
LlmProviderNames.OLLAMA_CHAT: OLLAMA_API_KEY_CONFIG_KEY,
|
||||
|
||||
@@ -19,6 +19,7 @@ from onyx.llm.well_known_providers.constants import BIFROST_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import LITELLM_PROXY_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import LM_STUDIO_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OLLAMA_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENAI_COMPATIBLE_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENROUTER_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import VERTEXAI_PROVIDER_NAME
|
||||
@@ -51,6 +52,7 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
|
||||
OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API
|
||||
LITELLM_PROXY_PROVIDER_NAME: [], # Dynamic - fetched from LiteLLM proxy API
|
||||
BIFROST_PROVIDER_NAME: [], # Dynamic - fetched from Bifrost API
|
||||
OPENAI_COMPATIBLE_PROVIDER_NAME: [], # Dynamic - fetched from OpenAI-compatible API
|
||||
}
|
||||
|
||||
|
||||
@@ -336,6 +338,7 @@ def get_provider_display_name(provider_name: str) -> str:
|
||||
VERTEXAI_PROVIDER_NAME: "Google Vertex AI",
|
||||
OPENROUTER_PROVIDER_NAME: "OpenRouter",
|
||||
LITELLM_PROXY_PROVIDER_NAME: "LiteLLM Proxy",
|
||||
OPENAI_COMPATIBLE_PROVIDER_NAME: "OpenAI Compatible",
|
||||
}
|
||||
|
||||
if provider_name in _ONYX_PROVIDER_DISPLAY_NAMES:
|
||||
|
||||
@@ -6,6 +6,7 @@ from onyx.configs.app_configs import MCP_SERVER_ENABLED
|
||||
from onyx.configs.app_configs import MCP_SERVER_HOST
|
||||
from onyx.configs.app_configs import MCP_SERVER_PORT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -16,6 +17,7 @@ def main() -> None:
|
||||
logger.info("MCP server is disabled (MCP_SERVER_ENABLED=false)")
|
||||
return
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
logger.info(f"Starting MCP server on {MCP_SERVER_HOST}:{MCP_SERVER_PORT}")
|
||||
|
||||
from onyx.mcp_server.api import mcp_app
|
||||
|
||||
@@ -74,6 +74,8 @@ from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OllamaModelDetails
|
||||
from onyx.server.manage.llm.models import OllamaModelsRequest
|
||||
from onyx.server.manage.llm.models import OpenAICompatibleFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OpenAICompatibleModelsRequest
|
||||
from onyx.server.manage.llm.models import OpenRouterFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OpenRouterModelDetails
|
||||
from onyx.server.manage.llm.models import OpenRouterModelsRequest
|
||||
@@ -1575,3 +1577,95 @@ def _get_bifrost_models_response(api_base: str, api_key: str | None = None) -> d
|
||||
source_name="Bifrost",
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
|
||||
@admin_router.post("/openai-compatible/available-models")
|
||||
def get_openai_compatible_server_available_models(
|
||||
request: OpenAICompatibleModelsRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[OpenAICompatibleFinalModelResponse]:
|
||||
"""Fetch available models from a generic OpenAI-compatible /v1/models endpoint."""
|
||||
response_json = _get_openai_compatible_server_response(
|
||||
api_base=request.api_base, api_key=request.api_key
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
if not isinstance(models, list) or len(models) == 0:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No models found from your OpenAI-compatible endpoint",
|
||||
)
|
||||
|
||||
results: list[OpenAICompatibleFinalModelResponse] = []
|
||||
for model in models:
|
||||
try:
|
||||
model_id = model.get("id", "")
|
||||
model_name = model.get("name", model_id)
|
||||
|
||||
if not model_id:
|
||||
continue
|
||||
|
||||
# Skip embedding models
|
||||
if is_embedding_model(model_id):
|
||||
continue
|
||||
|
||||
results.append(
|
||||
OpenAICompatibleFinalModelResponse(
|
||||
name=model_id,
|
||||
display_name=model_name,
|
||||
max_input_tokens=model.get("context_length"),
|
||||
supports_image_input=infer_vision_support(model_id),
|
||||
supports_reasoning=is_reasoning_model(model_id, model_name),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to parse OpenAI-compatible model entry",
|
||||
extra={"error": str(e), "item": str(model)[:1000]},
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No compatible models found from OpenAI-compatible endpoint",
|
||||
)
|
||||
|
||||
sorted_results = sorted(results, key=lambda m: m.name.lower())
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="OpenAI Compatible",
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
|
||||
def _get_openai_compatible_server_response(
|
||||
api_base: str, api_key: str | None = None
|
||||
) -> dict:
|
||||
"""Perform GET to an OpenAI-compatible /v1/models and return parsed JSON."""
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
# Ensure we hit /v1/models
|
||||
if cleaned_api_base.endswith("/v1"):
|
||||
url = f"{cleaned_api_base}/models"
|
||||
else:
|
||||
url = f"{cleaned_api_base}/v1/models"
|
||||
|
||||
return _get_openai_compatible_models_response(
|
||||
url=url,
|
||||
source_name="OpenAI Compatible",
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
@@ -464,3 +464,18 @@ class BifrostFinalModelResponse(BaseModel):
|
||||
max_input_tokens: int | None
|
||||
supports_image_input: bool
|
||||
supports_reasoning: bool
|
||||
|
||||
|
||||
# OpenAI Compatible dynamic models fetch
|
||||
class OpenAICompatibleModelsRequest(BaseModel):
|
||||
api_base: str
|
||||
api_key: str | None = None
|
||||
provider_name: str | None = None # Optional: to save models to existing provider
|
||||
|
||||
|
||||
class OpenAICompatibleFinalModelResponse(BaseModel):
|
||||
name: str # Model ID (e.g. "meta-llama/Llama-3-8B-Instruct")
|
||||
display_name: str # Human-readable name from API
|
||||
max_input_tokens: int | None
|
||||
supports_image_input: bool
|
||||
supports_reasoning: bool
|
||||
|
||||
@@ -26,6 +26,7 @@ DYNAMIC_LLM_PROVIDERS = frozenset(
|
||||
LlmProviderNames.OLLAMA_CHAT,
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.BIFROST,
|
||||
LlmProviderNames.OPENAI_COMPATIBLE,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -147,6 +147,7 @@ class UserInfo(BaseModel):
|
||||
is_anonymous_user: bool | None = None,
|
||||
tenant_info: TenantInfo | None = None,
|
||||
assistant_specific_configs: UserSpecificAssistantPreferences | None = None,
|
||||
memories: list[MemoryItem] | None = None,
|
||||
) -> "UserInfo":
|
||||
return cls(
|
||||
id=str(user.id),
|
||||
@@ -191,10 +192,7 @@ class UserInfo(BaseModel):
|
||||
role=user.personal_role or "",
|
||||
use_memories=user.use_memories,
|
||||
enable_memory_tool=user.enable_memory_tool,
|
||||
memories=[
|
||||
MemoryItem(id=memory.id, content=memory.memory_text)
|
||||
for memory in (user.memories or [])
|
||||
],
|
||||
memories=memories or [],
|
||||
user_preferences=user.user_preferences or "",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -57,6 +57,7 @@ from onyx.db.user_preferences import activate_user
|
||||
from onyx.db.user_preferences import deactivate_user
|
||||
from onyx.db.user_preferences import get_all_user_assistant_specific_configs
|
||||
from onyx.db.user_preferences import get_latest_access_token_for_user
|
||||
from onyx.db.user_preferences import get_memories_for_user
|
||||
from onyx.db.user_preferences import update_assistant_preferences
|
||||
from onyx.db.user_preferences import update_user_assistant_visibility
|
||||
from onyx.db.user_preferences import update_user_auto_scroll
|
||||
@@ -823,6 +824,11 @@ def verify_user_logged_in(
|
||||
[],
|
||||
),
|
||||
)
|
||||
memories = [
|
||||
MemoryItem(id=memory.id, content=memory.memory_text)
|
||||
for memory in get_memories_for_user(user.id, db_session)
|
||||
]
|
||||
|
||||
user_info = UserInfo.from_model(
|
||||
user,
|
||||
current_token_created_at=token_created_at,
|
||||
@@ -833,6 +839,7 @@ def verify_user_logged_in(
|
||||
new_tenant=new_tenant,
|
||||
invitation=tenant_invitation,
|
||||
),
|
||||
memories=memories,
|
||||
)
|
||||
|
||||
return user_info
|
||||
@@ -930,7 +937,8 @@ def update_user_personalization_api(
|
||||
else user.enable_memory_tool
|
||||
)
|
||||
existing_memories = [
|
||||
MemoryItem(id=memory.id, content=memory.memory_text) for memory in user.memories
|
||||
MemoryItem(id=memory.id, content=memory.memory_text)
|
||||
for memory in get_memories_for_user(user.id, db_session)
|
||||
]
|
||||
new_memories = (
|
||||
request.memories if request.memories is not None else existing_memories
|
||||
|
||||
@@ -12,7 +12,6 @@ stale, which is fine for monitoring dashboards.
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
@@ -104,25 +103,23 @@ class _CachedCollector(Collector):
|
||||
|
||||
|
||||
class QueueDepthCollector(_CachedCollector):
|
||||
"""Reads Celery queue lengths from the broker Redis on each scrape.
|
||||
|
||||
Uses a Redis client factory (callable) rather than a stored client
|
||||
reference so the connection is always fresh from Celery's pool.
|
||||
"""
|
||||
"""Reads Celery queue lengths from the broker Redis on each scrape."""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._get_redis: Callable[[], Redis] | None = None
|
||||
self._celery_app: Any | None = None
|
||||
|
||||
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
|
||||
"""Set a callable that returns a broker Redis client on demand."""
|
||||
self._get_redis = factory
|
||||
def set_celery_app(self, app: Any) -> None:
|
||||
"""Set the Celery app for broker Redis access."""
|
||||
self._celery_app = app
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if self._get_redis is None:
|
||||
if self._celery_app is None:
|
||||
return []
|
||||
|
||||
redis_client = self._get_redis()
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
|
||||
redis_client = celery_get_broker_client(self._celery_app)
|
||||
|
||||
depth = GaugeMetricFamily(
|
||||
"onyx_queue_depth",
|
||||
@@ -404,17 +401,19 @@ class RedisHealthCollector(_CachedCollector):
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._get_redis: Callable[[], Redis] | None = None
|
||||
self._celery_app: Any | None = None
|
||||
|
||||
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
|
||||
"""Set a callable that returns a broker Redis client on demand."""
|
||||
self._get_redis = factory
|
||||
def set_celery_app(self, app: Any) -> None:
|
||||
"""Set the Celery app for broker Redis access."""
|
||||
self._celery_app = app
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if self._get_redis is None:
|
||||
if self._celery_app is None:
|
||||
return []
|
||||
|
||||
redis_client = self._get_redis()
|
||||
from onyx.background.celery.celery_redis import celery_get_broker_client
|
||||
|
||||
redis_client = celery_get_broker_client(self._celery_app)
|
||||
|
||||
memory_used = GaugeMetricFamily(
|
||||
"onyx_redis_memory_used_bytes",
|
||||
|
||||
@@ -3,12 +3,8 @@
|
||||
Called once by the monitoring celery worker after Redis and DB are ready.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from prometheus_client.registry import REGISTRY
|
||||
from redis import Redis
|
||||
|
||||
from onyx.server.metrics.indexing_pipeline import ConnectorHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
|
||||
@@ -21,7 +17,7 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
# Module-level singletons — these are lightweight objects (no connections or DB
|
||||
# state) until configure() / set_redis_factory() is called. Keeping them at
|
||||
# state) until configure() / set_celery_app() is called. Keeping them at
|
||||
# module level ensures they survive the lifetime of the worker process and are
|
||||
# only registered with the Prometheus registry once.
|
||||
_queue_collector = QueueDepthCollector()
|
||||
@@ -32,72 +28,15 @@ _worker_health_collector = WorkerHealthCollector()
|
||||
_heartbeat_monitor: WorkerHeartbeatMonitor | None = None
|
||||
|
||||
|
||||
def _make_broker_redis_factory(celery_app: Celery) -> Callable[[], Redis]:
|
||||
"""Create a factory that returns a cached broker Redis client.
|
||||
|
||||
Reuses a single connection across scrapes to avoid leaking connections.
|
||||
Reconnects automatically if the cached connection becomes stale.
|
||||
"""
|
||||
_cached_client: list[Redis | None] = [None]
|
||||
# Keep a reference to the Kombu Connection so we can close it on
|
||||
# reconnect (the raw Redis client outlives the Kombu wrapper).
|
||||
_cached_kombu_conn: list[Any] = [None]
|
||||
|
||||
def _close_client(client: Redis) -> None:
|
||||
"""Best-effort close of a Redis client."""
|
||||
try:
|
||||
client.close()
|
||||
except Exception:
|
||||
logger.debug("Failed to close stale Redis client", exc_info=True)
|
||||
|
||||
def _close_kombu_conn() -> None:
|
||||
"""Best-effort close of the cached Kombu Connection."""
|
||||
conn = _cached_kombu_conn[0]
|
||||
if conn is not None:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
logger.debug("Failed to close Kombu connection", exc_info=True)
|
||||
_cached_kombu_conn[0] = None
|
||||
|
||||
def _get_broker_redis() -> Redis:
|
||||
client = _cached_client[0]
|
||||
if client is not None:
|
||||
try:
|
||||
client.ping()
|
||||
return client
|
||||
except Exception:
|
||||
logger.debug("Cached Redis client stale, reconnecting")
|
||||
_close_client(client)
|
||||
_cached_client[0] = None
|
||||
_close_kombu_conn()
|
||||
|
||||
# Get a fresh Redis client from the broker connection.
|
||||
# We hold this client long-term (cached above) rather than using a
|
||||
# context manager, because we need it to persist across scrapes.
|
||||
# The caching logic above ensures we only ever hold one connection,
|
||||
# and we close it explicitly on reconnect.
|
||||
conn = celery_app.broker_connection()
|
||||
# kombu's Channel exposes .client at runtime (the underlying Redis
|
||||
# client) but the type stubs don't declare it.
|
||||
new_client: Redis = conn.channel().client # type: ignore[attr-defined]
|
||||
_cached_client[0] = new_client
|
||||
_cached_kombu_conn[0] = conn
|
||||
return new_client
|
||||
|
||||
return _get_broker_redis
|
||||
|
||||
|
||||
def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
|
||||
"""Register all indexing pipeline collectors with the default registry.
|
||||
|
||||
Args:
|
||||
celery_app: The Celery application instance. Used to obtain a fresh
|
||||
celery_app: The Celery application instance. Used to obtain a
|
||||
broker Redis client on each scrape for queue depth metrics.
|
||||
"""
|
||||
redis_factory = _make_broker_redis_factory(celery_app)
|
||||
_queue_collector.set_redis_factory(redis_factory)
|
||||
_redis_health_collector.set_redis_factory(redis_factory)
|
||||
_queue_collector.set_celery_app(celery_app)
|
||||
_redis_health_collector.set_celery_app(celery_app)
|
||||
|
||||
# Start the heartbeat monitor daemon thread — uses a single persistent
|
||||
# connection to receive worker-heartbeat events.
|
||||
|
||||
@@ -129,6 +129,10 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
|
||||
return_value=mock_app,
|
||||
),
|
||||
patch(_PATCH_QUEUE_DEPTH, return_value=0),
|
||||
patch(
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@@ -88,10 +88,22 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
|
||||
the actual task instance. We patch ``app`` on that instance's class
|
||||
(a unique Celery-generated Task subclass) so the mock is scoped to this
|
||||
task only.
|
||||
|
||||
Also patches ``celery_get_broker_client`` so the mock app doesn't need
|
||||
a real broker URL.
|
||||
"""
|
||||
task_instance = task.run.__self__
|
||||
with patch.object(
|
||||
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
|
||||
with (
|
||||
patch.object(
|
||||
type(task_instance),
|
||||
"app",
|
||||
new_callable=PropertyMock,
|
||||
return_value=mock_app,
|
||||
),
|
||||
patch(
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@@ -90,8 +90,17 @@ def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, Non
|
||||
task only.
|
||||
"""
|
||||
task_instance = task.run.__self__
|
||||
with patch.object(
|
||||
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
|
||||
with (
|
||||
patch.object(
|
||||
type(task_instance),
|
||||
"app",
|
||||
new_callable=PropertyMock,
|
||||
return_value=mock_app,
|
||||
),
|
||||
patch(
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_broker_client",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import MASK_CREDENTIAL_CHAR
|
||||
from onyx.db.federated import _reject_masked_credentials
|
||||
|
||||
|
||||
class TestRejectMaskedCredentials:
|
||||
"""Verify that masked credential values are never accepted for DB writes.
|
||||
|
||||
mask_string() has two output formats:
|
||||
- Short strings (< 14 chars): "••••••••••••" (U+2022 BULLET)
|
||||
- Long strings (>= 14 chars): "abcd...wxyz" (first4 + "..." + last4)
|
||||
_reject_masked_credentials must catch both.
|
||||
"""
|
||||
|
||||
def test_rejects_fully_masked_value(self) -> None:
|
||||
masked = MASK_CREDENTIAL_CHAR * 12 # "••••••••••••"
|
||||
with pytest.raises(ValueError, match="masked placeholder"):
|
||||
_reject_masked_credentials({"client_id": masked})
|
||||
|
||||
def test_rejects_long_string_masked_value(self) -> None:
|
||||
"""mask_string returns 'first4...last4' for long strings — the real
|
||||
format used for OAuth credentials like client_id and client_secret."""
|
||||
with pytest.raises(ValueError, match="masked placeholder"):
|
||||
_reject_masked_credentials({"client_id": "1234...7890"})
|
||||
|
||||
def test_rejects_when_any_field_is_masked(self) -> None:
|
||||
"""Even if client_id is real, a masked client_secret must be caught."""
|
||||
with pytest.raises(ValueError, match="client_secret"):
|
||||
_reject_masked_credentials(
|
||||
{
|
||||
"client_id": "1234567890.1234567890",
|
||||
"client_secret": MASK_CREDENTIAL_CHAR * 12,
|
||||
}
|
||||
)
|
||||
|
||||
def test_accepts_real_credentials(self) -> None:
|
||||
# Should not raise
|
||||
_reject_masked_credentials(
|
||||
{
|
||||
"client_id": "1234567890.1234567890",
|
||||
"client_secret": "test_client_secret_value",
|
||||
}
|
||||
)
|
||||
|
||||
def test_accepts_empty_dict(self) -> None:
|
||||
# Should not raise — empty credentials are handled elsewhere
|
||||
_reject_masked_credentials({})
|
||||
|
||||
def test_ignores_non_string_values(self) -> None:
|
||||
# Non-string values (None, bool, int) should pass through
|
||||
_reject_masked_credentials(
|
||||
{
|
||||
"client_id": "real_value",
|
||||
"redirect_uri": None,
|
||||
"some_flag": True,
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,87 @@
|
||||
"""Tests for celery_get_broker_client singleton."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.background.celery import celery_redis
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_singleton() -> Iterator[None]:
|
||||
"""Reset the module-level singleton between tests."""
|
||||
celery_redis._broker_client = None
|
||||
celery_redis._broker_url = None
|
||||
yield
|
||||
celery_redis._broker_client = None
|
||||
celery_redis._broker_url = None
|
||||
|
||||
|
||||
def _make_mock_app(broker_url: str = "redis://localhost:6379/15") -> MagicMock:
|
||||
app = MagicMock()
|
||||
app.conf.broker_url = broker_url
|
||||
return app
|
||||
|
||||
|
||||
class TestCeleryGetBrokerClient:
|
||||
@patch("onyx.background.celery.celery_redis.Redis")
|
||||
def test_creates_client_on_first_call(self, mock_redis_cls: MagicMock) -> None:
|
||||
mock_client = MagicMock()
|
||||
mock_redis_cls.from_url.return_value = mock_client
|
||||
|
||||
app = _make_mock_app()
|
||||
result = celery_redis.celery_get_broker_client(app)
|
||||
|
||||
assert result is mock_client
|
||||
call_args = mock_redis_cls.from_url.call_args
|
||||
assert call_args[0][0] == "redis://localhost:6379/15"
|
||||
assert call_args[1]["decode_responses"] is False
|
||||
assert call_args[1]["socket_keepalive"] is True
|
||||
assert call_args[1]["retry_on_timeout"] is True
|
||||
|
||||
@patch("onyx.background.celery.celery_redis.Redis")
|
||||
def test_reuses_cached_client(self, mock_redis_cls: MagicMock) -> None:
|
||||
mock_client = MagicMock()
|
||||
mock_client.ping.return_value = True
|
||||
mock_redis_cls.from_url.return_value = mock_client
|
||||
|
||||
app = _make_mock_app()
|
||||
client1 = celery_redis.celery_get_broker_client(app)
|
||||
client2 = celery_redis.celery_get_broker_client(app)
|
||||
|
||||
assert client1 is client2
|
||||
# from_url called only once
|
||||
assert mock_redis_cls.from_url.call_count == 1
|
||||
|
||||
@patch("onyx.background.celery.celery_redis.Redis")
|
||||
def test_reconnects_on_ping_failure(self, mock_redis_cls: MagicMock) -> None:
|
||||
stale_client = MagicMock()
|
||||
stale_client.ping.side_effect = ConnectionError("disconnected")
|
||||
|
||||
fresh_client = MagicMock()
|
||||
fresh_client.ping.return_value = True
|
||||
|
||||
mock_redis_cls.from_url.side_effect = [stale_client, fresh_client]
|
||||
|
||||
app = _make_mock_app()
|
||||
|
||||
# First call creates stale_client
|
||||
client1 = celery_redis.celery_get_broker_client(app)
|
||||
assert client1 is stale_client
|
||||
|
||||
# Second call: ping fails, creates fresh_client
|
||||
client2 = celery_redis.celery_get_broker_client(app)
|
||||
assert client2 is fresh_client
|
||||
assert mock_redis_cls.from_url.call_count == 2
|
||||
|
||||
@patch("onyx.background.celery.celery_redis.Redis")
|
||||
def test_uses_broker_url_from_app_config(self, mock_redis_cls: MagicMock) -> None:
|
||||
mock_redis_cls.from_url.return_value = MagicMock()
|
||||
|
||||
app = _make_mock_app("redis://custom-host:6380/3")
|
||||
celery_redis.celery_get_broker_client(app)
|
||||
|
||||
call_args = mock_redis_cls.from_url.call_args
|
||||
assert call_args[0][0] == "redis://custom-host:6380/3"
|
||||
225
backend/tests/unit/onyx/db/test_chat_sessions.py
Normal file
225
backend/tests/unit/onyx/db/test_chat_sessions.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Tests for get_chat_sessions_by_user filtering behavior.
|
||||
|
||||
Verifies that failed chat sessions (those with only SYSTEM messages) are
|
||||
correctly filtered out while preserving recently created sessions, matching
|
||||
the behavior specified in PR #7233.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.models import ChatSession
|
||||
|
||||
|
||||
def _make_session(
|
||||
user_id: UUID,
|
||||
time_created: datetime | None = None,
|
||||
time_updated: datetime | None = None,
|
||||
description: str = "",
|
||||
) -> MagicMock:
|
||||
"""Create a mock ChatSession with the given attributes."""
|
||||
session = MagicMock(spec=ChatSession)
|
||||
session.id = uuid4()
|
||||
session.user_id = user_id
|
||||
session.time_created = time_created or datetime.now(timezone.utc)
|
||||
session.time_updated = time_updated or session.time_created
|
||||
session.description = description
|
||||
session.deleted = False
|
||||
session.onyxbot_flow = False
|
||||
session.project_id = None
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_id() -> UUID:
|
||||
return uuid4()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def old_time() -> datetime:
|
||||
"""A timestamp well outside the 5-minute leeway window."""
|
||||
return datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def recent_time() -> datetime:
|
||||
"""A timestamp within the 5-minute leeway window."""
|
||||
return datetime.now(timezone.utc) - timedelta(minutes=2)
|
||||
|
||||
|
||||
class TestGetChatSessionsByUser:
|
||||
"""Tests for the failed chat filtering logic in get_chat_sessions_by_user."""
|
||||
|
||||
def test_filters_out_failed_sessions(
|
||||
self, user_id: UUID, old_time: datetime
|
||||
) -> None:
|
||||
"""Sessions with only SYSTEM messages should be excluded."""
|
||||
valid_session = _make_session(user_id, time_created=old_time)
|
||||
failed_session = _make_session(user_id, time_created=old_time)
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
# First execute: returns all sessions
|
||||
# Second execute: returns only the valid session's ID (has non-system msgs)
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.scalars.return_value.all.return_value = [
|
||||
valid_session,
|
||||
failed_session,
|
||||
]
|
||||
|
||||
mock_result_2 = MagicMock()
|
||||
mock_result_2.scalars.return_value.all.return_value = [valid_session.id]
|
||||
|
||||
db_session.execute.side_effect = [mock_result_1, mock_result_2]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].id == valid_session.id
|
||||
|
||||
def test_keeps_recent_sessions_without_messages(
|
||||
self, user_id: UUID, recent_time: datetime
|
||||
) -> None:
|
||||
"""Recently created sessions should be kept even without messages."""
|
||||
recent_session = _make_session(user_id, time_created=recent_time)
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.scalars.return_value.all.return_value = [recent_session]
|
||||
|
||||
db_session.execute.side_effect = [mock_result_1]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].id == recent_session.id
|
||||
# Should only have been called once — no second query needed
|
||||
# because the recent session is within the leeway window
|
||||
assert db_session.execute.call_count == 1
|
||||
|
||||
def test_include_failed_chats_skips_filtering(
|
||||
self, user_id: UUID, old_time: datetime
|
||||
) -> None:
|
||||
"""When include_failed_chats=True, no filtering should occur."""
|
||||
session_a = _make_session(user_id, time_created=old_time)
|
||||
session_b = _make_session(user_id, time_created=old_time)
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [session_a, session_b]
|
||||
|
||||
db_session.execute.side_effect = [mock_result]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=True,
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
# Only one DB call — no second query for message validation
|
||||
assert db_session.execute.call_count == 1
|
||||
|
||||
def test_limit_applied_after_filtering(
|
||||
self, user_id: UUID, old_time: datetime
|
||||
) -> None:
|
||||
"""Limit should be applied after filtering, not before."""
|
||||
sessions = [_make_session(user_id, time_created=old_time) for _ in range(5)]
|
||||
valid_ids = [s.id for s in sessions[:3]]
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.scalars.return_value.all.return_value = sessions
|
||||
|
||||
mock_result_2 = MagicMock()
|
||||
mock_result_2.scalars.return_value.all.return_value = valid_ids
|
||||
|
||||
db_session.execute.side_effect = [mock_result_1, mock_result_2]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
limit=2,
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
# Should be the first 2 valid sessions (order preserved)
|
||||
assert result[0].id == sessions[0].id
|
||||
assert result[1].id == sessions[1].id
|
||||
|
||||
def test_mixed_recent_and_old_sessions(
|
||||
self, user_id: UUID, old_time: datetime, recent_time: datetime
|
||||
) -> None:
|
||||
"""Mix of recent and old sessions should filter correctly."""
|
||||
old_valid = _make_session(user_id, time_created=old_time)
|
||||
old_failed = _make_session(user_id, time_created=old_time)
|
||||
recent_no_msgs = _make_session(user_id, time_created=recent_time)
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result_1 = MagicMock()
|
||||
mock_result_1.scalars.return_value.all.return_value = [
|
||||
old_valid,
|
||||
old_failed,
|
||||
recent_no_msgs,
|
||||
]
|
||||
|
||||
mock_result_2 = MagicMock()
|
||||
mock_result_2.scalars.return_value.all.return_value = [old_valid.id]
|
||||
|
||||
db_session.execute.side_effect = [mock_result_1, mock_result_2]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
)
|
||||
|
||||
result_ids = {cs.id for cs in result}
|
||||
assert old_valid.id in result_ids
|
||||
assert recent_no_msgs.id in result_ids
|
||||
assert old_failed.id not in result_ids
|
||||
|
||||
def test_empty_result(self, user_id: UUID) -> None:
|
||||
"""No sessions should return empty list without errors."""
|
||||
db_session = MagicMock(spec=Session)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
|
||||
db_session.execute.side_effect = [mock_result]
|
||||
|
||||
result = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
include_failed_chats=False,
|
||||
)
|
||||
|
||||
assert result == []
|
||||
assert db_session.execute.call_count == 1
|
||||
@@ -0,0 +1,76 @@
|
||||
%PDF-1.3
|
||||
%<25><><EFBFBD><EFBFBD>
|
||||
1 0 obj
|
||||
<<
|
||||
/Producer <1083d595b1>
|
||||
>>
|
||||
endobj
|
||||
2 0 obj
|
||||
<<
|
||||
/Type /Pages
|
||||
/Count 1
|
||||
/Kids [ 4 0 R ]
|
||||
>>
|
||||
endobj
|
||||
3 0 obj
|
||||
<<
|
||||
/Type /Catalog
|
||||
/Pages 2 0 R
|
||||
>>
|
||||
endobj
|
||||
4 0 obj
|
||||
<<
|
||||
/Type /Page
|
||||
/Resources <<
|
||||
/Font <<
|
||||
/F1 <<
|
||||
/Type /Font
|
||||
/Subtype /Type1
|
||||
/BaseFont /Helvetica
|
||||
>>
|
||||
>>
|
||||
>>
|
||||
/MediaBox [ 0.0 0.0 200 200 ]
|
||||
/Contents 5 0 R
|
||||
/Parent 2 0 R
|
||||
>>
|
||||
endobj
|
||||
5 0 obj
|
||||
<<
|
||||
/Length 42
|
||||
>>
|
||||
stream
|
||||
,N<><6~<7E>)<29><><EFBFBD><EFBFBD><EFBFBD>u<EFBFBD><0C><><EFBFBD>Zc'<27><>>8g<38><67><EFBFBD>n<EFBFBD><6E><EFBFBD><EFBFBD><EFBFBD>9"
|
||||
endstream
|
||||
endobj
|
||||
6 0 obj
|
||||
<<
|
||||
/V 2
|
||||
/R 3
|
||||
/Length 128
|
||||
/P 4294967292
|
||||
/Filter /Standard
|
||||
/O <6a340a292629053da84a6d8b19a5d505953b8b3fdac3d2d389fde0e354528d44>
|
||||
/U <d6f0dc91c7b9de264a8d708515468e6528bf4e5e4e758a4164004e56fffa0108>
|
||||
>>
|
||||
endobj
|
||||
xref
|
||||
0 7
|
||||
0000000000 65535 f
|
||||
0000000015 00000 n
|
||||
0000000059 00000 n
|
||||
0000000118 00000 n
|
||||
0000000167 00000 n
|
||||
0000000348 00000 n
|
||||
0000000440 00000 n
|
||||
trailer
|
||||
<<
|
||||
/Size 7
|
||||
/Root 3 0 R
|
||||
/Info 1 0 R
|
||||
/ID [ <6364336635356135633239323638353039306635656133623165313637366430> <6364336635356135633239323638353039306635656133623165313637366430> ]
|
||||
/Encrypt 6 0 R
|
||||
>>
|
||||
startxref
|
||||
655
|
||||
%%EOF
|
||||
@@ -54,6 +54,12 @@ class TestReadPdfFile:
|
||||
text, _, _ = read_pdf_file(_load("encrypted.pdf"), pdf_pass="wrong")
|
||||
assert text == ""
|
||||
|
||||
def test_owner_password_only_pdf_extracts_text(self) -> None:
|
||||
"""A PDF encrypted with only an owner password (no user password)
|
||||
should still yield its text content. Regression for #9754."""
|
||||
text, _, _ = read_pdf_file(_load("owner_protected.pdf"))
|
||||
assert "Hello World" in text
|
||||
|
||||
def test_empty_pdf(self) -> None:
|
||||
text, _, _ = read_pdf_file(_load("empty.pdf"))
|
||||
assert text.strip() == ""
|
||||
@@ -117,6 +123,12 @@ class TestIsPdfProtected:
|
||||
def test_protected_pdf(self) -> None:
|
||||
assert is_pdf_protected(_load("encrypted.pdf")) is True
|
||||
|
||||
def test_owner_password_only_is_not_protected(self) -> None:
|
||||
"""A PDF with only an owner password (permission restrictions) but no
|
||||
user password should NOT be considered protected — any viewer can open
|
||||
it without prompting for a password."""
|
||||
assert is_pdf_protected(_load("owner_protected.pdf")) is False
|
||||
|
||||
def test_preserves_file_position(self) -> None:
|
||||
pdf = _load("simple.pdf")
|
||||
pdf.seek(42)
|
||||
|
||||
79
backend/tests/unit/onyx/file_processing/test_pptx_to_text.py
Normal file
79
backend/tests/unit/onyx/file_processing/test_pptx_to_text.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import io
|
||||
|
||||
from pptx import Presentation # type: ignore[import-untyped]
|
||||
from pptx.chart.data import CategoryChartData # type: ignore[import-untyped]
|
||||
from pptx.enum.chart import XL_CHART_TYPE # type: ignore[import-untyped]
|
||||
from pptx.util import Inches # type: ignore[import-untyped]
|
||||
|
||||
from onyx.file_processing.extract_file_text import pptx_to_text
|
||||
|
||||
|
||||
def _make_pptx_with_chart() -> io.BytesIO:
|
||||
"""Create an in-memory pptx with one text slide and one chart slide."""
|
||||
prs = Presentation()
|
||||
|
||||
# Slide 1: text only
|
||||
slide1 = prs.slides.add_slide(prs.slide_layouts[1])
|
||||
slide1.shapes.title.text = "Introduction"
|
||||
slide1.placeholders[1].text = "This is the first slide."
|
||||
|
||||
# Slide 2: chart
|
||||
slide2 = prs.slides.add_slide(prs.slide_layouts[5]) # Blank layout
|
||||
chart_data = CategoryChartData()
|
||||
chart_data.categories = ["Q1", "Q2", "Q3"]
|
||||
chart_data.add_series("Revenue", (100, 200, 300))
|
||||
slide2.shapes.add_chart(
|
||||
XL_CHART_TYPE.COLUMN_CLUSTERED,
|
||||
Inches(1),
|
||||
Inches(1),
|
||||
Inches(6),
|
||||
Inches(4),
|
||||
chart_data,
|
||||
)
|
||||
|
||||
buf = io.BytesIO()
|
||||
prs.save(buf)
|
||||
buf.seek(0)
|
||||
return buf
|
||||
|
||||
|
||||
def _make_pptx_without_chart() -> io.BytesIO:
|
||||
"""Create an in-memory pptx with a single text-only slide."""
|
||||
prs = Presentation()
|
||||
slide = prs.slides.add_slide(prs.slide_layouts[1])
|
||||
slide.shapes.title.text = "Hello World"
|
||||
slide.placeholders[1].text = "Some content here."
|
||||
|
||||
buf = io.BytesIO()
|
||||
prs.save(buf)
|
||||
buf.seek(0)
|
||||
return buf
|
||||
|
||||
|
||||
class TestPptxToText:
|
||||
def test_chart_is_omitted(self) -> None:
|
||||
# Precondition
|
||||
pptx_file = _make_pptx_with_chart()
|
||||
|
||||
# Under test
|
||||
result = pptx_to_text(pptx_file)
|
||||
|
||||
# Postcondition
|
||||
assert "Introduction" in result
|
||||
assert "first slide" in result
|
||||
assert "[chart omitted]" in result
|
||||
# The actual chart data should NOT appear in the output.
|
||||
assert "Revenue" not in result
|
||||
assert "Q1" not in result
|
||||
|
||||
def test_text_only_pptx(self) -> None:
|
||||
# Precondition
|
||||
pptx_file = _make_pptx_without_chart()
|
||||
|
||||
# Under test
|
||||
result = pptx_to_text(pptx_file)
|
||||
|
||||
# Postcondition
|
||||
assert "Hello World" in result
|
||||
assert "Some content" in result
|
||||
assert "[chart omitted]" not in result
|
||||
@@ -11,6 +11,7 @@ from litellm.types.utils import ChatCompletionDeltaToolCall
|
||||
from litellm.types.utils import Delta
|
||||
from litellm.types.utils import Function as LiteLLMFunction
|
||||
|
||||
import onyx.llm.models
|
||||
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
@@ -1479,6 +1480,147 @@ def test_bifrost_normalizes_api_base_in_model_kwargs() -> None:
|
||||
assert llm._model_kwargs["api_base"] == "https://bifrost.example.com/v1"
|
||||
|
||||
|
||||
def test_prompt_contains_tool_call_history_true() -> None:
|
||||
from onyx.llm.multi_llm import _prompt_contains_tool_call_history
|
||||
|
||||
messages: LanguageModelInput = [
|
||||
UserMessage(content="What's the weather?"),
|
||||
AssistantMessage(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id="tc_1",
|
||||
function=FunctionCall(name="get_weather", arguments="{}"),
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
assert _prompt_contains_tool_call_history(messages) is True
|
||||
|
||||
|
||||
def test_prompt_contains_tool_call_history_false_no_tools() -> None:
|
||||
from onyx.llm.multi_llm import _prompt_contains_tool_call_history
|
||||
|
||||
messages: LanguageModelInput = [
|
||||
UserMessage(content="Hello"),
|
||||
AssistantMessage(content="Hi there!"),
|
||||
]
|
||||
assert _prompt_contains_tool_call_history(messages) is False
|
||||
|
||||
|
||||
def test_prompt_contains_tool_call_history_false_user_only() -> None:
|
||||
from onyx.llm.multi_llm import _prompt_contains_tool_call_history
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hello")]
|
||||
assert _prompt_contains_tool_call_history(messages) is False
|
||||
|
||||
|
||||
def test_bedrock_claude_drops_thinking_when_thinking_blocks_missing() -> None:
|
||||
"""When thinking is enabled but assistant messages with tool_calls lack
|
||||
thinking_blocks, the thinking param must be dropped to avoid the Bedrock
|
||||
BadRequestError about missing thinking blocks."""
|
||||
llm = LitellmLLM(
|
||||
api_key=None,
|
||||
timeout=30,
|
||||
model_provider=LlmProviderNames.BEDROCK,
|
||||
model_name="anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
max_input_tokens=200000,
|
||||
)
|
||||
|
||||
messages: LanguageModelInput = [
|
||||
UserMessage(content="What's the weather?"),
|
||||
AssistantMessage(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id="tc_1",
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "Paris"}',
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
onyx.llm.models.ToolMessage(
|
||||
content="22°C sunny",
|
||||
tool_call_id="tc_1",
|
||||
),
|
||||
]
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with (
|
||||
patch("litellm.completion") as mock_completion,
|
||||
patch("onyx.llm.multi_llm.model_is_reasoning_model", return_value=True),
|
||||
):
|
||||
mock_completion.return_value = []
|
||||
|
||||
list(llm.stream(messages, tools=tools, reasoning_effort=ReasoningEffort.HIGH))
|
||||
|
||||
kwargs = mock_completion.call_args.kwargs
|
||||
assert "thinking" not in kwargs, (
|
||||
"thinking param should be dropped when thinking_blocks are missing "
|
||||
"from assistant messages with tool_calls"
|
||||
)
|
||||
|
||||
|
||||
def test_bedrock_claude_keeps_thinking_when_no_tool_history() -> None:
|
||||
"""When thinking is enabled and there are no historical assistant messages
|
||||
with tool_calls, the thinking param should be preserved."""
|
||||
llm = LitellmLLM(
|
||||
api_key=None,
|
||||
timeout=30,
|
||||
model_provider=LlmProviderNames.BEDROCK,
|
||||
model_name="anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
max_input_tokens=200000,
|
||||
)
|
||||
|
||||
messages: LanguageModelInput = [
|
||||
UserMessage(content="What's the weather?"),
|
||||
]
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with (
|
||||
patch("litellm.completion") as mock_completion,
|
||||
patch("onyx.llm.multi_llm.model_is_reasoning_model", return_value=True),
|
||||
):
|
||||
mock_completion.return_value = []
|
||||
|
||||
list(llm.stream(messages, tools=tools, reasoning_effort=ReasoningEffort.HIGH))
|
||||
|
||||
kwargs = mock_completion.call_args.kwargs
|
||||
assert "thinking" in kwargs, (
|
||||
"thinking param should be preserved when no assistant messages "
|
||||
"with tool_calls exist in history"
|
||||
)
|
||||
assert kwargs["thinking"]["type"] == "enabled"
|
||||
|
||||
|
||||
def test_bifrost_claude_includes_allowed_openai_params() -> None:
|
||||
llm = LitellmLLM(
|
||||
api_key="test_key",
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Tests for indexing pipeline Prometheus collectors."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@@ -13,6 +14,16 @@ from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
|
||||
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_broker_client() -> Iterator[None]:
|
||||
"""Patch celery_get_broker_client for all collector tests."""
|
||||
with patch(
|
||||
"onyx.background.celery.celery_redis.celery_get_broker_client",
|
||||
return_value=MagicMock(),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class TestQueueDepthCollector:
|
||||
def test_returns_empty_when_factory_not_set(self) -> None:
|
||||
collector = QueueDepthCollector()
|
||||
@@ -24,8 +35,7 @@ class TestQueueDepthCollector:
|
||||
|
||||
def test_collects_queue_depths(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
collector.set_celery_app(MagicMock())
|
||||
|
||||
with (
|
||||
patch(
|
||||
@@ -60,8 +70,8 @@ class TestQueueDepthCollector:
|
||||
|
||||
def test_handles_redis_error_gracefully(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
MagicMock()
|
||||
collector.set_celery_app(MagicMock())
|
||||
|
||||
with patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
@@ -74,8 +84,8 @@ class TestQueueDepthCollector:
|
||||
|
||||
def test_caching_returns_stale_within_ttl(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=60)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
MagicMock()
|
||||
collector.set_celery_app(MagicMock())
|
||||
|
||||
with (
|
||||
patch(
|
||||
@@ -98,31 +108,10 @@ class TestQueueDepthCollector:
|
||||
|
||||
assert first is second # Same object, from cache
|
||||
|
||||
def test_factory_called_each_scrape(self) -> None:
|
||||
"""Verify the Redis factory is called on each fresh collect, not cached."""
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
factory = MagicMock(return_value=MagicMock())
|
||||
collector.set_redis_factory(factory)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
return_value=0,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
|
||||
return_value=set(),
|
||||
),
|
||||
):
|
||||
collector.collect()
|
||||
collector.collect()
|
||||
|
||||
assert factory.call_count == 2
|
||||
|
||||
def test_error_returns_stale_cache(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
MagicMock()
|
||||
collector.set_celery_app(MagicMock())
|
||||
|
||||
# First call succeeds
|
||||
with (
|
||||
|
||||
@@ -1,96 +1,22 @@
|
||||
"""Tests for indexing pipeline setup (Redis factory caching)."""
|
||||
"""Tests for indexing pipeline setup."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from onyx.server.metrics.indexing_pipeline_setup import _make_broker_redis_factory
|
||||
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import RedisHealthCollector
|
||||
|
||||
|
||||
def _make_mock_app(client: MagicMock) -> MagicMock:
|
||||
"""Create a mock Celery app whose broker_connection().channel().client
|
||||
returns the given client."""
|
||||
mock_app = MagicMock()
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.channel.return_value.client = client
|
||||
class TestCollectorCeleryAppSetup:
|
||||
def test_queue_depth_collector_uses_celery_app(self) -> None:
|
||||
"""QueueDepthCollector.set_celery_app stores the app for broker access."""
|
||||
collector = QueueDepthCollector()
|
||||
mock_app = MagicMock()
|
||||
collector.set_celery_app(mock_app)
|
||||
assert collector._celery_app is mock_app
|
||||
|
||||
mock_app.broker_connection.return_value = mock_conn
|
||||
|
||||
return mock_app
|
||||
|
||||
|
||||
class TestMakeBrokerRedisFactory:
|
||||
def test_caches_redis_client_across_calls(self) -> None:
|
||||
"""Factory should reuse the same client on subsequent calls."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.ping.return_value = True
|
||||
mock_app = _make_mock_app(mock_client)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
|
||||
client1 = factory()
|
||||
client2 = factory()
|
||||
|
||||
assert client1 is client2
|
||||
# broker_connection should only be called once
|
||||
assert mock_app.broker_connection.call_count == 1
|
||||
|
||||
def test_reconnects_when_ping_fails(self) -> None:
|
||||
"""Factory should create a new client if ping fails (stale connection)."""
|
||||
mock_client_stale = MagicMock()
|
||||
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
|
||||
|
||||
mock_client_fresh = MagicMock()
|
||||
mock_client_fresh.ping.return_value = True
|
||||
|
||||
mock_app = _make_mock_app(mock_client_stale)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
|
||||
# First call — creates and caches
|
||||
client1 = factory()
|
||||
assert client1 is mock_client_stale
|
||||
assert mock_app.broker_connection.call_count == 1
|
||||
|
||||
# Switch to fresh client for next connection
|
||||
mock_conn_fresh = MagicMock()
|
||||
mock_conn_fresh.channel.return_value.client = mock_client_fresh
|
||||
mock_app.broker_connection.return_value = mock_conn_fresh
|
||||
|
||||
# Second call — ping fails on stale, reconnects
|
||||
client2 = factory()
|
||||
assert client2 is mock_client_fresh
|
||||
assert mock_app.broker_connection.call_count == 2
|
||||
|
||||
def test_reconnect_closes_stale_client(self) -> None:
|
||||
"""When ping fails, the old client should be closed before reconnecting."""
|
||||
mock_client_stale = MagicMock()
|
||||
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
|
||||
|
||||
mock_client_fresh = MagicMock()
|
||||
mock_client_fresh.ping.return_value = True
|
||||
|
||||
mock_app = _make_mock_app(mock_client_stale)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
|
||||
# First call — creates and caches
|
||||
factory()
|
||||
|
||||
# Switch to fresh client
|
||||
mock_conn_fresh = MagicMock()
|
||||
mock_conn_fresh.channel.return_value.client = mock_client_fresh
|
||||
mock_app.broker_connection.return_value = mock_conn_fresh
|
||||
|
||||
# Second call — ping fails, should close stale client
|
||||
factory()
|
||||
mock_client_stale.close.assert_called_once()
|
||||
|
||||
def test_first_call_creates_connection(self) -> None:
|
||||
"""First call should always create a new connection."""
|
||||
mock_client = MagicMock()
|
||||
mock_app = _make_mock_app(mock_client)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
client = factory()
|
||||
|
||||
assert client is mock_client
|
||||
mock_app.broker_connection.assert_called_once()
|
||||
def test_redis_health_collector_uses_celery_app(self) -> None:
|
||||
"""RedisHealthCollector.set_celery_app stores the app for broker access."""
|
||||
collector = RedisHealthCollector()
|
||||
mock_app = MagicMock()
|
||||
collector.set_celery_app(mock_app)
|
||||
assert collector._celery_app is mock_app
|
||||
|
||||
@@ -70,6 +70,10 @@ backend = [
|
||||
"lazy_imports==1.0.1",
|
||||
"lxml==5.3.0",
|
||||
"Mako==1.2.4",
|
||||
# NOTE: Do not update without understanding the patching behavior in
|
||||
# get_markitdown_converter in
|
||||
# backend/onyx/file_processing/extract_file_text.py and what impacts
|
||||
# updating might have on this behavior.
|
||||
"markitdown[pdf, docx, pptx, xlsx, xls]==0.1.2",
|
||||
"mcp[cli]==1.26.0",
|
||||
"msal==1.34.0",
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import { useState, useMemo, useEffect } from "react";
|
||||
import useSWR from "swr";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Select } from "@/refresh-components/cards";
|
||||
import { useCreateModal } from "@/refresh-components/contexts/ModalContext";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
@@ -24,8 +23,9 @@ import { ProviderIcon } from "@/app/admin/configuration/llm/ProviderIcon";
|
||||
import Message from "@/refresh-components/messages/Message";
|
||||
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import { Button } from "@opal/components";
|
||||
import { Button, Text } from "@opal/components";
|
||||
import { SvgSlash, SvgUnplug } from "@opal/icons";
|
||||
import { markdown } from "@opal/utils";
|
||||
|
||||
const NO_DEFAULT_VALUE = "__none__";
|
||||
|
||||
@@ -201,10 +201,10 @@ export default function ImageGenerationContent() {
|
||||
<div className="flex flex-col gap-6">
|
||||
{/* Section Header */}
|
||||
<div className="flex flex-col gap-0.5">
|
||||
<Text mainContentEmphasis text05>
|
||||
<Text font="main-content-emphasis" color="text-05">
|
||||
Image Generation Model
|
||||
</Text>
|
||||
<Text secondaryBody text03>
|
||||
<Text font="secondary-body" color="text-03">
|
||||
Select a model to generate images in chat.
|
||||
</Text>
|
||||
</div>
|
||||
@@ -223,7 +223,7 @@ export default function ImageGenerationContent() {
|
||||
{/* Provider Groups */}
|
||||
{IMAGE_PROVIDER_GROUPS.map((group) => (
|
||||
<div key={group.name} className="flex flex-col gap-2">
|
||||
<Text secondaryBody text03>
|
||||
<Text font="secondary-body" color="text-03">
|
||||
{group.name}
|
||||
</Text>
|
||||
<div className="flex flex-col gap-2">
|
||||
@@ -277,12 +277,13 @@ export default function ImageGenerationContent() {
|
||||
{needsReplacement ? (
|
||||
hasReplacements ? (
|
||||
<Section alignItems="start">
|
||||
<Text as="p" text03>
|
||||
<b>{disconnectProvider.title}</b> is currently the default
|
||||
image generation model. Session history will be preserved.
|
||||
<Text as="p" color="text-03">
|
||||
{markdown(
|
||||
`**${disconnectProvider.title}** is currently the default image generation model. Session history will be preserved.`
|
||||
)}
|
||||
</Text>
|
||||
<Section alignItems="start" gap={0.25}>
|
||||
<Text as="p" text04>
|
||||
<Text as="p" color="text-04">
|
||||
Set New Default
|
||||
</Text>
|
||||
<InputSelect
|
||||
@@ -329,22 +330,24 @@ export default function ImageGenerationContent() {
|
||||
</Section>
|
||||
) : (
|
||||
<>
|
||||
<Text as="p" text03>
|
||||
<b>{disconnectProvider.title}</b> is currently the default
|
||||
image generation model.
|
||||
<Text as="p" color="text-03">
|
||||
{markdown(
|
||||
`**${disconnectProvider.title}** is currently the default image generation model.`
|
||||
)}
|
||||
</Text>
|
||||
<Text as="p" text03>
|
||||
<Text as="p" color="text-03">
|
||||
Connect another provider to continue using image generation.
|
||||
</Text>
|
||||
</>
|
||||
)
|
||||
) : (
|
||||
<>
|
||||
<Text as="p" text03>
|
||||
<b>{disconnectProvider.title}</b> models will no longer be used
|
||||
to generate images.
|
||||
<Text as="p" color="text-03">
|
||||
{markdown(
|
||||
`**${disconnectProvider.title}** models will no longer be used to generate images.`
|
||||
)}
|
||||
</Text>
|
||||
<Text as="p" text03>
|
||||
<Text as="p" color="text-03">
|
||||
Session history will be preserved.
|
||||
</Text>
|
||||
</>
|
||||
|
||||
@@ -1,7 +1 @@
|
||||
"use client";
|
||||
|
||||
import LLMConfigurationPage from "@/refresh-pages/admin/LLMConfigurationPage";
|
||||
|
||||
export default function Page() {
|
||||
return <LLMConfigurationPage />;
|
||||
}
|
||||
export { default } from "@/refresh-pages/admin/LLMProviderConfigurationPage";
|
||||
|
||||
@@ -32,8 +32,10 @@ import {
|
||||
OpenRouterFetchParams,
|
||||
LiteLLMProxyFetchParams,
|
||||
BifrostFetchParams,
|
||||
OpenAICompatibleFetchParams,
|
||||
OpenAICompatibleModelResponse,
|
||||
} from "@/interfaces/llm";
|
||||
import { SvgAws, SvgBifrost, SvgOpenrouter } from "@opal/icons";
|
||||
import { SvgAws, SvgBifrost, SvgOpenrouter, SvgPlug } from "@opal/icons";
|
||||
|
||||
// Aggregator providers that host models from multiple vendors
|
||||
export const AGGREGATOR_PROVIDERS = new Set([
|
||||
@@ -44,6 +46,7 @@ export const AGGREGATOR_PROVIDERS = new Set([
|
||||
"lm_studio",
|
||||
"litellm_proxy",
|
||||
"bifrost",
|
||||
"openai_compatible",
|
||||
"vertex_ai",
|
||||
]);
|
||||
|
||||
@@ -82,6 +85,7 @@ export const getProviderIcon = (
|
||||
openrouter: SvgOpenrouter,
|
||||
litellm_proxy: LiteLLMIcon,
|
||||
bifrost: SvgBifrost,
|
||||
openai_compatible: SvgPlug,
|
||||
vertex_ai: GeminiIcon,
|
||||
};
|
||||
|
||||
@@ -411,6 +415,64 @@ export const fetchBifrostModels = async (
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches models from a generic OpenAI-compatible server.
|
||||
* Uses snake_case params to match API structure.
|
||||
*/
|
||||
export const fetchOpenAICompatibleModels = async (
|
||||
params: OpenAICompatibleFetchParams
|
||||
): Promise<{ models: ModelConfiguration[]; error?: string }> => {
|
||||
const apiBase = params.api_base;
|
||||
if (!apiBase) {
|
||||
return { models: [], error: "API Base is required" };
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(
|
||||
"/api/admin/llm/openai-compatible/available-models",
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
api_base: apiBase,
|
||||
api_key: params.api_key,
|
||||
provider_name: params.provider_name,
|
||||
}),
|
||||
signal: params.signal,
|
||||
}
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
let errorMessage = "Failed to fetch models";
|
||||
try {
|
||||
const errorData = await response.json();
|
||||
errorMessage = errorData.detail || errorData.message || errorMessage;
|
||||
} catch {
|
||||
// ignore JSON parsing errors
|
||||
}
|
||||
return { models: [], error: errorMessage };
|
||||
}
|
||||
|
||||
const data: OpenAICompatibleModelResponse[] = await response.json();
|
||||
const models: ModelConfiguration[] = data.map((modelData) => ({
|
||||
name: modelData.name,
|
||||
display_name: modelData.display_name,
|
||||
is_visible: true,
|
||||
max_input_tokens: modelData.max_input_tokens,
|
||||
supports_image_input: modelData.supports_image_input,
|
||||
supports_reasoning: modelData.supports_reasoning,
|
||||
}));
|
||||
|
||||
return { models };
|
||||
} catch (error) {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : "Unknown error";
|
||||
return { models: [], error: errorMessage };
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches LiteLLM Proxy models directly without any form state dependencies.
|
||||
* Uses snake_case params to match API structure.
|
||||
@@ -531,6 +593,13 @@ export const fetchModels = async (
|
||||
provider_name: formValues.name,
|
||||
signal,
|
||||
});
|
||||
case LLMProviderName.OPENAI_COMPATIBLE:
|
||||
return fetchOpenAICompatibleModels({
|
||||
api_base: formValues.api_base,
|
||||
api_key: formValues.api_key,
|
||||
provider_name: formValues.name,
|
||||
signal,
|
||||
});
|
||||
default:
|
||||
return { models: [], error: `Unknown provider: ${providerName}` };
|
||||
}
|
||||
@@ -545,6 +614,7 @@ export function canProviderFetchModels(providerName?: string) {
|
||||
case LLMProviderName.OPENROUTER:
|
||||
case LLMProviderName.LITELLM_PROXY:
|
||||
case LLMProviderName.BIFROST:
|
||||
case LLMProviderName.OPENAI_COMPATIBLE:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
||||
@@ -15,7 +15,7 @@ import { Callout } from "@/components/ui/callout";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { SvgGlobe, SvgOnyxLogo, SvgSlash, SvgUnplug } from "@opal/icons";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
import { Button } from "@opal/components";
|
||||
import { ADMIN_ROUTES } from "@/lib/admin-routes";
|
||||
import { WebProviderSetupModal } from "@/app/admin/configuration/web-search/WebProviderSetupModal";
|
||||
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
|
||||
@@ -151,7 +151,7 @@ function WebSearchDisconnectModal({
|
||||
description="This will remove the stored credentials for this provider."
|
||||
onClose={onClose}
|
||||
submit={
|
||||
<OpalButton
|
||||
<Button
|
||||
variant="danger"
|
||||
onClick={onDisconnect}
|
||||
disabled={
|
||||
@@ -159,7 +159,7 @@ function WebSearchDisconnectModal({
|
||||
}
|
||||
>
|
||||
Disconnect
|
||||
</OpalButton>
|
||||
</Button>
|
||||
}
|
||||
>
|
||||
{needsReplacement ? (
|
||||
|
||||
@@ -24,7 +24,6 @@ import {
|
||||
} from "@/app/craft/onboarding/constants";
|
||||
import { LLMProviderDescriptor } from "@/interfaces/llm";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "@/lib/llmConfig/constants";
|
||||
import { buildOnboardingInitialValues as buildInitialValues } from "@/sections/modals/llmConfig/utils";
|
||||
import { testApiKeyHelper } from "@/sections/modals/llmConfig/svc";
|
||||
import OnboardingInfoPages from "@/app/craft/onboarding/components/OnboardingInfoPages";
|
||||
import OnboardingUserInfo from "@/app/craft/onboarding/components/OnboardingUserInfo";
|
||||
@@ -221,10 +220,8 @@ export default function BuildOnboardingModal({
|
||||
setConnectionStatus("testing");
|
||||
setErrorMessage("");
|
||||
|
||||
const baseValues = buildInitialValues();
|
||||
const providerName = `build-mode-${currentProviderConfig.providerName}`;
|
||||
const payload = {
|
||||
...baseValues,
|
||||
name: providerName,
|
||||
provider: currentProviderConfig.providerName,
|
||||
api_key: apiKey,
|
||||
|
||||
@@ -133,7 +133,7 @@ async function createFederatedConnector(
|
||||
|
||||
async function updateFederatedConnector(
|
||||
id: number,
|
||||
credentials: CredentialForm,
|
||||
credentials: CredentialForm | null,
|
||||
config?: ConfigForm
|
||||
): Promise<{ success: boolean; message: string }> {
|
||||
try {
|
||||
@@ -143,7 +143,7 @@ async function updateFederatedConnector(
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
credentials,
|
||||
credentials: credentials ?? undefined,
|
||||
config: config || {},
|
||||
}),
|
||||
});
|
||||
@@ -201,7 +201,9 @@ export function FederatedConnectorForm({
|
||||
const isEditMode = connectorId !== undefined;
|
||||
|
||||
const [formState, setFormState] = useState<FormState>({
|
||||
credentials: preloadedConnectorData?.credentials || {},
|
||||
// In edit mode, don't populate credentials with masked values from the API.
|
||||
// Masked values (e.g. "••••••••••••") would be saved back and corrupt the real credentials.
|
||||
credentials: isEditMode ? {} : preloadedConnectorData?.credentials || {},
|
||||
config: preloadedConnectorData?.config || {},
|
||||
schema: preloadedCredentialSchema?.credentials || null,
|
||||
configurationSchema: null,
|
||||
@@ -209,6 +211,7 @@ export function FederatedConnectorForm({
|
||||
configurationSchemaError: null,
|
||||
connectorError: null,
|
||||
});
|
||||
const [credentialsModified, setCredentialsModified] = useState(false);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [submitMessage, setSubmitMessage] = useState<string | null>(null);
|
||||
const [submitSuccess, setSubmitSuccess] = useState<boolean | null>(null);
|
||||
@@ -333,6 +336,7 @@ export function FederatedConnectorForm({
|
||||
}
|
||||
|
||||
const handleCredentialChange = (key: string, value: string) => {
|
||||
setCredentialsModified(true);
|
||||
setFormState((prev) => ({
|
||||
...prev,
|
||||
credentials: {
|
||||
@@ -354,6 +358,11 @@ export function FederatedConnectorForm({
|
||||
|
||||
const handleValidateCredentials = async () => {
|
||||
if (!formState.schema) return;
|
||||
if (isEditMode && !credentialsModified) {
|
||||
setSubmitMessage("Enter new credential values before validating.");
|
||||
setSubmitSuccess(false);
|
||||
return;
|
||||
}
|
||||
|
||||
setIsValidating(true);
|
||||
setSubmitMessage(null);
|
||||
@@ -411,8 +420,10 @@ export function FederatedConnectorForm({
|
||||
setSubmitSuccess(null);
|
||||
|
||||
try {
|
||||
// Validate required fields
|
||||
if (formState.schema) {
|
||||
const shouldValidateCredentials = !isEditMode || credentialsModified;
|
||||
|
||||
// Validate required fields (skip for credentials in edit mode when unchanged)
|
||||
if (formState.schema && shouldValidateCredentials) {
|
||||
const missingRequired = Object.entries(formState.schema)
|
||||
.filter(
|
||||
([key, field]) => field.required && !formState.credentials[key]
|
||||
@@ -442,16 +453,20 @@ export function FederatedConnectorForm({
|
||||
}
|
||||
setConfigValidationErrors({});
|
||||
|
||||
// Validate credentials before creating/updating
|
||||
const validation = await validateCredentials(
|
||||
connector,
|
||||
formState.credentials
|
||||
);
|
||||
if (!validation.success) {
|
||||
setSubmitMessage(`Credential validation failed: ${validation.message}`);
|
||||
setSubmitSuccess(false);
|
||||
setIsSubmitting(false);
|
||||
return;
|
||||
// Validate credentials before creating/updating (skip in edit mode when unchanged)
|
||||
if (shouldValidateCredentials) {
|
||||
const validation = await validateCredentials(
|
||||
connector,
|
||||
formState.credentials
|
||||
);
|
||||
if (!validation.success) {
|
||||
setSubmitMessage(
|
||||
`Credential validation failed: ${validation.message}`
|
||||
);
|
||||
setSubmitSuccess(false);
|
||||
setIsSubmitting(false);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Create or update the connector
|
||||
@@ -459,7 +474,7 @@ export function FederatedConnectorForm({
|
||||
isEditMode && connectorId
|
||||
? await updateFederatedConnector(
|
||||
connectorId,
|
||||
formState.credentials,
|
||||
credentialsModified ? formState.credentials : null,
|
||||
formState.config
|
||||
)
|
||||
: await createFederatedConnector(
|
||||
@@ -538,14 +553,16 @@ export function FederatedConnectorForm({
|
||||
id={fieldKey}
|
||||
type={fieldSpec.secret ? "password" : "text"}
|
||||
placeholder={
|
||||
fieldSpec.example
|
||||
? String(fieldSpec.example)
|
||||
: fieldSpec.description
|
||||
isEditMode && !credentialsModified
|
||||
? "•••••••• (leave blank to keep current value)"
|
||||
: fieldSpec.example
|
||||
? String(fieldSpec.example)
|
||||
: fieldSpec.description
|
||||
}
|
||||
value={formState.credentials[fieldKey] || ""}
|
||||
onChange={(e) => handleCredentialChange(fieldKey, e.target.value)}
|
||||
className="w-96"
|
||||
required={fieldSpec.required}
|
||||
required={!isEditMode && fieldSpec.required}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
|
||||
@@ -4,6 +4,7 @@ import useSWR from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import {
|
||||
LLMProviderDescriptor,
|
||||
LLMProviderName,
|
||||
LLMProviderResponse,
|
||||
LLMProviderView,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
@@ -136,14 +137,12 @@ export function useAdminLLMProviders() {
|
||||
* Used inside individual provider modals to pre-populate model lists
|
||||
* before the user has entered credentials.
|
||||
*
|
||||
* @param providerEndpoint - The provider's API endpoint name (e.g. "openai", "anthropic").
|
||||
* @param providerName - The provider's API endpoint name (e.g. "openai", "anthropic").
|
||||
* Pass `null` to suppress the request.
|
||||
*/
|
||||
export function useWellKnownLLMProvider(providerEndpoint: string | null) {
|
||||
export function useWellKnownLLMProvider(providerName: LLMProviderName) {
|
||||
const { data, error, isLoading } = useSWR<WellKnownLLMProviderDescriptor>(
|
||||
providerEndpoint
|
||||
? `/api/admin/llm/built-in/options/${providerEndpoint}`
|
||||
: null,
|
||||
providerName ? `/api/admin/llm/built-in/options/${providerName}` : null,
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
|
||||
@@ -6,6 +6,8 @@ import { INTERNAL_URL, IS_DEV } from "@/lib/constants";
|
||||
const TARGET_SAMPLE_RATE = 24000;
|
||||
const CHUNK_INTERVAL_MS = 250;
|
||||
const DUPLICATE_FINAL_TRANSCRIPT_WINDOW_MS = 1500;
|
||||
// When VAD-based auto-stop is disabled, force-stop after this much silence as a fallback
|
||||
const SILENCE_FALLBACK_TIMEOUT_MS = 10000;
|
||||
|
||||
interface TranscriptMessage {
|
||||
type: "transcript" | "error";
|
||||
@@ -58,6 +60,8 @@ class VoiceRecorderSession {
|
||||
private finalTranscriptDelivered = false;
|
||||
private lastDeliveredFinalText: string | null = null;
|
||||
private lastDeliveredFinalAtMs = 0;
|
||||
// Fallback timer: force-stop after extended silence when VAD auto-stop is disabled
|
||||
private silenceFallbackTimer: NodeJS.Timeout | null = null;
|
||||
|
||||
// Callbacks to update React state
|
||||
private onTranscriptChange: (text: string) => void;
|
||||
@@ -174,6 +178,8 @@ class VoiceRecorderSession {
|
||||
async stop(): Promise<string | null> {
|
||||
if (!this.isActive) return this.transcript || null;
|
||||
|
||||
this.resetSilenceFallbackTimer();
|
||||
|
||||
// Stop audio capture
|
||||
if (this.sendInterval) {
|
||||
clearInterval(this.sendInterval);
|
||||
@@ -219,6 +225,7 @@ class VoiceRecorderSession {
|
||||
}
|
||||
|
||||
cleanup(): void {
|
||||
this.resetSilenceFallbackTimer();
|
||||
if (this.sendInterval) clearInterval(this.sendInterval);
|
||||
if (this.scriptNode) this.scriptNode.disconnect();
|
||||
if (this.sourceNode) this.sourceNode.disconnect();
|
||||
@@ -274,6 +281,23 @@ class VoiceRecorderSession {
|
||||
});
|
||||
}
|
||||
|
||||
private resetSilenceFallbackTimer(): void {
|
||||
if (this.silenceFallbackTimer) {
|
||||
clearTimeout(this.silenceFallbackTimer);
|
||||
this.silenceFallbackTimer = null;
|
||||
}
|
||||
}
|
||||
|
||||
private startSilenceFallbackTimer(): void {
|
||||
this.resetSilenceFallbackTimer();
|
||||
this.silenceFallbackTimer = setTimeout(() => {
|
||||
// 10s of silence with no new speech — force-stop as a safety fallback
|
||||
if (this.isActive && this.onVADStop) {
|
||||
this.onVADStop();
|
||||
}
|
||||
}, SILENCE_FALLBACK_TIMEOUT_MS);
|
||||
}
|
||||
|
||||
private handleMessage = (event: MessageEvent): void => {
|
||||
try {
|
||||
const data: TranscriptMessage = JSON.parse(event.data);
|
||||
@@ -281,47 +305,53 @@ class VoiceRecorderSession {
|
||||
if (data.type === "transcript") {
|
||||
if (data.text) {
|
||||
this.transcript = data.text;
|
||||
this.onTranscriptChange(data.text);
|
||||
// Only push live updates to React while actively recording.
|
||||
// After stop(), the final transcript is returned via stopResolver
|
||||
// instead — this prevents stale text from reappearing in the
|
||||
// input box when the user clears it and starts a new recording.
|
||||
if (this.isActive) {
|
||||
this.onTranscriptChange(data.text);
|
||||
}
|
||||
}
|
||||
|
||||
if (data.is_final && data.text) {
|
||||
// VAD detected silence - trigger callback (only once per utterance)
|
||||
const now = Date.now();
|
||||
const isLikelyDuplicateFinal =
|
||||
this.autoStopOnSilence &&
|
||||
this.lastDeliveredFinalText === data.text &&
|
||||
now - this.lastDeliveredFinalAtMs <
|
||||
DUPLICATE_FINAL_TRANSCRIPT_WINDOW_MS;
|
||||
|
||||
if (
|
||||
this.onFinalTranscript &&
|
||||
!this.finalTranscriptDelivered &&
|
||||
!isLikelyDuplicateFinal
|
||||
) {
|
||||
this.finalTranscriptDelivered = true;
|
||||
this.lastDeliveredFinalText = data.text;
|
||||
this.lastDeliveredFinalAtMs = now;
|
||||
this.onFinalTranscript(data.text);
|
||||
// Resolve stop promise if waiting — must run even after stop()
|
||||
// so the caller receives the final transcript.
|
||||
if (this.stopResolver) {
|
||||
this.stopResolver(data.text);
|
||||
this.stopResolver = null;
|
||||
}
|
||||
|
||||
// Auto-stop recording if enabled
|
||||
// Skip VAD logic if session is no longer active
|
||||
if (!this.isActive) return;
|
||||
|
||||
if (this.autoStopOnSilence) {
|
||||
// Trigger stop callback to update React state
|
||||
// VAD detected silence — auto-stop and trigger callback
|
||||
const now = Date.now();
|
||||
const isLikelyDuplicateFinal =
|
||||
this.lastDeliveredFinalText === data.text &&
|
||||
now - this.lastDeliveredFinalAtMs <
|
||||
DUPLICATE_FINAL_TRANSCRIPT_WINDOW_MS;
|
||||
|
||||
if (
|
||||
this.onFinalTranscript &&
|
||||
!this.finalTranscriptDelivered &&
|
||||
!isLikelyDuplicateFinal
|
||||
) {
|
||||
this.finalTranscriptDelivered = true;
|
||||
this.lastDeliveredFinalText = data.text;
|
||||
this.lastDeliveredFinalAtMs = now;
|
||||
this.onFinalTranscript(data.text);
|
||||
}
|
||||
|
||||
if (this.onVADStop) {
|
||||
this.onVADStop();
|
||||
}
|
||||
} else {
|
||||
// If not auto-stopping, reset for next utterance
|
||||
this.transcript = "";
|
||||
this.finalTranscriptDelivered = false;
|
||||
this.onTranscriptChange("");
|
||||
this.resetBackendTranscript();
|
||||
}
|
||||
|
||||
// Resolve stop promise if waiting
|
||||
if (this.stopResolver) {
|
||||
this.stopResolver(data.text);
|
||||
this.stopResolver = null;
|
||||
// Auto-stop disabled (push-to-talk): ignore VAD, keep recording.
|
||||
// Start/reset a 10s fallback timer — if no new speech arrives,
|
||||
// force-stop to avoid recording silence indefinitely.
|
||||
this.startSilenceFallbackTimer();
|
||||
}
|
||||
}
|
||||
} else if (data.type === "error") {
|
||||
|
||||
@@ -14,6 +14,7 @@ export enum LLMProviderName {
|
||||
BEDROCK = "bedrock",
|
||||
LITELLM_PROXY = "litellm_proxy",
|
||||
BIFROST = "bifrost",
|
||||
OPENAI_COMPATIBLE = "openai_compatible",
|
||||
CUSTOM = "custom",
|
||||
}
|
||||
|
||||
@@ -123,14 +124,11 @@ export interface LLMProviderFormProps {
|
||||
shouldMarkAsDefault?: boolean;
|
||||
open?: boolean;
|
||||
onOpenChange?: (open: boolean) => void;
|
||||
|
||||
/** The current default model name for this provider (from the global default). */
|
||||
defaultModelName?: string;
|
||||
/** Called after successful provider creation/update. */
|
||||
onSuccess?: () => void | Promise<void>;
|
||||
|
||||
// Onboarding-specific (only when variant === "onboarding")
|
||||
onboardingState?: OnboardingState;
|
||||
onboardingActions?: OnboardingActions;
|
||||
llmDescriptor?: WellKnownLLMProviderDescriptor;
|
||||
}
|
||||
|
||||
// Param types for model fetching functions - use snake_case to match API structure
|
||||
@@ -181,6 +179,21 @@ export interface BifrostModelResponse {
|
||||
supports_reasoning: boolean;
|
||||
}
|
||||
|
||||
export interface OpenAICompatibleFetchParams {
|
||||
api_base?: string;
|
||||
api_key?: string;
|
||||
provider_name?: string;
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
|
||||
export interface OpenAICompatibleModelResponse {
|
||||
name: string;
|
||||
display_name: string;
|
||||
max_input_tokens: number | null;
|
||||
supports_image_input: boolean;
|
||||
supports_reasoning: boolean;
|
||||
}
|
||||
|
||||
export interface VertexAIFetchParams {
|
||||
model_configurations?: ModelConfiguration[];
|
||||
}
|
||||
@@ -199,5 +212,6 @@ export type FetchModelsParams =
|
||||
| OpenRouterFetchParams
|
||||
| LiteLLMProxyFetchParams
|
||||
| BifrostFetchParams
|
||||
| OpenAICompatibleFetchParams
|
||||
| VertexAIFetchParams
|
||||
| LMStudioFetchParams;
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"use client";
|
||||
|
||||
import type { RichStr } from "@opal/types";
|
||||
import type { RichStr, WithoutStyles } from "@opal/types";
|
||||
import { resolveStr } from "@opal/components/text/InlineMarkdown";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import { SvgXOctagon, SvgAlertCircle } from "@opal/icons";
|
||||
import { useField, useFormikContext } from "formik";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
@@ -234,9 +235,27 @@ function ErrorTextLayout({ children, type = "error" }: ErrorTextLayoutProps) {
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* FieldSeparator - A horizontal rule with inline padding, used to visually separate field groups.
|
||||
*/
|
||||
function FieldSeparator() {
|
||||
return <Separator noPadding className="p-2" />;
|
||||
}
|
||||
|
||||
/**
|
||||
* FieldPadder - Wraps a field in standard horizontal + vertical padding (`p-2 w-full`).
|
||||
*/
|
||||
type FieldPadderProps = WithoutStyles<React.HTMLAttributes<HTMLDivElement>>;
|
||||
function FieldPadder(props: FieldPadderProps) {
|
||||
return <div {...props} className="p-2 w-full" />;
|
||||
}
|
||||
|
||||
export {
|
||||
VerticalInputLayout as Vertical,
|
||||
HorizontalInputLayout as Horizontal,
|
||||
ErrorLayout as Error,
|
||||
ErrorTextLayout,
|
||||
FieldSeparator,
|
||||
FieldPadder,
|
||||
type FieldPadderProps,
|
||||
};
|
||||
|
||||
@@ -8,6 +8,7 @@ import {
|
||||
SvgCloud,
|
||||
SvgAws,
|
||||
SvgOpenrouter,
|
||||
SvgPlug,
|
||||
SvgServer,
|
||||
SvgAzure,
|
||||
SvgGemini,
|
||||
@@ -28,6 +29,7 @@ const PROVIDER_ICONS: Record<string, IconFunctionComponent> = {
|
||||
[LLMProviderName.OPENROUTER]: SvgOpenrouter,
|
||||
[LLMProviderName.LM_STUDIO]: SvgLmStudio,
|
||||
[LLMProviderName.BIFROST]: SvgBifrost,
|
||||
[LLMProviderName.OPENAI_COMPATIBLE]: SvgPlug,
|
||||
|
||||
// fallback
|
||||
[LLMProviderName.CUSTOM]: SvgServer,
|
||||
@@ -45,6 +47,7 @@ const PROVIDER_PRODUCT_NAMES: Record<string, string> = {
|
||||
[LLMProviderName.OPENROUTER]: "OpenRouter",
|
||||
[LLMProviderName.LM_STUDIO]: "LM Studio",
|
||||
[LLMProviderName.BIFROST]: "Bifrost",
|
||||
[LLMProviderName.OPENAI_COMPATIBLE]: "OpenAI Compatible",
|
||||
|
||||
// fallback
|
||||
[LLMProviderName.CUSTOM]: "Custom Models",
|
||||
@@ -62,6 +65,7 @@ const PROVIDER_DISPLAY_NAMES: Record<string, string> = {
|
||||
[LLMProviderName.OPENROUTER]: "OpenRouter",
|
||||
[LLMProviderName.LM_STUDIO]: "LM Studio",
|
||||
[LLMProviderName.BIFROST]: "Bifrost",
|
||||
[LLMProviderName.OPENAI_COMPATIBLE]: "OpenAI Compatible",
|
||||
|
||||
// fallback
|
||||
[LLMProviderName.CUSTOM]: "Other providers or self-hosted",
|
||||
|
||||
@@ -142,6 +142,7 @@ function PopoverContent({
|
||||
collisionPadding={8}
|
||||
className={cn(
|
||||
"bg-background-neutral-00 p-1 z-popover rounded-12 border shadow-md data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2",
|
||||
"flex flex-col",
|
||||
"max-h-[var(--radix-popover-content-available-height)]",
|
||||
"overflow-hidden",
|
||||
widthClasses[width]
|
||||
@@ -226,7 +227,7 @@ export function PopoverMenu({
|
||||
});
|
||||
|
||||
return (
|
||||
<Section alignItems="stretch">
|
||||
<Section alignItems="stretch" height="auto" className="flex-1 min-h-0">
|
||||
<ShadowDiv
|
||||
scrollContainerRef={scrollContainerRef}
|
||||
className="flex flex-col gap-1 max-h-[20rem] w-full"
|
||||
|
||||
@@ -105,7 +105,7 @@ export default function ShadowDiv({
|
||||
}, [containerRef, checkScroll]);
|
||||
|
||||
return (
|
||||
<div className="relative min-h-0">
|
||||
<div className="relative min-h-0 flex flex-col">
|
||||
<div
|
||||
ref={containerRef}
|
||||
className={cn("overflow-y-auto", className)}
|
||||
|
||||
@@ -984,8 +984,8 @@ function ChatPreferencesSettings() {
|
||||
/>
|
||||
<Card>
|
||||
<InputLayouts.Horizontal
|
||||
title="Auto-Send"
|
||||
description="Automatically send voice input when recording stops."
|
||||
title="Auto-Send on Pause"
|
||||
description="Automatically send voice input when you stop speaking."
|
||||
>
|
||||
<Switch
|
||||
checked={user?.preferences.voice_auto_send ?? false}
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { useMemo, useState } from "react";
|
||||
import { useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import useSWR from "swr";
|
||||
import { Table, Button } from "@opal/components";
|
||||
import { IllustrationContent } from "@opal/layouts";
|
||||
import { SvgUsers } from "@opal/icons";
|
||||
@@ -14,16 +13,14 @@ import Text from "@/refresh-components/texts/Text";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import useAdminUsers from "@/hooks/useAdminUsers";
|
||||
import type { ApiKeyDescriptor, MemberRow } from "./interfaces";
|
||||
import useGroupMemberCandidates from "./useGroupMemberCandidates";
|
||||
import {
|
||||
createGroup,
|
||||
updateAgentGroupSharing,
|
||||
updateDocSetGroupSharing,
|
||||
saveTokenLimits,
|
||||
} from "./svc";
|
||||
import { apiKeyToMemberRow, memberTableColumns, PAGE_SIZE } from "./shared";
|
||||
import { memberTableColumns, PAGE_SIZE } from "./shared";
|
||||
import SharedGroupResources from "@/refresh-pages/admin/GroupsPage/SharedGroupResources";
|
||||
import TokenLimitSection from "./TokenLimitSection";
|
||||
import type { TokenLimit } from "./TokenLimitSection";
|
||||
@@ -41,22 +38,7 @@ function CreateGroupPage() {
|
||||
{ tokenBudget: null, periodHours: null },
|
||||
]);
|
||||
|
||||
const { users, isLoading: usersLoading, error: usersError } = useAdminUsers();
|
||||
|
||||
const {
|
||||
data: apiKeys,
|
||||
isLoading: apiKeysLoading,
|
||||
error: apiKeysError,
|
||||
} = useSWR<ApiKeyDescriptor[]>("/api/admin/api-key", errorHandlingFetcher);
|
||||
|
||||
const isLoading = usersLoading || apiKeysLoading;
|
||||
const error = usersError ?? apiKeysError;
|
||||
|
||||
const allRows: MemberRow[] = useMemo(() => {
|
||||
const activeUsers = users.filter((u) => u.is_active);
|
||||
const serviceAccountRows = (apiKeys ?? []).map(apiKeyToMemberRow);
|
||||
return [...activeUsers, ...serviceAccountRows];
|
||||
}, [users, apiKeys]);
|
||||
const { rows: allRows, isLoading, error } = useGroupMemberCandidates();
|
||||
|
||||
async function handleCreate() {
|
||||
const trimmed = groupName.trim();
|
||||
@@ -133,11 +115,11 @@ function CreateGroupPage() {
|
||||
{/* Members table */}
|
||||
{isLoading && <SimpleLoader />}
|
||||
|
||||
{error && (
|
||||
{error ? (
|
||||
<Text as="p" secondaryBody text03>
|
||||
Failed to load users.
|
||||
</Text>
|
||||
)}
|
||||
) : null}
|
||||
|
||||
{!isLoading && !error && (
|
||||
<Section
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import useSWR, { useSWRConfig } from "swr";
|
||||
import useGroupMemberCandidates from "./useGroupMemberCandidates";
|
||||
import { Table, Button } from "@opal/components";
|
||||
import { IllustrationContent } from "@opal/layouts";
|
||||
import { SvgUsers, SvgTrash, SvgMinusCircle, SvgPlusCircle } from "@opal/icons";
|
||||
@@ -19,20 +20,9 @@ import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationMo
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import useAdminUsers from "@/hooks/useAdminUsers";
|
||||
import type { UserGroup } from "@/lib/types";
|
||||
import type {
|
||||
ApiKeyDescriptor,
|
||||
MemberRow,
|
||||
TokenRateLimitDisplay,
|
||||
} from "./interfaces";
|
||||
import {
|
||||
apiKeyToMemberRow,
|
||||
baseColumns,
|
||||
memberTableColumns,
|
||||
tc,
|
||||
PAGE_SIZE,
|
||||
} from "./shared";
|
||||
import type { MemberRow, TokenRateLimitDisplay } from "./interfaces";
|
||||
import { baseColumns, memberTableColumns, tc, PAGE_SIZE } from "./shared";
|
||||
import {
|
||||
USER_GROUP_URL,
|
||||
renameGroup,
|
||||
@@ -104,18 +94,15 @@ function EditGroupPage({ groupId }: EditGroupPageProps) {
|
||||
const initialAgentIdsRef = useRef<number[]>([]);
|
||||
const initialDocSetIdsRef = useRef<number[]>([]);
|
||||
|
||||
// Users and API keys
|
||||
const { users, isLoading: usersLoading, error: usersError } = useAdminUsers();
|
||||
|
||||
// Users + service accounts (curator-accessible — see hook docs).
|
||||
const {
|
||||
data: apiKeys,
|
||||
isLoading: apiKeysLoading,
|
||||
error: apiKeysError,
|
||||
} = useSWR<ApiKeyDescriptor[]>("/api/admin/api-key", errorHandlingFetcher);
|
||||
rows: allRows,
|
||||
isLoading: candidatesLoading,
|
||||
error: candidatesError,
|
||||
} = useGroupMemberCandidates();
|
||||
|
||||
const isLoading =
|
||||
groupLoading || usersLoading || apiKeysLoading || tokenLimitsLoading;
|
||||
const error = groupError ?? usersError ?? apiKeysError;
|
||||
const isLoading = groupLoading || candidatesLoading || tokenLimitsLoading;
|
||||
const error = groupError ?? candidatesError;
|
||||
|
||||
// Pre-populate form when group data loads
|
||||
useEffect(() => {
|
||||
@@ -145,12 +132,6 @@ function EditGroupPage({ groupId }: EditGroupPageProps) {
|
||||
}
|
||||
}, [tokenRateLimits]);
|
||||
|
||||
const allRows = useMemo(() => {
|
||||
const activeUsers = users.filter((u) => u.is_active);
|
||||
const serviceAccountRows = (apiKeys ?? []).map(apiKeyToMemberRow);
|
||||
return [...activeUsers, ...serviceAccountRows];
|
||||
}, [users, apiKeys]);
|
||||
|
||||
const memberRows = useMemo(() => {
|
||||
const selected = new Set(selectedUserIds);
|
||||
return allRows.filter((r) => selected.has(r.id ?? r.email));
|
||||
|
||||
@@ -0,0 +1,161 @@
|
||||
"use client";
|
||||
|
||||
import { useMemo } from "react";
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
|
||||
// Curator-accessible listing of all users (and service-account entries via
|
||||
// `?include_api_keys=true`). The admin-only `/manage/users/accepted/all` and
|
||||
// `/manage/users/invited` endpoints 403 for global curators, which used to
|
||||
// break the Edit Group page entirely — see useGroupMemberCandidates docs.
|
||||
const GROUP_MEMBER_CANDIDATES_URL = "/api/manage/users?include_api_keys=true";
|
||||
const ADMIN_API_KEYS_URL = "/api/admin/api-key";
|
||||
import { UserStatus, type UserRole } from "@/lib/types";
|
||||
|
||||
// Mirrors `DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN` on the backend; service-account
|
||||
// users are identified by this email suffix because release/v3.1 does not yet
|
||||
// expose `account_type` on the FullUserSnapshot returned from `/manage/users`.
|
||||
const API_KEY_EMAIL_SUFFIX = "@onyxapikey.ai";
|
||||
|
||||
function isApiKeyEmail(email: string): boolean {
|
||||
return email.endsWith(API_KEY_EMAIL_SUFFIX);
|
||||
}
|
||||
import type {
|
||||
UserGroupInfo,
|
||||
UserRow,
|
||||
} from "@/refresh-pages/admin/UsersPage/interfaces";
|
||||
import type { ApiKeyDescriptor, MemberRow } from "./interfaces";
|
||||
|
||||
// Backend response shape for `/api/manage/users?include_api_keys=true`. The
|
||||
// existing `AllUsersResponse` in `lib/types.ts` types `accepted` as `User[]`,
|
||||
// which is missing fields the table needs (`personal_name`, `account_type`,
|
||||
// `groups`, etc.), so we declare an accurate local type here.
|
||||
interface FullUserSnapshot {
|
||||
id: string;
|
||||
email: string;
|
||||
role: UserRole;
|
||||
is_active: boolean;
|
||||
password_configured: boolean;
|
||||
personal_name: string | null;
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
groups: UserGroupInfo[];
|
||||
is_scim_synced: boolean;
|
||||
}
|
||||
|
||||
interface ManageUsersResponse {
|
||||
accepted: FullUserSnapshot[];
|
||||
invited: { email: string }[];
|
||||
slack_users: FullUserSnapshot[];
|
||||
accepted_pages: number;
|
||||
invited_pages: number;
|
||||
slack_users_pages: number;
|
||||
}
|
||||
|
||||
function snapshotToMemberRow(snapshot: FullUserSnapshot): MemberRow {
|
||||
return {
|
||||
id: snapshot.id,
|
||||
email: snapshot.email,
|
||||
role: snapshot.role,
|
||||
status: snapshot.is_active ? UserStatus.ACTIVE : UserStatus.INACTIVE,
|
||||
is_active: snapshot.is_active,
|
||||
is_scim_synced: snapshot.is_scim_synced,
|
||||
personal_name: snapshot.personal_name,
|
||||
created_at: snapshot.created_at,
|
||||
updated_at: snapshot.updated_at,
|
||||
groups: snapshot.groups,
|
||||
};
|
||||
}
|
||||
|
||||
function serviceAccountToMemberRow(
|
||||
snapshot: FullUserSnapshot,
|
||||
apiKey: ApiKeyDescriptor | undefined
|
||||
): MemberRow {
|
||||
return {
|
||||
id: snapshot.id,
|
||||
email: "Service Account",
|
||||
role: apiKey?.api_key_role ?? snapshot.role,
|
||||
status: UserStatus.ACTIVE,
|
||||
is_active: true,
|
||||
is_scim_synced: false,
|
||||
personal_name:
|
||||
apiKey?.api_key_name ?? snapshot.personal_name ?? "Unnamed Key",
|
||||
created_at: null,
|
||||
updated_at: null,
|
||||
groups: [],
|
||||
api_key_display: apiKey?.api_key_display,
|
||||
};
|
||||
}
|
||||
|
||||
interface UseGroupMemberCandidatesResult {
|
||||
/** Active users + service-account rows, in the order the table expects. */
|
||||
rows: MemberRow[];
|
||||
/** Subset of `rows` representing real (non-service-account) users. */
|
||||
userRows: MemberRow[];
|
||||
isLoading: boolean;
|
||||
error: unknown;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the candidate list for the group create/edit member pickers.
|
||||
*
|
||||
* Hits `/api/manage/users?include_api_keys=true`, which is gated by
|
||||
* `current_curator_or_admin_user` on the backend, so this works for both
|
||||
* admins and global curators (the admin-only `/accepted/all` and `/invited`
|
||||
* endpoints used to be called here, which 403'd for global curators and broke
|
||||
* the Edit Group page entirely).
|
||||
*
|
||||
* For admins, we additionally fetch `/admin/api-key` to enrich service-account
|
||||
* rows with the masked api-key display string. That call is admin-only and is
|
||||
* skipped for curators; its failure is non-fatal.
|
||||
*/
|
||||
export default function useGroupMemberCandidates(): UseGroupMemberCandidatesResult {
|
||||
const { isAdmin } = useUser();
|
||||
|
||||
const {
|
||||
data: usersData,
|
||||
isLoading: usersLoading,
|
||||
error: usersError,
|
||||
} = useSWR<ManageUsersResponse>(
|
||||
GROUP_MEMBER_CANDIDATES_URL,
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
||||
const { data: apiKeys, isLoading: apiKeysLoading } = useSWR<
|
||||
ApiKeyDescriptor[]
|
||||
>(isAdmin ? ADMIN_API_KEYS_URL : null, errorHandlingFetcher);
|
||||
|
||||
const apiKeysByUserId = useMemo(() => {
|
||||
const map = new Map<string, ApiKeyDescriptor>();
|
||||
for (const key of apiKeys ?? []) map.set(key.user_id, key);
|
||||
return map;
|
||||
}, [apiKeys]);
|
||||
|
||||
const { rows, userRows } = useMemo(() => {
|
||||
const accepted = usersData?.accepted ?? [];
|
||||
const userRowsLocal: MemberRow[] = [];
|
||||
const serviceAccountRows: MemberRow[] = [];
|
||||
for (const snapshot of accepted) {
|
||||
if (!snapshot.is_active) continue;
|
||||
if (isApiKeyEmail(snapshot.email)) {
|
||||
serviceAccountRows.push(
|
||||
serviceAccountToMemberRow(snapshot, apiKeysByUserId.get(snapshot.id))
|
||||
);
|
||||
} else {
|
||||
userRowsLocal.push(snapshotToMemberRow(snapshot));
|
||||
}
|
||||
}
|
||||
return {
|
||||
rows: [...userRowsLocal, ...serviceAccountRows],
|
||||
userRows: userRowsLocal,
|
||||
};
|
||||
}, [usersData, apiKeysByUserId]);
|
||||
|
||||
return {
|
||||
rows,
|
||||
userRows,
|
||||
isLoading: usersLoading || (isAdmin && apiKeysLoading),
|
||||
error: usersError,
|
||||
};
|
||||
}
|
||||
@@ -31,6 +31,7 @@ import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationMo
|
||||
import { useCreateModal } from "@/refresh-components/contexts/ModalContext";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import {
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "@/interfaces/llm";
|
||||
@@ -43,9 +44,10 @@ import BedrockModal from "@/sections/modals/llmConfig/BedrockModal";
|
||||
import VertexAIModal from "@/sections/modals/llmConfig/VertexAIModal";
|
||||
import OpenRouterModal from "@/sections/modals/llmConfig/OpenRouterModal";
|
||||
import CustomModal from "@/sections/modals/llmConfig/CustomModal";
|
||||
import LMStudioForm from "@/sections/modals/llmConfig/LMStudioForm";
|
||||
import LMStudioModal from "@/sections/modals/llmConfig/LMStudioModal";
|
||||
import LiteLLMProxyModal from "@/sections/modals/llmConfig/LiteLLMProxyModal";
|
||||
import BifrostModal from "@/sections/modals/llmConfig/BifrostModal";
|
||||
import OpenAICompatibleModal from "@/sections/modals/llmConfig/OpenAICompatibleModal";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
|
||||
const route = ADMIN_ROUTES.LLM_MODELS;
|
||||
@@ -57,16 +59,18 @@ const route = ADMIN_ROUTES.LLM_MODELS;
|
||||
// Client-side ordering for the "Add Provider" cards. The backend may return
|
||||
// wellKnownLLMProviders in an arbitrary order, so we sort explicitly here.
|
||||
const PROVIDER_DISPLAY_ORDER: string[] = [
|
||||
"openai",
|
||||
"anthropic",
|
||||
"vertex_ai",
|
||||
"bedrock",
|
||||
"azure",
|
||||
"litellm_proxy",
|
||||
"ollama_chat",
|
||||
"openrouter",
|
||||
"lm_studio",
|
||||
"bifrost",
|
||||
LLMProviderName.OPENAI,
|
||||
LLMProviderName.ANTHROPIC,
|
||||
LLMProviderName.VERTEX_AI,
|
||||
LLMProviderName.BEDROCK,
|
||||
LLMProviderName.AZURE,
|
||||
"litellm",
|
||||
LLMProviderName.LITELLM_PROXY,
|
||||
LLMProviderName.OLLAMA_CHAT,
|
||||
LLMProviderName.OPENROUTER,
|
||||
LLMProviderName.LM_STUDIO,
|
||||
LLMProviderName.BIFROST,
|
||||
LLMProviderName.OPENAI_COMPATIBLE,
|
||||
];
|
||||
|
||||
const PROVIDER_MODAL_MAP: Record<
|
||||
@@ -127,7 +131,7 @@ const PROVIDER_MODAL_MAP: Record<
|
||||
/>
|
||||
),
|
||||
lm_studio: (d, open, onOpenChange) => (
|
||||
<LMStudioForm
|
||||
<LMStudioModal
|
||||
shouldMarkAsDefault={d}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
@@ -147,6 +151,13 @@ const PROVIDER_MODAL_MAP: Record<
|
||||
onOpenChange={onOpenChange}
|
||||
/>
|
||||
),
|
||||
openai_compatible: (d, open, onOpenChange) => (
|
||||
<OpenAICompatibleModal
|
||||
shouldMarkAsDefault={d}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
/>
|
||||
),
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
@@ -341,7 +352,7 @@ function NewCustomProviderCard({
|
||||
// LLMConfigurationPage — main page component
|
||||
// ============================================================================
|
||||
|
||||
export default function LLMConfigurationPage() {
|
||||
export default function LLMProviderConfigurationPage() {
|
||||
const { mutate } = useSWRConfig();
|
||||
const { llmProviders: existingLlmProviders, defaultText } =
|
||||
useAdminLLMProviders();
|
||||
@@ -7,7 +7,6 @@ import {
|
||||
IconProps,
|
||||
OpenAIIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Select } from "@/refresh-components/cards";
|
||||
import Message from "@/refresh-components/messages/Message";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
@@ -26,7 +25,8 @@ import { toast } from "@/hooks/useToast";
|
||||
import { Callout } from "@/components/ui/callout";
|
||||
import { Content } from "@opal/layouts";
|
||||
import { SvgMicrophone, SvgSlash, SvgUnplug } from "@opal/icons";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
import { Button, Text } from "@opal/components";
|
||||
import { markdown } from "@opal/utils";
|
||||
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { ADMIN_ROUTES } from "@/lib/admin-routes";
|
||||
@@ -205,7 +205,7 @@ function VoiceDisconnectModal({
|
||||
description="Voice models"
|
||||
onClose={onClose}
|
||||
submit={
|
||||
<OpalButton
|
||||
<Button
|
||||
variant="danger"
|
||||
onClick={onDisconnect}
|
||||
disabled={
|
||||
@@ -213,19 +213,19 @@ function VoiceDisconnectModal({
|
||||
}
|
||||
>
|
||||
Disconnect
|
||||
</OpalButton>
|
||||
</Button>
|
||||
}
|
||||
>
|
||||
{needsReplacement ? (
|
||||
hasReplacements ? (
|
||||
<Section alignItems="start">
|
||||
<Text as="p" text03>
|
||||
<b>{disconnectTarget.providerLabel}</b> models will no longer be
|
||||
used for speech-to-text or text-to-speech, and it will no longer
|
||||
be your default. Session history will be preserved.
|
||||
<Text as="p" color="text-03">
|
||||
{markdown(
|
||||
`**${disconnectTarget.providerLabel}** models will no longer be used for speech-to-text or text-to-speech, and it will no longer be your default. Session history will be preserved.`
|
||||
)}
|
||||
</Text>
|
||||
<Section alignItems="start" gap={0.25}>
|
||||
<Text as="p" text04>
|
||||
<Text as="p" color="text-04">
|
||||
Set New Default
|
||||
</Text>
|
||||
<InputSelect
|
||||
@@ -256,23 +256,24 @@ function VoiceDisconnectModal({
|
||||
</Section>
|
||||
) : (
|
||||
<>
|
||||
<Text as="p" text03>
|
||||
<b>{disconnectTarget.providerLabel}</b> models will no longer be
|
||||
used for speech-to-text or text-to-speech, and it will no longer
|
||||
be your default.
|
||||
<Text as="p" color="text-03">
|
||||
{markdown(
|
||||
`**${disconnectTarget.providerLabel}** models will no longer be used for speech-to-text or text-to-speech, and it will no longer be your default.`
|
||||
)}
|
||||
</Text>
|
||||
<Text as="p" text03>
|
||||
<Text as="p" color="text-03">
|
||||
Connect another provider to continue using voice.
|
||||
</Text>
|
||||
</>
|
||||
)
|
||||
) : (
|
||||
<>
|
||||
<Text as="p" text03>
|
||||
<b>{disconnectTarget.providerLabel}</b> models will no longer be
|
||||
available for voice.
|
||||
<Text as="p" color="text-03">
|
||||
{markdown(
|
||||
`**${disconnectTarget.providerLabel}** models will no longer be available for voice.`
|
||||
)}
|
||||
</Text>
|
||||
<Text as="p" text03>
|
||||
<Text as="p" color="text-03">
|
||||
Session history will be preserved.
|
||||
</Text>
|
||||
</>
|
||||
@@ -536,7 +537,7 @@ export default function VoiceConfigurationPage() {
|
||||
<Callout type="danger" title="Failed to load voice settings">
|
||||
{message}
|
||||
{detail && (
|
||||
<Text as="p" mainContentBody text03>
|
||||
<Text as="p" font="main-content-body" color="text-03">
|
||||
{detail}
|
||||
</Text>
|
||||
)}
|
||||
@@ -626,7 +627,7 @@ export default function VoiceConfigurationPage() {
|
||||
|
||||
{TTS_PROVIDER_GROUPS.map((group) => (
|
||||
<div key={group.providerType} className="flex flex-col gap-2">
|
||||
<Text secondaryBody text03>
|
||||
<Text font="secondary-body" color="text-03">
|
||||
{group.providerLabel}
|
||||
</Text>
|
||||
<div className="flex flex-col gap-2">
|
||||
|
||||
@@ -122,7 +122,10 @@ function MicrophoneButton({
|
||||
startRecording,
|
||||
stopRecording,
|
||||
setMuted,
|
||||
} = useVoiceRecorder({ onFinalTranscript: handleFinalTranscript });
|
||||
} = useVoiceRecorder({
|
||||
onFinalTranscript: handleFinalTranscript,
|
||||
autoStopOnSilence: autoSend,
|
||||
});
|
||||
|
||||
// Expose stopRecording to parent
|
||||
useEffect(() => {
|
||||
|
||||
@@ -1,33 +1,23 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { Formik } from "formik";
|
||||
import { LLMProviderFormProps } from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import { LLMProviderFormProps, LLMProviderName } from "@/interfaces/llm";
|
||||
import {
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
APIKeyField,
|
||||
ModelsField,
|
||||
ModelSelectionField,
|
||||
DisplayNameField,
|
||||
ModelsAccessField,
|
||||
FieldSeparator,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
|
||||
const ANTHROPIC_PROVIDER_NAME = "anthropic";
|
||||
const DEFAULT_DEFAULT_MODEL_NAME = "claude-sonnet-4-5";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
|
||||
export default function AnthropicModal({
|
||||
variant = "llm-configuration",
|
||||
@@ -35,143 +25,78 @@ export default function AnthropicModal({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
onSuccess,
|
||||
}: LLMProviderFormProps) {
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
|
||||
ANTHROPIC_PROVIDER_NAME
|
||||
|
||||
const initialValues = useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.ANTHROPIC,
|
||||
existingLlmProvider
|
||||
);
|
||||
|
||||
if (open === false) return null;
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
apiKey: true,
|
||||
});
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const initialValues = isOnboarding
|
||||
? {
|
||||
...buildOnboardingInitialValues(),
|
||||
name: ANTHROPIC_PROVIDER_NAME,
|
||||
provider: ANTHROPIC_PROVIDER_NAME,
|
||||
api_key: "",
|
||||
default_model_name: DEFAULT_DEFAULT_MODEL_NAME,
|
||||
}
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
api_base: existingLlmProvider?.api_base ?? undefined,
|
||||
default_model_name:
|
||||
(defaultModelName &&
|
||||
modelConfigurations.some((m) => m.name === defaultModelName)
|
||||
? defaultModelName
|
||||
: undefined) ??
|
||||
wellKnownLLMProvider?.recommended_default_model?.name ??
|
||||
DEFAULT_DEFAULT_MODEL_NAME,
|
||||
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
});
|
||||
if (open === false) return null;
|
||||
|
||||
return (
|
||||
<Formik
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.ANTHROPIC}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
(wellKnownLLMProvider ?? llmDescriptor)?.known_models ?? [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: ANTHROPIC_PROVIDER_NAME,
|
||||
payload: {
|
||||
...values,
|
||||
model_configurations: modelConfigsToUse,
|
||||
is_auto_mode:
|
||||
values.default_model_name === DEFAULT_DEFAULT_MODEL_NAME,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: ANTHROPIC_PROVIDER_NAME,
|
||||
values,
|
||||
initialValues,
|
||||
modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.ANTHROPIC,
|
||||
values,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
{(formikProps) => (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={ANTHROPIC_PROVIDER_NAME}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
>
|
||||
<APIKeyField providerName="Anthropic" />
|
||||
<APIKeyField providerName="Anthropic" />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<FieldSeparator />
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. claude-sonnet-4-5" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={modelConfigurations}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={
|
||||
wellKnownLLMProvider?.recommended_default_model ?? null
|
||||
}
|
||||
shouldShowAutoUpdateToggle={true}
|
||||
/>
|
||||
)}
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
</>
|
||||
)}
|
||||
</LLMConfigurationModalWrapper>
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
</Formik>
|
||||
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField shouldShowAutoUpdateToggle={true} />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</ModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,41 +1,35 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { Formik } from "formik";
|
||||
import { useFormikContext } from "formik";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import { LLMProviderFormProps, LLMProviderView } from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import {
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
} from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import {
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
BaseLLMFormValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
APIKeyField,
|
||||
DisplayNameField,
|
||||
FieldSeparator,
|
||||
FieldWrapper,
|
||||
ModelsAccessField,
|
||||
ModelsField,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
ModelAccessField,
|
||||
ModelSelectionField,
|
||||
ModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import {
|
||||
isValidAzureTargetUri,
|
||||
parseAzureTargetUri,
|
||||
} from "@/lib/azureTargetUri";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
|
||||
const AZURE_PROVIDER_NAME = "azure";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
|
||||
interface AzureModalValues extends BaseLLMFormValues {
|
||||
api_key: string;
|
||||
@@ -45,6 +39,33 @@ interface AzureModalValues extends BaseLLMFormValues {
|
||||
deployment_name?: string;
|
||||
}
|
||||
|
||||
function AzureModelSelection() {
|
||||
const formikProps = useFormikContext<AzureModalValues>();
|
||||
return (
|
||||
<ModelSelectionField
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onAddModel={(modelName) => {
|
||||
const current = formikProps.values.model_configurations;
|
||||
if (current.some((m) => m.name === modelName)) return;
|
||||
const updated = [
|
||||
...current,
|
||||
{
|
||||
name: modelName,
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
];
|
||||
formikProps.setFieldValue("model_configurations", updated);
|
||||
if (!formikProps.values.test_model_name) {
|
||||
formikProps.setFieldValue("test_model_name", modelName);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
function buildTargetUri(existingLlmProvider?: LLMProviderView): string {
|
||||
if (!existingLlmProvider?.api_base || !existingLlmProvider?.api_version) {
|
||||
return "";
|
||||
@@ -81,160 +102,105 @@ export default function AzureModal({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
onSuccess,
|
||||
}: LLMProviderFormProps) {
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } = useWellKnownLLMProvider(AZURE_PROVIDER_NAME);
|
||||
|
||||
if (open === false) return null;
|
||||
const initialValues: AzureModalValues = {
|
||||
...useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.AZURE,
|
||||
existingLlmProvider
|
||||
),
|
||||
target_uri: buildTargetUri(existingLlmProvider),
|
||||
} as AzureModalValues;
|
||||
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
apiKey: true,
|
||||
extra: {
|
||||
target_uri: Yup.string()
|
||||
.required("Target URI is required")
|
||||
.test(
|
||||
"valid-target-uri",
|
||||
"Target URI must be a valid URL with api-version query parameter and either a deployment name in the path or /openai/responses",
|
||||
(value) => (value ? isValidAzureTargetUri(value) : false)
|
||||
),
|
||||
},
|
||||
});
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const initialValues: AzureModalValues = isOnboarding
|
||||
? ({
|
||||
...buildOnboardingInitialValues(),
|
||||
name: AZURE_PROVIDER_NAME,
|
||||
provider: AZURE_PROVIDER_NAME,
|
||||
api_key: "",
|
||||
target_uri: "",
|
||||
default_model_name: "",
|
||||
} as AzureModalValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
target_uri: buildTargetUri(existingLlmProvider),
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
target_uri: Yup.string()
|
||||
.required("Target URI is required")
|
||||
.test(
|
||||
"valid-target-uri",
|
||||
"Target URI must be a valid URL with api-version query parameter and either a deployment name in the path or /openai/responses",
|
||||
(value) => (value ? isValidAzureTargetUri(value) : false)
|
||||
),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
target_uri: Yup.string()
|
||||
.required("Target URI is required")
|
||||
.test(
|
||||
"valid-target-uri",
|
||||
"Target URI must be a valid URL with api-version query parameter and either a deployment name in the path or /openai/responses",
|
||||
(value) => (value ? isValidAzureTargetUri(value) : false)
|
||||
),
|
||||
});
|
||||
if (open === false) return null;
|
||||
|
||||
return (
|
||||
<Formik
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.AZURE}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
const processedValues = processValues(values);
|
||||
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
(wellKnownLLMProvider ?? llmDescriptor)?.known_models ?? [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: AZURE_PROVIDER_NAME,
|
||||
payload: {
|
||||
...processedValues,
|
||||
model_configurations: modelConfigsToUse,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: AZURE_PROVIDER_NAME,
|
||||
values: processedValues,
|
||||
initialValues,
|
||||
modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.AZURE,
|
||||
values: processedValues,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
{(formikProps) => (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={AZURE_PROVIDER_NAME}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
<InputLayouts.FieldPadder>
|
||||
<InputLayouts.Vertical
|
||||
name="target_uri"
|
||||
title="Target URI"
|
||||
subDescription="Paste your endpoint target URI from Azure OpenAI (including API endpoint base, deployment name, and API version)."
|
||||
>
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="target_uri"
|
||||
title="Target URI"
|
||||
subDescription="Paste your endpoint target URI from Azure OpenAI (including API endpoint base, deployment name, and API version)."
|
||||
>
|
||||
<InputTypeInField
|
||||
name="target_uri"
|
||||
placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview"
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
<InputTypeInField
|
||||
name="target_uri"
|
||||
placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview"
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</InputLayouts.FieldPadder>
|
||||
|
||||
<APIKeyField providerName="Azure" />
|
||||
<APIKeyField providerName="Azure" />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<FieldSeparator />
|
||||
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. gpt-4o" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={modelConfigurations}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={null}
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
/>
|
||||
)}
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
</>
|
||||
)}
|
||||
</LLMConfigurationModalWrapper>
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
</Formik>
|
||||
|
||||
<InputLayouts.FieldSeparator />
|
||||
<AzureModelSelection />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</ModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import { useEffect } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { Formik, FormikProps } from "formik";
|
||||
import { useFormikContext } from "formik";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import InputSelectField from "@/refresh-components/form/InputSelectField";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
@@ -10,30 +10,22 @@ import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
} from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import {
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
BaseLLMFormValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import {
|
||||
ModelsField,
|
||||
ModelSelectionField,
|
||||
DisplayNameField,
|
||||
FieldSeparator,
|
||||
FieldWrapper,
|
||||
ModelsAccessField,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import { fetchBedrockModels } from "@/app/admin/configuration/llm/utils";
|
||||
import { Card } from "@opal/components";
|
||||
@@ -41,9 +33,9 @@ import { Section } from "@/layouts/general-layouts";
|
||||
import { SvgAlertCircle } from "@opal/icons";
|
||||
import { Content } from "@opal/layouts";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
import useOnMount from "@/hooks/useOnMount";
|
||||
|
||||
const BEDROCK_PROVIDER_NAME = "bedrock";
|
||||
const AWS_REGION_OPTIONS = [
|
||||
{ name: "us-east-1", value: "us-east-1" },
|
||||
{ name: "us-east-2", value: "us-east-2" },
|
||||
@@ -79,26 +71,15 @@ interface BedrockModalValues extends BaseLLMFormValues {
|
||||
}
|
||||
|
||||
interface BedrockModalInternalsProps {
|
||||
formikProps: FormikProps<BedrockModalValues>;
|
||||
existingLlmProvider: LLMProviderView | undefined;
|
||||
fetchedModels: ModelConfiguration[];
|
||||
setFetchedModels: (models: ModelConfiguration[]) => void;
|
||||
modelConfigurations: ModelConfiguration[];
|
||||
isTesting: boolean;
|
||||
onClose: () => void;
|
||||
isOnboarding: boolean;
|
||||
}
|
||||
|
||||
function BedrockModalInternals({
|
||||
formikProps,
|
||||
existingLlmProvider,
|
||||
fetchedModels,
|
||||
setFetchedModels,
|
||||
modelConfigurations,
|
||||
isTesting,
|
||||
onClose,
|
||||
isOnboarding,
|
||||
}: BedrockModalInternalsProps) {
|
||||
const formikProps = useFormikContext<BedrockModalValues>();
|
||||
const authMethod = formikProps.values.custom_config?.BEDROCK_AUTH_METHOD;
|
||||
|
||||
useEffect(() => {
|
||||
@@ -115,11 +96,6 @@ function BedrockModalInternals({
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [authMethod]);
|
||||
|
||||
const currentModels =
|
||||
fetchedModels.length > 0
|
||||
? fetchedModels
|
||||
: existingLlmProvider?.model_configurations || modelConfigurations;
|
||||
|
||||
const isAuthComplete =
|
||||
authMethod === AUTH_METHOD_IAM ||
|
||||
(authMethod === AUTH_METHOD_ACCESS_KEY &&
|
||||
@@ -139,12 +115,12 @@ function BedrockModalInternals({
|
||||
formikProps.values.custom_config?.AWS_SECRET_ACCESS_KEY,
|
||||
aws_bearer_token_bedrock:
|
||||
formikProps.values.custom_config?.AWS_BEARER_TOKEN_BEDROCK,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
provider_name: LLMProviderName.BEDROCK,
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
setFetchedModels(models);
|
||||
formikProps.setFieldValue("model_configurations", models);
|
||||
};
|
||||
|
||||
// Auto-fetch models on initial load when editing an existing provider
|
||||
@@ -159,16 +135,8 @@ function BedrockModalInternals({
|
||||
});
|
||||
|
||||
return (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={BEDROCK_PROVIDER_NAME}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
>
|
||||
<FieldWrapper>
|
||||
<>
|
||||
<InputLayouts.FieldPadder>
|
||||
<Section gap={1}>
|
||||
<InputLayouts.Vertical
|
||||
name={FIELD_AWS_REGION_NAME}
|
||||
@@ -222,7 +190,7 @@ function BedrockModalInternals({
|
||||
</InputSelect>
|
||||
</InputLayouts.Vertical>
|
||||
</Section>
|
||||
</FieldWrapper>
|
||||
</InputLayouts.FieldPadder>
|
||||
|
||||
{authMethod === AUTH_METHOD_ACCESS_KEY && (
|
||||
<Card backgroundVariant="light" borderVariant="none" sizeVariant="lg">
|
||||
@@ -250,7 +218,7 @@ function BedrockModalInternals({
|
||||
)}
|
||||
|
||||
{authMethod === AUTH_METHOD_IAM && (
|
||||
<FieldWrapper>
|
||||
<InputLayouts.FieldPadder>
|
||||
<Card backgroundVariant="none" borderVariant="solid">
|
||||
<Content
|
||||
icon={SvgAlertCircle}
|
||||
@@ -259,7 +227,7 @@ function BedrockModalInternals({
|
||||
sizePreset="main-ui"
|
||||
/>
|
||||
</Card>
|
||||
</FieldWrapper>
|
||||
</InputLayouts.FieldPadder>
|
||||
)}
|
||||
|
||||
{authMethod === AUTH_METHOD_LONG_TERM_API_KEY && (
|
||||
@@ -280,32 +248,24 @@ function BedrockModalInternals({
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<FieldSeparator />
|
||||
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. us.anthropic.claude-sonnet-4-5-v1" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={currentModels}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={null}
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
)}
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</LLMConfigurationModalWrapper>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -315,86 +275,54 @@ export default function BedrockModal({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
onSuccess,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
|
||||
BEDROCK_PROVIDER_NAME
|
||||
);
|
||||
|
||||
if (open === false) return null;
|
||||
const initialValues: BedrockModalValues = {
|
||||
...useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.BEDROCK,
|
||||
existingLlmProvider
|
||||
),
|
||||
custom_config: {
|
||||
AWS_REGION_NAME:
|
||||
(existingLlmProvider?.custom_config?.AWS_REGION_NAME as string) ?? "",
|
||||
BEDROCK_AUTH_METHOD:
|
||||
(existingLlmProvider?.custom_config?.BEDROCK_AUTH_METHOD as string) ??
|
||||
"access_key",
|
||||
AWS_ACCESS_KEY_ID:
|
||||
(existingLlmProvider?.custom_config?.AWS_ACCESS_KEY_ID as string) ?? "",
|
||||
AWS_SECRET_ACCESS_KEY:
|
||||
(existingLlmProvider?.custom_config?.AWS_SECRET_ACCESS_KEY as string) ??
|
||||
"",
|
||||
AWS_BEARER_TOKEN_BEDROCK:
|
||||
(existingLlmProvider?.custom_config
|
||||
?.AWS_BEARER_TOKEN_BEDROCK as string) ?? "",
|
||||
},
|
||||
} as BedrockModalValues;
|
||||
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
extra: {
|
||||
custom_config: Yup.object({
|
||||
AWS_REGION_NAME: Yup.string().required("AWS Region is required"),
|
||||
}),
|
||||
},
|
||||
});
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const initialValues: BedrockModalValues = isOnboarding
|
||||
? ({
|
||||
...buildOnboardingInitialValues(),
|
||||
name: BEDROCK_PROVIDER_NAME,
|
||||
provider: BEDROCK_PROVIDER_NAME,
|
||||
default_model_name: "",
|
||||
custom_config: {
|
||||
AWS_REGION_NAME: "",
|
||||
BEDROCK_AUTH_METHOD: "access_key",
|
||||
AWS_ACCESS_KEY_ID: "",
|
||||
AWS_SECRET_ACCESS_KEY: "",
|
||||
AWS_BEARER_TOKEN_BEDROCK: "",
|
||||
},
|
||||
} as BedrockModalValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
custom_config: {
|
||||
AWS_REGION_NAME:
|
||||
(existingLlmProvider?.custom_config?.AWS_REGION_NAME as string) ??
|
||||
"",
|
||||
BEDROCK_AUTH_METHOD:
|
||||
(existingLlmProvider?.custom_config
|
||||
?.BEDROCK_AUTH_METHOD as string) ?? "access_key",
|
||||
AWS_ACCESS_KEY_ID:
|
||||
(existingLlmProvider?.custom_config?.AWS_ACCESS_KEY_ID as string) ??
|
||||
"",
|
||||
AWS_SECRET_ACCESS_KEY:
|
||||
(existingLlmProvider?.custom_config
|
||||
?.AWS_SECRET_ACCESS_KEY as string) ?? "",
|
||||
AWS_BEARER_TOKEN_BEDROCK:
|
||||
(existingLlmProvider?.custom_config
|
||||
?.AWS_BEARER_TOKEN_BEDROCK as string) ?? "",
|
||||
},
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
custom_config: Yup.object({
|
||||
AWS_REGION_NAME: Yup.string().required("AWS Region is required"),
|
||||
}),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
custom_config: Yup.object({
|
||||
AWS_REGION_NAME: Yup.string().required("AWS Region is required"),
|
||||
}),
|
||||
});
|
||||
if (open === false) return null;
|
||||
|
||||
return (
|
||||
<Formik
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.BEDROCK}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
const filteredCustomConfig = Object.fromEntries(
|
||||
Object.entries(values.custom_config || {}).filter(([, v]) => v !== "")
|
||||
);
|
||||
@@ -407,51 +335,37 @@ export default function BedrockModal({
|
||||
: undefined,
|
||||
};
|
||||
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
fetchedModels.length > 0 ? fetchedModels : [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: BEDROCK_PROVIDER_NAME,
|
||||
payload: {
|
||||
...submitValues,
|
||||
model_configurations: modelConfigsToUse,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: BEDROCK_PROVIDER_NAME,
|
||||
values: submitValues,
|
||||
initialValues,
|
||||
modelConfigurations:
|
||||
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.BEDROCK,
|
||||
values: submitValues,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
{(formikProps) => (
|
||||
<BedrockModalInternals
|
||||
formikProps={formikProps}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
fetchedModels={fetchedModels}
|
||||
setFetchedModels={setFetchedModels}
|
||||
modelConfigurations={modelConfigurations}
|
||||
isTesting={isTesting}
|
||||
onClose={onClose}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
)}
|
||||
</Formik>
|
||||
<BedrockModalInternals
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
</ModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,45 +1,33 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import { useEffect } from "react";
|
||||
import { markdown } from "@opal/utils";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { Formik, FormikProps } from "formik";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
|
||||
import { useFormikContext } from "formik";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
} from "@/interfaces/llm";
|
||||
import { fetchBifrostModels } from "@/app/admin/configuration/llm/utils";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import {
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
BaseLLMFormValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import {
|
||||
ModelsField,
|
||||
APIBaseField,
|
||||
APIKeyField,
|
||||
ModelSelectionField,
|
||||
DisplayNameField,
|
||||
ModelsAccessField,
|
||||
FieldSeparator,
|
||||
FieldWrapper,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
|
||||
const BIFROST_PROVIDER_NAME = LLMProviderName.BIFROST;
|
||||
const DEFAULT_API_BASE = "";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
|
||||
interface BifrostModalValues extends BaseLLMFormValues {
|
||||
api_key: string;
|
||||
@@ -47,30 +35,15 @@ interface BifrostModalValues extends BaseLLMFormValues {
|
||||
}
|
||||
|
||||
interface BifrostModalInternalsProps {
|
||||
formikProps: FormikProps<BifrostModalValues>;
|
||||
existingLlmProvider: LLMProviderView | undefined;
|
||||
fetchedModels: ModelConfiguration[];
|
||||
setFetchedModels: (models: ModelConfiguration[]) => void;
|
||||
modelConfigurations: ModelConfiguration[];
|
||||
isTesting: boolean;
|
||||
onClose: () => void;
|
||||
isOnboarding: boolean;
|
||||
}
|
||||
|
||||
function BifrostModalInternals({
|
||||
formikProps,
|
||||
existingLlmProvider,
|
||||
fetchedModels,
|
||||
setFetchedModels,
|
||||
modelConfigurations,
|
||||
isTesting,
|
||||
onClose,
|
||||
isOnboarding,
|
||||
}: BifrostModalInternalsProps) {
|
||||
const currentModels =
|
||||
fetchedModels.length > 0
|
||||
? fetchedModels
|
||||
: existingLlmProvider?.model_configurations || modelConfigurations;
|
||||
const formikProps = useFormikContext<BifrostModalValues>();
|
||||
|
||||
const isFetchDisabled = !formikProps.values.api_base;
|
||||
|
||||
@@ -78,12 +51,12 @@ function BifrostModalInternals({
|
||||
const { models, error } = await fetchBifrostModels({
|
||||
api_base: formikProps.values.api_base,
|
||||
api_key: formikProps.values.api_key || undefined,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
provider_name: LLMProviderName.BIFROST,
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
setFetchedModels(models);
|
||||
formikProps.setFieldValue("model_configurations", models);
|
||||
};
|
||||
|
||||
// Auto-fetch models on initial load when editing an existing provider
|
||||
@@ -100,69 +73,39 @@ function BifrostModalInternals({
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={LLMProviderName.BIFROST}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
>
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="api_base"
|
||||
title="API Base URL"
|
||||
subDescription="Paste your Bifrost gateway endpoint URL (including API version)."
|
||||
>
|
||||
<InputTypeInField
|
||||
name="api_base"
|
||||
placeholder="https://your-bifrost-gateway.com/v1"
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
<>
|
||||
<APIBaseField
|
||||
subDescription="Paste your Bifrost gateway endpoint URL (including API version)."
|
||||
placeholder="https://your-bifrost-gateway.com/v1"
|
||||
/>
|
||||
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="api_key"
|
||||
title="API Key"
|
||||
optional={true}
|
||||
subDescription={markdown(
|
||||
"Paste your API key from [Bifrost](https://docs.getbifrost.ai/overview) to access your models."
|
||||
)}
|
||||
>
|
||||
<PasswordInputTypeInField name="api_key" placeholder="API Key" />
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
<APIKeyField
|
||||
optional
|
||||
subDescription={markdown(
|
||||
"Paste your API key from [Bifrost](https://docs.getbifrost.ai/overview) to access your models."
|
||||
)}
|
||||
/>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<FieldSeparator />
|
||||
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. anthropic/claude-sonnet-4-6" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={currentModels}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={null}
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
)}
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</LLMConfigurationModalWrapper>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -172,107 +115,64 @@ export default function BifrostModal({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
onSuccess,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
|
||||
BIFROST_PROVIDER_NAME
|
||||
);
|
||||
|
||||
if (open === false) return null;
|
||||
const initialValues: BifrostModalValues = useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.BIFROST,
|
||||
existingLlmProvider
|
||||
) as BifrostModalValues;
|
||||
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
apiBase: true,
|
||||
});
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const initialValues: BifrostModalValues = isOnboarding
|
||||
? ({
|
||||
...buildOnboardingInitialValues(),
|
||||
name: BIFROST_PROVIDER_NAME,
|
||||
provider: BIFROST_PROVIDER_NAME,
|
||||
api_key: "",
|
||||
api_base: DEFAULT_API_BASE,
|
||||
default_model_name: "",
|
||||
} as BifrostModalValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
});
|
||||
if (open === false) return null;
|
||||
|
||||
return (
|
||||
<Formik
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.BIFROST}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
fetchedModels.length > 0 ? fetchedModels : [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: BIFROST_PROVIDER_NAME,
|
||||
payload: {
|
||||
...values,
|
||||
model_configurations: modelConfigsToUse,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: BIFROST_PROVIDER_NAME,
|
||||
values,
|
||||
initialValues,
|
||||
modelConfigurations:
|
||||
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.BIFROST,
|
||||
values,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
{(formikProps) => (
|
||||
<BifrostModalInternals
|
||||
formikProps={formikProps}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
fetchedModels={fetchedModels}
|
||||
setFetchedModels={setFetchedModels}
|
||||
modelConfigurations={modelConfigurations}
|
||||
isTesting={isTesting}
|
||||
onClose={onClose}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
)}
|
||||
</Formik>
|
||||
<BifrostModalInternals
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
</ModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -70,7 +70,9 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
}
|
||||
) {
|
||||
const nameInput = screen.getByPlaceholderText("Display Name");
|
||||
const providerInput = screen.getByPlaceholderText("Provider Name");
|
||||
const providerInput = screen.getByPlaceholderText(
|
||||
"Provider Name as shown on LiteLLM"
|
||||
);
|
||||
|
||||
await user.type(nameInput, options.name);
|
||||
await user.type(providerInput, options.provider);
|
||||
@@ -498,7 +500,9 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
const nameInput = screen.getByPlaceholderText("Display Name");
|
||||
await user.type(nameInput, "Cloudflare Provider");
|
||||
|
||||
const providerInput = screen.getByPlaceholderText("Provider Name");
|
||||
const providerInput = screen.getByPlaceholderText(
|
||||
"Provider Name as shown on LiteLLM"
|
||||
);
|
||||
await user.type(providerInput, "cloudflare");
|
||||
|
||||
// Click "Add Line" button for custom config (aria-label from KeyValueInput)
|
||||
|
||||
@@ -1,24 +1,23 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { markdown } from "@opal/utils";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { Formik, FormikProps } from "formik";
|
||||
import { LLMProviderFormProps, ModelConfiguration } from "@/interfaces/llm";
|
||||
import { useFormikContext } from "formik";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
ModelConfiguration,
|
||||
} from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import { useInitialValues } from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
buildDefaultInitialValues,
|
||||
buildOnboardingInitialValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import {
|
||||
APIKeyField,
|
||||
APIBaseField,
|
||||
DisplayNameField,
|
||||
FieldSeparator,
|
||||
ModelsAccessField,
|
||||
LLMConfigurationModalWrapper,
|
||||
FieldWrapper,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
@@ -32,6 +31,7 @@ import { Button, Card, EmptyMessageCard } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { SvgMinusCircle, SvgPlusCircle } from "@opal/icons";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
import { Content } from "@opal/layouts";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
|
||||
@@ -109,13 +109,10 @@ function ModelConfigurationItem({
|
||||
);
|
||||
}
|
||||
|
||||
interface ModelConfigurationListProps {
|
||||
formikProps: FormikProps<{
|
||||
function ModelConfigurationList() {
|
||||
const formikProps = useFormikContext<{
|
||||
model_configurations: CustomModelConfiguration[];
|
||||
}>;
|
||||
}
|
||||
|
||||
function ModelConfigurationList({ formikProps }: ModelConfigurationListProps) {
|
||||
}>();
|
||||
const models = formikProps.values.model_configurations;
|
||||
|
||||
function handleChange(index: number, next: CustomModelConfiguration) {
|
||||
@@ -181,6 +178,19 @@ function ModelConfigurationList({ formikProps }: ModelConfigurationListProps) {
|
||||
);
|
||||
}
|
||||
|
||||
function CustomConfigKeyValue() {
|
||||
const formikProps = useFormikContext<{ custom_config_list: KeyValue[] }>();
|
||||
return (
|
||||
<KeyValueInput
|
||||
items={formikProps.values.custom_config_list}
|
||||
onChange={(items) =>
|
||||
formikProps.setFieldValue("custom_config_list", items)
|
||||
}
|
||||
addButtonLabel="Add Line"
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Custom Config Processing ─────────────────────────────────────────────────
|
||||
|
||||
function customConfigProcessing(items: KeyValue[]) {
|
||||
@@ -197,39 +207,36 @@ export default function CustomModal({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
onSuccess,
|
||||
}: LLMProviderFormProps) {
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const { mutate } = useSWRConfig();
|
||||
|
||||
if (open === false) return null;
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const initialValues = {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
undefined,
|
||||
defaultModelName
|
||||
...useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.CUSTOM,
|
||||
existingLlmProvider
|
||||
),
|
||||
...(isOnboarding ? buildOnboardingInitialValues() : {}),
|
||||
provider: existingLlmProvider?.provider ?? "",
|
||||
api_version: existingLlmProvider?.api_version ?? "",
|
||||
model_configurations: existingLlmProvider?.model_configurations.map(
|
||||
(mc) => ({
|
||||
name: mc.name,
|
||||
display_name: mc.display_name ?? "",
|
||||
is_visible: mc.is_visible,
|
||||
max_input_tokens: mc.max_input_tokens ?? null,
|
||||
supports_image_input: mc.supports_image_input,
|
||||
supports_reasoning: mc.supports_reasoning,
|
||||
})
|
||||
) ?? [
|
||||
{
|
||||
name: "",
|
||||
display_name: "",
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
],
|
||||
custom_config_list: existingLlmProvider?.custom_config
|
||||
@@ -260,12 +267,18 @@ export default function CustomModal({
|
||||
model_configurations: Yup.array(modelConfigurationSchema),
|
||||
});
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
if (open === false) return null;
|
||||
|
||||
return (
|
||||
<Formik
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.CUSTOM}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
setSubmitting(true);
|
||||
|
||||
const modelConfigurations = values.model_configurations
|
||||
@@ -285,127 +298,123 @@ export default function CustomModal({
|
||||
return;
|
||||
}
|
||||
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
await submitOnboardingProvider({
|
||||
providerName: values.provider,
|
||||
payload: {
|
||||
...values,
|
||||
model_configurations: modelConfigurations,
|
||||
custom_config: customConfigProcessing(values.custom_config_list),
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: true,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
const selectedModelNames = modelConfigurations.map(
|
||||
(config) => config.name
|
||||
);
|
||||
// Always send custom_config as a dict (even empty) so the backend
|
||||
// preserves it as non-null — this is the signal that the provider was
|
||||
// created via CustomModal.
|
||||
const customConfig = customConfigProcessing(values.custom_config_list);
|
||||
|
||||
await submitLLMProvider({
|
||||
providerName: values.provider,
|
||||
values: {
|
||||
...values,
|
||||
selected_model_names: selectedModelNames,
|
||||
custom_config: customConfigProcessing(values.custom_config_list),
|
||||
},
|
||||
initialValues: {
|
||||
...initialValues,
|
||||
custom_config: customConfigProcessing(
|
||||
initialValues.custom_config_list
|
||||
),
|
||||
},
|
||||
modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: (values as Record<string, unknown>).provider as string,
|
||||
values: {
|
||||
...values,
|
||||
model_configurations: modelConfigurations,
|
||||
custom_config: customConfig,
|
||||
},
|
||||
initialValues: {
|
||||
...initialValues,
|
||||
custom_config: customConfigProcessing(
|
||||
initialValues.custom_config_list
|
||||
),
|
||||
},
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
isCustomProvider: true,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
{(formikProps) => (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint="custom"
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
<InputLayouts.FieldPadder>
|
||||
<InputLayouts.Vertical
|
||||
name="provider"
|
||||
title="Provider Name"
|
||||
subDescription={markdown(
|
||||
"Should be one of the providers listed at [LiteLLM](https://docs.litellm.ai/docs/providers)."
|
||||
)}
|
||||
>
|
||||
{!isOnboarding && (
|
||||
<Section gap={0}>
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
<InputTypeInField
|
||||
name="provider"
|
||||
placeholder="Provider Name as shown on LiteLLM"
|
||||
variant={existingLlmProvider ? "disabled" : undefined}
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</InputLayouts.FieldPadder>
|
||||
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="provider"
|
||||
title="Provider Name"
|
||||
subDescription="Should be one of the providers listed at https://docs.litellm.ai/docs/providers."
|
||||
>
|
||||
<InputTypeInField
|
||||
name="provider"
|
||||
placeholder="Provider Name"
|
||||
variant={existingLlmProvider ? "disabled" : undefined}
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
</Section>
|
||||
)}
|
||||
<APIBaseField optional />
|
||||
|
||||
<FieldSeparator />
|
||||
<InputLayouts.FieldPadder>
|
||||
<InputLayouts.Vertical name="api_version" title="API Version" optional>
|
||||
<InputTypeInField name="api_version" />
|
||||
</InputLayouts.Vertical>
|
||||
</InputLayouts.FieldPadder>
|
||||
|
||||
<FieldWrapper>
|
||||
<Section gap={0.75}>
|
||||
<Content
|
||||
title="Provider Configs"
|
||||
description="Add properties as needed by the model provider. This is passed to LiteLLM completion() call as arguments in the environment variable. See LiteLLM documentation for more instructions."
|
||||
widthVariant="full"
|
||||
variant="section"
|
||||
sizePreset="main-content"
|
||||
/>
|
||||
<APIKeyField
|
||||
optional
|
||||
subDescription="Paste your API key if your model provider requires authentication."
|
||||
/>
|
||||
|
||||
<KeyValueInput
|
||||
items={formikProps.values.custom_config_list}
|
||||
onChange={(items) =>
|
||||
formikProps.setFieldValue("custom_config_list", items)
|
||||
}
|
||||
addButtonLabel="Add Line"
|
||||
/>
|
||||
</Section>
|
||||
</FieldWrapper>
|
||||
<InputLayouts.FieldPadder>
|
||||
<Section gap={0.75}>
|
||||
<Content
|
||||
title="Additional Configs"
|
||||
description={markdown(
|
||||
"Add extra properties as needed by the model provider. These are passed to LiteLLM's `completion()` call as [environment variables](https://docs.litellm.ai/docs/set_keys#environment-variables). See [documentation](https://docs.onyx.app/admins/ai_models/custom_inference_provider) for more instructions."
|
||||
)}
|
||||
widthVariant="full"
|
||||
variant="section"
|
||||
sizePreset="main-content"
|
||||
/>
|
||||
|
||||
<FieldSeparator />
|
||||
<CustomConfigKeyValue />
|
||||
</Section>
|
||||
</InputLayouts.FieldPadder>
|
||||
|
||||
<Section gap={0.5}>
|
||||
<FieldWrapper>
|
||||
<Content
|
||||
title="Models"
|
||||
description="List LLM models you wish to use and their configurations for this provider. See full list of models at LiteLLM."
|
||||
variant="section"
|
||||
sizePreset="main-content"
|
||||
widthVariant="full"
|
||||
/>
|
||||
</FieldWrapper>
|
||||
|
||||
<Card>
|
||||
<ModelConfigurationList formikProps={formikProps as any} />
|
||||
</Card>
|
||||
</Section>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
</>
|
||||
)}
|
||||
</LLMConfigurationModalWrapper>
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
</Formik>
|
||||
|
||||
<InputLayouts.FieldSeparator />
|
||||
<Section gap={0.5}>
|
||||
<InputLayouts.FieldPadder>
|
||||
<Content
|
||||
title="Models"
|
||||
description="List LLM models you wish to use and their configurations for this provider. See full list of models at LiteLLM."
|
||||
variant="section"
|
||||
sizePreset="main-content"
|
||||
widthVariant="full"
|
||||
/>
|
||||
</InputLayouts.FieldPadder>
|
||||
|
||||
<Card sizeVariant="lg">
|
||||
<ModelConfigurationList />
|
||||
</Card>
|
||||
</Section>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</ModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,315 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useCallback, useEffect, useMemo, useState } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { Formik, FormikProps } from "formik";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
} from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import {
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
BaseLLMFormValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import {
|
||||
ModelsField,
|
||||
DisplayNameField,
|
||||
ModelsAccessField,
|
||||
FieldSeparator,
|
||||
FieldWrapper,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import { fetchModels } from "@/app/admin/configuration/llm/utils";
|
||||
import debounce from "lodash/debounce";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
|
||||
const DEFAULT_API_BASE = "http://localhost:1234";
|
||||
|
||||
interface LMStudioFormValues extends BaseLLMFormValues {
|
||||
api_base: string;
|
||||
custom_config: {
|
||||
LM_STUDIO_API_KEY?: string;
|
||||
};
|
||||
}
|
||||
|
||||
interface LMStudioFormInternalsProps {
|
||||
formikProps: FormikProps<LMStudioFormValues>;
|
||||
existingLlmProvider: LLMProviderView | undefined;
|
||||
fetchedModels: ModelConfiguration[];
|
||||
setFetchedModels: (models: ModelConfiguration[]) => void;
|
||||
isTesting: boolean;
|
||||
onClose: () => void;
|
||||
isOnboarding: boolean;
|
||||
}
|
||||
|
||||
function LMStudioFormInternals({
|
||||
formikProps,
|
||||
existingLlmProvider,
|
||||
fetchedModels,
|
||||
setFetchedModels,
|
||||
isTesting,
|
||||
onClose,
|
||||
isOnboarding,
|
||||
}: LMStudioFormInternalsProps) {
|
||||
const initialApiKey =
|
||||
(existingLlmProvider?.custom_config?.LM_STUDIO_API_KEY as string) ?? "";
|
||||
|
||||
const doFetchModels = useCallback(
|
||||
(apiBase: string, apiKey: string | undefined, signal: AbortSignal) => {
|
||||
fetchModels(
|
||||
LLMProviderName.LM_STUDIO,
|
||||
{
|
||||
api_base: apiBase,
|
||||
custom_config: apiKey ? { LM_STUDIO_API_KEY: apiKey } : {},
|
||||
api_key_changed: apiKey !== initialApiKey,
|
||||
name: existingLlmProvider?.name,
|
||||
},
|
||||
signal
|
||||
).then((data) => {
|
||||
if (signal.aborted) return;
|
||||
if (data.error) {
|
||||
toast.error(data.error);
|
||||
setFetchedModels([]);
|
||||
return;
|
||||
}
|
||||
setFetchedModels(data.models);
|
||||
});
|
||||
},
|
||||
[existingLlmProvider?.name, initialApiKey, setFetchedModels]
|
||||
);
|
||||
|
||||
const debouncedFetchModels = useMemo(
|
||||
() => debounce(doFetchModels, 500),
|
||||
[doFetchModels]
|
||||
);
|
||||
|
||||
const apiBase = formikProps.values.api_base;
|
||||
const apiKey = formikProps.values.custom_config?.LM_STUDIO_API_KEY;
|
||||
|
||||
useEffect(() => {
|
||||
if (apiBase) {
|
||||
const controller = new AbortController();
|
||||
debouncedFetchModels(apiBase, apiKey, controller.signal);
|
||||
return () => {
|
||||
debouncedFetchModels.cancel();
|
||||
controller.abort();
|
||||
};
|
||||
} else {
|
||||
setFetchedModels([]);
|
||||
}
|
||||
}, [apiBase, apiKey, debouncedFetchModels, setFetchedModels]);
|
||||
|
||||
const currentModels =
|
||||
fetchedModels.length > 0
|
||||
? fetchedModels
|
||||
: existingLlmProvider?.model_configurations || [];
|
||||
|
||||
return (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={LLMProviderName.LM_STUDIO}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
>
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="api_base"
|
||||
title="API Base URL"
|
||||
subDescription="The base URL for your LM Studio server."
|
||||
>
|
||||
<InputTypeInField
|
||||
name="api_base"
|
||||
placeholder="Your LM Studio API base URL"
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="custom_config.LM_STUDIO_API_KEY"
|
||||
title="API Key"
|
||||
subDescription="Optional API key if your LM Studio server requires authentication."
|
||||
optional
|
||||
>
|
||||
<PasswordInputTypeInField
|
||||
name="custom_config.LM_STUDIO_API_KEY"
|
||||
placeholder="API Key"
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<FieldSeparator />
|
||||
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. llama3.1" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={currentModels}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={null}
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
/>
|
||||
)}
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
</>
|
||||
)}
|
||||
</LLMConfigurationModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
export default function LMStudioForm({
|
||||
variant = "llm-configuration",
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
|
||||
LLMProviderName.LM_STUDIO
|
||||
);
|
||||
|
||||
if (open === false) return null;
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const initialValues: LMStudioFormValues = isOnboarding
|
||||
? ({
|
||||
...buildOnboardingInitialValues(),
|
||||
name: LLMProviderName.LM_STUDIO,
|
||||
provider: LLMProviderName.LM_STUDIO,
|
||||
api_base: DEFAULT_API_BASE,
|
||||
default_model_name: "",
|
||||
custom_config: {
|
||||
LM_STUDIO_API_KEY: "",
|
||||
},
|
||||
} as LMStudioFormValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
custom_config: {
|
||||
LM_STUDIO_API_KEY:
|
||||
(existingLlmProvider?.custom_config?.LM_STUDIO_API_KEY as string) ??
|
||||
"",
|
||||
},
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
});
|
||||
|
||||
return (
|
||||
<Formik
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
const filteredCustomConfig = Object.fromEntries(
|
||||
Object.entries(values.custom_config || {}).filter(([, v]) => v !== "")
|
||||
);
|
||||
|
||||
const submitValues = {
|
||||
...values,
|
||||
custom_config:
|
||||
Object.keys(filteredCustomConfig).length > 0
|
||||
? filteredCustomConfig
|
||||
: undefined,
|
||||
};
|
||||
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
fetchedModels.length > 0 ? fetchedModels : [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: LLMProviderName.LM_STUDIO,
|
||||
payload: {
|
||||
...submitValues,
|
||||
model_configurations: modelConfigsToUse,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: LLMProviderName.LM_STUDIO,
|
||||
values: submitValues,
|
||||
initialValues,
|
||||
modelConfigurations:
|
||||
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
{(formikProps) => (
|
||||
<LMStudioFormInternals
|
||||
formikProps={formikProps}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
fetchedModels={fetchedModels}
|
||||
setFetchedModels={setFetchedModels}
|
||||
isTesting={isTesting}
|
||||
onClose={onClose}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
)}
|
||||
</Formik>
|
||||
);
|
||||
}
|
||||
217
web/src/sections/modals/llmConfig/LMStudioModal.tsx
Normal file
217
web/src/sections/modals/llmConfig/LMStudioModal.tsx
Normal file
@@ -0,0 +1,217 @@
|
||||
"use client";
|
||||
|
||||
import { useCallback, useEffect, useMemo } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { useFormikContext } from "formik";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
} from "@/interfaces/llm";
|
||||
import {
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
BaseLLMFormValues as BaseLLMModalValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
APIKeyField,
|
||||
APIBaseField,
|
||||
ModelSelectionField,
|
||||
DisplayNameField,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import { fetchModels } from "@/app/admin/configuration/llm/utils";
|
||||
import debounce from "lodash/debounce";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
|
||||
const DEFAULT_API_BASE = "http://localhost:1234";
|
||||
|
||||
interface LMStudioModalValues extends BaseLLMModalValues {
|
||||
api_base: string;
|
||||
custom_config: {
|
||||
LM_STUDIO_API_KEY?: string;
|
||||
};
|
||||
}
|
||||
|
||||
interface LMStudioModalInternalsProps {
|
||||
existingLlmProvider: LLMProviderView | undefined;
|
||||
isOnboarding: boolean;
|
||||
}
|
||||
|
||||
function LMStudioModalInternals({
|
||||
existingLlmProvider,
|
||||
isOnboarding,
|
||||
}: LMStudioModalInternalsProps) {
|
||||
const formikProps = useFormikContext<LMStudioModalValues>();
|
||||
const initialApiKey = existingLlmProvider?.custom_config?.LM_STUDIO_API_KEY;
|
||||
|
||||
const doFetchModels = useCallback(
|
||||
(apiBase: string, apiKey: string | undefined, signal: AbortSignal) => {
|
||||
fetchModels(
|
||||
LLMProviderName.LM_STUDIO,
|
||||
{
|
||||
api_base: apiBase,
|
||||
custom_config: apiKey ? { LM_STUDIO_API_KEY: apiKey } : {},
|
||||
api_key_changed: apiKey !== initialApiKey,
|
||||
name: existingLlmProvider?.name,
|
||||
},
|
||||
signal
|
||||
).then((data) => {
|
||||
if (signal.aborted) return;
|
||||
if (data.error) {
|
||||
toast.error(data.error);
|
||||
formikProps.setFieldValue("model_configurations", []);
|
||||
return;
|
||||
}
|
||||
formikProps.setFieldValue("model_configurations", data.models);
|
||||
});
|
||||
},
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
[existingLlmProvider?.name, initialApiKey]
|
||||
);
|
||||
|
||||
const debouncedFetchModels = useMemo(
|
||||
() => debounce(doFetchModels, 500),
|
||||
[doFetchModels]
|
||||
);
|
||||
|
||||
const apiBase = formikProps.values.api_base;
|
||||
const apiKey = formikProps.values.custom_config?.LM_STUDIO_API_KEY;
|
||||
|
||||
useEffect(() => {
|
||||
if (apiBase) {
|
||||
const controller = new AbortController();
|
||||
debouncedFetchModels(apiBase, apiKey, controller.signal);
|
||||
return () => {
|
||||
debouncedFetchModels.cancel();
|
||||
controller.abort();
|
||||
};
|
||||
} else {
|
||||
formikProps.setFieldValue("model_configurations", []);
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [apiBase, apiKey, debouncedFetchModels]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<APIBaseField
|
||||
subDescription="The base URL for your LM Studio server."
|
||||
placeholder="Your LM Studio API base URL"
|
||||
/>
|
||||
|
||||
<APIKeyField
|
||||
name="custom_config.LM_STUDIO_API_KEY"
|
||||
optional
|
||||
subDescription="Optional API key if your LM Studio server requires authentication."
|
||||
/>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField shouldShowAutoUpdateToggle={false} />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export default function LMStudioModal({
|
||||
variant = "llm-configuration",
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
onSuccess,
|
||||
}: LLMProviderFormProps) {
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const { mutate } = useSWRConfig();
|
||||
|
||||
const initialValues: LMStudioModalValues = {
|
||||
...useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.LM_STUDIO,
|
||||
existingLlmProvider
|
||||
),
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
custom_config: {
|
||||
LM_STUDIO_API_KEY: existingLlmProvider?.custom_config?.LM_STUDIO_API_KEY,
|
||||
},
|
||||
} as LMStudioModalValues;
|
||||
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
apiBase: true,
|
||||
});
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
if (open === false) return null;
|
||||
|
||||
return (
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.LM_STUDIO}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
const filteredCustomConfig = Object.fromEntries(
|
||||
Object.entries(values.custom_config || {}).filter(([, v]) => v !== "")
|
||||
);
|
||||
|
||||
const submitValues = {
|
||||
...values,
|
||||
custom_config:
|
||||
Object.keys(filteredCustomConfig).length > 0
|
||||
? filteredCustomConfig
|
||||
: undefined,
|
||||
};
|
||||
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.LM_STUDIO,
|
||||
values: submitValues,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
<LMStudioModalInternals
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
</ModalWrapper>
|
||||
);
|
||||
}
|
||||
@@ -1,41 +1,32 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import { useEffect } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { Formik, FormikProps } from "formik";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import { useFormikContext } from "formik";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
} from "@/interfaces/llm";
|
||||
import { fetchLiteLLMProxyModels } from "@/app/admin/configuration/llm/utils";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import {
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
BaseLLMFormValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
APIKeyField,
|
||||
ModelsField,
|
||||
APIBaseField,
|
||||
ModelSelectionField,
|
||||
DisplayNameField,
|
||||
ModelsAccessField,
|
||||
FieldSeparator,
|
||||
FieldWrapper,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
|
||||
const DEFAULT_API_BASE = "http://localhost:4000";
|
||||
|
||||
@@ -45,30 +36,15 @@ interface LiteLLMProxyModalValues extends BaseLLMFormValues {
|
||||
}
|
||||
|
||||
interface LiteLLMProxyModalInternalsProps {
|
||||
formikProps: FormikProps<LiteLLMProxyModalValues>;
|
||||
existingLlmProvider: LLMProviderView | undefined;
|
||||
fetchedModels: ModelConfiguration[];
|
||||
setFetchedModels: (models: ModelConfiguration[]) => void;
|
||||
modelConfigurations: ModelConfiguration[];
|
||||
isTesting: boolean;
|
||||
onClose: () => void;
|
||||
isOnboarding: boolean;
|
||||
}
|
||||
|
||||
function LiteLLMProxyModalInternals({
|
||||
formikProps,
|
||||
existingLlmProvider,
|
||||
fetchedModels,
|
||||
setFetchedModels,
|
||||
modelConfigurations,
|
||||
isTesting,
|
||||
onClose,
|
||||
isOnboarding,
|
||||
}: LiteLLMProxyModalInternalsProps) {
|
||||
const currentModels =
|
||||
fetchedModels.length > 0
|
||||
? fetchedModels
|
||||
: existingLlmProvider?.model_configurations || modelConfigurations;
|
||||
const formikProps = useFormikContext<LiteLLMProxyModalValues>();
|
||||
|
||||
const isFetchDisabled =
|
||||
!formikProps.values.api_base || !formikProps.values.api_key;
|
||||
@@ -77,12 +53,12 @@ function LiteLLMProxyModalInternals({
|
||||
const { models, error } = await fetchLiteLLMProxyModels({
|
||||
api_base: formikProps.values.api_base,
|
||||
api_key: formikProps.values.api_key,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
provider_name: LLMProviderName.LITELLM_PROXY,
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
setFetchedModels(models);
|
||||
formikProps.setFieldValue("model_configurations", models);
|
||||
};
|
||||
|
||||
// Auto-fetch models on initial load when editing an existing provider
|
||||
@@ -98,58 +74,34 @@ function LiteLLMProxyModalInternals({
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={LLMProviderName.LITELLM_PROXY}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
>
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="api_base"
|
||||
title="API Base URL"
|
||||
subDescription="The base URL for your LiteLLM Proxy server."
|
||||
>
|
||||
<InputTypeInField
|
||||
name="api_base"
|
||||
placeholder="https://your-litellm-proxy.com"
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
<>
|
||||
<APIBaseField
|
||||
subDescription="The base URL for your LiteLLM Proxy server."
|
||||
placeholder="https://your-litellm-proxy.com"
|
||||
/>
|
||||
|
||||
<APIKeyField providerName="LiteLLM Proxy" />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<FieldSeparator />
|
||||
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. gpt-4o" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={currentModels}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={null}
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
)}
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</LLMConfigurationModalWrapper>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -159,109 +111,68 @@ export default function LiteLLMProxyModal({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
onSuccess,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
|
||||
LLMProviderName.LITELLM_PROXY
|
||||
);
|
||||
|
||||
if (open === false) return null;
|
||||
const initialValues: LiteLLMProxyModalValues = {
|
||||
...useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.LITELLM_PROXY,
|
||||
existingLlmProvider
|
||||
),
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
} as LiteLLMProxyModalValues;
|
||||
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
apiKey: true,
|
||||
apiBase: true,
|
||||
});
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const initialValues: LiteLLMProxyModalValues = isOnboarding
|
||||
? ({
|
||||
...buildOnboardingInitialValues(),
|
||||
name: LLMProviderName.LITELLM_PROXY,
|
||||
provider: LLMProviderName.LITELLM_PROXY,
|
||||
api_key: "",
|
||||
api_base: DEFAULT_API_BASE,
|
||||
default_model_name: "",
|
||||
} as LiteLLMProxyModalValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
});
|
||||
if (open === false) return null;
|
||||
|
||||
return (
|
||||
<Formik
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.LITELLM_PROXY}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
fetchedModels.length > 0 ? fetchedModels : [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: LLMProviderName.LITELLM_PROXY,
|
||||
payload: {
|
||||
...values,
|
||||
model_configurations: modelConfigsToUse,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: LLMProviderName.LITELLM_PROXY,
|
||||
values,
|
||||
initialValues,
|
||||
modelConfigurations:
|
||||
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.LITELLM_PROXY,
|
||||
values,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
{(formikProps) => (
|
||||
<LiteLLMProxyModalInternals
|
||||
formikProps={formikProps}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
fetchedModels={fetchedModels}
|
||||
setFetchedModels={setFetchedModels}
|
||||
modelConfigurations={modelConfigurations}
|
||||
isTesting={isTesting}
|
||||
onClose={onClose}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
)}
|
||||
</Formik>
|
||||
<LiteLLMProxyModalInternals
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
</ModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,47 +1,44 @@
|
||||
"use client";
|
||||
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import * as Yup from "yup";
|
||||
import { Dispatch, SetStateAction, useMemo, useState } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { Formik, FormikProps } from "formik";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import { useFormikContext } from "formik";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
} from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import {
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
BaseLLMFormValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import {
|
||||
ModelsField,
|
||||
ModelSelectionField,
|
||||
DisplayNameField,
|
||||
ModelsAccessField,
|
||||
FieldSeparator,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import { fetchOllamaModels } from "@/app/admin/configuration/llm/utils";
|
||||
import debounce from "lodash/debounce";
|
||||
import Tabs from "@/refresh-components/Tabs";
|
||||
import { Card } from "@opal/components";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import useOnMount from "@/hooks/useOnMount";
|
||||
|
||||
const OLLAMA_PROVIDER_NAME = "ollama_chat";
|
||||
const DEFAULT_API_BASE = "http://127.0.0.1:11434";
|
||||
const TAB_SELF_HOSTED = "self-hosted";
|
||||
const TAB_CLOUD = "cloud";
|
||||
const CLOUD_API_BASE = "https://ollama.com";
|
||||
|
||||
enum Tab {
|
||||
TAB_SELF_HOSTED = "self-hosted",
|
||||
TAB_CLOUD = "cloud",
|
||||
}
|
||||
|
||||
interface OllamaModalValues extends BaseLLMFormValues {
|
||||
api_base: string;
|
||||
@@ -51,104 +48,65 @@ interface OllamaModalValues extends BaseLLMFormValues {
|
||||
}
|
||||
|
||||
interface OllamaModalInternalsProps {
|
||||
formikProps: FormikProps<OllamaModalValues>;
|
||||
existingLlmProvider: LLMProviderView | undefined;
|
||||
fetchedModels: ModelConfiguration[];
|
||||
setFetchedModels: (models: ModelConfiguration[]) => void;
|
||||
isTesting: boolean;
|
||||
onClose: () => void;
|
||||
isOnboarding: boolean;
|
||||
tab: Tab;
|
||||
setTab: Dispatch<SetStateAction<Tab>>;
|
||||
}
|
||||
|
||||
function OllamaModalInternals({
|
||||
formikProps,
|
||||
existingLlmProvider,
|
||||
fetchedModels,
|
||||
setFetchedModels,
|
||||
isTesting,
|
||||
onClose,
|
||||
isOnboarding,
|
||||
tab,
|
||||
setTab,
|
||||
}: OllamaModalInternalsProps) {
|
||||
const isInitialMount = useRef(true);
|
||||
const formikProps = useFormikContext<OllamaModalValues>();
|
||||
|
||||
const doFetchModels = useCallback(
|
||||
(apiBase: string, signal: AbortSignal) => {
|
||||
fetchOllamaModels({
|
||||
api_base: apiBase,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
signal,
|
||||
}).then((data) => {
|
||||
if (signal.aborted) return;
|
||||
if (data.error) {
|
||||
toast.error(data.error);
|
||||
setFetchedModels([]);
|
||||
return;
|
||||
}
|
||||
setFetchedModels(data.models);
|
||||
const isFetchDisabled = useMemo(
|
||||
() =>
|
||||
tab === Tab.TAB_SELF_HOSTED
|
||||
? !formikProps.values.api_base
|
||||
: !formikProps.values.custom_config.OLLAMA_API_KEY,
|
||||
[tab, formikProps]
|
||||
);
|
||||
|
||||
const handleFetchModels = async () => {
|
||||
// Only Ollama cloud accepts API key
|
||||
const apiBase = formikProps.values.custom_config?.OLLAMA_API_KEY
|
||||
? CLOUD_API_BASE
|
||||
: formikProps.values.api_base;
|
||||
const { models, error } = await fetchOllamaModels({
|
||||
api_base: apiBase,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
formikProps.setFieldValue("model_configurations", models);
|
||||
};
|
||||
|
||||
// Auto-fetch models on initial load when editing an existing provider
|
||||
useOnMount(() => {
|
||||
if (existingLlmProvider) {
|
||||
handleFetchModels().catch((err) => {
|
||||
toast.error(
|
||||
err instanceof Error ? err.message : "Failed to fetch models"
|
||||
);
|
||||
});
|
||||
},
|
||||
[existingLlmProvider?.name, setFetchedModels]
|
||||
);
|
||||
|
||||
const debouncedFetchModels = useMemo(
|
||||
() => debounce(doFetchModels, 500),
|
||||
[doFetchModels]
|
||||
);
|
||||
|
||||
// Skip the initial fetch for new providers — api_base starts with a default
|
||||
// value, which would otherwise trigger a fetch before the user has done
|
||||
// anything. Existing providers should still auto-fetch on mount.
|
||||
useEffect(() => {
|
||||
if (isInitialMount.current) {
|
||||
isInitialMount.current = false;
|
||||
if (!existingLlmProvider) return;
|
||||
}
|
||||
|
||||
if (formikProps.values.api_base) {
|
||||
const controller = new AbortController();
|
||||
debouncedFetchModels(formikProps.values.api_base, controller.signal);
|
||||
return () => {
|
||||
debouncedFetchModels.cancel();
|
||||
controller.abort();
|
||||
};
|
||||
} else {
|
||||
setFetchedModels([]);
|
||||
}
|
||||
}, [
|
||||
formikProps.values.api_base,
|
||||
debouncedFetchModels,
|
||||
setFetchedModels,
|
||||
existingLlmProvider,
|
||||
]);
|
||||
|
||||
const currentModels =
|
||||
fetchedModels.length > 0
|
||||
? fetchedModels
|
||||
: existingLlmProvider?.model_configurations || [];
|
||||
|
||||
const hasApiKey = !!formikProps.values.custom_config?.OLLAMA_API_KEY;
|
||||
const defaultTab =
|
||||
existingLlmProvider && hasApiKey ? TAB_CLOUD : TAB_SELF_HOSTED;
|
||||
});
|
||||
|
||||
return (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={OLLAMA_PROVIDER_NAME}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
>
|
||||
<>
|
||||
<Card backgroundVariant="light" borderVariant="none" sizeVariant="lg">
|
||||
<Tabs defaultValue={defaultTab}>
|
||||
<Tabs value={tab} onValueChange={(value) => setTab(value as Tab)}>
|
||||
<Tabs.List>
|
||||
<Tabs.Trigger value={TAB_SELF_HOSTED}>
|
||||
<Tabs.Trigger value={Tab.TAB_SELF_HOSTED}>
|
||||
Self-hosted Ollama
|
||||
</Tabs.Trigger>
|
||||
<Tabs.Trigger value={TAB_CLOUD}>Ollama Cloud</Tabs.Trigger>
|
||||
<Tabs.Trigger value={Tab.TAB_CLOUD}>Ollama Cloud</Tabs.Trigger>
|
||||
</Tabs.List>
|
||||
<Tabs.Content value={TAB_SELF_HOSTED}>
|
||||
<Tabs.Content value={Tab.TAB_SELF_HOSTED} padding={0}>
|
||||
<InputLayouts.Vertical
|
||||
name="api_base"
|
||||
title="API Base URL"
|
||||
@@ -161,7 +119,7 @@ function OllamaModalInternals({
|
||||
</InputLayouts.Vertical>
|
||||
</Tabs.Content>
|
||||
|
||||
<Tabs.Content value={TAB_CLOUD}>
|
||||
<Tabs.Content value={Tab.TAB_CLOUD}>
|
||||
<InputLayouts.Vertical
|
||||
name="custom_config.OLLAMA_API_KEY"
|
||||
title="API Key"
|
||||
@@ -178,31 +136,24 @@ function OllamaModalInternals({
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<FieldSeparator />
|
||||
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. llama3.1" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={currentModels}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={null}
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
/>
|
||||
)}
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</LLMConfigurationModalWrapper>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -212,67 +163,55 @@ export default function OllamaModal({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
onSuccess,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } =
|
||||
useWellKnownLLMProvider(OLLAMA_PROVIDER_NAME);
|
||||
const apiKey = existingLlmProvider?.custom_config?.OLLAMA_API_KEY;
|
||||
const defaultTab =
|
||||
existingLlmProvider && !!apiKey ? Tab.TAB_CLOUD : Tab.TAB_SELF_HOSTED;
|
||||
const [tab, setTab] = useState<Tab>(defaultTab);
|
||||
|
||||
if (open === false) return null;
|
||||
const initialValues: OllamaModalValues = {
|
||||
...useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.OLLAMA_CHAT,
|
||||
existingLlmProvider
|
||||
),
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
custom_config: {
|
||||
OLLAMA_API_KEY: apiKey,
|
||||
},
|
||||
} as OllamaModalValues;
|
||||
|
||||
const validationSchema = useMemo(
|
||||
() =>
|
||||
buildValidationSchema(isOnboarding, {
|
||||
apiBase: tab === Tab.TAB_SELF_HOSTED,
|
||||
extra:
|
||||
tab === Tab.TAB_CLOUD
|
||||
? {
|
||||
custom_config: Yup.object({
|
||||
OLLAMA_API_KEY: Yup.string().required("API Key is required"),
|
||||
}),
|
||||
}
|
||||
: undefined,
|
||||
}),
|
||||
[tab, isOnboarding]
|
||||
);
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const initialValues: OllamaModalValues = isOnboarding
|
||||
? ({
|
||||
...buildOnboardingInitialValues(),
|
||||
name: OLLAMA_PROVIDER_NAME,
|
||||
provider: OLLAMA_PROVIDER_NAME,
|
||||
api_base: DEFAULT_API_BASE,
|
||||
default_model_name: "",
|
||||
custom_config: {
|
||||
OLLAMA_API_KEY: "",
|
||||
},
|
||||
} as OllamaModalValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
custom_config: {
|
||||
OLLAMA_API_KEY:
|
||||
(existingLlmProvider?.custom_config?.OLLAMA_API_KEY as string) ??
|
||||
"",
|
||||
},
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
});
|
||||
if (open === false) return null;
|
||||
|
||||
return (
|
||||
<Formik
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.OLLAMA_CHAT}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
const filteredCustomConfig = Object.fromEntries(
|
||||
Object.entries(values.custom_config || {}).filter(([, v]) => v !== "")
|
||||
);
|
||||
@@ -285,50 +224,39 @@ export default function OllamaModal({
|
||||
: undefined,
|
||||
};
|
||||
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
fetchedModels.length > 0 ? fetchedModels : [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: OLLAMA_PROVIDER_NAME,
|
||||
payload: {
|
||||
...submitValues,
|
||||
model_configurations: modelConfigsToUse,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: OLLAMA_PROVIDER_NAME,
|
||||
values: submitValues,
|
||||
initialValues,
|
||||
modelConfigurations:
|
||||
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.OLLAMA_CHAT,
|
||||
values: submitValues,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
{(formikProps) => (
|
||||
<OllamaModalInternals
|
||||
formikProps={formikProps}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
fetchedModels={fetchedModels}
|
||||
setFetchedModels={setFetchedModels}
|
||||
isTesting={isTesting}
|
||||
onClose={onClose}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
)}
|
||||
</Formik>
|
||||
<OllamaModalInternals
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
isOnboarding={isOnboarding}
|
||||
tab={tab}
|
||||
setTab={setTab}
|
||||
/>
|
||||
</ModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
177
web/src/sections/modals/llmConfig/OpenAICompatibleModal.tsx
Normal file
177
web/src/sections/modals/llmConfig/OpenAICompatibleModal.tsx
Normal file
@@ -0,0 +1,177 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect } from "react";
|
||||
import { markdown } from "@opal/utils";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { useFormikContext } from "formik";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
} from "@/interfaces/llm";
|
||||
import { fetchOpenAICompatibleModels } from "@/app/admin/configuration/llm/utils";
|
||||
import {
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
BaseLLMFormValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
APIBaseField,
|
||||
APIKeyField,
|
||||
ModelSelectionField,
|
||||
DisplayNameField,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
|
||||
interface OpenAICompatibleModalValues extends BaseLLMFormValues {
|
||||
api_key: string;
|
||||
api_base: string;
|
||||
}
|
||||
|
||||
interface OpenAICompatibleModalInternalsProps {
|
||||
existingLlmProvider: LLMProviderView | undefined;
|
||||
isOnboarding: boolean;
|
||||
}
|
||||
|
||||
function OpenAICompatibleModalInternals({
|
||||
existingLlmProvider,
|
||||
isOnboarding,
|
||||
}: OpenAICompatibleModalInternalsProps) {
|
||||
const formikProps = useFormikContext<OpenAICompatibleModalValues>();
|
||||
|
||||
const isFetchDisabled = !formikProps.values.api_base;
|
||||
|
||||
const handleFetchModels = async () => {
|
||||
const { models, error } = await fetchOpenAICompatibleModels({
|
||||
api_base: formikProps.values.api_base,
|
||||
api_key: formikProps.values.api_key || undefined,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
formikProps.setFieldValue("model_configurations", models);
|
||||
};
|
||||
|
||||
// Auto-fetch models on initial load when editing an existing provider
|
||||
useEffect(() => {
|
||||
if (existingLlmProvider && !isFetchDisabled) {
|
||||
handleFetchModels().catch((err) => {
|
||||
toast.error(
|
||||
err instanceof Error ? err.message : "Failed to fetch models"
|
||||
);
|
||||
});
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<>
|
||||
<APIBaseField
|
||||
subDescription="The base URL of your OpenAI-compatible server."
|
||||
placeholder="http://localhost:8000/v1"
|
||||
/>
|
||||
|
||||
<APIKeyField
|
||||
optional
|
||||
subDescription={markdown(
|
||||
"Provide an API key if your server requires authentication."
|
||||
)}
|
||||
/>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export default function OpenAICompatibleModal({
|
||||
variant = "llm-configuration",
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
onSuccess,
|
||||
}: LLMProviderFormProps) {
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const { mutate } = useSWRConfig();
|
||||
|
||||
const initialValues = useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.OPENAI_COMPATIBLE,
|
||||
existingLlmProvider
|
||||
) as OpenAICompatibleModalValues;
|
||||
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
apiBase: true,
|
||||
});
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
if (open === false) return null;
|
||||
|
||||
return (
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.OPENAI_COMPATIBLE}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.OPENAI_COMPATIBLE,
|
||||
values,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
<OpenAICompatibleModalInternals
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
</ModalWrapper>
|
||||
);
|
||||
}
|
||||
@@ -1,33 +1,23 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { Formik } from "formik";
|
||||
import { LLMProviderFormProps } from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import { LLMProviderFormProps, LLMProviderName } from "@/interfaces/llm";
|
||||
import {
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
APIKeyField,
|
||||
ModelsField,
|
||||
ModelSelectionField,
|
||||
DisplayNameField,
|
||||
FieldSeparator,
|
||||
ModelsAccessField,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
|
||||
const OPENAI_PROVIDER_NAME = "openai";
|
||||
const DEFAULT_DEFAULT_MODEL_NAME = "gpt-5.2";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
|
||||
export default function OpenAIModal({
|
||||
variant = "llm-configuration",
|
||||
@@ -35,141 +25,78 @@ export default function OpenAIModal({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
onSuccess,
|
||||
}: LLMProviderFormProps) {
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } =
|
||||
useWellKnownLLMProvider(OPENAI_PROVIDER_NAME);
|
||||
|
||||
if (open === false) return null;
|
||||
const initialValues = useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.OPENAI,
|
||||
existingLlmProvider
|
||||
);
|
||||
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
apiKey: true,
|
||||
});
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const initialValues = isOnboarding
|
||||
? {
|
||||
...buildOnboardingInitialValues(),
|
||||
name: OPENAI_PROVIDER_NAME,
|
||||
provider: OPENAI_PROVIDER_NAME,
|
||||
api_key: "",
|
||||
default_model_name: DEFAULT_DEFAULT_MODEL_NAME,
|
||||
}
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
default_model_name:
|
||||
(defaultModelName &&
|
||||
modelConfigurations.some((m) => m.name === defaultModelName)
|
||||
? defaultModelName
|
||||
: undefined) ??
|
||||
wellKnownLLMProvider?.recommended_default_model?.name ??
|
||||
DEFAULT_DEFAULT_MODEL_NAME,
|
||||
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
});
|
||||
if (open === false) return null;
|
||||
|
||||
return (
|
||||
<Formik
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.OPENAI}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
(wellKnownLLMProvider ?? llmDescriptor)?.known_models ?? [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: OPENAI_PROVIDER_NAME,
|
||||
payload: {
|
||||
...values,
|
||||
model_configurations: modelConfigsToUse,
|
||||
is_auto_mode:
|
||||
values.default_model_name === DEFAULT_DEFAULT_MODEL_NAME,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: OPENAI_PROVIDER_NAME,
|
||||
values,
|
||||
initialValues,
|
||||
modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.OPENAI,
|
||||
values,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
{(formikProps) => (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={OPENAI_PROVIDER_NAME}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
>
|
||||
<APIKeyField providerName="OpenAI" />
|
||||
<APIKeyField providerName="OpenAI" />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<FieldSeparator />
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. gpt-5.2" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={modelConfigurations}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={
|
||||
wellKnownLLMProvider?.recommended_default_model ?? null
|
||||
}
|
||||
shouldShowAutoUpdateToggle={true}
|
||||
/>
|
||||
)}
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
</>
|
||||
)}
|
||||
</LLMConfigurationModalWrapper>
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
</Formik>
|
||||
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField shouldShowAutoUpdateToggle={true} />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</ModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,73 +1,50 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import { useEffect } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { Formik, FormikProps } from "formik";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import { useFormikContext } from "formik";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
} from "@/interfaces/llm";
|
||||
import { fetchOpenRouterModels } from "@/app/admin/configuration/llm/utils";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import {
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
BaseLLMFormValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
APIKeyField,
|
||||
ModelsField,
|
||||
APIBaseField,
|
||||
ModelSelectionField,
|
||||
DisplayNameField,
|
||||
ModelsAccessField,
|
||||
FieldSeparator,
|
||||
FieldWrapper,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
|
||||
const OPENROUTER_PROVIDER_NAME = "openrouter";
|
||||
const DEFAULT_API_BASE = "https://openrouter.ai/api/v1";
|
||||
|
||||
interface OpenRouterModalValues extends BaseLLMFormValues {
|
||||
api_key: string;
|
||||
api_base: string;
|
||||
}
|
||||
|
||||
interface OpenRouterModalInternalsProps {
|
||||
formikProps: FormikProps<OpenRouterModalValues>;
|
||||
existingLlmProvider: LLMProviderView | undefined;
|
||||
fetchedModels: ModelConfiguration[];
|
||||
setFetchedModels: (models: ModelConfiguration[]) => void;
|
||||
modelConfigurations: ModelConfiguration[];
|
||||
isTesting: boolean;
|
||||
onClose: () => void;
|
||||
isOnboarding: boolean;
|
||||
}
|
||||
|
||||
function OpenRouterModalInternals({
|
||||
formikProps,
|
||||
existingLlmProvider,
|
||||
fetchedModels,
|
||||
setFetchedModels,
|
||||
modelConfigurations,
|
||||
isTesting,
|
||||
onClose,
|
||||
isOnboarding,
|
||||
}: OpenRouterModalInternalsProps) {
|
||||
const currentModels =
|
||||
fetchedModels.length > 0
|
||||
? fetchedModels
|
||||
: existingLlmProvider?.model_configurations || modelConfigurations;
|
||||
const formikProps = useFormikContext<OpenRouterModalValues>();
|
||||
|
||||
const isFetchDisabled =
|
||||
!formikProps.values.api_base || !formikProps.values.api_key;
|
||||
@@ -76,12 +53,12 @@ function OpenRouterModalInternals({
|
||||
const { models, error } = await fetchOpenRouterModels({
|
||||
api_base: formikProps.values.api_base,
|
||||
api_key: formikProps.values.api_key,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
provider_name: LLMProviderName.OPENROUTER,
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
setFetchedModels(models);
|
||||
formikProps.setFieldValue("model_configurations", models);
|
||||
};
|
||||
|
||||
// Auto-fetch models on initial load when editing an existing provider
|
||||
@@ -97,58 +74,34 @@ function OpenRouterModalInternals({
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={OPENROUTER_PROVIDER_NAME}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
>
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="api_base"
|
||||
title="API Base URL"
|
||||
subDescription="Paste your OpenRouter-compatible endpoint URL or use OpenRouter API directly."
|
||||
>
|
||||
<InputTypeInField
|
||||
name="api_base"
|
||||
placeholder="Your OpenRouter base URL"
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
<>
|
||||
<APIBaseField
|
||||
subDescription="Paste your OpenRouter-compatible endpoint URL or use OpenRouter API directly."
|
||||
placeholder="Your OpenRouter base URL"
|
||||
/>
|
||||
|
||||
<APIKeyField providerName="OpenRouter" />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<FieldSeparator />
|
||||
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. openai/gpt-4o" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={currentModels}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={null}
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
)}
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</LLMConfigurationModalWrapper>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -158,109 +111,68 @@ export default function OpenRouterModal({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
onSuccess,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
|
||||
OPENROUTER_PROVIDER_NAME
|
||||
);
|
||||
|
||||
if (open === false) return null;
|
||||
const initialValues: OpenRouterModalValues = {
|
||||
...useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.OPENROUTER,
|
||||
existingLlmProvider
|
||||
),
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
} as OpenRouterModalValues;
|
||||
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
apiKey: true,
|
||||
apiBase: true,
|
||||
});
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const initialValues: OpenRouterModalValues = isOnboarding
|
||||
? ({
|
||||
...buildOnboardingInitialValues(),
|
||||
name: OPENROUTER_PROVIDER_NAME,
|
||||
provider: OPENROUTER_PROVIDER_NAME,
|
||||
api_key: "",
|
||||
api_base: DEFAULT_API_BASE,
|
||||
default_model_name: "",
|
||||
} as OpenRouterModalValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
});
|
||||
if (open === false) return null;
|
||||
|
||||
return (
|
||||
<Formik
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.OPENROUTER}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
fetchedModels.length > 0 ? fetchedModels : [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: OPENROUTER_PROVIDER_NAME,
|
||||
payload: {
|
||||
...values,
|
||||
model_configurations: modelConfigsToUse,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: OPENROUTER_PROVIDER_NAME,
|
||||
values,
|
||||
initialValues,
|
||||
modelConfigurations:
|
||||
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.OPENROUTER,
|
||||
values,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
{(formikProps) => (
|
||||
<OpenRouterModalInternals
|
||||
formikProps={formikProps}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
fetchedModels={fetchedModels}
|
||||
setFetchedModels={setFetchedModels}
|
||||
modelConfigurations={modelConfigurations}
|
||||
isTesting={isTesting}
|
||||
onClose={onClose}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
)}
|
||||
</Formik>
|
||||
<OpenRouterModalInternals
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
</ModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,38 +1,27 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { Formik } from "formik";
|
||||
import { FileUploadFormField } from "@/components/Field";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import { LLMProviderFormProps } from "@/interfaces/llm";
|
||||
import { LLMProviderFormProps, LLMProviderName } from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import {
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
BaseLLMFormValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import {
|
||||
ModelsField,
|
||||
ModelSelectionField,
|
||||
DisplayNameField,
|
||||
FieldSeparator,
|
||||
FieldWrapper,
|
||||
ModelsAccessField,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
|
||||
const VERTEXAI_PROVIDER_NAME = "vertex_ai";
|
||||
const VERTEXAI_DISPLAY_NAME = "Google Cloud Vertex AI";
|
||||
const VERTEXAI_DEFAULT_MODEL = "gemini-2.5-pro";
|
||||
const VERTEXAI_DEFAULT_LOCATION = "global";
|
||||
|
||||
interface VertexAIModalValues extends BaseLLMFormValues {
|
||||
@@ -48,87 +37,50 @@ export default function VertexAIModal({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
onSuccess,
|
||||
}: LLMProviderFormProps) {
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
|
||||
VERTEXAI_PROVIDER_NAME
|
||||
);
|
||||
|
||||
if (open === false) return null;
|
||||
const initialValues: VertexAIModalValues = {
|
||||
...useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.VERTEX_AI,
|
||||
existingLlmProvider
|
||||
),
|
||||
custom_config: {
|
||||
vertex_credentials:
|
||||
(existingLlmProvider?.custom_config?.vertex_credentials as string) ??
|
||||
"",
|
||||
vertex_location:
|
||||
(existingLlmProvider?.custom_config?.vertex_location as string) ??
|
||||
VERTEXAI_DEFAULT_LOCATION,
|
||||
},
|
||||
} as VertexAIModalValues;
|
||||
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
extra: {
|
||||
custom_config: Yup.object({
|
||||
vertex_credentials: Yup.string().required(
|
||||
"Credentials file is required"
|
||||
),
|
||||
vertex_location: Yup.string(),
|
||||
}),
|
||||
},
|
||||
});
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const initialValues: VertexAIModalValues = isOnboarding
|
||||
? ({
|
||||
...buildOnboardingInitialValues(),
|
||||
name: VERTEXAI_PROVIDER_NAME,
|
||||
provider: VERTEXAI_PROVIDER_NAME,
|
||||
default_model_name: VERTEXAI_DEFAULT_MODEL,
|
||||
custom_config: {
|
||||
vertex_credentials: "",
|
||||
vertex_location: VERTEXAI_DEFAULT_LOCATION,
|
||||
},
|
||||
} as VertexAIModalValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
default_model_name:
|
||||
(defaultModelName &&
|
||||
modelConfigurations.some((m) => m.name === defaultModelName)
|
||||
? defaultModelName
|
||||
: undefined) ??
|
||||
wellKnownLLMProvider?.recommended_default_model?.name ??
|
||||
VERTEXAI_DEFAULT_MODEL,
|
||||
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,
|
||||
custom_config: {
|
||||
vertex_credentials:
|
||||
(existingLlmProvider?.custom_config
|
||||
?.vertex_credentials as string) ?? "",
|
||||
vertex_location:
|
||||
(existingLlmProvider?.custom_config?.vertex_location as string) ??
|
||||
VERTEXAI_DEFAULT_LOCATION,
|
||||
},
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
custom_config: Yup.object({
|
||||
vertex_credentials: Yup.string().required(
|
||||
"Credentials file is required"
|
||||
),
|
||||
vertex_location: Yup.string(),
|
||||
}),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
custom_config: Yup.object({
|
||||
vertex_credentials: Yup.string().required(
|
||||
"Credentials file is required"
|
||||
),
|
||||
vertex_location: Yup.string(),
|
||||
}),
|
||||
});
|
||||
if (open === false) return null;
|
||||
|
||||
return (
|
||||
<Formik
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.VERTEX_AI}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
const filteredCustomConfig = Object.fromEntries(
|
||||
Object.entries(values.custom_config || {}).filter(
|
||||
([key, v]) => key === "vertex_credentials" || v !== ""
|
||||
@@ -143,101 +95,75 @@ export default function VertexAIModal({
|
||||
: undefined,
|
||||
};
|
||||
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
(wellKnownLLMProvider ?? llmDescriptor)?.known_models ?? [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: VERTEXAI_PROVIDER_NAME,
|
||||
payload: {
|
||||
...submitValues,
|
||||
model_configurations: modelConfigsToUse,
|
||||
is_auto_mode:
|
||||
values.default_model_name === VERTEXAI_DEFAULT_MODEL,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: VERTEXAI_PROVIDER_NAME,
|
||||
values: submitValues,
|
||||
initialValues,
|
||||
modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.VERTEX_AI,
|
||||
values: submitValues,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
{(formikProps) => (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={VERTEXAI_PROVIDER_NAME}
|
||||
providerName={VERTEXAI_DISPLAY_NAME}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
<InputLayouts.FieldPadder>
|
||||
<InputLayouts.Vertical
|
||||
name="custom_config.vertex_location"
|
||||
title="Google Cloud Region Name"
|
||||
subDescription="Region where your Google Vertex AI models are hosted. See full list of regions supported at Google Cloud."
|
||||
>
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="custom_config.vertex_location"
|
||||
title="Google Cloud Region Name"
|
||||
subDescription="Region where your Google Vertex AI models are hosted. See full list of regions supported at Google Cloud."
|
||||
>
|
||||
<InputTypeInField
|
||||
name="custom_config.vertex_location"
|
||||
placeholder={VERTEXAI_DEFAULT_LOCATION}
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
<InputTypeInField
|
||||
name="custom_config.vertex_location"
|
||||
placeholder={VERTEXAI_DEFAULT_LOCATION}
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</InputLayouts.FieldPadder>
|
||||
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="custom_config.vertex_credentials"
|
||||
title="API Key"
|
||||
subDescription="Attach your API key JSON from Google Cloud to access your models."
|
||||
>
|
||||
<FileUploadFormField
|
||||
name="custom_config.vertex_credentials"
|
||||
label=""
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
<InputLayouts.FieldPadder>
|
||||
<InputLayouts.Vertical
|
||||
name="custom_config.vertex_credentials"
|
||||
title="API Key"
|
||||
subDescription="Attach your API key JSON from Google Cloud to access your models."
|
||||
>
|
||||
<FileUploadFormField
|
||||
name="custom_config.vertex_credentials"
|
||||
label=""
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</InputLayouts.FieldPadder>
|
||||
|
||||
<FieldSeparator />
|
||||
|
||||
{!isOnboarding && (
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
)}
|
||||
|
||||
<FieldSeparator />
|
||||
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. gemini-2.5-pro" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={modelConfigurations}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={
|
||||
wellKnownLLMProvider?.recommended_default_model ?? null
|
||||
}
|
||||
shouldShowAutoUpdateToggle={true}
|
||||
/>
|
||||
)}
|
||||
|
||||
{!isOnboarding && <ModelsAccessField formikProps={formikProps} />}
|
||||
</LLMConfigurationModalWrapper>
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
</Formik>
|
||||
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField shouldShowAutoUpdateToggle={true} />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</ModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -7,9 +7,10 @@ import VertexAIModal from "@/sections/modals/llmConfig/VertexAIModal";
|
||||
import OpenRouterModal from "@/sections/modals/llmConfig/OpenRouterModal";
|
||||
import CustomModal from "@/sections/modals/llmConfig/CustomModal";
|
||||
import BedrockModal from "@/sections/modals/llmConfig/BedrockModal";
|
||||
import LMStudioForm from "@/sections/modals/llmConfig/LMStudioForm";
|
||||
import LMStudioModal from "@/sections/modals/llmConfig/LMStudioModal";
|
||||
import LiteLLMProxyModal from "@/sections/modals/llmConfig/LiteLLMProxyModal";
|
||||
import BifrostModal from "@/sections/modals/llmConfig/BifrostModal";
|
||||
import OpenAICompatibleModal from "@/sections/modals/llmConfig/OpenAICompatibleModal";
|
||||
|
||||
function detectIfRealOpenAIProvider(provider: LLMProviderView) {
|
||||
return (
|
||||
@@ -54,11 +55,13 @@ export function getModalForExistingProvider(
|
||||
case LLMProviderName.OPENROUTER:
|
||||
return <OpenRouterModal {...props} />;
|
||||
case LLMProviderName.LM_STUDIO:
|
||||
return <LMStudioForm {...props} />;
|
||||
return <LMStudioModal {...props} />;
|
||||
case LLMProviderName.LITELLM_PROXY:
|
||||
return <LiteLLMProxyModal {...props} />;
|
||||
case LLMProviderName.BIFROST:
|
||||
return <BifrostModal {...props} />;
|
||||
case LLMProviderName.OPENAI_COMPATIBLE:
|
||||
return <OpenAICompatibleModal {...props} />;
|
||||
default:
|
||||
return <CustomModal {...props} />;
|
||||
}
|
||||
|
||||
@@ -1,30 +1,31 @@
|
||||
"use client";
|
||||
|
||||
import { ReactNode } from "react";
|
||||
import { Form, FormikProps } from "formik";
|
||||
import React, { useEffect, useRef, useState } from "react";
|
||||
import { Formik, Form, useFormikContext } from "formik";
|
||||
import type { FormikConfig } from "formik";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import { useAgents } from "@/hooks/useAgents";
|
||||
import { useUserGroups } from "@/lib/hooks";
|
||||
import { ModelConfiguration, SimpleKnownModel } from "@/interfaces/llm";
|
||||
import { LLMProviderView, ModelConfiguration } from "@/interfaces/llm";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import Checkbox from "@/refresh-components/inputs/Checkbox";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import InputComboBox from "@/refresh-components/inputs/InputComboBox";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
|
||||
import Switch from "@/refresh-components/inputs/Switch";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Button, LineItemButton, Tag } from "@opal/components";
|
||||
import { Button, LineItemButton } from "@opal/components";
|
||||
import { BaseLLMFormValues } from "@/sections/modals/llmConfig/utils";
|
||||
import { WithoutStyles } from "@opal/types";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import type { RichStr } from "@opal/types";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { Disabled, Hoverable } from "@opal/core";
|
||||
import { Content } from "@opal/layouts";
|
||||
import {
|
||||
SvgArrowExchange,
|
||||
SvgOnyxOctagon,
|
||||
SvgOrganization,
|
||||
SvgPlusCircle,
|
||||
SvgRefreshCw,
|
||||
SvgSparkle,
|
||||
SvgUserManage,
|
||||
@@ -46,27 +47,14 @@ import {
|
||||
getProviderProductName,
|
||||
} from "@/lib/llmConfig/providers";
|
||||
|
||||
export function FieldSeparator() {
|
||||
return <Separator noPadding className="px-2" />;
|
||||
}
|
||||
|
||||
export type FieldWrapperProps = WithoutStyles<
|
||||
React.HTMLAttributes<HTMLDivElement>
|
||||
>;
|
||||
|
||||
export function FieldWrapper(props: FieldWrapperProps) {
|
||||
return <div {...props} className="p-2 w-full" />;
|
||||
}
|
||||
|
||||
// ─── DisplayNameField ────────────────────────────────────────────────────────
|
||||
|
||||
export interface DisplayNameFieldProps {
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
export function DisplayNameField({ disabled = false }: DisplayNameFieldProps) {
|
||||
return (
|
||||
<FieldWrapper>
|
||||
<InputLayouts.FieldPadder>
|
||||
<InputLayouts.Vertical
|
||||
name="name"
|
||||
title="Display Name"
|
||||
@@ -78,56 +66,68 @@ export function DisplayNameField({ disabled = false }: DisplayNameFieldProps) {
|
||||
variant={disabled ? "disabled" : undefined}
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
</InputLayouts.FieldPadder>
|
||||
);
|
||||
}
|
||||
|
||||
// ─── APIKeyField ─────────────────────────────────────────────────────────────
|
||||
|
||||
export interface APIKeyFieldProps {
|
||||
/** Formik field name. @default "api_key" */
|
||||
name?: string;
|
||||
optional?: boolean;
|
||||
providerName?: string;
|
||||
subDescription?: string | RichStr;
|
||||
}
|
||||
|
||||
export function APIKeyField({
|
||||
name = "api_key",
|
||||
optional = false,
|
||||
providerName,
|
||||
subDescription,
|
||||
}: APIKeyFieldProps) {
|
||||
return (
|
||||
<FieldWrapper>
|
||||
<InputLayouts.FieldPadder>
|
||||
<InputLayouts.Vertical
|
||||
name="api_key"
|
||||
name={name}
|
||||
title="API Key"
|
||||
subDescription={
|
||||
providerName
|
||||
? `Paste your API key from ${providerName} to access your models.`
|
||||
: "Paste your API key to access your models."
|
||||
subDescription
|
||||
? subDescription
|
||||
: providerName
|
||||
? `Paste your API key from ${providerName} to access your models.`
|
||||
: "Paste your API key to access your models."
|
||||
}
|
||||
optional={optional}
|
||||
>
|
||||
<PasswordInputTypeInField name="api_key" placeholder="API Key" />
|
||||
<PasswordInputTypeInField name={name} />
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
</InputLayouts.FieldPadder>
|
||||
);
|
||||
}
|
||||
|
||||
// ─── SingleDefaultModelField ─────────────────────────────────────────────────
|
||||
// ─── APIBaseField ───────────────────────────────────────────────────────────
|
||||
|
||||
export interface SingleDefaultModelFieldProps {
|
||||
export interface APIBaseFieldProps {
|
||||
optional?: boolean;
|
||||
subDescription?: string | RichStr;
|
||||
placeholder?: string;
|
||||
}
|
||||
|
||||
export function SingleDefaultModelField({
|
||||
placeholder = "E.g. gpt-4o",
|
||||
}: SingleDefaultModelFieldProps) {
|
||||
export function APIBaseField({
|
||||
optional = false,
|
||||
subDescription,
|
||||
placeholder = "https://",
|
||||
}: APIBaseFieldProps) {
|
||||
return (
|
||||
<InputLayouts.Vertical
|
||||
name="default_model_name"
|
||||
title="Default Model"
|
||||
description="The model to use by default for this provider unless otherwise specified."
|
||||
>
|
||||
<InputTypeInField name="default_model_name" placeholder={placeholder} />
|
||||
</InputLayouts.Vertical>
|
||||
<InputLayouts.FieldPadder>
|
||||
<InputLayouts.Vertical
|
||||
name="api_base"
|
||||
title="API Base URL"
|
||||
subDescription={subDescription}
|
||||
optional={optional}
|
||||
>
|
||||
<InputTypeInField name="api_base" placeholder={placeholder} />
|
||||
</InputLayouts.Vertical>
|
||||
</InputLayouts.FieldPadder>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -137,13 +137,8 @@ export function SingleDefaultModelField({
|
||||
const GROUP_PREFIX = "group:";
|
||||
const AGENT_PREFIX = "agent:";
|
||||
|
||||
interface ModelsAccessFieldProps<T> {
|
||||
formikProps: FormikProps<T>;
|
||||
}
|
||||
|
||||
export function ModelsAccessField<T extends BaseLLMFormValues>({
|
||||
formikProps,
|
||||
}: ModelsAccessFieldProps<T>) {
|
||||
export function ModelAccessField() {
|
||||
const formikProps = useFormikContext<BaseLLMFormValues>();
|
||||
const { agents } = useAgents();
|
||||
const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups();
|
||||
const { data: usersData } = useUsers({ includeApiKeys: false });
|
||||
@@ -229,7 +224,7 @@ export function ModelsAccessField<T extends BaseLLMFormValues>({
|
||||
|
||||
return (
|
||||
<div className="flex flex-col w-full">
|
||||
<FieldWrapper>
|
||||
<InputLayouts.FieldPadder>
|
||||
<InputLayouts.Horizontal
|
||||
name="is_public"
|
||||
title="Models Access"
|
||||
@@ -250,7 +245,7 @@ export function ModelsAccessField<T extends BaseLLMFormValues>({
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</InputLayouts.Horizontal>
|
||||
</FieldWrapper>
|
||||
</InputLayouts.FieldPadder>
|
||||
|
||||
{!isPublic && (
|
||||
<Card backgroundVariant="light" borderVariant="none" sizeVariant="lg">
|
||||
@@ -316,7 +311,7 @@ export function ModelsAccessField<T extends BaseLLMFormValues>({
|
||||
</div>
|
||||
)}
|
||||
|
||||
<FieldSeparator />
|
||||
<InputLayouts.FieldSeparator />
|
||||
|
||||
{selectedAgentIds.length > 0 ? (
|
||||
<div className="grid grid-cols-2 gap-1 w-full">
|
||||
@@ -371,84 +366,73 @@ export function ModelsAccessField<T extends BaseLLMFormValues>({
|
||||
|
||||
// ─── ModelsField ─────────────────────────────────────────────────────
|
||||
|
||||
export interface ModelsFieldProps<T> {
|
||||
formikProps: FormikProps<T>;
|
||||
modelConfigurations: ModelConfiguration[];
|
||||
recommendedDefaultModel: SimpleKnownModel | null;
|
||||
export interface ModelSelectionFieldProps {
|
||||
shouldShowAutoUpdateToggle: boolean;
|
||||
/** Called when the user clicks the refresh button to re-fetch models. */
|
||||
onRefetch?: () => Promise<void> | void;
|
||||
/** Called when the user adds a custom model name (e.g. for Azure). */
|
||||
onAddModel?: (modelName: string) => void;
|
||||
}
|
||||
|
||||
export function ModelsField<T extends BaseLLMFormValues>({
|
||||
formikProps,
|
||||
modelConfigurations,
|
||||
recommendedDefaultModel,
|
||||
export function ModelSelectionField({
|
||||
shouldShowAutoUpdateToggle,
|
||||
onRefetch,
|
||||
}: ModelsFieldProps<T>) {
|
||||
const isAutoMode = formikProps.values.is_auto_mode;
|
||||
const selectedModels = formikProps.values.selected_model_names ?? [];
|
||||
const defaultModel = formikProps.values.default_model_name;
|
||||
onAddModel,
|
||||
}: ModelSelectionFieldProps) {
|
||||
const formikProps = useFormikContext<BaseLLMFormValues>();
|
||||
const [newModelName, setNewModelName] = useState("");
|
||||
// When the auto-update toggle is hidden, auto mode should have no effect —
|
||||
// otherwise models can't be deselected and "Select All" stays disabled.
|
||||
const isAutoMode =
|
||||
shouldShowAutoUpdateToggle && formikProps.values.is_auto_mode;
|
||||
const models = formikProps.values.model_configurations;
|
||||
|
||||
function handleCheckboxChange(modelName: string, checked: boolean) {
|
||||
// Read current values inside the handler to avoid stale closure issues
|
||||
const currentSelected = formikProps.values.selected_model_names ?? [];
|
||||
const currentDefault = formikProps.values.default_model_name;
|
||||
|
||||
if (checked) {
|
||||
const newSelected = [...currentSelected, modelName];
|
||||
formikProps.setFieldValue("selected_model_names", newSelected);
|
||||
// If this is the first model, set it as default
|
||||
if (currentSelected.length === 0) {
|
||||
formikProps.setFieldValue("default_model_name", modelName);
|
||||
}
|
||||
} else {
|
||||
const newSelected = currentSelected.filter((name) => name !== modelName);
|
||||
formikProps.setFieldValue("selected_model_names", newSelected);
|
||||
// If removing the default, set the first remaining model as default
|
||||
if (currentDefault === modelName && newSelected.length > 0) {
|
||||
formikProps.setFieldValue("default_model_name", newSelected[0]);
|
||||
} else if (newSelected.length === 0) {
|
||||
formikProps.setFieldValue("default_model_name", undefined);
|
||||
}
|
||||
// Snapshot the original model visibility so we can restore it when
|
||||
// toggling auto mode back on.
|
||||
const originalModelsRef = useRef(models);
|
||||
useEffect(() => {
|
||||
if (originalModelsRef.current.length === 0 && models.length > 0) {
|
||||
originalModelsRef.current = models;
|
||||
}
|
||||
}
|
||||
}, [models]);
|
||||
|
||||
function handleSetDefault(modelName: string) {
|
||||
formikProps.setFieldValue("default_model_name", modelName);
|
||||
// Automatically derive test_model_name from model_configurations.
|
||||
// Any change to visibility or the model list syncs this automatically.
|
||||
useEffect(() => {
|
||||
const firstVisible = models.find((m) => m.is_visible)?.name;
|
||||
if (firstVisible !== formikProps.values.test_model_name) {
|
||||
formikProps.setFieldValue("test_model_name", firstVisible);
|
||||
}
|
||||
}, [models]); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
function setVisibility(modelName: string, visible: boolean) {
|
||||
const updated = models.map((m) =>
|
||||
m.name === modelName ? { ...m, is_visible: visible } : m
|
||||
);
|
||||
formikProps.setFieldValue("model_configurations", updated);
|
||||
}
|
||||
|
||||
function handleToggleAutoMode(nextIsAutoMode: boolean) {
|
||||
formikProps.setFieldValue("is_auto_mode", nextIsAutoMode);
|
||||
formikProps.setFieldValue(
|
||||
"selected_model_names",
|
||||
modelConfigurations.filter((m) => m.is_visible).map((m) => m.name)
|
||||
);
|
||||
formikProps.setFieldValue(
|
||||
"default_model_name",
|
||||
recommendedDefaultModel?.name ?? undefined
|
||||
);
|
||||
}
|
||||
|
||||
const allSelected =
|
||||
modelConfigurations.length > 0 &&
|
||||
modelConfigurations.every((m) => selectedModels.includes(m.name));
|
||||
|
||||
function handleToggleSelectAll() {
|
||||
if (allSelected) {
|
||||
formikProps.setFieldValue("selected_model_names", []);
|
||||
formikProps.setFieldValue("default_model_name", undefined);
|
||||
} else {
|
||||
const allNames = modelConfigurations.map((m) => m.name);
|
||||
formikProps.setFieldValue("selected_model_names", allNames);
|
||||
if (!formikProps.values.default_model_name && allNames.length > 0) {
|
||||
formikProps.setFieldValue("default_model_name", allNames[0]);
|
||||
}
|
||||
if (nextIsAutoMode) {
|
||||
formikProps.setFieldValue(
|
||||
"model_configurations",
|
||||
originalModelsRef.current
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const visibleModels = modelConfigurations.filter((m) => m.is_visible);
|
||||
const allSelected = models.length > 0 && models.every((m) => m.is_visible);
|
||||
|
||||
function handleToggleSelectAll() {
|
||||
const nextVisible = !allSelected;
|
||||
const updated = models.map((m) => ({
|
||||
...m,
|
||||
is_visible: nextVisible,
|
||||
}));
|
||||
formikProps.setFieldValue("model_configurations", updated);
|
||||
}
|
||||
|
||||
const visibleModels = models.filter((m) => m.is_visible);
|
||||
|
||||
return (
|
||||
<Card backgroundVariant="light" borderVariant="none" sizeVariant="lg">
|
||||
@@ -460,15 +444,14 @@ export function ModelsField<T extends BaseLLMFormValues>({
|
||||
center
|
||||
>
|
||||
<Section flexDirection="row" gap={0}>
|
||||
<Disabled disabled={isAutoMode || modelConfigurations.length === 0}>
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
size="md"
|
||||
onClick={handleToggleSelectAll}
|
||||
>
|
||||
{allSelected ? "Unselect All" : "Select All"}
|
||||
</Button>
|
||||
</Disabled>
|
||||
<Button
|
||||
disabled={isAutoMode || models.length === 0}
|
||||
prominence="tertiary"
|
||||
size="md"
|
||||
onClick={handleToggleSelectAll}
|
||||
>
|
||||
{allSelected ? "Unselect All" : "Select All"}
|
||||
</Button>
|
||||
{onRefetch && (
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
@@ -489,91 +472,75 @@ export function ModelsField<T extends BaseLLMFormValues>({
|
||||
</Section>
|
||||
</InputLayouts.Horizontal>
|
||||
|
||||
{modelConfigurations.length === 0 ? (
|
||||
{models.length === 0 ? (
|
||||
<EmptyMessageCard title="No models available." />
|
||||
) : (
|
||||
<Section gap={0.25}>
|
||||
{isAutoMode
|
||||
? // Auto mode: read-only display
|
||||
visibleModels.map((model) => (
|
||||
<Hoverable.Root
|
||||
? visibleModels.map((model) => (
|
||||
<LineItemButton
|
||||
key={model.name}
|
||||
group="LLMConfigurationButton"
|
||||
widthVariant="full"
|
||||
>
|
||||
<LineItemButton
|
||||
variant="section"
|
||||
sizePreset="main-ui"
|
||||
selectVariant="select-heavy"
|
||||
state="selected"
|
||||
icon={() => <Checkbox checked />}
|
||||
title={model.display_name || model.name}
|
||||
rightChildren={
|
||||
model.name === defaultModel ? (
|
||||
<Section>
|
||||
<Tag title="Default Model" color="blue" />
|
||||
</Section>
|
||||
) : undefined
|
||||
}
|
||||
/>
|
||||
</Hoverable.Root>
|
||||
variant="section"
|
||||
sizePreset="main-ui"
|
||||
selectVariant="select-heavy"
|
||||
state="selected"
|
||||
icon={() => <Checkbox checked />}
|
||||
title={model.display_name || model.name}
|
||||
/>
|
||||
))
|
||||
: // Manual mode: checkbox selection
|
||||
modelConfigurations.map((modelConfiguration) => {
|
||||
const isSelected = selectedModels.includes(
|
||||
modelConfiguration.name
|
||||
);
|
||||
const isDefault = defaultModel === modelConfiguration.name;
|
||||
: models.map((model) => (
|
||||
<LineItemButton
|
||||
key={model.name}
|
||||
variant="section"
|
||||
sizePreset="main-ui"
|
||||
selectVariant="select-heavy"
|
||||
state={model.is_visible ? "selected" : "empty"}
|
||||
icon={() => <Checkbox checked={model.is_visible} />}
|
||||
title={model.name}
|
||||
onClick={() => setVisibility(model.name, !model.is_visible)}
|
||||
/>
|
||||
))}
|
||||
</Section>
|
||||
)}
|
||||
|
||||
return (
|
||||
<Hoverable.Root
|
||||
key={modelConfiguration.name}
|
||||
group="LLMConfigurationButton"
|
||||
widthVariant="full"
|
||||
>
|
||||
<LineItemButton
|
||||
variant="section"
|
||||
sizePreset="main-ui"
|
||||
selectVariant="select-heavy"
|
||||
state={isSelected ? "selected" : "empty"}
|
||||
icon={() => <Checkbox checked={isSelected} />}
|
||||
title={modelConfiguration.name}
|
||||
onClick={() =>
|
||||
handleCheckboxChange(
|
||||
modelConfiguration.name,
|
||||
!isSelected
|
||||
)
|
||||
}
|
||||
rightChildren={
|
||||
isSelected ? (
|
||||
isDefault ? (
|
||||
<Section>
|
||||
<Tag color="blue" title="Default Model" />
|
||||
</Section>
|
||||
) : (
|
||||
<Hoverable.Item
|
||||
group="LLMConfigurationButton"
|
||||
variant="opacity-on-hover"
|
||||
>
|
||||
<Button
|
||||
size="sm"
|
||||
prominence="internal"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
handleSetDefault(modelConfiguration.name);
|
||||
}}
|
||||
type="button"
|
||||
>
|
||||
Set as default
|
||||
</Button>
|
||||
</Hoverable.Item>
|
||||
)
|
||||
) : undefined
|
||||
}
|
||||
/>
|
||||
</Hoverable.Root>
|
||||
);
|
||||
})}
|
||||
{onAddModel && !isAutoMode && (
|
||||
<Section flexDirection="row" gap={0.5}>
|
||||
<div className="flex-1">
|
||||
<InputTypeIn
|
||||
placeholder="Enter model name"
|
||||
value={newModelName}
|
||||
onChange={(e) => setNewModelName(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter" && newModelName.trim()) {
|
||||
e.preventDefault();
|
||||
const trimmed = newModelName.trim();
|
||||
if (!models.some((m) => m.name === trimmed)) {
|
||||
onAddModel(trimmed);
|
||||
setNewModelName("");
|
||||
}
|
||||
}
|
||||
}}
|
||||
showClearButton={false}
|
||||
/>
|
||||
</div>
|
||||
<Button
|
||||
prominence="secondary"
|
||||
icon={SvgPlusCircle}
|
||||
type="button"
|
||||
disabled={
|
||||
!newModelName.trim() ||
|
||||
models.some((m) => m.name === newModelName.trim())
|
||||
}
|
||||
onClick={() => {
|
||||
const trimmed = newModelName.trim();
|
||||
if (trimmed && !models.some((m) => m.name === trimmed)) {
|
||||
onAddModel(trimmed);
|
||||
setNewModelName("");
|
||||
}
|
||||
}}
|
||||
>
|
||||
Add Model
|
||||
</Button>
|
||||
</Section>
|
||||
)}
|
||||
|
||||
@@ -593,41 +560,96 @@ export function ModelsField<T extends BaseLLMFormValues>({
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// LLMConfigurationModalWrapper
|
||||
// ============================================================================
|
||||
// ─── ModalWrapper ─────────────────────────────────────────────────────
|
||||
|
||||
interface LLMConfigurationModalWrapperProps {
|
||||
providerEndpoint: string;
|
||||
providerName?: string;
|
||||
existingProviderName?: string;
|
||||
export interface ModalWrapperProps<
|
||||
T extends BaseLLMFormValues = BaseLLMFormValues,
|
||||
> {
|
||||
providerName: string;
|
||||
llmProvider?: LLMProviderView;
|
||||
onClose: () => void;
|
||||
isFormValid: boolean;
|
||||
isDirty?: boolean;
|
||||
isTesting?: boolean;
|
||||
isSubmitting?: boolean;
|
||||
children: ReactNode;
|
||||
initialValues: T;
|
||||
validationSchema: FormikConfig<T>["validationSchema"];
|
||||
onSubmit: FormikConfig<T>["onSubmit"];
|
||||
children: React.ReactNode;
|
||||
}
|
||||
export function ModalWrapper<T extends BaseLLMFormValues = BaseLLMFormValues>({
|
||||
providerName,
|
||||
llmProvider,
|
||||
onClose,
|
||||
initialValues,
|
||||
validationSchema,
|
||||
onSubmit,
|
||||
children,
|
||||
}: ModalWrapperProps<T>) {
|
||||
return (
|
||||
<Formik
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
validateOnMount
|
||||
onSubmit={onSubmit}
|
||||
>
|
||||
{() => (
|
||||
<ModalWrapperInner
|
||||
providerName={providerName}
|
||||
llmProvider={llmProvider}
|
||||
onClose={onClose}
|
||||
modelConfigurations={initialValues.model_configurations}
|
||||
>
|
||||
{children}
|
||||
</ModalWrapperInner>
|
||||
)}
|
||||
</Formik>
|
||||
);
|
||||
}
|
||||
|
||||
export function LLMConfigurationModalWrapper({
|
||||
providerEndpoint,
|
||||
interface ModalWrapperInnerProps {
|
||||
providerName: string;
|
||||
llmProvider?: LLMProviderView;
|
||||
onClose: () => void;
|
||||
modelConfigurations?: ModelConfiguration[];
|
||||
children: React.ReactNode;
|
||||
}
|
||||
function ModalWrapperInner({
|
||||
providerName,
|
||||
existingProviderName,
|
||||
llmProvider,
|
||||
onClose,
|
||||
isFormValid,
|
||||
isDirty,
|
||||
isTesting,
|
||||
isSubmitting,
|
||||
modelConfigurations,
|
||||
children,
|
||||
}: LLMConfigurationModalWrapperProps) {
|
||||
const busy = isTesting || isSubmitting;
|
||||
const providerIcon = getProviderIcon(providerEndpoint);
|
||||
const providerDisplayName =
|
||||
providerName ?? getProviderDisplayName(providerEndpoint);
|
||||
const providerProductName = getProviderProductName(providerEndpoint);
|
||||
}: ModalWrapperInnerProps) {
|
||||
const { isValid, dirty, isSubmitting, status, setFieldValue, values } =
|
||||
useFormikContext<BaseLLMFormValues>();
|
||||
|
||||
const title = existingProviderName
|
||||
? `Configure "${existingProviderName}"`
|
||||
// When SWR resolves after mount, populate model_configurations if still
|
||||
// empty. test_model_name is then derived automatically by
|
||||
// ModelSelectionField's useEffect.
|
||||
useEffect(() => {
|
||||
if (
|
||||
modelConfigurations &&
|
||||
modelConfigurations.length > 0 &&
|
||||
values.model_configurations.length === 0
|
||||
) {
|
||||
setFieldValue("model_configurations", modelConfigurations);
|
||||
}
|
||||
}, [modelConfigurations]); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
const isTesting = status?.isTesting === true;
|
||||
const busy = isTesting || isSubmitting;
|
||||
|
||||
const disabledTooltip = busy
|
||||
? undefined
|
||||
: !isValid
|
||||
? "Please fill in all required fields."
|
||||
: !dirty
|
||||
? "No changes to save."
|
||||
: undefined;
|
||||
|
||||
const providerIcon = getProviderIcon(providerName);
|
||||
const providerDisplayName = getProviderDisplayName(providerName);
|
||||
const providerProductName = getProviderProductName(providerName);
|
||||
|
||||
const title = llmProvider
|
||||
? `Configure "${llmProvider.name}"`
|
||||
: `Set up ${providerProductName}`;
|
||||
const description = `Connect to ${providerDisplayName} and set up your ${providerProductName} models.`;
|
||||
|
||||
@@ -650,21 +672,20 @@ export function LLMConfigurationModalWrapper({
|
||||
<Button prominence="secondary" onClick={onClose} type="button">
|
||||
Cancel
|
||||
</Button>
|
||||
<Disabled
|
||||
disabled={
|
||||
!isFormValid || busy || (!!existingProviderName && !isDirty)
|
||||
}
|
||||
<Button
|
||||
disabled={!isValid || !dirty || busy}
|
||||
type="submit"
|
||||
icon={busy ? SimpleLoader : undefined}
|
||||
tooltip={disabledTooltip}
|
||||
>
|
||||
<Button type="submit" icon={busy ? SimpleLoader : undefined}>
|
||||
{existingProviderName
|
||||
? busy
|
||||
? "Updating"
|
||||
: "Update"
|
||||
: busy
|
||||
? "Connecting"
|
||||
: "Connect"}
|
||||
</Button>
|
||||
</Disabled>
|
||||
{llmProvider?.name
|
||||
? busy
|
||||
? "Updating"
|
||||
: "Update"
|
||||
: busy
|
||||
? "Connecting"
|
||||
: "Connect"}
|
||||
</Button>
|
||||
</Modal.Footer>
|
||||
</Form>
|
||||
</Modal.Content>
|
||||
|
||||
@@ -1,13 +1,8 @@
|
||||
import {
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
} from "@/interfaces/llm";
|
||||
import { LLMProviderName, LLMProviderView } from "@/interfaces/llm";
|
||||
import {
|
||||
LLM_ADMIN_URL,
|
||||
LLM_PROVIDERS_ADMIN_URL,
|
||||
} from "@/lib/llmConfig/constants";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import isEqual from "lodash/isEqual";
|
||||
import { parseAzureTargetUri } from "@/lib/azureTargetUri";
|
||||
@@ -18,13 +13,11 @@ import {
|
||||
} from "@/lib/analytics";
|
||||
import {
|
||||
BaseLLMFormValues,
|
||||
SubmitLLMProviderParams,
|
||||
SubmitOnboardingProviderParams,
|
||||
TestApiKeyResult,
|
||||
filterModelConfigurations,
|
||||
getAutoModeModelConfigurations,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
|
||||
// ─── Test helpers ─────────────────────────────────────────────────────────
|
||||
|
||||
const submitLlmTestRequest = async (
|
||||
payload: Record<string, unknown>,
|
||||
fallbackErrorMessage: string
|
||||
@@ -50,161 +43,6 @@ const submitLlmTestRequest = async (
|
||||
}
|
||||
};
|
||||
|
||||
export const submitLLMProvider = async <T extends BaseLLMFormValues>({
|
||||
providerName,
|
||||
values,
|
||||
initialValues,
|
||||
modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
hideSuccess,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
}: SubmitLLMProviderParams<T>): Promise<void> => {
|
||||
setSubmitting(true);
|
||||
|
||||
const { selected_model_names: visibleModels, api_key, ...rest } = values;
|
||||
|
||||
// In auto mode, use recommended models from descriptor
|
||||
// In manual mode, use user's selection
|
||||
let filteredModelConfigurations: ModelConfiguration[];
|
||||
let finalDefaultModelName = rest.default_model_name;
|
||||
|
||||
if (values.is_auto_mode) {
|
||||
filteredModelConfigurations =
|
||||
getAutoModeModelConfigurations(modelConfigurations);
|
||||
|
||||
// In auto mode, use the first recommended model as default if current default isn't in the list
|
||||
const visibleModelNames = new Set(
|
||||
filteredModelConfigurations.map((m) => m.name)
|
||||
);
|
||||
if (
|
||||
finalDefaultModelName &&
|
||||
!visibleModelNames.has(finalDefaultModelName)
|
||||
) {
|
||||
finalDefaultModelName = filteredModelConfigurations[0]?.name ?? "";
|
||||
}
|
||||
} else {
|
||||
filteredModelConfigurations = filterModelConfigurations(
|
||||
modelConfigurations,
|
||||
visibleModels,
|
||||
rest.default_model_name as string | undefined
|
||||
);
|
||||
}
|
||||
|
||||
const customConfigChanged = !isEqual(
|
||||
values.custom_config,
|
||||
initialValues.custom_config
|
||||
);
|
||||
|
||||
const normalizedApiBase =
|
||||
typeof rest.api_base === "string" && rest.api_base.trim() === ""
|
||||
? undefined
|
||||
: rest.api_base;
|
||||
|
||||
const finalValues = {
|
||||
...rest,
|
||||
api_base: normalizedApiBase,
|
||||
default_model_name: finalDefaultModelName,
|
||||
api_key,
|
||||
api_key_changed: api_key !== (initialValues.api_key as string | undefined),
|
||||
custom_config_changed: customConfigChanged,
|
||||
model_configurations: filteredModelConfigurations,
|
||||
};
|
||||
|
||||
// Test the configuration
|
||||
if (!isEqual(finalValues, initialValues)) {
|
||||
setIsTesting(true);
|
||||
|
||||
const response = await fetch("/api/admin/llm/test", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
provider: providerName,
|
||||
...finalValues,
|
||||
model: finalDefaultModelName,
|
||||
id: existingLlmProvider?.id,
|
||||
}),
|
||||
});
|
||||
setIsTesting(false);
|
||||
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
toast.error(errorMsg);
|
||||
setSubmitting(false);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const response = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}${
|
||||
existingLlmProvider ? "" : "?is_creation=true"
|
||||
}`,
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
provider: providerName,
|
||||
...finalValues,
|
||||
id: existingLlmProvider?.id,
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
const fullErrorMsg = existingLlmProvider
|
||||
? `Failed to update provider: ${errorMsg}`
|
||||
: `Failed to enable provider: ${errorMsg}`;
|
||||
toast.error(fullErrorMsg);
|
||||
return;
|
||||
}
|
||||
|
||||
if (shouldMarkAsDefault) {
|
||||
const newLlmProvider = (await response.json()) as LLMProviderView;
|
||||
const setDefaultResponse = await fetch(`${LLM_ADMIN_URL}/default`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
provider_id: newLlmProvider.id,
|
||||
model_name: finalDefaultModelName,
|
||||
}),
|
||||
});
|
||||
if (!setDefaultResponse.ok) {
|
||||
const errorMsg = (await setDefaultResponse.json()).detail;
|
||||
toast.error(`Failed to set provider as default: ${errorMsg}`);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
onClose();
|
||||
|
||||
if (!hideSuccess) {
|
||||
const successMsg = existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!";
|
||||
toast.success(successMsg);
|
||||
}
|
||||
|
||||
const knownProviders = new Set<string>(Object.values(LLMProviderName));
|
||||
track(AnalyticsEvent.CONFIGURED_LLM_PROVIDER, {
|
||||
provider: knownProviders.has(providerName) ? providerName : "custom",
|
||||
is_creation: !existingLlmProvider,
|
||||
source: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
});
|
||||
|
||||
setSubmitting(false);
|
||||
};
|
||||
|
||||
export const testApiKeyHelper = async (
|
||||
providerName: string,
|
||||
formValues: Record<string, unknown>,
|
||||
@@ -241,7 +79,7 @@ export const testApiKeyHelper = async (
|
||||
...((formValues?.custom_config as Record<string, unknown>) ?? {}),
|
||||
...(customConfigOverride ?? {}),
|
||||
},
|
||||
model: modelName ?? (formValues?.default_model_name as string) ?? "",
|
||||
model: modelName ?? (formValues?.test_model_name as string) ?? "",
|
||||
};
|
||||
|
||||
return await submitLlmTestRequest(
|
||||
@@ -259,96 +97,148 @@ export const testCustomProvider = async (
|
||||
);
|
||||
};
|
||||
|
||||
export const submitOnboardingProvider = async ({
|
||||
// ─── Submit provider ──────────────────────────────────────────────────────
|
||||
|
||||
export interface SubmitProviderParams<
|
||||
T extends BaseLLMFormValues = BaseLLMFormValues,
|
||||
> {
|
||||
providerName: string;
|
||||
values: T;
|
||||
initialValues: T;
|
||||
existingLlmProvider?: LLMProviderView;
|
||||
shouldMarkAsDefault?: boolean;
|
||||
isCustomProvider?: boolean;
|
||||
setStatus: (status: Record<string, unknown>) => void;
|
||||
setSubmitting: (submitting: boolean) => void;
|
||||
onClose: () => void;
|
||||
/** Called after successful create/update + set-default. Use for cache refresh, state updates, toasts, etc. */
|
||||
onSuccess?: () => void | Promise<void>;
|
||||
/** Analytics source for tracking. @default LLMProviderConfiguredSource.ADMIN_PAGE */
|
||||
analyticsSource?: LLMProviderConfiguredSource;
|
||||
}
|
||||
|
||||
export async function submitProvider<T extends BaseLLMFormValues>({
|
||||
providerName,
|
||||
payload,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
values,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
isCustomProvider,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
setIsSubmitting,
|
||||
}: SubmitOnboardingProviderParams): Promise<void> => {
|
||||
setIsSubmitting(true);
|
||||
onSuccess,
|
||||
analyticsSource = LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
}: SubmitProviderParams<T>): Promise<void> {
|
||||
setSubmitting(true);
|
||||
|
||||
// Test credentials
|
||||
let result: TestApiKeyResult;
|
||||
if (isCustomProvider) {
|
||||
result = await testCustomProvider(payload);
|
||||
} else {
|
||||
result = await testApiKeyHelper(providerName, payload);
|
||||
const { test_model_name, api_key, ...rest } = values;
|
||||
const testModelName =
|
||||
test_model_name ||
|
||||
values.model_configurations.find((m) => m.is_visible)?.name ||
|
||||
"";
|
||||
|
||||
// ── Test credentials ────────────────────────────────────────────────
|
||||
const customConfigChanged = !isEqual(
|
||||
values.custom_config,
|
||||
initialValues.custom_config
|
||||
);
|
||||
|
||||
const normalizedApiBase =
|
||||
typeof rest.api_base === "string" && rest.api_base.trim() === ""
|
||||
? undefined
|
||||
: rest.api_base;
|
||||
|
||||
const finalValues = {
|
||||
...rest,
|
||||
api_base: normalizedApiBase,
|
||||
api_key,
|
||||
api_key_changed: api_key !== (initialValues.api_key as string | undefined),
|
||||
custom_config_changed: customConfigChanged,
|
||||
};
|
||||
|
||||
if (!isEqual(finalValues, initialValues)) {
|
||||
setStatus({ isTesting: true });
|
||||
|
||||
const testResult = await submitLlmTestRequest(
|
||||
{
|
||||
provider: providerName,
|
||||
...finalValues,
|
||||
model: testModelName,
|
||||
id: existingLlmProvider?.id,
|
||||
},
|
||||
"An error occurred while testing the provider."
|
||||
);
|
||||
setStatus({ isTesting: false });
|
||||
|
||||
if (!testResult.ok) {
|
||||
toast.error(testResult.errorMessage);
|
||||
setSubmitting(false);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (!result.ok) {
|
||||
toast.error(result.errorMessage);
|
||||
setIsSubmitting(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Create provider
|
||||
const response = await fetch(`${LLM_PROVIDERS_ADMIN_URL}?is_creation=true`, {
|
||||
method: "PUT",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
// ── Create/update provider ──────────────────────────────────────────
|
||||
const response = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}${
|
||||
existingLlmProvider ? "" : "?is_creation=true"
|
||||
}`,
|
||||
{
|
||||
method: "PUT",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
provider: providerName,
|
||||
...finalValues,
|
||||
id: existingLlmProvider?.id,
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
toast.error(errorMsg);
|
||||
setIsSubmitting(false);
|
||||
const fullErrorMsg = existingLlmProvider
|
||||
? `Failed to update provider: ${errorMsg}`
|
||||
: `Failed to enable provider: ${errorMsg}`;
|
||||
toast.error(fullErrorMsg);
|
||||
setSubmitting(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Set as default if first provider
|
||||
if (
|
||||
onboardingState?.data?.llmProviders == null ||
|
||||
onboardingState.data.llmProviders.length === 0
|
||||
) {
|
||||
// ── Set as default ──────────────────────────────────────────────────
|
||||
if (shouldMarkAsDefault && testModelName) {
|
||||
try {
|
||||
const newLlmProvider = await response.json();
|
||||
if (newLlmProvider?.id != null) {
|
||||
const defaultModelName =
|
||||
(payload as Record<string, string>).default_model_name ??
|
||||
(payload as Record<string, ModelConfiguration[]>)
|
||||
.model_configurations?.[0]?.name ??
|
||||
"";
|
||||
|
||||
if (defaultModelName) {
|
||||
const setDefaultResponse = await fetch(`${LLM_ADMIN_URL}/default`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
provider_id: newLlmProvider.id,
|
||||
model_name: defaultModelName,
|
||||
}),
|
||||
});
|
||||
if (!setDefaultResponse.ok) {
|
||||
const err = await setDefaultResponse.json().catch(() => ({}));
|
||||
toast.error(err?.detail ?? "Failed to set provider as default");
|
||||
setIsSubmitting(false);
|
||||
return;
|
||||
}
|
||||
const setDefaultResponse = await fetch(`${LLM_ADMIN_URL}/default`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
provider_id: newLlmProvider.id,
|
||||
model_name: testModelName,
|
||||
}),
|
||||
});
|
||||
if (!setDefaultResponse.ok) {
|
||||
const err = await setDefaultResponse.json().catch(() => ({}));
|
||||
toast.error(err?.detail ?? "Failed to set provider as default");
|
||||
setSubmitting(false);
|
||||
return;
|
||||
}
|
||||
}
|
||||
} catch (_e) {
|
||||
} catch {
|
||||
toast.error("Failed to set new provider as default");
|
||||
}
|
||||
}
|
||||
|
||||
// ── Post-success ────────────────────────────────────────────────────
|
||||
const knownProviders = new Set<string>(Object.values(LLMProviderName));
|
||||
track(AnalyticsEvent.CONFIGURED_LLM_PROVIDER, {
|
||||
provider: isCustomProvider ? "custom" : providerName,
|
||||
is_creation: true,
|
||||
source: LLMProviderConfiguredSource.CHAT_ONBOARDING,
|
||||
provider: knownProviders.has(providerName) ? providerName : "custom",
|
||||
is_creation: !existingLlmProvider,
|
||||
source: analyticsSource,
|
||||
});
|
||||
|
||||
// Update onboarding state
|
||||
onboardingActions.updateData({
|
||||
llmProviders: [
|
||||
...(onboardingState?.data.llmProviders ?? []),
|
||||
isCustomProvider ? "custom" : providerName,
|
||||
],
|
||||
});
|
||||
onboardingActions.setButtonActive(true);
|
||||
if (onSuccess) await onSuccess();
|
||||
|
||||
setIsSubmitting(false);
|
||||
setSubmitting(false);
|
||||
onClose();
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,197 +1,130 @@
|
||||
import {
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import { ScopedMutator } from "swr";
|
||||
import { OnboardingActions, OnboardingState } from "@/interfaces/onboarding";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
|
||||
// Common class names for the Form component across all LLM provider forms
|
||||
export const LLM_FORM_CLASS_NAME = "flex flex-col gap-y-4 items-stretch mt-6";
|
||||
// ─── useInitialValues ─────────────────────────────────────────────────────
|
||||
|
||||
export const buildDefaultInitialValues = (
|
||||
/** Builds the merged model list from existing + well-known, deduped by name. */
|
||||
function buildModelConfigurations(
|
||||
existingLlmProvider?: LLMProviderView,
|
||||
modelConfigurations?: ModelConfiguration[],
|
||||
currentDefaultModelName?: string
|
||||
) => {
|
||||
const defaultModelName =
|
||||
(currentDefaultModelName &&
|
||||
existingLlmProvider?.model_configurations?.some(
|
||||
(m) => m.name === currentDefaultModelName
|
||||
)
|
||||
? currentDefaultModelName
|
||||
: undefined) ??
|
||||
existingLlmProvider?.model_configurations?.[0]?.name ??
|
||||
modelConfigurations?.[0]?.name ??
|
||||
"";
|
||||
wellKnownLLMProvider?: WellKnownLLMProviderDescriptor
|
||||
): ModelConfiguration[] {
|
||||
const existingModels = existingLlmProvider?.model_configurations ?? [];
|
||||
const wellKnownModels = wellKnownLLMProvider?.known_models ?? [];
|
||||
|
||||
// Auto mode must be explicitly enabled by the user
|
||||
// Default to false for new providers, preserve existing value when editing
|
||||
const isAutoMode = existingLlmProvider?.is_auto_mode ?? false;
|
||||
const modelMap = new Map<string, ModelConfiguration>();
|
||||
wellKnownModels.forEach((m) => modelMap.set(m.name, m));
|
||||
existingModels.forEach((m) => modelMap.set(m.name, m));
|
||||
|
||||
return Array.from(modelMap.values());
|
||||
}
|
||||
|
||||
/** Shared initial values for all LLM provider forms (both onboarding and admin). */
|
||||
export function useInitialValues(
|
||||
isOnboarding: boolean,
|
||||
providerName: LLMProviderName,
|
||||
existingLlmProvider?: LLMProviderView
|
||||
) {
|
||||
const { wellKnownLLMProvider } = useWellKnownLLMProvider(providerName);
|
||||
|
||||
const modelConfigurations = buildModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? undefined
|
||||
);
|
||||
|
||||
const testModelName =
|
||||
modelConfigurations.find((m) => m.is_visible)?.name ??
|
||||
wellKnownLLMProvider?.recommended_default_model?.name;
|
||||
|
||||
return {
|
||||
name: existingLlmProvider?.name || "",
|
||||
default_model_name: defaultModelName,
|
||||
provider: existingLlmProvider?.provider ?? providerName,
|
||||
name: isOnboarding ? providerName : existingLlmProvider?.name ?? "",
|
||||
api_key: existingLlmProvider?.api_key ?? undefined,
|
||||
api_base: existingLlmProvider?.api_base ?? undefined,
|
||||
is_public: existingLlmProvider?.is_public ?? true,
|
||||
is_auto_mode: isAutoMode,
|
||||
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,
|
||||
groups: existingLlmProvider?.groups ?? [],
|
||||
personas: existingLlmProvider?.personas ?? [],
|
||||
selected_model_names: existingLlmProvider
|
||||
? existingLlmProvider.model_configurations
|
||||
.filter((modelConfiguration) => modelConfiguration.is_visible)
|
||||
.map((modelConfiguration) => modelConfiguration.name)
|
||||
: modelConfigurations
|
||||
?.filter((modelConfiguration) => modelConfiguration.is_visible)
|
||||
.map((modelConfiguration) => modelConfiguration.name) ?? [],
|
||||
model_configurations: modelConfigurations,
|
||||
test_model_name: testModelName,
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
// ─── buildValidationSchema ────────────────────────────────────────────────
|
||||
|
||||
interface ValidationSchemaOptions {
|
||||
apiKey?: boolean;
|
||||
apiBase?: boolean;
|
||||
extra?: Yup.ObjectShape;
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds the validation schema for a modal.
|
||||
*
|
||||
* @param isOnboarding — controls the base schema:
|
||||
* - `true`: minimal (only `test_model_name`).
|
||||
* - `false`: full admin schema (display name, access, models, etc.).
|
||||
* @param options.apiKey — require `api_key`.
|
||||
* @param options.apiBase — require `api_base`.
|
||||
* @param options.extra — arbitrary Yup fields for provider-specific validation.
|
||||
*/
|
||||
export function buildValidationSchema(
|
||||
isOnboarding: boolean,
|
||||
{ apiKey, apiBase, extra }: ValidationSchemaOptions = {}
|
||||
) {
|
||||
const providerFields: Yup.ObjectShape = {
|
||||
...(apiKey && {
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
}),
|
||||
...(apiBase && {
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
}),
|
||||
...extra,
|
||||
};
|
||||
|
||||
if (isOnboarding) {
|
||||
return Yup.object().shape({
|
||||
test_model_name: Yup.string().required("Model name is required"),
|
||||
...providerFields,
|
||||
});
|
||||
}
|
||||
|
||||
export const buildDefaultValidationSchema = () => {
|
||||
return Yup.object({
|
||||
name: Yup.string().required("Display Name is required"),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
is_public: Yup.boolean().required(),
|
||||
is_auto_mode: Yup.boolean().required(),
|
||||
groups: Yup.array().of(Yup.number()),
|
||||
personas: Yup.array().of(Yup.number()),
|
||||
selected_model_names: Yup.array().of(Yup.string()),
|
||||
test_model_name: Yup.string().required("Model name is required"),
|
||||
...providerFields,
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
export const buildAvailableModelConfigurations = (
|
||||
existingLlmProvider?: LLMProviderView,
|
||||
wellKnownLLMProvider?: WellKnownLLMProviderDescriptor
|
||||
): ModelConfiguration[] => {
|
||||
const existingModels = existingLlmProvider?.model_configurations ?? [];
|
||||
const wellKnownModels = wellKnownLLMProvider?.known_models ?? [];
|
||||
// ─── Form value types ─────────────────────────────────────────────────────
|
||||
|
||||
// Create a map to deduplicate by model name, preferring existing models
|
||||
const modelMap = new Map<string, ModelConfiguration>();
|
||||
|
||||
// Add well-known models first
|
||||
wellKnownModels.forEach((model) => {
|
||||
modelMap.set(model.name, model);
|
||||
});
|
||||
|
||||
// Override with existing models (they take precedence)
|
||||
existingModels.forEach((model) => {
|
||||
modelMap.set(model.name, model);
|
||||
});
|
||||
|
||||
return Array.from(modelMap.values());
|
||||
};
|
||||
|
||||
// Base form values that all provider forms share
|
||||
/** Base form values that all provider forms share. */
|
||||
export interface BaseLLMFormValues {
|
||||
name: string;
|
||||
api_key?: string;
|
||||
api_base?: string;
|
||||
default_model_name?: string;
|
||||
/** Model name used for the test request — automatically derived. */
|
||||
test_model_name?: string;
|
||||
is_public: boolean;
|
||||
is_auto_mode: boolean;
|
||||
groups: number[];
|
||||
personas: number[];
|
||||
selected_model_names: string[];
|
||||
/** The full model list with is_visible set directly by user interaction. */
|
||||
model_configurations: ModelConfiguration[];
|
||||
custom_config?: Record<string, string>;
|
||||
}
|
||||
|
||||
export interface SubmitLLMProviderParams<
|
||||
T extends BaseLLMFormValues = BaseLLMFormValues,
|
||||
> {
|
||||
providerName: string;
|
||||
values: T;
|
||||
initialValues: T;
|
||||
modelConfigurations: ModelConfiguration[];
|
||||
existingLlmProvider?: LLMProviderView;
|
||||
shouldMarkAsDefault?: boolean;
|
||||
hideSuccess?: boolean;
|
||||
setIsTesting: (testing: boolean) => void;
|
||||
mutate: ScopedMutator;
|
||||
onClose: () => void;
|
||||
setSubmitting: (submitting: boolean) => void;
|
||||
}
|
||||
|
||||
export const filterModelConfigurations = (
|
||||
currentModelConfigurations: ModelConfiguration[],
|
||||
visibleModels: string[],
|
||||
defaultModelName?: string
|
||||
): ModelConfiguration[] => {
|
||||
return currentModelConfigurations
|
||||
.map(
|
||||
(modelConfiguration): ModelConfiguration => ({
|
||||
name: modelConfiguration.name,
|
||||
is_visible: visibleModels.includes(modelConfiguration.name),
|
||||
max_input_tokens: modelConfiguration.max_input_tokens ?? null,
|
||||
supports_image_input: modelConfiguration.supports_image_input,
|
||||
supports_reasoning: modelConfiguration.supports_reasoning,
|
||||
display_name: modelConfiguration.display_name,
|
||||
})
|
||||
)
|
||||
.filter(
|
||||
(modelConfiguration) =>
|
||||
modelConfiguration.name === defaultModelName ||
|
||||
modelConfiguration.is_visible
|
||||
);
|
||||
};
|
||||
|
||||
// Helper to get model configurations for auto mode
|
||||
// In auto mode, we include ALL models but preserve their visibility status
|
||||
// Models in the auto config are visible, others are created but not visible
|
||||
export const getAutoModeModelConfigurations = (
|
||||
modelConfigurations: ModelConfiguration[]
|
||||
): ModelConfiguration[] => {
|
||||
return modelConfigurations.map(
|
||||
(modelConfiguration): ModelConfiguration => ({
|
||||
name: modelConfiguration.name,
|
||||
is_visible: modelConfiguration.is_visible,
|
||||
max_input_tokens: modelConfiguration.max_input_tokens ?? null,
|
||||
supports_image_input: modelConfiguration.supports_image_input,
|
||||
supports_reasoning: modelConfiguration.supports_reasoning,
|
||||
display_name: modelConfiguration.display_name,
|
||||
})
|
||||
);
|
||||
};
|
||||
// ─── Misc ─────────────────────────────────────────────────────────────────
|
||||
|
||||
export type TestApiKeyResult =
|
||||
| { ok: true }
|
||||
| { ok: false; errorMessage: string };
|
||||
|
||||
export const getModelOptions = (
|
||||
fetchedModelConfigurations: Array<{ name: string }>
|
||||
) => {
|
||||
return fetchedModelConfigurations.map((model) => ({
|
||||
label: model.name,
|
||||
value: model.name,
|
||||
}));
|
||||
};
|
||||
|
||||
/** Initial values used by onboarding forms (flat shape, always creating new). */
|
||||
export const buildOnboardingInitialValues = () => ({
|
||||
name: "",
|
||||
provider: "",
|
||||
api_key: "",
|
||||
api_base: "",
|
||||
api_version: "",
|
||||
default_model_name: "",
|
||||
model_configurations: [] as ModelConfiguration[],
|
||||
custom_config: {} as Record<string, string>,
|
||||
api_key_changed: true,
|
||||
groups: [] as number[],
|
||||
is_public: true,
|
||||
is_auto_mode: false,
|
||||
personas: [] as number[],
|
||||
selected_model_names: [] as string[],
|
||||
deployment_name: "",
|
||||
target_uri: "",
|
||||
});
|
||||
|
||||
export interface SubmitOnboardingProviderParams {
|
||||
providerName: string;
|
||||
payload: Record<string, unknown>;
|
||||
onboardingState: OnboardingState;
|
||||
onboardingActions: OnboardingActions;
|
||||
isCustomProvider: boolean;
|
||||
onClose: () => void;
|
||||
setIsSubmitting: (submitting: boolean) => void;
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ import React from "react";
|
||||
import {
|
||||
WellKnownLLMProviderDescriptor,
|
||||
LLMProviderName,
|
||||
LLMProviderFormProps,
|
||||
} from "@/interfaces/llm";
|
||||
import { OnboardingActions, OnboardingState } from "@/interfaces/onboarding";
|
||||
import OpenAIModal from "@/sections/modals/llmConfig/OpenAIModal";
|
||||
@@ -12,8 +13,9 @@ import BedrockModal from "@/sections/modals/llmConfig/BedrockModal";
|
||||
import VertexAIModal from "@/sections/modals/llmConfig/VertexAIModal";
|
||||
import OpenRouterModal from "@/sections/modals/llmConfig/OpenRouterModal";
|
||||
import CustomModal from "@/sections/modals/llmConfig/CustomModal";
|
||||
import LMStudioForm from "@/sections/modals/llmConfig/LMStudioForm";
|
||||
import LMStudioModal from "@/sections/modals/llmConfig/LMStudioModal";
|
||||
import LiteLLMProxyModal from "@/sections/modals/llmConfig/LiteLLMProxyModal";
|
||||
import OpenAICompatibleModal from "@/sections/modals/llmConfig/OpenAICompatibleModal";
|
||||
|
||||
// Display info for LLM provider cards - title is the product name, displayName is the company/platform
|
||||
const PROVIDER_DISPLAY_INFO: Record<
|
||||
@@ -47,6 +49,10 @@ const PROVIDER_DISPLAY_INFO: Record<
|
||||
title: "LiteLLM Proxy",
|
||||
displayName: "LiteLLM Proxy",
|
||||
},
|
||||
[LLMProviderName.OPENAI_COMPATIBLE]: {
|
||||
title: "OpenAI Compatible",
|
||||
displayName: "OpenAI Compatible",
|
||||
},
|
||||
};
|
||||
|
||||
export function getProviderDisplayInfo(providerName: string): {
|
||||
@@ -78,12 +84,26 @@ export function getOnboardingForm({
|
||||
open,
|
||||
onOpenChange,
|
||||
}: OnboardingFormProps): React.ReactNode {
|
||||
const sharedProps = {
|
||||
const providerName = isCustomProvider
|
||||
? "custom"
|
||||
: llmDescriptor?.name ?? "custom";
|
||||
|
||||
const sharedProps: LLMProviderFormProps = {
|
||||
variant: "onboarding" as const,
|
||||
onboardingState,
|
||||
shouldMarkAsDefault:
|
||||
(onboardingState?.data.llmProviders ?? []).length === 0,
|
||||
onboardingActions,
|
||||
open,
|
||||
onOpenChange,
|
||||
onSuccess: () => {
|
||||
onboardingActions.updateData({
|
||||
llmProviders: [
|
||||
...(onboardingState?.data.llmProviders ?? []),
|
||||
providerName,
|
||||
],
|
||||
});
|
||||
onboardingActions.setButtonActive(true);
|
||||
},
|
||||
};
|
||||
|
||||
// Handle custom provider
|
||||
@@ -91,38 +111,36 @@ export function getOnboardingForm({
|
||||
return <CustomModal {...sharedProps} />;
|
||||
}
|
||||
|
||||
const providerProps = {
|
||||
...sharedProps,
|
||||
llmDescriptor,
|
||||
};
|
||||
|
||||
switch (llmDescriptor.name) {
|
||||
case LLMProviderName.OPENAI:
|
||||
return <OpenAIModal {...providerProps} />;
|
||||
return <OpenAIModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.ANTHROPIC:
|
||||
return <AnthropicModal {...providerProps} />;
|
||||
return <AnthropicModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.OLLAMA_CHAT:
|
||||
return <OllamaModal {...providerProps} />;
|
||||
return <OllamaModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.AZURE:
|
||||
return <AzureModal {...providerProps} />;
|
||||
return <AzureModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.BEDROCK:
|
||||
return <BedrockModal {...providerProps} />;
|
||||
return <BedrockModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.VERTEX_AI:
|
||||
return <VertexAIModal {...providerProps} />;
|
||||
return <VertexAIModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.OPENROUTER:
|
||||
return <OpenRouterModal {...providerProps} />;
|
||||
return <OpenRouterModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.LM_STUDIO:
|
||||
return <LMStudioForm {...providerProps} />;
|
||||
return <LMStudioModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.LITELLM_PROXY:
|
||||
return <LiteLLMProxyModal {...providerProps} />;
|
||||
return <LiteLLMProxyModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.OPENAI_COMPATIBLE:
|
||||
return <OpenAICompatibleModal {...sharedProps} />;
|
||||
|
||||
default:
|
||||
return <CustomModal {...sharedProps} />;
|
||||
|
||||
@@ -4,6 +4,14 @@ import { expectScreenshot } from "@tests/e2e/utils/visualRegression";
|
||||
|
||||
test.use({ storageState: "admin_auth.json" });
|
||||
|
||||
/** Maps each settings slug to the header title shown on that page. */
|
||||
const SLUG_TO_HEADER: Record<string, string> = {
|
||||
general: "Profile",
|
||||
"chat-preferences": "Chats",
|
||||
"accounts-access": "Accounts",
|
||||
connectors: "Connectors",
|
||||
};
|
||||
|
||||
for (const theme of THEMES) {
|
||||
test.describe(`Settings pages (${theme} mode)`, () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
@@ -11,21 +19,33 @@ for (const theme of THEMES) {
|
||||
});
|
||||
|
||||
test("should screenshot each settings tab", async ({ page }) => {
|
||||
await page.goto("/app/settings");
|
||||
await page.waitForLoadState("networkidle");
|
||||
await page.goto("/app/settings/general");
|
||||
await page
|
||||
.getByTestId("settings-left-tab-navigation")
|
||||
.waitFor({ state: "visible" });
|
||||
|
||||
const nav = page.getByTestId("settings-left-tab-navigation");
|
||||
const tabs = nav.locator("a");
|
||||
await expect(tabs.first()).toBeVisible({ timeout: 10_000 });
|
||||
const count = await tabs.count();
|
||||
|
||||
expect(count).toBeGreaterThan(0);
|
||||
for (let i = 0; i < count; i++) {
|
||||
const tab = tabs.nth(i);
|
||||
const href = await tab.getAttribute("href");
|
||||
const slug = href ? href.replace("/app/settings/", "") : `tab-${i}`;
|
||||
|
||||
await tab.click();
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
const expectedHeader = SLUG_TO_HEADER[slug];
|
||||
if (expectedHeader) {
|
||||
await expect(
|
||||
page
|
||||
.locator(".opal-content-md-header")
|
||||
.filter({ hasText: expectedHeader })
|
||||
).toBeVisible({ timeout: 10_000 });
|
||||
} else {
|
||||
await page.waitForLoadState("networkidle");
|
||||
}
|
||||
|
||||
await expectScreenshot(page, {
|
||||
name: `settings-${theme}-${slug}`,
|
||||
|
||||
Reference in New Issue
Block a user