Compare commits

...

8 Commits

47 changed files with 3036 additions and 2747 deletions

View File

@@ -11,6 +11,8 @@ require a valid SCIM bearer token.
from __future__ import annotations
import hashlib
import struct
from uuid import UUID
from fastapi import APIRouter
@@ -22,6 +24,7 @@ from fastapi import Response
from fastapi.responses import JSONResponse
from fastapi_users.password import PasswordHelper
from sqlalchemy import func
from sqlalchemy import text
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
@@ -59,9 +62,25 @@ from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
# Group names reserved for system default groups (seeded by migration).
_RESERVED_GROUP_NAMES = frozenset({"Admin", "Basic"})
# Namespace prefix for the seat-allocation advisory lock. Hashed together
# with the tenant ID so the lock is scoped per-tenant (unrelated tenants
# never block each other) and cannot collide with unrelated advisory locks.
_SEAT_LOCK_NAMESPACE = "onyx_scim_seat_lock"
def _seat_lock_id_for_tenant(tenant_id: str) -> int:
"""Derive a stable 64-bit signed int lock id for this tenant's seat lock."""
digest = hashlib.sha256(f"{_SEAT_LOCK_NAMESPACE}:{tenant_id}".encode()).digest()
# pg_advisory_xact_lock takes a signed 8-byte int; unpack as such.
return struct.unpack("q", digest[:8])[0]
class ScimJSONResponse(JSONResponse):
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
@@ -200,12 +219,37 @@ def _apply_exclusions(
def _check_seat_availability(dal: ScimDAL) -> str | None:
"""Return an error message if seat limit is reached, else None."""
"""Return an error message if seat limit is reached, else None.
Acquires a transaction-scoped advisory lock so that concurrent
SCIM requests are serialized. IdPs like Okta send provisioning
requests in parallel batches — without serialization the check is
vulnerable to a TOCTOU race where N concurrent requests each see
"seats available", all insert, and the tenant ends up over its
seat limit.
The lock is held until the caller's next COMMIT or ROLLBACK, which
means the seat count cannot change between the check here and the
subsequent INSERT/UPDATE. Each call site in this module follows
the pattern: _check_seat_availability → write → dal.commit()
(which releases the lock for the next waiting request).
"""
check_fn = fetch_ee_implementation_or_noop(
"onyx.db.license", "check_seat_availability", None
)
if check_fn is None:
return None
# Transaction-scoped advisory lock — released on dal.commit() / dal.rollback().
# The lock id is derived from the tenant so unrelated tenants never block
# each other, and from a namespace string so it cannot collide with
# unrelated advisory locks elsewhere in the codebase.
lock_id = _seat_lock_id_for_tenant(get_current_tenant_id())
dal.session.execute(
text("SELECT pg_advisory_xact_lock(:lock_id)"),
{"lock_id": lock_id},
)
result = check_fn(dal.session, seats_needed=1)
if not result.available:
return result.error_message or "Seat limit reached"

View File

@@ -816,6 +816,29 @@ MAX_FILE_SIZE_BYTES = int(
os.environ.get("MAX_FILE_SIZE_BYTES") or 2 * 1024 * 1024 * 1024
) # 2GB in bytes
# Maximum embedded images allowed in a single file. PDFs (and other formats)
# with thousands of embedded images can OOM the user-file-processing worker
# because every image is decoded with PIL and then sent to the vision LLM.
# Enforced both at upload time (rejects the file) and during extraction
# (defense-in-depth: caps the number of images materialized).
#
# Clamped to >= 0; a negative env value would turn upload validation into
# always-fail and extraction into always-stop, which is never desired. 0
# disables image extraction entirely, which is a valid (if aggressive) setting.
MAX_EMBEDDED_IMAGES_PER_FILE = max(
0, int(os.environ.get("MAX_EMBEDDED_IMAGES_PER_FILE") or 500)
)
# Maximum embedded images allowed across all files in a single upload batch.
# Protects against the scenario where a user uploads many files that each
# fall under MAX_EMBEDDED_IMAGES_PER_FILE but aggregate to enough work
# (serial-ish celery fan-out plus per-image vision-LLM calls) to OOM the
# worker under concurrency or run up surprise latency/cost. Also clamped
# to >= 0.
MAX_EMBEDDED_IMAGES_PER_UPLOAD = max(
0, int(os.environ.get("MAX_EMBEDDED_IMAGES_PER_UPLOAD") or 1000)
)
# Use document summary for contextual rag
USE_DOCUMENT_SUMMARY = os.environ.get("USE_DOCUMENT_SUMMARY", "true").lower() == "true"
# Use chunk summary for contextual rag

View File

@@ -1,3 +1,4 @@
from dataclasses import dataclass
from datetime import datetime
from typing import TypedDict
@@ -6,6 +7,14 @@ from pydantic import BaseModel
from onyx.onyxbot.slack.models import ChannelType
@dataclass(frozen=True)
class DirectThreadFetch:
"""Request to fetch a Slack thread directly by channel and timestamp."""
channel_id: str
thread_ts: str
class ChannelMetadata(TypedDict):
"""Type definition for cached channel metadata."""

View File

@@ -19,6 +19,7 @@ from onyx.configs.chat_configs import DOC_TIME_DECAY
from onyx.connectors.models import IndexingDocument
from onyx.connectors.models import TextSection
from onyx.context.search.federated.models import ChannelMetadata
from onyx.context.search.federated.models import DirectThreadFetch
from onyx.context.search.federated.models import SlackMessage
from onyx.context.search.federated.slack_search_utils import ALL_CHANNEL_TYPES
from onyx.context.search.federated.slack_search_utils import build_channel_query_filter
@@ -49,7 +50,6 @@ from onyx.server.federated.models import FederatedConnectorDetail
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.timing import log_function_time
from shared_configs.configs import DOC_EMBEDDING_CONTEXT_SIZE
logger = setup_logger()
@@ -58,7 +58,6 @@ HIGHLIGHT_END_CHAR = "\ue001"
CHANNEL_METADATA_CACHE_TTL = 60 * 60 * 24 # 24 hours
USER_PROFILE_CACHE_TTL = 60 * 60 * 24 # 24 hours
SLACK_THREAD_CONTEXT_WINDOW = 3 # Number of messages before matched message to include
CHANNEL_METADATA_MAX_RETRIES = 3 # Maximum retry attempts for channel metadata fetching
CHANNEL_METADATA_RETRY_DELAY = 1 # Initial retry delay in seconds (exponential backoff)
@@ -421,6 +420,94 @@ class SlackQueryResult(BaseModel):
filtered_channels: list[str] # Channels filtered out during this query
def _fetch_thread_from_url(
thread_fetch: DirectThreadFetch,
access_token: str,
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
) -> SlackQueryResult:
"""Fetch a thread directly from a Slack URL via conversations.replies."""
channel_id = thread_fetch.channel_id
thread_ts = thread_fetch.thread_ts
slack_client = WebClient(token=access_token)
try:
response = slack_client.conversations_replies(
channel=channel_id,
ts=thread_ts,
)
response.validate()
messages: list[dict[str, Any]] = response.get("messages", [])
except SlackApiError as e:
logger.warning(
f"Failed to fetch thread from URL (channel={channel_id}, ts={thread_ts}): {e}"
)
return SlackQueryResult(messages=[], filtered_channels=[])
if not messages:
logger.warning(
f"No messages found for URL override (channel={channel_id}, ts={thread_ts})"
)
return SlackQueryResult(messages=[], filtered_channels=[])
# Build thread text from all messages
thread_text = _build_thread_text(messages, access_token, None, slack_client)
# Get channel name from metadata cache or API
channel_name = "unknown"
if channel_metadata_dict and channel_id in channel_metadata_dict:
channel_name = channel_metadata_dict[channel_id].get("name", "unknown")
else:
try:
ch_response = slack_client.conversations_info(channel=channel_id)
ch_response.validate()
channel_info: dict[str, Any] = ch_response.get("channel", {})
channel_name = channel_info.get("name", "unknown")
except SlackApiError:
pass
# Build the SlackMessage
parent_msg = messages[0]
message_ts = parent_msg.get("ts", thread_ts)
username = parent_msg.get("user", "unknown_user")
parent_text = parent_msg.get("text", "")
snippet = (
parent_text[:50].rstrip() + "..." if len(parent_text) > 50 else parent_text
).replace("\n", " ")
doc_time = datetime.fromtimestamp(float(message_ts))
decay_factor = DOC_TIME_DECAY
doc_age_years = (datetime.now() - doc_time).total_seconds() / (365 * 24 * 60 * 60)
recency_bias = max(1 / (1 + decay_factor * doc_age_years), 0.75)
permalink = (
f"https://slack.com/archives/{channel_id}/p{message_ts.replace('.', '')}"
)
slack_message = SlackMessage(
document_id=f"{channel_id}_{message_ts}",
channel_id=channel_id,
message_id=message_ts,
thread_id=None, # Prevent double-enrichment in thread context fetch
link=permalink,
metadata={
"channel": channel_name,
"time": doc_time.isoformat(),
},
timestamp=doc_time,
recency_bias=recency_bias,
semantic_identifier=f"{username} in #{channel_name}: {snippet}",
text=thread_text,
highlighted_texts=set(),
slack_score=100000.0, # High priority — user explicitly asked for this thread
)
logger.info(
f"URL override: fetched thread from channel={channel_id}, ts={thread_ts}, {len(messages)} messages"
)
return SlackQueryResult(messages=[slack_message], filtered_channels=[])
def query_slack(
query_string: str,
access_token: str,
@@ -432,7 +519,6 @@ def query_slack(
available_channels: list[str] | None = None,
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
) -> SlackQueryResult:
# Check if query has channel override (user specified channels in query)
has_channel_override = query_string.startswith("__CHANNEL_OVERRIDE__")
@@ -662,7 +748,6 @@ def _fetch_thread_context(
"""
channel_id = message.channel_id
thread_id = message.thread_id
message_id = message.message_id
# If not a thread, return original text as success
if thread_id is None:
@@ -695,62 +780,37 @@ def _fetch_thread_context(
if len(messages) <= 1:
return ThreadContextResult.success(message.text)
# Build thread text from thread starter + context window around matched message
thread_text = _build_thread_text(
messages, message_id, thread_id, access_token, team_id, slack_client
)
# Build thread text from thread starter + all replies
thread_text = _build_thread_text(messages, access_token, team_id, slack_client)
return ThreadContextResult.success(thread_text)
def _build_thread_text(
messages: list[dict[str, Any]],
message_id: str,
thread_id: str,
access_token: str,
team_id: str | None,
slack_client: WebClient,
) -> str:
"""Build the thread text from messages."""
"""Build thread text including all replies.
Includes the thread parent message followed by all replies in order.
"""
msg_text = messages[0].get("text", "")
msg_sender = messages[0].get("user", "")
thread_text = f"<@{msg_sender}>: {msg_text}"
# All messages after index 0 are replies
replies = messages[1:]
if not replies:
return thread_text
logger.debug(f"Thread {messages[0].get('ts')}: {len(replies)} replies included")
thread_text += "\n\nReplies:"
if thread_id == message_id:
message_id_idx = 0
else:
message_id_idx = next(
(i for i, msg in enumerate(messages) if msg.get("ts") == message_id), 0
)
if not message_id_idx:
return thread_text
start_idx = max(1, message_id_idx - SLACK_THREAD_CONTEXT_WINDOW)
if start_idx > 1:
thread_text += "\n..."
for i in range(start_idx, message_id_idx):
msg_text = messages[i].get("text", "")
msg_sender = messages[i].get("user", "")
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
msg_text = messages[message_id_idx].get("text", "")
msg_sender = messages[message_id_idx].get("user", "")
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
# Add following replies
len_replies = 0
for msg in messages[message_id_idx + 1 :]:
for msg in replies:
msg_text = msg.get("text", "")
msg_sender = msg.get("user", "")
reply = f"\n\n<@{msg_sender}>: {msg_text}"
thread_text += reply
len_replies += len(reply)
if len_replies >= DOC_EMBEDDING_CONTEXT_SIZE * 4:
thread_text += "\n..."
break
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
# Replace user IDs with names using cached lookups
userids: set[str] = set(re.findall(r"<@([A-Z0-9]+)>", thread_text))
@@ -976,7 +1036,16 @@ def slack_retrieval(
# Query slack with entity filtering
llm = get_default_llm()
query_strings = build_slack_queries(query, llm, entities, available_channels)
query_items = build_slack_queries(query, llm, entities, available_channels)
# Partition into direct thread fetches and search query strings
direct_fetches: list[DirectThreadFetch] = []
query_strings: list[str] = []
for item in query_items:
if isinstance(item, DirectThreadFetch):
direct_fetches.append(item)
else:
query_strings.append(item)
# Determine filtering based on entities OR context (bot)
include_dm = False
@@ -993,8 +1062,16 @@ def slack_retrieval(
f"Private channel context: will only allow messages from {allowed_private_channel} + public channels"
)
# Build search tasks
search_tasks = [
# Build search tasks — direct thread fetches + keyword searches
search_tasks: list[tuple] = [
(
_fetch_thread_from_url,
(fetch, access_token, channel_metadata_dict),
)
for fetch in direct_fetches
]
search_tasks.extend(
(
query_slack,
(
@@ -1010,7 +1087,7 @@ def slack_retrieval(
),
)
for query_string in query_strings
]
)
# If include_dm is True AND we're not already searching all channels,
# add additional searches without channel filters.

View File

@@ -10,6 +10,7 @@ from pydantic import ValidationError
from onyx.configs.app_configs import MAX_SLACK_QUERY_EXPANSIONS
from onyx.context.search.federated.models import ChannelMetadata
from onyx.context.search.federated.models import DirectThreadFetch
from onyx.context.search.models import ChunkIndexRequest
from onyx.federated_connectors.slack.models import SlackEntities
from onyx.llm.interfaces import LLM
@@ -638,12 +639,38 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
return [query_text]
SLACK_URL_PATTERN = re.compile(
r"https?://[a-z0-9-]+\.slack\.com/archives/([A-Z0-9]+)/p(\d{16})"
)
def extract_slack_message_urls(
query_text: str,
) -> list[tuple[str, str]]:
"""Extract Slack message URLs from query text.
Parses URLs like:
https://onyx-company.slack.com/archives/C097NBWMY8Y/p1775491616524769
Returns list of (channel_id, thread_ts) tuples.
The 16-digit timestamp is converted to Slack ts format (with dot).
"""
results = []
for match in SLACK_URL_PATTERN.finditer(query_text):
channel_id = match.group(1)
raw_ts = match.group(2)
# Convert p1775491616524769 -> 1775491616.524769
thread_ts = f"{raw_ts[:10]}.{raw_ts[10:]}"
results.append((channel_id, thread_ts))
return results
def build_slack_queries(
query: ChunkIndexRequest,
llm: LLM,
entities: dict[str, Any] | None = None,
available_channels: list[str] | None = None,
) -> list[str]:
) -> list[str | DirectThreadFetch]:
"""Build Slack query strings with date filtering and query expansion."""
default_search_days = 30
if entities:
@@ -668,6 +695,15 @@ def build_slack_queries(
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_back)
time_filter = f" after:{cutoff_date.strftime('%Y-%m-%d')}"
# Check for Slack message URLs — if found, add direct fetch requests
url_fetches: list[DirectThreadFetch] = []
slack_urls = extract_slack_message_urls(query.query)
for channel_id, thread_ts in slack_urls:
url_fetches.append(
DirectThreadFetch(channel_id=channel_id, thread_ts=thread_ts)
)
logger.info(f"Detected Slack URL: channel={channel_id}, ts={thread_ts}")
# ALWAYS extract channel references from the query (not just for recency queries)
channel_references = extract_channel_references_from_query(query.query)
@@ -684,7 +720,9 @@ def build_slack_queries(
# If valid channels detected, use ONLY those channels with NO keywords
# Return query with ONLY time filter + channel filter (no keywords)
return [build_channel_override_query(channel_references, time_filter)]
return url_fetches + [
build_channel_override_query(channel_references, time_filter)
]
except ValueError as e:
# If validation fails, log the error and continue with normal flow
logger.warning(f"Channel reference validation failed: {e}")
@@ -702,7 +740,8 @@ def build_slack_queries(
rephrased_queries = expand_query_with_llm(query.query, llm)
# Build final query strings with time filters
return [
search_queries = [
rephrased_query.strip() + time_filter
for rephrased_query in rephrased_queries[:MAX_SLACK_QUERY_EXPANSIONS]
]
return url_fetches + search_queries

View File

@@ -21,6 +21,7 @@ import chardet
import openpyxl
from PIL import Image
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_FILE
from onyx.configs.constants import ONYX_METADATA_FILENAME
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.file_processing.file_types import OnyxFileExtensions
@@ -188,6 +189,56 @@ def read_text_file(
return file_content_raw, metadata
def count_pdf_embedded_images(file: IO[Any], cap: int) -> int:
"""Return the number of embedded images in a PDF, short-circuiting at cap+1.
Used to reject PDFs whose image count would OOM the user-file-processing
worker during indexing. Returns a value > cap as a sentinel once the count
exceeds the cap, so callers do not iterate thousands of image objects just
to report a number. Returns 0 if the PDF cannot be parsed.
Owner-password-only PDFs (permission restrictions but no open password) are
counted normally — they decrypt with an empty string. Truly password-locked
PDFs are skipped (return 0) since we can't inspect them; the caller should
ensure the password-protected check runs first.
Always restores the file pointer to its original position before returning.
"""
from pypdf import PdfReader
try:
start_pos = file.tell()
except Exception:
start_pos = None
try:
if start_pos is not None:
file.seek(0)
reader = PdfReader(file)
if reader.is_encrypted:
# Try empty password first (owner-password-only PDFs); give up if that fails.
try:
if reader.decrypt("") == 0:
return 0
except Exception:
return 0
count = 0
for page in reader.pages:
for _ in page.images:
count += 1
if count > cap:
return count
return count
except Exception:
logger.warning("Failed to count embedded images in PDF", exc_info=True)
return 0
finally:
if start_pos is not None:
try:
file.seek(start_pos)
except Exception:
pass
def pdf_to_text(file: IO[Any], pdf_pass: str | None = None) -> str:
"""
Extract text from a PDF. For embedded images, a more complex approach is needed.
@@ -251,8 +302,27 @@ def read_pdf_file(
)
if extract_images:
image_cap = MAX_EMBEDDED_IMAGES_PER_FILE
images_processed = 0
cap_reached = False
for page_num, page in enumerate(pdf_reader.pages):
if cap_reached:
break
for image_file_object in page.images:
if images_processed >= image_cap:
# Defense-in-depth backstop. Upload-time validation
# should have rejected files exceeding the cap, but
# we also break here so a single oversized file can
# never pin a worker.
logger.warning(
"PDF embedded image cap reached (%d). "
"Skipping remaining images on page %d and beyond.",
image_cap,
page_num + 1,
)
cap_reached = True
break
image = Image.open(io.BytesIO(image_file_object.data))
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format=image.format)
@@ -265,6 +335,7 @@ def read_pdf_file(
image_callback(img_bytes, image_name)
else:
extracted_images.append((img_bytes, image_name))
images_processed += 1
return text, metadata, extracted_images

View File

@@ -26,6 +26,7 @@ class LlmProviderNames(str, Enum):
MISTRAL = "mistral"
LITELLM_PROXY = "litellm_proxy"
BIFROST = "bifrost"
OPENAI_COMPATIBLE = "openai_compatible"
def __str__(self) -> str:
"""Needed so things like:
@@ -46,6 +47,7 @@ WELL_KNOWN_PROVIDER_NAMES = [
LlmProviderNames.LM_STUDIO,
LlmProviderNames.LITELLM_PROXY,
LlmProviderNames.BIFROST,
LlmProviderNames.OPENAI_COMPATIBLE,
]
@@ -64,6 +66,7 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
LlmProviderNames.LM_STUDIO: "LM Studio",
LlmProviderNames.LITELLM_PROXY: "LiteLLM Proxy",
LlmProviderNames.BIFROST: "Bifrost",
LlmProviderNames.OPENAI_COMPATIBLE: "OpenAI Compatible",
"groq": "Groq",
"anyscale": "Anyscale",
"deepseek": "DeepSeek",
@@ -116,6 +119,7 @@ AGGREGATOR_PROVIDERS: set[str] = {
LlmProviderNames.AZURE,
LlmProviderNames.LITELLM_PROXY,
LlmProviderNames.BIFROST,
LlmProviderNames.OPENAI_COMPATIBLE,
}
# Model family name mappings for display name generation

View File

@@ -175,6 +175,28 @@ def _strip_tool_content_from_messages(
return result
def _fix_tool_user_message_ordering(
messages: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Insert a synthetic assistant message between tool and user messages.
Some models (e.g. Mistral on Azure) require strict message ordering where
a user message cannot immediately follow a tool message. This function
inserts a minimal assistant message to bridge the gap.
"""
if len(messages) < 2:
return messages
result: list[dict[str, Any]] = [messages[0]]
for msg in messages[1:]:
prev_role = result[-1].get("role")
curr_role = msg.get("role")
if prev_role == "tool" and curr_role == "user":
result.append({"role": "assistant", "content": "Noted. Continuing."})
result.append(msg)
return result
def _messages_contain_tool_content(messages: list[dict[str, Any]]) -> bool:
"""Check if any messages contain tool-related content blocks."""
for msg in messages:
@@ -305,12 +327,19 @@ class LitellmLLM(LLM):
):
model_kwargs[VERTEX_LOCATION_KWARG] = "global"
# Bifrost: OpenAI-compatible proxy that expects model names in
# provider/model format (e.g. "anthropic/claude-sonnet-4-6").
# We route through LiteLLM's openai provider with the Bifrost base URL,
# and ensure /v1 is appended.
if model_provider == LlmProviderNames.BIFROST:
# Bifrost and OpenAI-compatible: OpenAI-compatible proxies that send
# model names directly to the endpoint. We route through LiteLLM's
# openai provider with the server's base URL, and ensure /v1 is appended.
if model_provider in (
LlmProviderNames.BIFROST,
LlmProviderNames.OPENAI_COMPATIBLE,
):
self._custom_llm_provider = "openai"
# LiteLLM's OpenAI client requires an api_key to be set.
# Many OpenAI-compatible servers don't need auth, so supply a
# placeholder to prevent LiteLLM from raising AuthenticationError.
if not self._api_key:
model_kwargs.setdefault("api_key", "not-needed")
if self._api_base is not None:
base = self._api_base.rstrip("/")
self._api_base = base if base.endswith("/v1") else f"{base}/v1"
@@ -427,17 +456,20 @@ class LitellmLLM(LLM):
optional_kwargs: dict[str, Any] = {}
# Model name
is_bifrost = self._model_provider == LlmProviderNames.BIFROST
is_openai_compatible_proxy = self._model_provider in (
LlmProviderNames.BIFROST,
LlmProviderNames.OPENAI_COMPATIBLE,
)
model_provider = (
f"{self.config.model_provider}/responses"
if is_openai_model # Uses litellm's completions -> responses bridge
else self.config.model_provider
)
if is_bifrost:
# Bifrost expects model names in provider/model format
# (e.g. "anthropic/claude-sonnet-4-6") sent directly to its
# OpenAI-compatible endpoint. We use custom_llm_provider="openai"
# so LiteLLM doesn't try to route based on the provider prefix.
if is_openai_compatible_proxy:
# OpenAI-compatible proxies (Bifrost, generic OpenAI-compatible
# servers) expect model names sent directly to their endpoint.
# We use custom_llm_provider="openai" so LiteLLM doesn't try
# to route based on the provider prefix.
model = self.config.deployment_name or self.config.model_name
else:
model = f"{model_provider}/{self.config.deployment_name or self.config.model_name}"
@@ -528,7 +560,10 @@ class LitellmLLM(LLM):
if structured_response_format:
optional_kwargs["response_format"] = structured_response_format
if not (is_claude_model or is_ollama or is_mistral) or is_bifrost:
if (
not (is_claude_model or is_ollama or is_mistral)
or is_openai_compatible_proxy
):
# Litellm bug: tool_choice is dropped silently if not specified here for OpenAI
# However, this param breaks Anthropic and Mistral models,
# so it must be conditionally included unless the request is
@@ -576,6 +611,18 @@ class LitellmLLM(LLM):
):
messages = _strip_tool_content_from_messages(messages)
# Some models (e.g. Mistral) reject a user message
# immediately after a tool message. Insert a synthetic
# assistant bridge message to satisfy the ordering
# constraint. Check both the provider and the deployment/
# model name to catch Mistral hosted on Azure.
model_or_deployment = (
self._deployment_name or self._model_version or ""
).lower()
is_mistral_model = is_mistral or "mistral" in model_or_deployment
if is_mistral_model:
messages = _fix_tool_user_message_ordering(messages)
# Only pass tool_choice when tools are present — some providers (e.g. Fireworks)
# reject requests where tool_choice is explicitly null.
if tools and tool_choice is not None:

View File

@@ -15,6 +15,8 @@ LITELLM_PROXY_PROVIDER_NAME = "litellm_proxy"
BIFROST_PROVIDER_NAME = "bifrost"
OPENAI_COMPATIBLE_PROVIDER_NAME = "openai_compatible"
# Providers that use optional Bearer auth from custom_config
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING: dict[str, str] = {
LlmProviderNames.OLLAMA_CHAT: OLLAMA_API_KEY_CONFIG_KEY,

View File

@@ -19,6 +19,7 @@ from onyx.llm.well_known_providers.constants import BIFROST_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import LITELLM_PROXY_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import LM_STUDIO_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OLLAMA_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OPENAI_COMPATIBLE_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OPENROUTER_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import VERTEXAI_PROVIDER_NAME
@@ -51,6 +52,7 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API
LITELLM_PROXY_PROVIDER_NAME: [], # Dynamic - fetched from LiteLLM proxy API
BIFROST_PROVIDER_NAME: [], # Dynamic - fetched from Bifrost API
OPENAI_COMPATIBLE_PROVIDER_NAME: [], # Dynamic - fetched from OpenAI-compatible API
}
@@ -336,6 +338,7 @@ def get_provider_display_name(provider_name: str) -> str:
VERTEXAI_PROVIDER_NAME: "Google Vertex AI",
OPENROUTER_PROVIDER_NAME: "OpenRouter",
LITELLM_PROXY_PROVIDER_NAME: "LiteLLM Proxy",
OPENAI_COMPATIBLE_PROVIDER_NAME: "OpenAI Compatible",
}
if provider_name in _ONYX_PROVIDER_DISPLAY_NAMES:

View File

@@ -40,6 +40,8 @@ from sqlalchemy.orm import Session
from onyx.auth.users import current_user
from onyx.background.celery.versioned_apps.client import app as celery_app
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_FILE
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_UPLOAD
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
@@ -50,6 +52,9 @@ from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import User
from onyx.document_index.interfaces import DocumentMetadata
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.file_processing.extract_file_text import count_pdf_embedded_images
from onyx.server.features.build.configs import USER_LIBRARY_MAX_FILE_SIZE_BYTES
from onyx.server.features.build.configs import USER_LIBRARY_MAX_FILES_PER_UPLOAD
from onyx.server.features.build.configs import USER_LIBRARY_MAX_TOTAL_SIZE_BYTES
@@ -127,6 +132,49 @@ class DeleteFileResponse(BaseModel):
# =============================================================================
def _looks_like_pdf(filename: str, content_type: str | None) -> bool:
"""True if either the filename or the content-type indicates a PDF.
Client-supplied ``content_type`` can be spoofed (e.g. a PDF uploaded with
``Content-Type: application/octet-stream``), so we also fall back to
extension-based detection via ``mimetypes.guess_type`` on the filename.
"""
if content_type == "application/pdf":
return True
guessed, _ = mimetypes.guess_type(filename)
return guessed == "application/pdf"
def _check_pdf_image_caps(
filename: str, content: bytes, content_type: str | None, batch_total: int
) -> int:
"""Enforce per-file and per-batch embedded-image caps for PDFs.
Returns the number of embedded images in this file (0 for non-PDFs) so
callers can update their running batch total. Raises OnyxError(INVALID_INPUT)
if either cap is exceeded.
"""
if not _looks_like_pdf(filename, content_type):
return 0
file_cap = MAX_EMBEDDED_IMAGES_PER_FILE
batch_cap = MAX_EMBEDDED_IMAGES_PER_UPLOAD
# Short-circuit at the larger cap so we get a useful count for both checks.
count = count_pdf_embedded_images(BytesIO(content), max(file_cap, batch_cap))
if count > file_cap:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
f"PDF '{filename}' contains too many embedded images "
f"(more than {file_cap}). Try splitting the document into smaller files.",
)
if batch_total + count > batch_cap:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
f"Upload would exceed the {batch_cap}-image limit across all "
f"files in this batch. Try uploading fewer image-heavy files at once.",
)
return count
def _sanitize_path(path: str) -> str:
"""Sanitize a file path, removing traversal attempts and normalizing.
@@ -355,6 +403,7 @@ async def upload_files(
uploaded_entries: list[LibraryEntryResponse] = []
total_size = 0
batch_image_total = 0
now = datetime.now(timezone.utc)
# Sanitize the base path
@@ -374,6 +423,14 @@ async def upload_files(
detail=f"File '{file.filename}' exceeds maximum size of {USER_LIBRARY_MAX_FILE_SIZE_BYTES // (1024 * 1024)}MB",
)
# Reject PDFs with an unreasonable per-file or per-batch image count
batch_image_total += _check_pdf_image_caps(
filename=file.filename or "unnamed",
content=content,
content_type=file.content_type,
batch_total=batch_image_total,
)
# Validate cumulative storage (existing + this upload batch)
total_size += file_size
if existing_usage + total_size > USER_LIBRARY_MAX_TOTAL_SIZE_BYTES:
@@ -472,6 +529,7 @@ async def upload_zip(
uploaded_entries: list[LibraryEntryResponse] = []
total_size = 0
batch_image_total = 0
# Extract zip contents into a subfolder named after the zip file
zip_name = api_sanitize_filename(file.filename or "upload")
@@ -510,6 +568,36 @@ async def upload_zip(
logger.warning(f"Skipping '{zip_info.filename}' - exceeds max size")
continue
# Skip PDFs that would trip the per-file or per-batch image
# cap (would OOM the user-file-processing worker). Matches
# /upload behavior but uses skip-and-warn to stay consistent
# with the zip path's handling of oversized files.
zip_file_name = zip_info.filename.split("/")[-1]
zip_content_type, _ = mimetypes.guess_type(zip_file_name)
if zip_content_type == "application/pdf":
image_count = count_pdf_embedded_images(
BytesIO(file_content),
max(
MAX_EMBEDDED_IMAGES_PER_FILE,
MAX_EMBEDDED_IMAGES_PER_UPLOAD,
),
)
if image_count > MAX_EMBEDDED_IMAGES_PER_FILE:
logger.warning(
"Skipping '%s' - exceeds %d per-file embedded-image cap",
zip_info.filename,
MAX_EMBEDDED_IMAGES_PER_FILE,
)
continue
if batch_image_total + image_count > MAX_EMBEDDED_IMAGES_PER_UPLOAD:
logger.warning(
"Skipping '%s' - would exceed %d per-batch embedded-image cap",
zip_info.filename,
MAX_EMBEDDED_IMAGES_PER_UPLOAD,
)
continue
batch_image_total += image_count
total_size += file_size
# Validate cumulative storage

View File

@@ -10,9 +10,12 @@ from pydantic import Field
from sqlalchemy.orm import Session
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_FILE
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_UPLOAD
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_BYTES
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.db.llm import fetch_default_llm_model
from onyx.file_processing.extract_file_text import count_pdf_embedded_images
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.file_types import OnyxFileExtensions
@@ -192,6 +195,11 @@ def categorize_uploaded_files(
except RuntimeError as e:
logger.warning(f"Failed to get current tenant ID: {str(e)}")
# Running total of embedded images across PDFs in this batch. Once the
# aggregate cap is reached, subsequent PDFs in the same upload are
# rejected even if they'd individually fit under MAX_EMBEDDED_IMAGES_PER_FILE.
batch_image_total = 0
for upload in files:
try:
filename = get_safe_filename(upload)
@@ -253,6 +261,47 @@ def categorize_uploaded_files(
)
continue
# Reject PDFs with an unreasonable number of embedded images
# (either per-file or accumulated across this upload batch).
# A PDF with thousands of embedded images can OOM the
# user-file-processing celery worker because every image is
# decoded with PIL and then sent to the vision LLM.
if extension == ".pdf":
file_cap = MAX_EMBEDDED_IMAGES_PER_FILE
batch_cap = MAX_EMBEDDED_IMAGES_PER_UPLOAD
# Use the larger of the two caps as the short-circuit
# threshold so we get a useful count for both checks.
# count_pdf_embedded_images restores the stream position.
count = count_pdf_embedded_images(
upload.file, max(file_cap, batch_cap)
)
if count > file_cap:
results.rejected.append(
RejectedFile(
filename=filename,
reason=(
f"PDF contains too many embedded images "
f"(more than {file_cap}). Try splitting "
f"the document into smaller files."
),
)
)
continue
if batch_image_total + count > batch_cap:
results.rejected.append(
RejectedFile(
filename=filename,
reason=(
f"Upload would exceed the "
f"{batch_cap}-image limit across all "
f"files in this batch. Try uploading "
f"fewer image-heavy files at once."
),
)
)
continue
batch_image_total += count
text_content = extract_file_text(
file=upload.file,
file_name=filename,

View File

@@ -74,6 +74,8 @@ from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.server.manage.llm.models import OllamaFinalModelResponse
from onyx.server.manage.llm.models import OllamaModelDetails
from onyx.server.manage.llm.models import OllamaModelsRequest
from onyx.server.manage.llm.models import OpenAICompatibleFinalModelResponse
from onyx.server.manage.llm.models import OpenAICompatibleModelsRequest
from onyx.server.manage.llm.models import OpenRouterFinalModelResponse
from onyx.server.manage.llm.models import OpenRouterModelDetails
from onyx.server.manage.llm.models import OpenRouterModelsRequest
@@ -1575,3 +1577,95 @@ def _get_bifrost_models_response(api_base: str, api_key: str | None = None) -> d
source_name="Bifrost",
api_key=api_key,
)
@admin_router.post("/openai-compatible/available-models")
def get_openai_compatible_server_available_models(
request: OpenAICompatibleModelsRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[OpenAICompatibleFinalModelResponse]:
"""Fetch available models from a generic OpenAI-compatible /v1/models endpoint."""
response_json = _get_openai_compatible_server_response(
api_base=request.api_base, api_key=request.api_key
)
models = response_json.get("data", [])
if not isinstance(models, list) or len(models) == 0:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No models found from your OpenAI-compatible endpoint",
)
results: list[OpenAICompatibleFinalModelResponse] = []
for model in models:
try:
model_id = model.get("id", "")
model_name = model.get("name", model_id)
if not model_id:
continue
# Skip embedding models
if is_embedding_model(model_id):
continue
results.append(
OpenAICompatibleFinalModelResponse(
name=model_id,
display_name=model_name,
max_input_tokens=model.get("context_length"),
supports_image_input=infer_vision_support(model_id),
supports_reasoning=is_reasoning_model(model_id, model_name),
)
)
except Exception as e:
logger.warning(
"Failed to parse OpenAI-compatible model entry",
extra={"error": str(e), "item": str(model)[:1000]},
)
if not results:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No compatible models found from OpenAI-compatible endpoint",
)
sorted_results = sorted(results, key=lambda m: m.name.lower())
# Sync new models to DB if provider_name is specified
if request.provider_name:
_sync_fetched_models(
db_session=db_session,
provider_name=request.provider_name,
models=[
SyncModelEntry(
name=r.name,
display_name=r.display_name,
max_input_tokens=r.max_input_tokens,
supports_image_input=r.supports_image_input,
)
for r in sorted_results
],
source_label="OpenAI Compatible",
)
return sorted_results
def _get_openai_compatible_server_response(
api_base: str, api_key: str | None = None
) -> dict:
"""Perform GET to an OpenAI-compatible /v1/models and return parsed JSON."""
cleaned_api_base = api_base.strip().rstrip("/")
# Ensure we hit /v1/models
if cleaned_api_base.endswith("/v1"):
url = f"{cleaned_api_base}/models"
else:
url = f"{cleaned_api_base}/v1/models"
return _get_openai_compatible_models_response(
url=url,
source_name="OpenAI Compatible",
api_key=api_key,
)

View File

@@ -464,3 +464,18 @@ class BifrostFinalModelResponse(BaseModel):
max_input_tokens: int | None
supports_image_input: bool
supports_reasoning: bool
# OpenAI Compatible dynamic models fetch
class OpenAICompatibleModelsRequest(BaseModel):
api_base: str
api_key: str | None = None
provider_name: str | None = None # Optional: to save models to existing provider
class OpenAICompatibleFinalModelResponse(BaseModel):
name: str # Model ID (e.g. "meta-llama/Llama-3-8B-Instruct")
display_name: str # Human-readable name from API
max_input_tokens: int | None
supports_image_input: bool
supports_reasoning: bool

View File

@@ -26,6 +26,7 @@ DYNAMIC_LLM_PROVIDERS = frozenset(
LlmProviderNames.OLLAMA_CHAT,
LlmProviderNames.LM_STUDIO,
LlmProviderNames.BIFROST,
LlmProviderNames.OPENAI_COMPATIBLE,
}
)

View File

@@ -0,0 +1,67 @@
"""Tests for _build_thread_text function."""
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.context.search.federated.slack_search import _build_thread_text
def _make_msg(user: str, text: str, ts: str) -> dict[str, str]:
return {"user": user, "text": text, "ts": ts}
class TestBuildThreadText:
"""Verify _build_thread_text includes full thread replies up to cap."""
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
def test_includes_all_replies(self, mock_profiles: MagicMock) -> None:
"""All replies within cap are included in output."""
mock_profiles.return_value = {}
messages = [
_make_msg("U1", "parent msg", "1000.0"),
_make_msg("U2", "reply 1", "1001.0"),
_make_msg("U3", "reply 2", "1002.0"),
_make_msg("U4", "reply 3", "1003.0"),
]
result = _build_thread_text(messages, "token", "T123", MagicMock())
assert "parent msg" in result
assert "reply 1" in result
assert "reply 2" in result
assert "reply 3" in result
assert "..." not in result
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
def test_non_thread_returns_parent_only(self, mock_profiles: MagicMock) -> None:
"""Single message (no replies) returns just the parent text."""
mock_profiles.return_value = {}
messages = [_make_msg("U1", "just a message", "1000.0")]
result = _build_thread_text(messages, "token", "T123", MagicMock())
assert "just a message" in result
assert "Replies:" not in result
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
def test_parent_always_first(self, mock_profiles: MagicMock) -> None:
"""Thread parent message is always the first line of output."""
mock_profiles.return_value = {}
messages = [
_make_msg("U1", "I am the parent", "1000.0"),
_make_msg("U2", "I am a reply", "1001.0"),
]
result = _build_thread_text(messages, "token", "T123", MagicMock())
parent_pos = result.index("I am the parent")
reply_pos = result.index("I am a reply")
assert parent_pos < reply_pos
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
def test_user_profiles_resolved(self, mock_profiles: MagicMock) -> None:
"""User IDs in thread text are replaced with display names."""
mock_profiles.return_value = {"U1": "Alice", "U2": "Bob"}
messages = [
_make_msg("U1", "hello", "1000.0"),
_make_msg("U2", "world", "1001.0"),
]
result = _build_thread_text(messages, "token", "T123", MagicMock())
assert "Alice" in result
assert "Bob" in result
assert "<@U1>" not in result
assert "<@U2>" not in result

View File

@@ -0,0 +1,108 @@
"""Tests for Slack URL parsing and direct thread fetch via URL override."""
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.context.search.federated.models import DirectThreadFetch
from onyx.context.search.federated.slack_search import _fetch_thread_from_url
from onyx.context.search.federated.slack_search_utils import extract_slack_message_urls
class TestExtractSlackMessageUrls:
"""Verify URL parsing extracts channel_id and timestamp correctly."""
def test_standard_url(self) -> None:
query = "summarize https://mycompany.slack.com/archives/C097NBWMY8Y/p1775491616524769"
results = extract_slack_message_urls(query)
assert len(results) == 1
assert results[0] == ("C097NBWMY8Y", "1775491616.524769")
def test_multiple_urls(self) -> None:
query = (
"compare https://co.slack.com/archives/C111/p1234567890123456 "
"and https://co.slack.com/archives/C222/p9876543210987654"
)
results = extract_slack_message_urls(query)
assert len(results) == 2
assert results[0] == ("C111", "1234567890.123456")
assert results[1] == ("C222", "9876543210.987654")
def test_no_urls(self) -> None:
query = "what happened in #general last week?"
results = extract_slack_message_urls(query)
assert len(results) == 0
def test_non_slack_url_ignored(self) -> None:
query = "check https://google.com/archives/C111/p1234567890123456"
results = extract_slack_message_urls(query)
assert len(results) == 0
def test_timestamp_conversion(self) -> None:
"""p prefix removed, dot inserted after 10th digit."""
query = "https://x.slack.com/archives/CABC123/p1775491616524769"
results = extract_slack_message_urls(query)
channel_id, ts = results[0]
assert channel_id == "CABC123"
assert ts == "1775491616.524769"
assert not ts.startswith("p")
assert "." in ts
class TestFetchThreadFromUrl:
"""Verify _fetch_thread_from_url calls conversations.replies and returns SlackMessage."""
@patch("onyx.context.search.federated.slack_search._build_thread_text")
@patch("onyx.context.search.federated.slack_search.WebClient")
def test_successful_fetch(
self, mock_webclient_cls: MagicMock, mock_build_thread: MagicMock
) -> None:
mock_client = MagicMock()
mock_webclient_cls.return_value = mock_client
# Mock conversations_replies
mock_response = MagicMock()
mock_response.get.return_value = [
{"user": "U1", "text": "parent", "ts": "1775491616.524769"},
{"user": "U2", "text": "reply 1", "ts": "1775491617.000000"},
{"user": "U3", "text": "reply 2", "ts": "1775491618.000000"},
]
mock_client.conversations_replies.return_value = mock_response
# Mock channel info
mock_ch_response = MagicMock()
mock_ch_response.get.return_value = {"name": "general"}
mock_client.conversations_info.return_value = mock_ch_response
mock_build_thread.return_value = (
"U1: parent\n\nReplies:\n\nU2: reply 1\n\nU3: reply 2"
)
fetch = DirectThreadFetch(
channel_id="C097NBWMY8Y", thread_ts="1775491616.524769"
)
result = _fetch_thread_from_url(fetch, "xoxp-token")
assert len(result.messages) == 1
msg = result.messages[0]
assert msg.channel_id == "C097NBWMY8Y"
assert msg.thread_id is None # Prevents double-enrichment
assert msg.slack_score == 100000.0
assert "parent" in msg.text
mock_client.conversations_replies.assert_called_once_with(
channel="C097NBWMY8Y", ts="1775491616.524769"
)
@patch("onyx.context.search.federated.slack_search.WebClient")
def test_api_error_returns_empty(self, mock_webclient_cls: MagicMock) -> None:
from slack_sdk.errors import SlackApiError
mock_client = MagicMock()
mock_webclient_cls.return_value = mock_client
mock_client.conversations_replies.side_effect = SlackApiError(
message="channel_not_found",
response=MagicMock(status_code=404),
)
fetch = DirectThreadFetch(channel_id="CBAD", thread_ts="1234567890.123456")
result = _fetch_thread_from_url(fetch, "xoxp-token")
assert len(result.messages) == 0

View File

@@ -12,6 +12,10 @@ dependency on pypdf internals (pypdf.generic).
from io import BytesIO
from pathlib import Path
import pytest
from onyx.file_processing import extract_file_text
from onyx.file_processing.extract_file_text import count_pdf_embedded_images
from onyx.file_processing.extract_file_text import pdf_to_text
from onyx.file_processing.extract_file_text import read_pdf_file
from onyx.file_processing.password_validation import is_pdf_protected
@@ -96,6 +100,80 @@ class TestReadPdfFile:
# Returned list is empty when callback is used
assert images == []
def test_image_cap_skips_images_above_limit(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""When the embedded-image cap is exceeded, remaining images are skipped.
The cap protects the user-file-processing worker from OOMing on PDFs
with thousands of embedded images. Setting the cap to 0 should yield
zero extracted images even though the fixture has one.
"""
monkeypatch.setattr(extract_file_text, "MAX_EMBEDDED_IMAGES_PER_FILE", 0)
_, _, images = read_pdf_file(_load("with_image.pdf"), extract_images=True)
assert images == []
def test_image_cap_at_limit_extracts_up_to_cap(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""A cap >= image count behaves identically to the uncapped path."""
monkeypatch.setattr(extract_file_text, "MAX_EMBEDDED_IMAGES_PER_FILE", 100)
_, _, images = read_pdf_file(_load("with_image.pdf"), extract_images=True)
assert len(images) == 1
def test_image_cap_with_callback_stops_streaming_at_limit(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""The cap also short-circuits the streaming callback path."""
monkeypatch.setattr(extract_file_text, "MAX_EMBEDDED_IMAGES_PER_FILE", 0)
collected: list[tuple[bytes, str]] = []
def callback(data: bytes, name: str) -> None:
collected.append((data, name))
read_pdf_file(
_load("with_image.pdf"), extract_images=True, image_callback=callback
)
assert collected == []
# ── count_pdf_embedded_images ────────────────────────────────────────────
class TestCountPdfEmbeddedImages:
def test_returns_count_for_normal_pdf(self) -> None:
assert count_pdf_embedded_images(_load("with_image.pdf"), cap=10) == 1
def test_short_circuits_above_cap(self) -> None:
# with_image.pdf has 1 image. cap=0 means "anything > 0 is over cap" —
# function returns on first increment as the over-cap sentinel.
assert count_pdf_embedded_images(_load("with_image.pdf"), cap=0) == 1
def test_returns_zero_for_pdf_without_images(self) -> None:
assert count_pdf_embedded_images(_load("simple.pdf"), cap=10) == 0
def test_returns_zero_for_invalid_pdf(self) -> None:
assert count_pdf_embedded_images(BytesIO(b"not a pdf"), cap=10) == 0
def test_returns_zero_for_password_locked_pdf(self) -> None:
# encrypted.pdf has an open password; we can't inspect without it, so
# the helper returns 0 — callers rely on the password-protected check
# that runs earlier in the upload pipeline.
assert count_pdf_embedded_images(_load("encrypted.pdf"), cap=10) == 0
def test_inspects_owner_password_only_pdf(self) -> None:
# owner_protected.pdf is encrypted but has no open password. It should
# decrypt with an empty string and count images normally. The fixture
# has zero images, so 0 is a real count (not the "bail on encrypted"
# path).
assert count_pdf_embedded_images(_load("owner_protected.pdf"), cap=10) == 0
def test_preserves_file_position(self) -> None:
pdf = _load("with_image.pdf")
pdf.seek(42)
count_pdf_embedded_images(pdf, cap=10)
assert pdf.tell() == 42
# ── pdf_to_text ──────────────────────────────────────────────────────────

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
@@ -9,7 +10,9 @@ from uuid import uuid4
from fastapi import Response
from sqlalchemy.exc import IntegrityError
from ee.onyx.server.scim.api import _check_seat_availability
from ee.onyx.server.scim.api import _scim_name_to_str
from ee.onyx.server.scim.api import _seat_lock_id_for_tenant
from ee.onyx.server.scim.api import create_user
from ee.onyx.server.scim.api import delete_user
from ee.onyx.server.scim.api import get_user
@@ -741,3 +744,80 @@ class TestEmailCasePreservation:
resource = parse_scim_user(result)
assert resource.userName == "Alice@Example.COM"
assert resource.emails[0].value == "Alice@Example.COM"
class TestSeatLock:
"""Tests for the advisory lock in _check_seat_availability."""
@patch("ee.onyx.server.scim.api.get_current_tenant_id", return_value="tenant_abc")
def test_acquires_advisory_lock_before_checking(
self,
_mock_tenant: MagicMock,
mock_dal: MagicMock,
) -> None:
"""The advisory lock must be acquired before the seat check runs."""
call_order: list[str] = []
def track_execute(stmt: Any, _params: Any = None) -> None:
if "pg_advisory_xact_lock" in str(stmt):
call_order.append("lock")
mock_dal.session.execute.side_effect = track_execute
with patch(
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop"
) as mock_fetch:
mock_result = MagicMock()
mock_result.available = True
mock_fn = MagicMock(return_value=mock_result)
mock_fetch.return_value = mock_fn
def track_check(*_args: Any, **_kwargs: Any) -> Any:
call_order.append("check")
return mock_result
mock_fn.side_effect = track_check
_check_seat_availability(mock_dal)
assert call_order == ["lock", "check"]
@patch("ee.onyx.server.scim.api.get_current_tenant_id", return_value="tenant_xyz")
def test_lock_uses_tenant_scoped_key(
self,
_mock_tenant: MagicMock,
mock_dal: MagicMock,
) -> None:
"""The lock id must be derived from the tenant via _seat_lock_id_for_tenant."""
mock_result = MagicMock()
mock_result.available = True
mock_check = MagicMock(return_value=mock_result)
with patch(
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop",
return_value=mock_check,
):
_check_seat_availability(mock_dal)
mock_dal.session.execute.assert_called_once()
params = mock_dal.session.execute.call_args[0][1]
assert params["lock_id"] == _seat_lock_id_for_tenant("tenant_xyz")
def test_seat_lock_id_is_stable_and_tenant_scoped(self) -> None:
"""Lock id must be deterministic and differ across tenants."""
assert _seat_lock_id_for_tenant("t1") == _seat_lock_id_for_tenant("t1")
assert _seat_lock_id_for_tenant("t1") != _seat_lock_id_for_tenant("t2")
def test_no_lock_when_ee_absent(
self,
mock_dal: MagicMock,
) -> None:
"""No advisory lock should be acquired when the EE check is absent."""
with patch(
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop",
return_value=None,
):
result = _check_seat_availability(mock_dal)
assert result is None
mock_dal.session.execute.assert_not_called()

View File

@@ -1,7 +1 @@
"use client";
import LLMConfigurationPage from "@/refresh-pages/admin/LLMConfigurationPage";
export default function Page() {
return <LLMConfigurationPage />;
}
export { default } from "@/refresh-pages/admin/LLMProviderConfigurationPage";

View File

@@ -32,8 +32,10 @@ import {
OpenRouterFetchParams,
LiteLLMProxyFetchParams,
BifrostFetchParams,
OpenAICompatibleFetchParams,
OpenAICompatibleModelResponse,
} from "@/interfaces/llm";
import { SvgAws, SvgBifrost, SvgOpenrouter } from "@opal/icons";
import { SvgAws, SvgBifrost, SvgOpenrouter, SvgPlug } from "@opal/icons";
// Aggregator providers that host models from multiple vendors
export const AGGREGATOR_PROVIDERS = new Set([
@@ -44,6 +46,7 @@ export const AGGREGATOR_PROVIDERS = new Set([
"lm_studio",
"litellm_proxy",
"bifrost",
"openai_compatible",
"vertex_ai",
]);
@@ -82,6 +85,7 @@ export const getProviderIcon = (
openrouter: SvgOpenrouter,
litellm_proxy: LiteLLMIcon,
bifrost: SvgBifrost,
openai_compatible: SvgPlug,
vertex_ai: GeminiIcon,
};
@@ -411,6 +415,64 @@ export const fetchBifrostModels = async (
}
};
/**
* Fetches models from a generic OpenAI-compatible server.
* Uses snake_case params to match API structure.
*/
export const fetchOpenAICompatibleModels = async (
params: OpenAICompatibleFetchParams
): Promise<{ models: ModelConfiguration[]; error?: string }> => {
const apiBase = params.api_base;
if (!apiBase) {
return { models: [], error: "API Base is required" };
}
try {
const response = await fetch(
"/api/admin/llm/openai-compatible/available-models",
{
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
api_base: apiBase,
api_key: params.api_key,
provider_name: params.provider_name,
}),
signal: params.signal,
}
);
if (!response.ok) {
let errorMessage = "Failed to fetch models";
try {
const errorData = await response.json();
errorMessage = errorData.detail || errorData.message || errorMessage;
} catch {
// ignore JSON parsing errors
}
return { models: [], error: errorMessage };
}
const data: OpenAICompatibleModelResponse[] = await response.json();
const models: ModelConfiguration[] = data.map((modelData) => ({
name: modelData.name,
display_name: modelData.display_name,
is_visible: true,
max_input_tokens: modelData.max_input_tokens,
supports_image_input: modelData.supports_image_input,
supports_reasoning: modelData.supports_reasoning,
}));
return { models };
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : "Unknown error";
return { models: [], error: errorMessage };
}
};
/**
* Fetches LiteLLM Proxy models directly without any form state dependencies.
* Uses snake_case params to match API structure.
@@ -531,6 +593,13 @@ export const fetchModels = async (
provider_name: formValues.name,
signal,
});
case LLMProviderName.OPENAI_COMPATIBLE:
return fetchOpenAICompatibleModels({
api_base: formValues.api_base,
api_key: formValues.api_key,
provider_name: formValues.name,
signal,
});
default:
return { models: [], error: `Unknown provider: ${providerName}` };
}
@@ -545,6 +614,7 @@ export function canProviderFetchModels(providerName?: string) {
case LLMProviderName.OPENROUTER:
case LLMProviderName.LITELLM_PROXY:
case LLMProviderName.BIFROST:
case LLMProviderName.OPENAI_COMPATIBLE:
return true;
default:
return false;

View File

@@ -24,7 +24,6 @@ import {
} from "@/app/craft/onboarding/constants";
import { LLMProviderDescriptor } from "@/interfaces/llm";
import { LLM_PROVIDERS_ADMIN_URL } from "@/lib/llmConfig/constants";
import { buildOnboardingInitialValues as buildInitialValues } from "@/sections/modals/llmConfig/utils";
import { testApiKeyHelper } from "@/sections/modals/llmConfig/svc";
import OnboardingInfoPages from "@/app/craft/onboarding/components/OnboardingInfoPages";
import OnboardingUserInfo from "@/app/craft/onboarding/components/OnboardingUserInfo";
@@ -221,10 +220,8 @@ export default function BuildOnboardingModal({
setConnectionStatus("testing");
setErrorMessage("");
const baseValues = buildInitialValues();
const providerName = `build-mode-${currentProviderConfig.providerName}`;
const payload = {
...baseValues,
name: providerName,
provider: currentProviderConfig.providerName,
api_key: apiKey,

View File

@@ -272,6 +272,22 @@ export default function UserLibraryModal({
</Disabled>
</Section>
{/* The exact cap is controlled by the backend env var
MAX_EMBEDDED_IMAGES_PER_FILE (default 500). This copy is
deliberately vague so it doesn't drift if the limit is
tuned per-deployment; the precise number is surfaced in
the rejection error the server returns. */}
<Section
flexDirection="row"
justifyContent="end"
padding={0.5}
height="fit"
>
<Text secondaryBody text03>
PDFs with many embedded images may be rejected.
</Text>
</Section>
{isLoading ? (
<Section padding={2} height="fit">
<Text secondaryBody text03>

View File

@@ -4,6 +4,7 @@ import useSWR from "swr";
import { errorHandlingFetcher } from "@/lib/fetcher";
import {
LLMProviderDescriptor,
LLMProviderName,
LLMProviderResponse,
LLMProviderView,
WellKnownLLMProviderDescriptor,
@@ -136,14 +137,12 @@ export function useAdminLLMProviders() {
* Used inside individual provider modals to pre-populate model lists
* before the user has entered credentials.
*
* @param providerEndpoint - The provider's API endpoint name (e.g. "openai", "anthropic").
* @param providerName - The provider's API endpoint name (e.g. "openai", "anthropic").
* Pass `null` to suppress the request.
*/
export function useWellKnownLLMProvider(providerEndpoint: string | null) {
export function useWellKnownLLMProvider(providerName: LLMProviderName) {
const { data, error, isLoading } = useSWR<WellKnownLLMProviderDescriptor>(
providerEndpoint
? `/api/admin/llm/built-in/options/${providerEndpoint}`
: null,
providerName ? `/api/admin/llm/built-in/options/${providerName}` : null,
errorHandlingFetcher,
{
revalidateOnFocus: false,

View File

@@ -14,6 +14,7 @@ export enum LLMProviderName {
BEDROCK = "bedrock",
LITELLM_PROXY = "litellm_proxy",
BIFROST = "bifrost",
OPENAI_COMPATIBLE = "openai_compatible",
CUSTOM = "custom",
}
@@ -123,14 +124,11 @@ export interface LLMProviderFormProps {
shouldMarkAsDefault?: boolean;
open?: boolean;
onOpenChange?: (open: boolean) => void;
/** The current default model name for this provider (from the global default). */
defaultModelName?: string;
/** Called after successful provider creation/update. */
onSuccess?: () => void | Promise<void>;
// Onboarding-specific (only when variant === "onboarding")
onboardingState?: OnboardingState;
onboardingActions?: OnboardingActions;
llmDescriptor?: WellKnownLLMProviderDescriptor;
}
// Param types for model fetching functions - use snake_case to match API structure
@@ -181,6 +179,21 @@ export interface BifrostModelResponse {
supports_reasoning: boolean;
}
export interface OpenAICompatibleFetchParams {
api_base?: string;
api_key?: string;
provider_name?: string;
signal?: AbortSignal;
}
export interface OpenAICompatibleModelResponse {
name: string;
display_name: string;
max_input_tokens: number | null;
supports_image_input: boolean;
supports_reasoning: boolean;
}
export interface VertexAIFetchParams {
model_configurations?: ModelConfiguration[];
}
@@ -199,5 +212,6 @@ export type FetchModelsParams =
| OpenRouterFetchParams
| LiteLLMProxyFetchParams
| BifrostFetchParams
| OpenAICompatibleFetchParams
| VertexAIFetchParams
| LMStudioFetchParams;

View File

@@ -1,8 +1,9 @@
"use client";
import type { RichStr } from "@opal/types";
import type { RichStr, WithoutStyles } from "@opal/types";
import { resolveStr } from "@opal/components/text/InlineMarkdown";
import Text from "@/refresh-components/texts/Text";
import Separator from "@/refresh-components/Separator";
import { SvgXOctagon, SvgAlertCircle } from "@opal/icons";
import { useField, useFormikContext } from "formik";
import { Section } from "@/layouts/general-layouts";
@@ -234,9 +235,27 @@ function ErrorTextLayout({ children, type = "error" }: ErrorTextLayoutProps) {
);
}
/**
* FieldSeparator - A horizontal rule with inline padding, used to visually separate field groups.
*/
function FieldSeparator() {
return <Separator noPadding className="p-2" />;
}
/**
* FieldPadder - Wraps a field in standard horizontal + vertical padding (`p-2 w-full`).
*/
type FieldPadderProps = WithoutStyles<React.HTMLAttributes<HTMLDivElement>>;
function FieldPadder(props: FieldPadderProps) {
return <div {...props} className="p-2 w-full" />;
}
export {
VerticalInputLayout as Vertical,
HorizontalInputLayout as Horizontal,
ErrorLayout as Error,
ErrorTextLayout,
FieldSeparator,
FieldPadder,
type FieldPadderProps,
};

View File

@@ -8,6 +8,7 @@ import {
SvgCloud,
SvgAws,
SvgOpenrouter,
SvgPlug,
SvgServer,
SvgAzure,
SvgGemini,
@@ -28,6 +29,7 @@ const PROVIDER_ICONS: Record<string, IconFunctionComponent> = {
[LLMProviderName.OPENROUTER]: SvgOpenrouter,
[LLMProviderName.LM_STUDIO]: SvgLmStudio,
[LLMProviderName.BIFROST]: SvgBifrost,
[LLMProviderName.OPENAI_COMPATIBLE]: SvgPlug,
// fallback
[LLMProviderName.CUSTOM]: SvgServer,
@@ -45,6 +47,7 @@ const PROVIDER_PRODUCT_NAMES: Record<string, string> = {
[LLMProviderName.OPENROUTER]: "OpenRouter",
[LLMProviderName.LM_STUDIO]: "LM Studio",
[LLMProviderName.BIFROST]: "Bifrost",
[LLMProviderName.OPENAI_COMPATIBLE]: "OpenAI Compatible",
// fallback
[LLMProviderName.CUSTOM]: "Custom Models",
@@ -62,6 +65,7 @@ const PROVIDER_DISPLAY_NAMES: Record<string, string> = {
[LLMProviderName.OPENROUTER]: "OpenRouter",
[LLMProviderName.LM_STUDIO]: "LM Studio",
[LLMProviderName.BIFROST]: "Bifrost",
[LLMProviderName.OPENAI_COMPATIBLE]: "OpenAI Compatible",
// fallback
[LLMProviderName.CUSTOM]: "Other providers or self-hosted",

View File

@@ -31,6 +31,7 @@ import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationMo
import { useCreateModal } from "@/refresh-components/contexts/ModalContext";
import Separator from "@/refresh-components/Separator";
import {
LLMProviderName,
LLMProviderView,
WellKnownLLMProviderDescriptor,
} from "@/interfaces/llm";
@@ -43,9 +44,10 @@ import BedrockModal from "@/sections/modals/llmConfig/BedrockModal";
import VertexAIModal from "@/sections/modals/llmConfig/VertexAIModal";
import OpenRouterModal from "@/sections/modals/llmConfig/OpenRouterModal";
import CustomModal from "@/sections/modals/llmConfig/CustomModal";
import LMStudioForm from "@/sections/modals/llmConfig/LMStudioForm";
import LMStudioModal from "@/sections/modals/llmConfig/LMStudioModal";
import LiteLLMProxyModal from "@/sections/modals/llmConfig/LiteLLMProxyModal";
import BifrostModal from "@/sections/modals/llmConfig/BifrostModal";
import OpenAICompatibleModal from "@/sections/modals/llmConfig/OpenAICompatibleModal";
import { Section } from "@/layouts/general-layouts";
const route = ADMIN_ROUTES.LLM_MODELS;
@@ -57,16 +59,18 @@ const route = ADMIN_ROUTES.LLM_MODELS;
// Client-side ordering for the "Add Provider" cards. The backend may return
// wellKnownLLMProviders in an arbitrary order, so we sort explicitly here.
const PROVIDER_DISPLAY_ORDER: string[] = [
"openai",
"anthropic",
"vertex_ai",
"bedrock",
"azure",
"litellm_proxy",
"ollama_chat",
"openrouter",
"lm_studio",
"bifrost",
LLMProviderName.OPENAI,
LLMProviderName.ANTHROPIC,
LLMProviderName.VERTEX_AI,
LLMProviderName.BEDROCK,
LLMProviderName.AZURE,
"litellm",
LLMProviderName.LITELLM_PROXY,
LLMProviderName.OLLAMA_CHAT,
LLMProviderName.OPENROUTER,
LLMProviderName.LM_STUDIO,
LLMProviderName.BIFROST,
LLMProviderName.OPENAI_COMPATIBLE,
];
const PROVIDER_MODAL_MAP: Record<
@@ -127,7 +131,7 @@ const PROVIDER_MODAL_MAP: Record<
/>
),
lm_studio: (d, open, onOpenChange) => (
<LMStudioForm
<LMStudioModal
shouldMarkAsDefault={d}
open={open}
onOpenChange={onOpenChange}
@@ -147,6 +151,13 @@ const PROVIDER_MODAL_MAP: Record<
onOpenChange={onOpenChange}
/>
),
openai_compatible: (d, open, onOpenChange) => (
<OpenAICompatibleModal
shouldMarkAsDefault={d}
open={open}
onOpenChange={onOpenChange}
/>
),
};
// ============================================================================
@@ -341,7 +352,7 @@ function NewCustomProviderCard({
// LLMConfigurationPage — main page component
// ============================================================================
export default function LLMConfigurationPage() {
export default function LLMProviderConfigurationPage() {
const { mutate } = useSWRConfig();
const { llmProviders: existingLlmProviders, defaultText } =
useAdminLLMProviders();

View File

@@ -1,33 +1,23 @@
"use client";
import { useState } from "react";
import { useSWRConfig } from "swr";
import { Formik } from "formik";
import { LLMProviderFormProps } from "@/interfaces/llm";
import * as Yup from "yup";
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
import { LLMProviderFormProps, LLMProviderName } from "@/interfaces/llm";
import {
buildDefaultInitialValues,
buildDefaultValidationSchema,
buildAvailableModelConfigurations,
buildOnboardingInitialValues,
useInitialValues,
buildValidationSchema,
} from "@/sections/modals/llmConfig/utils";
import {
submitLLMProvider,
submitOnboardingProvider,
} from "@/sections/modals/llmConfig/svc";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
import {
APIKeyField,
ModelsField,
ModelSelectionField,
DisplayNameField,
ModelsAccessField,
FieldSeparator,
SingleDefaultModelField,
LLMConfigurationModalWrapper,
ModelAccessField,
ModalWrapper,
} from "@/sections/modals/llmConfig/shared";
const ANTHROPIC_PROVIDER_NAME = "anthropic";
const DEFAULT_DEFAULT_MODEL_NAME = "claude-sonnet-4-5";
import * as InputLayouts from "@/layouts/input-layouts";
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
import { toast } from "@/hooks/useToast";
export default function AnthropicModal({
variant = "llm-configuration",
@@ -35,143 +25,78 @@ export default function AnthropicModal({
shouldMarkAsDefault,
open,
onOpenChange,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
onSuccess,
}: LLMProviderFormProps) {
const isOnboarding = variant === "onboarding";
const [isTesting, setIsTesting] = useState(false);
const { mutate } = useSWRConfig();
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
ANTHROPIC_PROVIDER_NAME
const initialValues = useInitialValues(
isOnboarding,
LLMProviderName.ANTHROPIC,
existingLlmProvider
);
if (open === false) return null;
const validationSchema = buildValidationSchema(isOnboarding, {
apiKey: true,
});
const onClose = () => onOpenChange?.(false);
const modelConfigurations = buildAvailableModelConfigurations(
existingLlmProvider,
wellKnownLLMProvider ?? llmDescriptor
);
const initialValues = isOnboarding
? {
...buildOnboardingInitialValues(),
name: ANTHROPIC_PROVIDER_NAME,
provider: ANTHROPIC_PROVIDER_NAME,
api_key: "",
default_model_name: DEFAULT_DEFAULT_MODEL_NAME,
}
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
api_key: existingLlmProvider?.api_key ?? "",
api_base: existingLlmProvider?.api_base ?? undefined,
default_model_name:
(defaultModelName &&
modelConfigurations.some((m) => m.name === defaultModelName)
? defaultModelName
: undefined) ??
wellKnownLLMProvider?.recommended_default_model?.name ??
DEFAULT_DEFAULT_MODEL_NAME,
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,
};
const validationSchema = isOnboarding
? Yup.object().shape({
api_key: Yup.string().required("API Key is required"),
default_model_name: Yup.string().required("Model name is required"),
})
: buildDefaultValidationSchema().shape({
api_key: Yup.string().required("API Key is required"),
});
if (open === false) return null;
return (
<Formik
<ModalWrapper
providerName={LLMProviderName.ANTHROPIC}
llmProvider={existingLlmProvider}
onClose={onClose}
initialValues={initialValues}
validationSchema={validationSchema}
validateOnMount={true}
onSubmit={async (values, { setSubmitting }) => {
if (isOnboarding && onboardingState && onboardingActions) {
const modelConfigsToUse =
(wellKnownLLMProvider ?? llmDescriptor)?.known_models ?? [];
await submitOnboardingProvider({
providerName: ANTHROPIC_PROVIDER_NAME,
payload: {
...values,
model_configurations: modelConfigsToUse,
is_auto_mode:
values.default_model_name === DEFAULT_DEFAULT_MODEL_NAME,
},
onboardingState,
onboardingActions,
isCustomProvider: false,
onClose,
setIsSubmitting: setSubmitting,
});
} else {
await submitLLMProvider({
providerName: ANTHROPIC_PROVIDER_NAME,
values,
initialValues,
modelConfigurations,
existingLlmProvider,
shouldMarkAsDefault,
setIsTesting,
mutate,
onClose,
setSubmitting,
});
}
onSubmit={async (values, { setSubmitting, setStatus }) => {
await submitProvider({
analyticsSource: isOnboarding
? LLMProviderConfiguredSource.CHAT_ONBOARDING
: LLMProviderConfiguredSource.ADMIN_PAGE,
providerName: LLMProviderName.ANTHROPIC,
values,
initialValues,
existingLlmProvider,
shouldMarkAsDefault,
setStatus,
setSubmitting,
onClose,
onSuccess: async () => {
if (onSuccess) {
await onSuccess();
} else {
await refreshLlmProviderCaches(mutate);
toast.success(
existingLlmProvider
? "Provider updated successfully!"
: "Provider enabled successfully!"
);
}
},
});
}}
>
{(formikProps) => (
<LLMConfigurationModalWrapper
providerEndpoint={ANTHROPIC_PROVIDER_NAME}
existingProviderName={existingLlmProvider?.name}
onClose={onClose}
isFormValid={formikProps.isValid}
isDirty={formikProps.dirty}
isTesting={isTesting}
isSubmitting={formikProps.isSubmitting}
>
<APIKeyField providerName="Anthropic" />
<APIKeyField providerName="Anthropic" />
{!isOnboarding && (
<>
<FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
)}
<FieldSeparator />
{isOnboarding ? (
<SingleDefaultModelField placeholder="E.g. claude-sonnet-4-5" />
) : (
<ModelsField
modelConfigurations={modelConfigurations}
formikProps={formikProps}
recommendedDefaultModel={
wellKnownLLMProvider?.recommended_default_model ?? null
}
shouldShowAutoUpdateToggle={true}
/>
)}
{!isOnboarding && (
<>
<FieldSeparator />
<ModelsAccessField formikProps={formikProps} />
</>
)}
</LLMConfigurationModalWrapper>
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
)}
</Formik>
<InputLayouts.FieldSeparator />
<ModelSelectionField shouldShowAutoUpdateToggle={true} />
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<ModelAccessField />
</>
)}
</ModalWrapper>
);
}

View File

@@ -1,41 +1,35 @@
"use client";
import { useState } from "react";
import { useSWRConfig } from "swr";
import { Formik } from "formik";
import { useFormikContext } from "formik";
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
import * as InputLayouts from "@/layouts/input-layouts";
import { LLMProviderFormProps, LLMProviderView } from "@/interfaces/llm";
import * as Yup from "yup";
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
import {
buildDefaultInitialValues,
buildDefaultValidationSchema,
buildAvailableModelConfigurations,
buildOnboardingInitialValues,
LLMProviderFormProps,
LLMProviderName,
LLMProviderView,
} from "@/interfaces/llm";
import * as Yup from "yup";
import {
useInitialValues,
buildValidationSchema,
BaseLLMFormValues,
} from "@/sections/modals/llmConfig/utils";
import {
submitLLMProvider,
submitOnboardingProvider,
} from "@/sections/modals/llmConfig/svc";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
import {
APIKeyField,
DisplayNameField,
FieldSeparator,
FieldWrapper,
ModelsAccessField,
ModelsField,
SingleDefaultModelField,
LLMConfigurationModalWrapper,
ModelAccessField,
ModelSelectionField,
ModalWrapper,
} from "@/sections/modals/llmConfig/shared";
import {
isValidAzureTargetUri,
parseAzureTargetUri,
} from "@/lib/azureTargetUri";
import { toast } from "@/hooks/useToast";
const AZURE_PROVIDER_NAME = "azure";
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
interface AzureModalValues extends BaseLLMFormValues {
api_key: string;
@@ -45,6 +39,33 @@ interface AzureModalValues extends BaseLLMFormValues {
deployment_name?: string;
}
function AzureModelSelection() {
const formikProps = useFormikContext<AzureModalValues>();
return (
<ModelSelectionField
shouldShowAutoUpdateToggle={false}
onAddModel={(modelName) => {
const current = formikProps.values.model_configurations;
if (current.some((m) => m.name === modelName)) return;
const updated = [
...current,
{
name: modelName,
is_visible: true,
max_input_tokens: null,
supports_image_input: false,
supports_reasoning: false,
},
];
formikProps.setFieldValue("model_configurations", updated);
if (!formikProps.values.test_model_name) {
formikProps.setFieldValue("test_model_name", modelName);
}
}}
/>
);
}
function buildTargetUri(existingLlmProvider?: LLMProviderView): string {
if (!existingLlmProvider?.api_base || !existingLlmProvider?.api_version) {
return "";
@@ -81,160 +102,105 @@ export default function AzureModal({
shouldMarkAsDefault,
open,
onOpenChange,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
onSuccess,
}: LLMProviderFormProps) {
const isOnboarding = variant === "onboarding";
const [isTesting, setIsTesting] = useState(false);
const { mutate } = useSWRConfig();
const { wellKnownLLMProvider } = useWellKnownLLMProvider(AZURE_PROVIDER_NAME);
if (open === false) return null;
const initialValues: AzureModalValues = {
...useInitialValues(
isOnboarding,
LLMProviderName.AZURE,
existingLlmProvider
),
target_uri: buildTargetUri(existingLlmProvider),
} as AzureModalValues;
const validationSchema = buildValidationSchema(isOnboarding, {
apiKey: true,
extra: {
target_uri: Yup.string()
.required("Target URI is required")
.test(
"valid-target-uri",
"Target URI must be a valid URL with api-version query parameter and either a deployment name in the path or /openai/responses",
(value) => (value ? isValidAzureTargetUri(value) : false)
),
},
});
const onClose = () => onOpenChange?.(false);
const modelConfigurations = buildAvailableModelConfigurations(
existingLlmProvider,
wellKnownLLMProvider ?? llmDescriptor
);
const initialValues: AzureModalValues = isOnboarding
? ({
...buildOnboardingInitialValues(),
name: AZURE_PROVIDER_NAME,
provider: AZURE_PROVIDER_NAME,
api_key: "",
target_uri: "",
default_model_name: "",
} as AzureModalValues)
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
api_key: existingLlmProvider?.api_key ?? "",
target_uri: buildTargetUri(existingLlmProvider),
};
const validationSchema = isOnboarding
? Yup.object().shape({
api_key: Yup.string().required("API Key is required"),
target_uri: Yup.string()
.required("Target URI is required")
.test(
"valid-target-uri",
"Target URI must be a valid URL with api-version query parameter and either a deployment name in the path or /openai/responses",
(value) => (value ? isValidAzureTargetUri(value) : false)
),
default_model_name: Yup.string().required("Model name is required"),
})
: buildDefaultValidationSchema().shape({
api_key: Yup.string().required("API Key is required"),
target_uri: Yup.string()
.required("Target URI is required")
.test(
"valid-target-uri",
"Target URI must be a valid URL with api-version query parameter and either a deployment name in the path or /openai/responses",
(value) => (value ? isValidAzureTargetUri(value) : false)
),
});
if (open === false) return null;
return (
<Formik
<ModalWrapper
providerName={LLMProviderName.AZURE}
llmProvider={existingLlmProvider}
onClose={onClose}
initialValues={initialValues}
validationSchema={validationSchema}
validateOnMount={true}
onSubmit={async (values, { setSubmitting }) => {
onSubmit={async (values, { setSubmitting, setStatus }) => {
const processedValues = processValues(values);
if (isOnboarding && onboardingState && onboardingActions) {
const modelConfigsToUse =
(wellKnownLLMProvider ?? llmDescriptor)?.known_models ?? [];
await submitOnboardingProvider({
providerName: AZURE_PROVIDER_NAME,
payload: {
...processedValues,
model_configurations: modelConfigsToUse,
},
onboardingState,
onboardingActions,
isCustomProvider: false,
onClose,
setIsSubmitting: setSubmitting,
});
} else {
await submitLLMProvider({
providerName: AZURE_PROVIDER_NAME,
values: processedValues,
initialValues,
modelConfigurations,
existingLlmProvider,
shouldMarkAsDefault,
setIsTesting,
mutate,
onClose,
setSubmitting,
});
}
await submitProvider({
analyticsSource: isOnboarding
? LLMProviderConfiguredSource.CHAT_ONBOARDING
: LLMProviderConfiguredSource.ADMIN_PAGE,
providerName: LLMProviderName.AZURE,
values: processedValues,
initialValues,
existingLlmProvider,
shouldMarkAsDefault,
setStatus,
setSubmitting,
onClose,
onSuccess: async () => {
if (onSuccess) {
await onSuccess();
} else {
await refreshLlmProviderCaches(mutate);
toast.success(
existingLlmProvider
? "Provider updated successfully!"
: "Provider enabled successfully!"
);
}
},
});
}}
>
{(formikProps) => (
<LLMConfigurationModalWrapper
providerEndpoint={AZURE_PROVIDER_NAME}
existingProviderName={existingLlmProvider?.name}
onClose={onClose}
isFormValid={formikProps.isValid}
isDirty={formikProps.dirty}
isTesting={isTesting}
isSubmitting={formikProps.isSubmitting}
<InputLayouts.FieldPadder>
<InputLayouts.Vertical
name="target_uri"
title="Target URI"
subDescription="Paste your endpoint target URI from Azure OpenAI (including API endpoint base, deployment name, and API version)."
>
<FieldWrapper>
<InputLayouts.Vertical
name="target_uri"
title="Target URI"
subDescription="Paste your endpoint target URI from Azure OpenAI (including API endpoint base, deployment name, and API version)."
>
<InputTypeInField
name="target_uri"
placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview"
/>
</InputLayouts.Vertical>
</FieldWrapper>
<InputTypeInField
name="target_uri"
placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview"
/>
</InputLayouts.Vertical>
</InputLayouts.FieldPadder>
<APIKeyField providerName="Azure" />
<APIKeyField providerName="Azure" />
{!isOnboarding && (
<>
<FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
)}
<FieldSeparator />
{isOnboarding ? (
<SingleDefaultModelField placeholder="E.g. gpt-4o" />
) : (
<ModelsField
modelConfigurations={modelConfigurations}
formikProps={formikProps}
recommendedDefaultModel={null}
shouldShowAutoUpdateToggle={false}
/>
)}
{!isOnboarding && (
<>
<FieldSeparator />
<ModelsAccessField formikProps={formikProps} />
</>
)}
</LLMConfigurationModalWrapper>
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
)}
</Formik>
<InputLayouts.FieldSeparator />
<AzureModelSelection />
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<ModelAccessField />
</>
)}
</ModalWrapper>
);
}

View File

@@ -1,8 +1,8 @@
"use client";
import { useState, useEffect } from "react";
import { useEffect } from "react";
import { useSWRConfig } from "swr";
import { Formik, FormikProps } from "formik";
import { useFormikContext } from "formik";
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
import InputSelectField from "@/refresh-components/form/InputSelectField";
import InputSelect from "@/refresh-components/inputs/InputSelect";
@@ -10,30 +10,22 @@ import * as InputLayouts from "@/layouts/input-layouts";
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
import {
LLMProviderFormProps,
LLMProviderName,
LLMProviderView,
ModelConfiguration,
} from "@/interfaces/llm";
import * as Yup from "yup";
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
import {
buildDefaultInitialValues,
buildDefaultValidationSchema,
buildAvailableModelConfigurations,
buildOnboardingInitialValues,
useInitialValues,
buildValidationSchema,
BaseLLMFormValues,
} from "@/sections/modals/llmConfig/utils";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
import {
submitLLMProvider,
submitOnboardingProvider,
} from "@/sections/modals/llmConfig/svc";
import {
ModelsField,
ModelSelectionField,
DisplayNameField,
FieldSeparator,
FieldWrapper,
ModelsAccessField,
SingleDefaultModelField,
LLMConfigurationModalWrapper,
ModelAccessField,
ModalWrapper,
} from "@/sections/modals/llmConfig/shared";
import { fetchBedrockModels } from "@/app/admin/configuration/llm/utils";
import { Card } from "@opal/components";
@@ -41,9 +33,9 @@ import { Section } from "@/layouts/general-layouts";
import { SvgAlertCircle } from "@opal/icons";
import { Content } from "@opal/layouts";
import { toast } from "@/hooks/useToast";
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
import useOnMount from "@/hooks/useOnMount";
const BEDROCK_PROVIDER_NAME = "bedrock";
const AWS_REGION_OPTIONS = [
{ name: "us-east-1", value: "us-east-1" },
{ name: "us-east-2", value: "us-east-2" },
@@ -79,26 +71,15 @@ interface BedrockModalValues extends BaseLLMFormValues {
}
interface BedrockModalInternalsProps {
formikProps: FormikProps<BedrockModalValues>;
existingLlmProvider: LLMProviderView | undefined;
fetchedModels: ModelConfiguration[];
setFetchedModels: (models: ModelConfiguration[]) => void;
modelConfigurations: ModelConfiguration[];
isTesting: boolean;
onClose: () => void;
isOnboarding: boolean;
}
function BedrockModalInternals({
formikProps,
existingLlmProvider,
fetchedModels,
setFetchedModels,
modelConfigurations,
isTesting,
onClose,
isOnboarding,
}: BedrockModalInternalsProps) {
const formikProps = useFormikContext<BedrockModalValues>();
const authMethod = formikProps.values.custom_config?.BEDROCK_AUTH_METHOD;
useEffect(() => {
@@ -115,11 +96,6 @@ function BedrockModalInternals({
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [authMethod]);
const currentModels =
fetchedModels.length > 0
? fetchedModels
: existingLlmProvider?.model_configurations || modelConfigurations;
const isAuthComplete =
authMethod === AUTH_METHOD_IAM ||
(authMethod === AUTH_METHOD_ACCESS_KEY &&
@@ -139,12 +115,12 @@ function BedrockModalInternals({
formikProps.values.custom_config?.AWS_SECRET_ACCESS_KEY,
aws_bearer_token_bedrock:
formikProps.values.custom_config?.AWS_BEARER_TOKEN_BEDROCK,
provider_name: existingLlmProvider?.name,
provider_name: LLMProviderName.BEDROCK,
});
if (error) {
throw new Error(error);
}
setFetchedModels(models);
formikProps.setFieldValue("model_configurations", models);
};
// Auto-fetch models on initial load when editing an existing provider
@@ -159,16 +135,8 @@ function BedrockModalInternals({
});
return (
<LLMConfigurationModalWrapper
providerEndpoint={BEDROCK_PROVIDER_NAME}
existingProviderName={existingLlmProvider?.name}
onClose={onClose}
isFormValid={formikProps.isValid}
isDirty={formikProps.dirty}
isTesting={isTesting}
isSubmitting={formikProps.isSubmitting}
>
<FieldWrapper>
<>
<InputLayouts.FieldPadder>
<Section gap={1}>
<InputLayouts.Vertical
name={FIELD_AWS_REGION_NAME}
@@ -222,7 +190,7 @@ function BedrockModalInternals({
</InputSelect>
</InputLayouts.Vertical>
</Section>
</FieldWrapper>
</InputLayouts.FieldPadder>
{authMethod === AUTH_METHOD_ACCESS_KEY && (
<Card backgroundVariant="light" borderVariant="none" sizeVariant="lg">
@@ -250,7 +218,7 @@ function BedrockModalInternals({
)}
{authMethod === AUTH_METHOD_IAM && (
<FieldWrapper>
<InputLayouts.FieldPadder>
<Card backgroundVariant="none" borderVariant="solid">
<Content
icon={SvgAlertCircle}
@@ -259,7 +227,7 @@ function BedrockModalInternals({
sizePreset="main-ui"
/>
</Card>
</FieldWrapper>
</InputLayouts.FieldPadder>
)}
{authMethod === AUTH_METHOD_LONG_TERM_API_KEY && (
@@ -280,32 +248,24 @@ function BedrockModalInternals({
{!isOnboarding && (
<>
<FieldSeparator />
<InputLayouts.FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
)}
<FieldSeparator />
{isOnboarding ? (
<SingleDefaultModelField placeholder="E.g. us.anthropic.claude-sonnet-4-5-v1" />
) : (
<ModelsField
modelConfigurations={currentModels}
formikProps={formikProps}
recommendedDefaultModel={null}
shouldShowAutoUpdateToggle={false}
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
/>
)}
<InputLayouts.FieldSeparator />
<ModelSelectionField
shouldShowAutoUpdateToggle={false}
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
/>
{!isOnboarding && (
<>
<FieldSeparator />
<ModelsAccessField formikProps={formikProps} />
<InputLayouts.FieldSeparator />
<ModelAccessField />
</>
)}
</LLMConfigurationModalWrapper>
</>
);
}
@@ -315,86 +275,54 @@ export default function BedrockModal({
shouldMarkAsDefault,
open,
onOpenChange,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
onSuccess,
}: LLMProviderFormProps) {
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
const [isTesting, setIsTesting] = useState(false);
const isOnboarding = variant === "onboarding";
const { mutate } = useSWRConfig();
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
BEDROCK_PROVIDER_NAME
);
if (open === false) return null;
const initialValues: BedrockModalValues = {
...useInitialValues(
isOnboarding,
LLMProviderName.BEDROCK,
existingLlmProvider
),
custom_config: {
AWS_REGION_NAME:
(existingLlmProvider?.custom_config?.AWS_REGION_NAME as string) ?? "",
BEDROCK_AUTH_METHOD:
(existingLlmProvider?.custom_config?.BEDROCK_AUTH_METHOD as string) ??
"access_key",
AWS_ACCESS_KEY_ID:
(existingLlmProvider?.custom_config?.AWS_ACCESS_KEY_ID as string) ?? "",
AWS_SECRET_ACCESS_KEY:
(existingLlmProvider?.custom_config?.AWS_SECRET_ACCESS_KEY as string) ??
"",
AWS_BEARER_TOKEN_BEDROCK:
(existingLlmProvider?.custom_config
?.AWS_BEARER_TOKEN_BEDROCK as string) ?? "",
},
} as BedrockModalValues;
const validationSchema = buildValidationSchema(isOnboarding, {
extra: {
custom_config: Yup.object({
AWS_REGION_NAME: Yup.string().required("AWS Region is required"),
}),
},
});
const onClose = () => onOpenChange?.(false);
const modelConfigurations = buildAvailableModelConfigurations(
existingLlmProvider,
wellKnownLLMProvider ?? llmDescriptor
);
const initialValues: BedrockModalValues = isOnboarding
? ({
...buildOnboardingInitialValues(),
name: BEDROCK_PROVIDER_NAME,
provider: BEDROCK_PROVIDER_NAME,
default_model_name: "",
custom_config: {
AWS_REGION_NAME: "",
BEDROCK_AUTH_METHOD: "access_key",
AWS_ACCESS_KEY_ID: "",
AWS_SECRET_ACCESS_KEY: "",
AWS_BEARER_TOKEN_BEDROCK: "",
},
} as BedrockModalValues)
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
custom_config: {
AWS_REGION_NAME:
(existingLlmProvider?.custom_config?.AWS_REGION_NAME as string) ??
"",
BEDROCK_AUTH_METHOD:
(existingLlmProvider?.custom_config
?.BEDROCK_AUTH_METHOD as string) ?? "access_key",
AWS_ACCESS_KEY_ID:
(existingLlmProvider?.custom_config?.AWS_ACCESS_KEY_ID as string) ??
"",
AWS_SECRET_ACCESS_KEY:
(existingLlmProvider?.custom_config
?.AWS_SECRET_ACCESS_KEY as string) ?? "",
AWS_BEARER_TOKEN_BEDROCK:
(existingLlmProvider?.custom_config
?.AWS_BEARER_TOKEN_BEDROCK as string) ?? "",
},
};
const validationSchema = isOnboarding
? Yup.object().shape({
default_model_name: Yup.string().required("Model name is required"),
custom_config: Yup.object({
AWS_REGION_NAME: Yup.string().required("AWS Region is required"),
}),
})
: buildDefaultValidationSchema().shape({
custom_config: Yup.object({
AWS_REGION_NAME: Yup.string().required("AWS Region is required"),
}),
});
if (open === false) return null;
return (
<Formik
<ModalWrapper
providerName={LLMProviderName.BEDROCK}
llmProvider={existingLlmProvider}
onClose={onClose}
initialValues={initialValues}
validationSchema={validationSchema}
validateOnMount={true}
onSubmit={async (values, { setSubmitting }) => {
onSubmit={async (values, { setSubmitting, setStatus }) => {
const filteredCustomConfig = Object.fromEntries(
Object.entries(values.custom_config || {}).filter(([, v]) => v !== "")
);
@@ -407,51 +335,37 @@ export default function BedrockModal({
: undefined,
};
if (isOnboarding && onboardingState && onboardingActions) {
const modelConfigsToUse =
fetchedModels.length > 0 ? fetchedModels : [];
await submitOnboardingProvider({
providerName: BEDROCK_PROVIDER_NAME,
payload: {
...submitValues,
model_configurations: modelConfigsToUse,
},
onboardingState,
onboardingActions,
isCustomProvider: false,
onClose,
setIsSubmitting: setSubmitting,
});
} else {
await submitLLMProvider({
providerName: BEDROCK_PROVIDER_NAME,
values: submitValues,
initialValues,
modelConfigurations:
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
existingLlmProvider,
shouldMarkAsDefault,
setIsTesting,
mutate,
onClose,
setSubmitting,
});
}
await submitProvider({
analyticsSource: isOnboarding
? LLMProviderConfiguredSource.CHAT_ONBOARDING
: LLMProviderConfiguredSource.ADMIN_PAGE,
providerName: LLMProviderName.BEDROCK,
values: submitValues,
initialValues,
existingLlmProvider,
shouldMarkAsDefault,
setStatus,
setSubmitting,
onClose,
onSuccess: async () => {
if (onSuccess) {
await onSuccess();
} else {
await refreshLlmProviderCaches(mutate);
toast.success(
existingLlmProvider
? "Provider updated successfully!"
: "Provider enabled successfully!"
);
}
},
});
}}
>
{(formikProps) => (
<BedrockModalInternals
formikProps={formikProps}
existingLlmProvider={existingLlmProvider}
fetchedModels={fetchedModels}
setFetchedModels={setFetchedModels}
modelConfigurations={modelConfigurations}
isTesting={isTesting}
onClose={onClose}
isOnboarding={isOnboarding}
/>
)}
</Formik>
<BedrockModalInternals
existingLlmProvider={existingLlmProvider}
isOnboarding={isOnboarding}
/>
</ModalWrapper>
);
}

View File

@@ -1,45 +1,33 @@
"use client";
import { useState, useEffect } from "react";
import { useEffect } from "react";
import { markdown } from "@opal/utils";
import { useSWRConfig } from "swr";
import { Formik, FormikProps } from "formik";
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
import { useFormikContext } from "formik";
import * as InputLayouts from "@/layouts/input-layouts";
import {
LLMProviderFormProps,
LLMProviderName,
LLMProviderView,
ModelConfiguration,
} from "@/interfaces/llm";
import { fetchBifrostModels } from "@/app/admin/configuration/llm/utils";
import * as Yup from "yup";
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
import {
buildDefaultInitialValues,
buildDefaultValidationSchema,
buildAvailableModelConfigurations,
buildOnboardingInitialValues,
useInitialValues,
buildValidationSchema,
BaseLLMFormValues,
} from "@/sections/modals/llmConfig/utils";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
import {
submitLLMProvider,
submitOnboardingProvider,
} from "@/sections/modals/llmConfig/svc";
import {
ModelsField,
APIBaseField,
APIKeyField,
ModelSelectionField,
DisplayNameField,
ModelsAccessField,
FieldSeparator,
FieldWrapper,
SingleDefaultModelField,
LLMConfigurationModalWrapper,
ModelAccessField,
ModalWrapper,
} from "@/sections/modals/llmConfig/shared";
import { toast } from "@/hooks/useToast";
const BIFROST_PROVIDER_NAME = LLMProviderName.BIFROST;
const DEFAULT_API_BASE = "";
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
interface BifrostModalValues extends BaseLLMFormValues {
api_key: string;
@@ -47,30 +35,15 @@ interface BifrostModalValues extends BaseLLMFormValues {
}
interface BifrostModalInternalsProps {
formikProps: FormikProps<BifrostModalValues>;
existingLlmProvider: LLMProviderView | undefined;
fetchedModels: ModelConfiguration[];
setFetchedModels: (models: ModelConfiguration[]) => void;
modelConfigurations: ModelConfiguration[];
isTesting: boolean;
onClose: () => void;
isOnboarding: boolean;
}
function BifrostModalInternals({
formikProps,
existingLlmProvider,
fetchedModels,
setFetchedModels,
modelConfigurations,
isTesting,
onClose,
isOnboarding,
}: BifrostModalInternalsProps) {
const currentModels =
fetchedModels.length > 0
? fetchedModels
: existingLlmProvider?.model_configurations || modelConfigurations;
const formikProps = useFormikContext<BifrostModalValues>();
const isFetchDisabled = !formikProps.values.api_base;
@@ -78,12 +51,12 @@ function BifrostModalInternals({
const { models, error } = await fetchBifrostModels({
api_base: formikProps.values.api_base,
api_key: formikProps.values.api_key || undefined,
provider_name: existingLlmProvider?.name,
provider_name: LLMProviderName.BIFROST,
});
if (error) {
throw new Error(error);
}
setFetchedModels(models);
formikProps.setFieldValue("model_configurations", models);
};
// Auto-fetch models on initial load when editing an existing provider
@@ -100,69 +73,39 @@ function BifrostModalInternals({
}, []);
return (
<LLMConfigurationModalWrapper
providerEndpoint={LLMProviderName.BIFROST}
existingProviderName={existingLlmProvider?.name}
onClose={onClose}
isFormValid={formikProps.isValid}
isDirty={formikProps.dirty}
isTesting={isTesting}
isSubmitting={formikProps.isSubmitting}
>
<FieldWrapper>
<InputLayouts.Vertical
name="api_base"
title="API Base URL"
subDescription="Paste your Bifrost gateway endpoint URL (including API version)."
>
<InputTypeInField
name="api_base"
placeholder="https://your-bifrost-gateway.com/v1"
/>
</InputLayouts.Vertical>
</FieldWrapper>
<>
<APIBaseField
subDescription="Paste your Bifrost gateway endpoint URL (including API version)."
placeholder="https://your-bifrost-gateway.com/v1"
/>
<FieldWrapper>
<InputLayouts.Vertical
name="api_key"
title="API Key"
optional={true}
subDescription={markdown(
"Paste your API key from [Bifrost](https://docs.getbifrost.ai/overview) to access your models."
)}
>
<PasswordInputTypeInField name="api_key" placeholder="API Key" />
</InputLayouts.Vertical>
</FieldWrapper>
<APIKeyField
optional
subDescription={markdown(
"Paste your API key from [Bifrost](https://docs.getbifrost.ai/overview) to access your models."
)}
/>
{!isOnboarding && (
<>
<FieldSeparator />
<InputLayouts.FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
)}
<FieldSeparator />
{isOnboarding ? (
<SingleDefaultModelField placeholder="E.g. anthropic/claude-sonnet-4-6" />
) : (
<ModelsField
modelConfigurations={currentModels}
formikProps={formikProps}
recommendedDefaultModel={null}
shouldShowAutoUpdateToggle={false}
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
/>
)}
<InputLayouts.FieldSeparator />
<ModelSelectionField
shouldShowAutoUpdateToggle={false}
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
/>
{!isOnboarding && (
<>
<FieldSeparator />
<ModelsAccessField formikProps={formikProps} />
<InputLayouts.FieldSeparator />
<ModelAccessField />
</>
)}
</LLMConfigurationModalWrapper>
</>
);
}
@@ -172,107 +115,64 @@ export default function BifrostModal({
shouldMarkAsDefault,
open,
onOpenChange,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
onSuccess,
}: LLMProviderFormProps) {
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
const [isTesting, setIsTesting] = useState(false);
const isOnboarding = variant === "onboarding";
const { mutate } = useSWRConfig();
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
BIFROST_PROVIDER_NAME
);
if (open === false) return null;
const initialValues: BifrostModalValues = useInitialValues(
isOnboarding,
LLMProviderName.BIFROST,
existingLlmProvider
) as BifrostModalValues;
const validationSchema = buildValidationSchema(isOnboarding, {
apiBase: true,
});
const onClose = () => onOpenChange?.(false);
const modelConfigurations = buildAvailableModelConfigurations(
existingLlmProvider,
wellKnownLLMProvider ?? llmDescriptor
);
const initialValues: BifrostModalValues = isOnboarding
? ({
...buildOnboardingInitialValues(),
name: BIFROST_PROVIDER_NAME,
provider: BIFROST_PROVIDER_NAME,
api_key: "",
api_base: DEFAULT_API_BASE,
default_model_name: "",
} as BifrostModalValues)
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
api_key: existingLlmProvider?.api_key ?? "",
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
};
const validationSchema = isOnboarding
? Yup.object().shape({
api_base: Yup.string().required("API Base URL is required"),
default_model_name: Yup.string().required("Model name is required"),
})
: buildDefaultValidationSchema().shape({
api_base: Yup.string().required("API Base URL is required"),
});
if (open === false) return null;
return (
<Formik
<ModalWrapper
providerName={LLMProviderName.BIFROST}
llmProvider={existingLlmProvider}
onClose={onClose}
initialValues={initialValues}
validationSchema={validationSchema}
validateOnMount={true}
onSubmit={async (values, { setSubmitting }) => {
if (isOnboarding && onboardingState && onboardingActions) {
const modelConfigsToUse =
fetchedModels.length > 0 ? fetchedModels : [];
await submitOnboardingProvider({
providerName: BIFROST_PROVIDER_NAME,
payload: {
...values,
model_configurations: modelConfigsToUse,
},
onboardingState,
onboardingActions,
isCustomProvider: false,
onClose,
setIsSubmitting: setSubmitting,
});
} else {
await submitLLMProvider({
providerName: BIFROST_PROVIDER_NAME,
values,
initialValues,
modelConfigurations:
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
existingLlmProvider,
shouldMarkAsDefault,
setIsTesting,
mutate,
onClose,
setSubmitting,
});
}
onSubmit={async (values, { setSubmitting, setStatus }) => {
await submitProvider({
analyticsSource: isOnboarding
? LLMProviderConfiguredSource.CHAT_ONBOARDING
: LLMProviderConfiguredSource.ADMIN_PAGE,
providerName: LLMProviderName.BIFROST,
values,
initialValues,
existingLlmProvider,
shouldMarkAsDefault,
setStatus,
setSubmitting,
onClose,
onSuccess: async () => {
if (onSuccess) {
await onSuccess();
} else {
await refreshLlmProviderCaches(mutate);
toast.success(
existingLlmProvider
? "Provider updated successfully!"
: "Provider enabled successfully!"
);
}
},
});
}}
>
{(formikProps) => (
<BifrostModalInternals
formikProps={formikProps}
existingLlmProvider={existingLlmProvider}
fetchedModels={fetchedModels}
setFetchedModels={setFetchedModels}
modelConfigurations={modelConfigurations}
isTesting={isTesting}
onClose={onClose}
isOnboarding={isOnboarding}
/>
)}
</Formik>
<BifrostModalInternals
existingLlmProvider={existingLlmProvider}
isOnboarding={isOnboarding}
/>
</ModalWrapper>
);
}

View File

@@ -70,7 +70,9 @@ describe("Custom LLM Provider Configuration Workflow", () => {
}
) {
const nameInput = screen.getByPlaceholderText("Display Name");
const providerInput = screen.getByPlaceholderText("Provider Name");
const providerInput = screen.getByPlaceholderText(
"Provider Name as shown on LiteLLM"
);
await user.type(nameInput, options.name);
await user.type(providerInput, options.provider);
@@ -498,7 +500,9 @@ describe("Custom LLM Provider Configuration Workflow", () => {
const nameInput = screen.getByPlaceholderText("Display Name");
await user.type(nameInput, "Cloudflare Provider");
const providerInput = screen.getByPlaceholderText("Provider Name");
const providerInput = screen.getByPlaceholderText(
"Provider Name as shown on LiteLLM"
);
await user.type(providerInput, "cloudflare");
// Click "Add Line" button for custom config (aria-label from KeyValueInput)

View File

@@ -1,24 +1,23 @@
"use client";
import { useState } from "react";
import { markdown } from "@opal/utils";
import { useSWRConfig } from "swr";
import { Formik, FormikProps } from "formik";
import { LLMProviderFormProps, ModelConfiguration } from "@/interfaces/llm";
import { useFormikContext } from "formik";
import {
LLMProviderFormProps,
LLMProviderName,
ModelConfiguration,
} from "@/interfaces/llm";
import * as Yup from "yup";
import { useInitialValues } from "@/sections/modals/llmConfig/utils";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
import {
buildDefaultInitialValues,
buildOnboardingInitialValues,
} from "@/sections/modals/llmConfig/utils";
import {
submitLLMProvider,
submitOnboardingProvider,
} from "@/sections/modals/llmConfig/svc";
import {
APIKeyField,
APIBaseField,
DisplayNameField,
FieldSeparator,
ModelsAccessField,
LLMConfigurationModalWrapper,
FieldWrapper,
ModelAccessField,
ModalWrapper,
} from "@/sections/modals/llmConfig/shared";
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
import * as InputLayouts from "@/layouts/input-layouts";
@@ -32,6 +31,7 @@ import { Button, Card, EmptyMessageCard } from "@opal/components";
import { Disabled } from "@opal/core";
import { SvgMinusCircle, SvgPlusCircle } from "@opal/icons";
import { toast } from "@/hooks/useToast";
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
import { Content } from "@opal/layouts";
import { Section } from "@/layouts/general-layouts";
@@ -109,13 +109,10 @@ function ModelConfigurationItem({
);
}
interface ModelConfigurationListProps {
formikProps: FormikProps<{
function ModelConfigurationList() {
const formikProps = useFormikContext<{
model_configurations: CustomModelConfiguration[];
}>;
}
function ModelConfigurationList({ formikProps }: ModelConfigurationListProps) {
}>();
const models = formikProps.values.model_configurations;
function handleChange(index: number, next: CustomModelConfiguration) {
@@ -181,6 +178,19 @@ function ModelConfigurationList({ formikProps }: ModelConfigurationListProps) {
);
}
function CustomConfigKeyValue() {
const formikProps = useFormikContext<{ custom_config_list: KeyValue[] }>();
return (
<KeyValueInput
items={formikProps.values.custom_config_list}
onChange={(items) =>
formikProps.setFieldValue("custom_config_list", items)
}
addButtonLabel="Add Line"
/>
);
}
// ─── Custom Config Processing ─────────────────────────────────────────────────
function customConfigProcessing(items: KeyValue[]) {
@@ -197,39 +207,36 @@ export default function CustomModal({
shouldMarkAsDefault,
open,
onOpenChange,
defaultModelName,
onboardingState,
onboardingActions,
onSuccess,
}: LLMProviderFormProps) {
const isOnboarding = variant === "onboarding";
const [isTesting, setIsTesting] = useState(false);
const { mutate } = useSWRConfig();
if (open === false) return null;
const onClose = () => onOpenChange?.(false);
const initialValues = {
...buildDefaultInitialValues(
existingLlmProvider,
undefined,
defaultModelName
...useInitialValues(
isOnboarding,
LLMProviderName.CUSTOM,
existingLlmProvider
),
...(isOnboarding ? buildOnboardingInitialValues() : {}),
provider: existingLlmProvider?.provider ?? "",
api_version: existingLlmProvider?.api_version ?? "",
model_configurations: existingLlmProvider?.model_configurations.map(
(mc) => ({
name: mc.name,
display_name: mc.display_name ?? "",
is_visible: mc.is_visible,
max_input_tokens: mc.max_input_tokens ?? null,
supports_image_input: mc.supports_image_input,
supports_reasoning: mc.supports_reasoning,
})
) ?? [
{
name: "",
display_name: "",
is_visible: true,
max_input_tokens: null,
supports_image_input: false,
supports_reasoning: false,
},
],
custom_config_list: existingLlmProvider?.custom_config
@@ -260,12 +267,18 @@ export default function CustomModal({
model_configurations: Yup.array(modelConfigurationSchema),
});
const onClose = () => onOpenChange?.(false);
if (open === false) return null;
return (
<Formik
<ModalWrapper
providerName={LLMProviderName.CUSTOM}
llmProvider={existingLlmProvider}
onClose={onClose}
initialValues={initialValues}
validationSchema={validationSchema}
validateOnMount={true}
onSubmit={async (values, { setSubmitting }) => {
onSubmit={async (values, { setSubmitting, setStatus }) => {
setSubmitting(true);
const modelConfigurations = values.model_configurations
@@ -285,127 +298,123 @@ export default function CustomModal({
return;
}
if (isOnboarding && onboardingState && onboardingActions) {
await submitOnboardingProvider({
providerName: values.provider,
payload: {
...values,
model_configurations: modelConfigurations,
custom_config: customConfigProcessing(values.custom_config_list),
},
onboardingState,
onboardingActions,
isCustomProvider: true,
onClose,
setIsSubmitting: setSubmitting,
});
} else {
const selectedModelNames = modelConfigurations.map(
(config) => config.name
);
// Always send custom_config as a dict (even empty) so the backend
// preserves it as non-null — this is the signal that the provider was
// created via CustomModal.
const customConfig = customConfigProcessing(values.custom_config_list);
await submitLLMProvider({
providerName: values.provider,
values: {
...values,
selected_model_names: selectedModelNames,
custom_config: customConfigProcessing(values.custom_config_list),
},
initialValues: {
...initialValues,
custom_config: customConfigProcessing(
initialValues.custom_config_list
),
},
modelConfigurations,
existingLlmProvider,
shouldMarkAsDefault,
setIsTesting,
mutate,
onClose,
setSubmitting,
});
}
await submitProvider({
analyticsSource: isOnboarding
? LLMProviderConfiguredSource.CHAT_ONBOARDING
: LLMProviderConfiguredSource.ADMIN_PAGE,
providerName: (values as Record<string, unknown>).provider as string,
values: {
...values,
model_configurations: modelConfigurations,
custom_config: customConfig,
},
initialValues: {
...initialValues,
custom_config: customConfigProcessing(
initialValues.custom_config_list
),
},
existingLlmProvider,
shouldMarkAsDefault,
isCustomProvider: true,
setStatus,
setSubmitting,
onClose,
onSuccess: async () => {
if (onSuccess) {
await onSuccess();
} else {
await refreshLlmProviderCaches(mutate);
toast.success(
existingLlmProvider
? "Provider updated successfully!"
: "Provider enabled successfully!"
);
}
},
});
}}
>
{(formikProps) => (
<LLMConfigurationModalWrapper
providerEndpoint="custom"
existingProviderName={existingLlmProvider?.name}
onClose={onClose}
isFormValid={formikProps.isValid}
isDirty={formikProps.dirty}
isTesting={isTesting}
isSubmitting={formikProps.isSubmitting}
<InputLayouts.FieldPadder>
<InputLayouts.Vertical
name="provider"
title="Provider Name"
subDescription={markdown(
"Should be one of the providers listed at [LiteLLM](https://docs.litellm.ai/docs/providers)."
)}
>
{!isOnboarding && (
<Section gap={0}>
<DisplayNameField disabled={!!existingLlmProvider} />
<InputTypeInField
name="provider"
placeholder="Provider Name as shown on LiteLLM"
variant={existingLlmProvider ? "disabled" : undefined}
/>
</InputLayouts.Vertical>
</InputLayouts.FieldPadder>
<FieldWrapper>
<InputLayouts.Vertical
name="provider"
title="Provider Name"
subDescription="Should be one of the providers listed at https://docs.litellm.ai/docs/providers."
>
<InputTypeInField
name="provider"
placeholder="Provider Name"
variant={existingLlmProvider ? "disabled" : undefined}
/>
</InputLayouts.Vertical>
</FieldWrapper>
</Section>
)}
<APIBaseField optional />
<FieldSeparator />
<InputLayouts.FieldPadder>
<InputLayouts.Vertical name="api_version" title="API Version" optional>
<InputTypeInField name="api_version" />
</InputLayouts.Vertical>
</InputLayouts.FieldPadder>
<FieldWrapper>
<Section gap={0.75}>
<Content
title="Provider Configs"
description="Add properties as needed by the model provider. This is passed to LiteLLM completion() call as arguments in the environment variable. See LiteLLM documentation for more instructions."
widthVariant="full"
variant="section"
sizePreset="main-content"
/>
<APIKeyField
optional
subDescription="Paste your API key if your model provider requires authentication."
/>
<KeyValueInput
items={formikProps.values.custom_config_list}
onChange={(items) =>
formikProps.setFieldValue("custom_config_list", items)
}
addButtonLabel="Add Line"
/>
</Section>
</FieldWrapper>
<InputLayouts.FieldPadder>
<Section gap={0.75}>
<Content
title="Additional Configs"
description={markdown(
"Add extra properties as needed by the model provider. These are passed to LiteLLM's `completion()` call as [environment variables](https://docs.litellm.ai/docs/set_keys#environment-variables). See [documentation](https://docs.onyx.app/admins/ai_models/custom_inference_provider) for more instructions."
)}
widthVariant="full"
variant="section"
sizePreset="main-content"
/>
<FieldSeparator />
<CustomConfigKeyValue />
</Section>
</InputLayouts.FieldPadder>
<Section gap={0.5}>
<FieldWrapper>
<Content
title="Models"
description="List LLM models you wish to use and their configurations for this provider. See full list of models at LiteLLM."
variant="section"
sizePreset="main-content"
widthVariant="full"
/>
</FieldWrapper>
<Card>
<ModelConfigurationList formikProps={formikProps as any} />
</Card>
</Section>
{!isOnboarding && (
<>
<FieldSeparator />
<ModelsAccessField formikProps={formikProps} />
</>
)}
</LLMConfigurationModalWrapper>
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
)}
</Formik>
<InputLayouts.FieldSeparator />
<Section gap={0.5}>
<InputLayouts.FieldPadder>
<Content
title="Models"
description="List LLM models you wish to use and their configurations for this provider. See full list of models at LiteLLM."
variant="section"
sizePreset="main-content"
widthVariant="full"
/>
</InputLayouts.FieldPadder>
<Card sizeVariant="lg">
<ModelConfigurationList />
</Card>
</Section>
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<ModelAccessField />
</>
)}
</ModalWrapper>
);
}

View File

@@ -1,315 +0,0 @@
"use client";
import { useCallback, useEffect, useMemo, useState } from "react";
import { useSWRConfig } from "swr";
import { Formik, FormikProps } from "formik";
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
import * as InputLayouts from "@/layouts/input-layouts";
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
import {
LLMProviderFormProps,
LLMProviderName,
LLMProviderView,
ModelConfiguration,
} from "@/interfaces/llm";
import * as Yup from "yup";
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
import {
buildDefaultInitialValues,
buildDefaultValidationSchema,
buildAvailableModelConfigurations,
buildOnboardingInitialValues,
BaseLLMFormValues,
} from "@/sections/modals/llmConfig/utils";
import {
submitLLMProvider,
submitOnboardingProvider,
} from "@/sections/modals/llmConfig/svc";
import {
ModelsField,
DisplayNameField,
ModelsAccessField,
FieldSeparator,
FieldWrapper,
SingleDefaultModelField,
LLMConfigurationModalWrapper,
} from "@/sections/modals/llmConfig/shared";
import { fetchModels } from "@/app/admin/configuration/llm/utils";
import debounce from "lodash/debounce";
import { toast } from "@/hooks/useToast";
const DEFAULT_API_BASE = "http://localhost:1234";
interface LMStudioFormValues extends BaseLLMFormValues {
api_base: string;
custom_config: {
LM_STUDIO_API_KEY?: string;
};
}
interface LMStudioFormInternalsProps {
formikProps: FormikProps<LMStudioFormValues>;
existingLlmProvider: LLMProviderView | undefined;
fetchedModels: ModelConfiguration[];
setFetchedModels: (models: ModelConfiguration[]) => void;
isTesting: boolean;
onClose: () => void;
isOnboarding: boolean;
}
function LMStudioFormInternals({
formikProps,
existingLlmProvider,
fetchedModels,
setFetchedModels,
isTesting,
onClose,
isOnboarding,
}: LMStudioFormInternalsProps) {
const initialApiKey =
(existingLlmProvider?.custom_config?.LM_STUDIO_API_KEY as string) ?? "";
const doFetchModels = useCallback(
(apiBase: string, apiKey: string | undefined, signal: AbortSignal) => {
fetchModels(
LLMProviderName.LM_STUDIO,
{
api_base: apiBase,
custom_config: apiKey ? { LM_STUDIO_API_KEY: apiKey } : {},
api_key_changed: apiKey !== initialApiKey,
name: existingLlmProvider?.name,
},
signal
).then((data) => {
if (signal.aborted) return;
if (data.error) {
toast.error(data.error);
setFetchedModels([]);
return;
}
setFetchedModels(data.models);
});
},
[existingLlmProvider?.name, initialApiKey, setFetchedModels]
);
const debouncedFetchModels = useMemo(
() => debounce(doFetchModels, 500),
[doFetchModels]
);
const apiBase = formikProps.values.api_base;
const apiKey = formikProps.values.custom_config?.LM_STUDIO_API_KEY;
useEffect(() => {
if (apiBase) {
const controller = new AbortController();
debouncedFetchModels(apiBase, apiKey, controller.signal);
return () => {
debouncedFetchModels.cancel();
controller.abort();
};
} else {
setFetchedModels([]);
}
}, [apiBase, apiKey, debouncedFetchModels, setFetchedModels]);
const currentModels =
fetchedModels.length > 0
? fetchedModels
: existingLlmProvider?.model_configurations || [];
return (
<LLMConfigurationModalWrapper
providerEndpoint={LLMProviderName.LM_STUDIO}
existingProviderName={existingLlmProvider?.name}
onClose={onClose}
isFormValid={formikProps.isValid}
isDirty={formikProps.dirty}
isTesting={isTesting}
isSubmitting={formikProps.isSubmitting}
>
<FieldWrapper>
<InputLayouts.Vertical
name="api_base"
title="API Base URL"
subDescription="The base URL for your LM Studio server."
>
<InputTypeInField
name="api_base"
placeholder="Your LM Studio API base URL"
/>
</InputLayouts.Vertical>
</FieldWrapper>
<FieldWrapper>
<InputLayouts.Vertical
name="custom_config.LM_STUDIO_API_KEY"
title="API Key"
subDescription="Optional API key if your LM Studio server requires authentication."
optional
>
<PasswordInputTypeInField
name="custom_config.LM_STUDIO_API_KEY"
placeholder="API Key"
/>
</InputLayouts.Vertical>
</FieldWrapper>
{!isOnboarding && (
<>
<FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
)}
<FieldSeparator />
{isOnboarding ? (
<SingleDefaultModelField placeholder="E.g. llama3.1" />
) : (
<ModelsField
modelConfigurations={currentModels}
formikProps={formikProps}
recommendedDefaultModel={null}
shouldShowAutoUpdateToggle={false}
/>
)}
{!isOnboarding && (
<>
<FieldSeparator />
<ModelsAccessField formikProps={formikProps} />
</>
)}
</LLMConfigurationModalWrapper>
);
}
export default function LMStudioForm({
variant = "llm-configuration",
existingLlmProvider,
shouldMarkAsDefault,
open,
onOpenChange,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
}: LLMProviderFormProps) {
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
const [isTesting, setIsTesting] = useState(false);
const isOnboarding = variant === "onboarding";
const { mutate } = useSWRConfig();
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
LLMProviderName.LM_STUDIO
);
if (open === false) return null;
const onClose = () => onOpenChange?.(false);
const modelConfigurations = buildAvailableModelConfigurations(
existingLlmProvider,
wellKnownLLMProvider ?? llmDescriptor
);
const initialValues: LMStudioFormValues = isOnboarding
? ({
...buildOnboardingInitialValues(),
name: LLMProviderName.LM_STUDIO,
provider: LLMProviderName.LM_STUDIO,
api_base: DEFAULT_API_BASE,
default_model_name: "",
custom_config: {
LM_STUDIO_API_KEY: "",
},
} as LMStudioFormValues)
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
custom_config: {
LM_STUDIO_API_KEY:
(existingLlmProvider?.custom_config?.LM_STUDIO_API_KEY as string) ??
"",
},
};
const validationSchema = isOnboarding
? Yup.object().shape({
api_base: Yup.string().required("API Base URL is required"),
default_model_name: Yup.string().required("Model name is required"),
})
: buildDefaultValidationSchema().shape({
api_base: Yup.string().required("API Base URL is required"),
});
return (
<Formik
initialValues={initialValues}
validationSchema={validationSchema}
validateOnMount={true}
onSubmit={async (values, { setSubmitting }) => {
const filteredCustomConfig = Object.fromEntries(
Object.entries(values.custom_config || {}).filter(([, v]) => v !== "")
);
const submitValues = {
...values,
custom_config:
Object.keys(filteredCustomConfig).length > 0
? filteredCustomConfig
: undefined,
};
if (isOnboarding && onboardingState && onboardingActions) {
const modelConfigsToUse =
fetchedModels.length > 0 ? fetchedModels : [];
await submitOnboardingProvider({
providerName: LLMProviderName.LM_STUDIO,
payload: {
...submitValues,
model_configurations: modelConfigsToUse,
},
onboardingState,
onboardingActions,
isCustomProvider: false,
onClose,
setIsSubmitting: setSubmitting,
});
} else {
await submitLLMProvider({
providerName: LLMProviderName.LM_STUDIO,
values: submitValues,
initialValues,
modelConfigurations:
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
existingLlmProvider,
shouldMarkAsDefault,
setIsTesting,
mutate,
onClose,
setSubmitting,
});
}
}}
>
{(formikProps) => (
<LMStudioFormInternals
formikProps={formikProps}
existingLlmProvider={existingLlmProvider}
fetchedModels={fetchedModels}
setFetchedModels={setFetchedModels}
isTesting={isTesting}
onClose={onClose}
isOnboarding={isOnboarding}
/>
)}
</Formik>
);
}

View File

@@ -0,0 +1,217 @@
"use client";
import { useCallback, useEffect, useMemo } from "react";
import { useSWRConfig } from "swr";
import { useFormikContext } from "formik";
import * as InputLayouts from "@/layouts/input-layouts";
import {
LLMProviderFormProps,
LLMProviderName,
LLMProviderView,
} from "@/interfaces/llm";
import {
useInitialValues,
buildValidationSchema,
BaseLLMFormValues as BaseLLMModalValues,
} from "@/sections/modals/llmConfig/utils";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
import {
APIKeyField,
APIBaseField,
ModelSelectionField,
DisplayNameField,
ModelAccessField,
ModalWrapper,
} from "@/sections/modals/llmConfig/shared";
import { fetchModels } from "@/app/admin/configuration/llm/utils";
import debounce from "lodash/debounce";
import { toast } from "@/hooks/useToast";
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
const DEFAULT_API_BASE = "http://localhost:1234";
interface LMStudioModalValues extends BaseLLMModalValues {
api_base: string;
custom_config: {
LM_STUDIO_API_KEY?: string;
};
}
interface LMStudioModalInternalsProps {
existingLlmProvider: LLMProviderView | undefined;
isOnboarding: boolean;
}
function LMStudioModalInternals({
existingLlmProvider,
isOnboarding,
}: LMStudioModalInternalsProps) {
const formikProps = useFormikContext<LMStudioModalValues>();
const initialApiKey = existingLlmProvider?.custom_config?.LM_STUDIO_API_KEY;
const doFetchModels = useCallback(
(apiBase: string, apiKey: string | undefined, signal: AbortSignal) => {
fetchModels(
LLMProviderName.LM_STUDIO,
{
api_base: apiBase,
custom_config: apiKey ? { LM_STUDIO_API_KEY: apiKey } : {},
api_key_changed: apiKey !== initialApiKey,
name: existingLlmProvider?.name,
},
signal
).then((data) => {
if (signal.aborted) return;
if (data.error) {
toast.error(data.error);
formikProps.setFieldValue("model_configurations", []);
return;
}
formikProps.setFieldValue("model_configurations", data.models);
});
},
// eslint-disable-next-line react-hooks/exhaustive-deps
[existingLlmProvider?.name, initialApiKey]
);
const debouncedFetchModels = useMemo(
() => debounce(doFetchModels, 500),
[doFetchModels]
);
const apiBase = formikProps.values.api_base;
const apiKey = formikProps.values.custom_config?.LM_STUDIO_API_KEY;
useEffect(() => {
if (apiBase) {
const controller = new AbortController();
debouncedFetchModels(apiBase, apiKey, controller.signal);
return () => {
debouncedFetchModels.cancel();
controller.abort();
};
} else {
formikProps.setFieldValue("model_configurations", []);
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [apiBase, apiKey, debouncedFetchModels]);
return (
<>
<APIBaseField
subDescription="The base URL for your LM Studio server."
placeholder="Your LM Studio API base URL"
/>
<APIKeyField
name="custom_config.LM_STUDIO_API_KEY"
optional
subDescription="Optional API key if your LM Studio server requires authentication."
/>
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
)}
<InputLayouts.FieldSeparator />
<ModelSelectionField shouldShowAutoUpdateToggle={false} />
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<ModelAccessField />
</>
)}
</>
);
}
export default function LMStudioModal({
variant = "llm-configuration",
existingLlmProvider,
shouldMarkAsDefault,
open,
onOpenChange,
onSuccess,
}: LLMProviderFormProps) {
const isOnboarding = variant === "onboarding";
const { mutate } = useSWRConfig();
const initialValues: LMStudioModalValues = {
...useInitialValues(
isOnboarding,
LLMProviderName.LM_STUDIO,
existingLlmProvider
),
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
custom_config: {
LM_STUDIO_API_KEY: existingLlmProvider?.custom_config?.LM_STUDIO_API_KEY,
},
} as LMStudioModalValues;
const validationSchema = buildValidationSchema(isOnboarding, {
apiBase: true,
});
const onClose = () => onOpenChange?.(false);
if (open === false) return null;
return (
<ModalWrapper
providerName={LLMProviderName.LM_STUDIO}
llmProvider={existingLlmProvider}
onClose={onClose}
initialValues={initialValues}
validationSchema={validationSchema}
onSubmit={async (values, { setSubmitting, setStatus }) => {
const filteredCustomConfig = Object.fromEntries(
Object.entries(values.custom_config || {}).filter(([, v]) => v !== "")
);
const submitValues = {
...values,
custom_config:
Object.keys(filteredCustomConfig).length > 0
? filteredCustomConfig
: undefined,
};
await submitProvider({
analyticsSource: isOnboarding
? LLMProviderConfiguredSource.CHAT_ONBOARDING
: LLMProviderConfiguredSource.ADMIN_PAGE,
providerName: LLMProviderName.LM_STUDIO,
values: submitValues,
initialValues,
existingLlmProvider,
shouldMarkAsDefault,
setStatus,
setSubmitting,
onClose,
onSuccess: async () => {
if (onSuccess) {
await onSuccess();
} else {
await refreshLlmProviderCaches(mutate);
toast.success(
existingLlmProvider
? "Provider updated successfully!"
: "Provider enabled successfully!"
);
}
},
});
}}
>
<LMStudioModalInternals
existingLlmProvider={existingLlmProvider}
isOnboarding={isOnboarding}
/>
</ModalWrapper>
);
}

View File

@@ -1,41 +1,32 @@
"use client";
import { useState, useEffect } from "react";
import { useEffect } from "react";
import { useSWRConfig } from "swr";
import { Formik, FormikProps } from "formik";
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
import { useFormikContext } from "formik";
import * as InputLayouts from "@/layouts/input-layouts";
import {
LLMProviderFormProps,
LLMProviderName,
LLMProviderView,
ModelConfiguration,
} from "@/interfaces/llm";
import { fetchLiteLLMProxyModels } from "@/app/admin/configuration/llm/utils";
import * as Yup from "yup";
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
import {
buildDefaultInitialValues,
buildDefaultValidationSchema,
buildAvailableModelConfigurations,
buildOnboardingInitialValues,
useInitialValues,
buildValidationSchema,
BaseLLMFormValues,
} from "@/sections/modals/llmConfig/utils";
import {
submitLLMProvider,
submitOnboardingProvider,
} from "@/sections/modals/llmConfig/svc";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
import {
APIKeyField,
ModelsField,
APIBaseField,
ModelSelectionField,
DisplayNameField,
ModelsAccessField,
FieldSeparator,
FieldWrapper,
SingleDefaultModelField,
LLMConfigurationModalWrapper,
ModelAccessField,
ModalWrapper,
} from "@/sections/modals/llmConfig/shared";
import { toast } from "@/hooks/useToast";
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
const DEFAULT_API_BASE = "http://localhost:4000";
@@ -45,30 +36,15 @@ interface LiteLLMProxyModalValues extends BaseLLMFormValues {
}
interface LiteLLMProxyModalInternalsProps {
formikProps: FormikProps<LiteLLMProxyModalValues>;
existingLlmProvider: LLMProviderView | undefined;
fetchedModels: ModelConfiguration[];
setFetchedModels: (models: ModelConfiguration[]) => void;
modelConfigurations: ModelConfiguration[];
isTesting: boolean;
onClose: () => void;
isOnboarding: boolean;
}
function LiteLLMProxyModalInternals({
formikProps,
existingLlmProvider,
fetchedModels,
setFetchedModels,
modelConfigurations,
isTesting,
onClose,
isOnboarding,
}: LiteLLMProxyModalInternalsProps) {
const currentModels =
fetchedModels.length > 0
? fetchedModels
: existingLlmProvider?.model_configurations || modelConfigurations;
const formikProps = useFormikContext<LiteLLMProxyModalValues>();
const isFetchDisabled =
!formikProps.values.api_base || !formikProps.values.api_key;
@@ -77,12 +53,12 @@ function LiteLLMProxyModalInternals({
const { models, error } = await fetchLiteLLMProxyModels({
api_base: formikProps.values.api_base,
api_key: formikProps.values.api_key,
provider_name: existingLlmProvider?.name,
provider_name: LLMProviderName.LITELLM_PROXY,
});
if (error) {
throw new Error(error);
}
setFetchedModels(models);
formikProps.setFieldValue("model_configurations", models);
};
// Auto-fetch models on initial load when editing an existing provider
@@ -98,58 +74,34 @@ function LiteLLMProxyModalInternals({
}, []);
return (
<LLMConfigurationModalWrapper
providerEndpoint={LLMProviderName.LITELLM_PROXY}
existingProviderName={existingLlmProvider?.name}
onClose={onClose}
isFormValid={formikProps.isValid}
isDirty={formikProps.dirty}
isTesting={isTesting}
isSubmitting={formikProps.isSubmitting}
>
<FieldWrapper>
<InputLayouts.Vertical
name="api_base"
title="API Base URL"
subDescription="The base URL for your LiteLLM Proxy server."
>
<InputTypeInField
name="api_base"
placeholder="https://your-litellm-proxy.com"
/>
</InputLayouts.Vertical>
</FieldWrapper>
<>
<APIBaseField
subDescription="The base URL for your LiteLLM Proxy server."
placeholder="https://your-litellm-proxy.com"
/>
<APIKeyField providerName="LiteLLM Proxy" />
{!isOnboarding && (
<>
<FieldSeparator />
<InputLayouts.FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
)}
<FieldSeparator />
{isOnboarding ? (
<SingleDefaultModelField placeholder="E.g. gpt-4o" />
) : (
<ModelsField
modelConfigurations={currentModels}
formikProps={formikProps}
recommendedDefaultModel={null}
shouldShowAutoUpdateToggle={false}
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
/>
)}
<InputLayouts.FieldSeparator />
<ModelSelectionField
shouldShowAutoUpdateToggle={false}
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
/>
{!isOnboarding && (
<>
<FieldSeparator />
<ModelsAccessField formikProps={formikProps} />
<InputLayouts.FieldSeparator />
<ModelAccessField />
</>
)}
</LLMConfigurationModalWrapper>
</>
);
}
@@ -159,109 +111,68 @@ export default function LiteLLMProxyModal({
shouldMarkAsDefault,
open,
onOpenChange,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
onSuccess,
}: LLMProviderFormProps) {
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
const [isTesting, setIsTesting] = useState(false);
const isOnboarding = variant === "onboarding";
const { mutate } = useSWRConfig();
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
LLMProviderName.LITELLM_PROXY
);
if (open === false) return null;
const initialValues: LiteLLMProxyModalValues = {
...useInitialValues(
isOnboarding,
LLMProviderName.LITELLM_PROXY,
existingLlmProvider
),
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
} as LiteLLMProxyModalValues;
const validationSchema = buildValidationSchema(isOnboarding, {
apiKey: true,
apiBase: true,
});
const onClose = () => onOpenChange?.(false);
const modelConfigurations = buildAvailableModelConfigurations(
existingLlmProvider,
wellKnownLLMProvider ?? llmDescriptor
);
const initialValues: LiteLLMProxyModalValues = isOnboarding
? ({
...buildOnboardingInitialValues(),
name: LLMProviderName.LITELLM_PROXY,
provider: LLMProviderName.LITELLM_PROXY,
api_key: "",
api_base: DEFAULT_API_BASE,
default_model_name: "",
} as LiteLLMProxyModalValues)
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
api_key: existingLlmProvider?.api_key ?? "",
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
};
const validationSchema = isOnboarding
? Yup.object().shape({
api_key: Yup.string().required("API Key is required"),
api_base: Yup.string().required("API Base URL is required"),
default_model_name: Yup.string().required("Model name is required"),
})
: buildDefaultValidationSchema().shape({
api_key: Yup.string().required("API Key is required"),
api_base: Yup.string().required("API Base URL is required"),
});
if (open === false) return null;
return (
<Formik
<ModalWrapper
providerName={LLMProviderName.LITELLM_PROXY}
llmProvider={existingLlmProvider}
onClose={onClose}
initialValues={initialValues}
validationSchema={validationSchema}
validateOnMount={true}
onSubmit={async (values, { setSubmitting }) => {
if (isOnboarding && onboardingState && onboardingActions) {
const modelConfigsToUse =
fetchedModels.length > 0 ? fetchedModels : [];
await submitOnboardingProvider({
providerName: LLMProviderName.LITELLM_PROXY,
payload: {
...values,
model_configurations: modelConfigsToUse,
},
onboardingState,
onboardingActions,
isCustomProvider: false,
onClose,
setIsSubmitting: setSubmitting,
});
} else {
await submitLLMProvider({
providerName: LLMProviderName.LITELLM_PROXY,
values,
initialValues,
modelConfigurations:
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
existingLlmProvider,
shouldMarkAsDefault,
setIsTesting,
mutate,
onClose,
setSubmitting,
});
}
onSubmit={async (values, { setSubmitting, setStatus }) => {
await submitProvider({
analyticsSource: isOnboarding
? LLMProviderConfiguredSource.CHAT_ONBOARDING
: LLMProviderConfiguredSource.ADMIN_PAGE,
providerName: LLMProviderName.LITELLM_PROXY,
values,
initialValues,
existingLlmProvider,
shouldMarkAsDefault,
setStatus,
setSubmitting,
onClose,
onSuccess: async () => {
if (onSuccess) {
await onSuccess();
} else {
await refreshLlmProviderCaches(mutate);
toast.success(
existingLlmProvider
? "Provider updated successfully!"
: "Provider enabled successfully!"
);
}
},
});
}}
>
{(formikProps) => (
<LiteLLMProxyModalInternals
formikProps={formikProps}
existingLlmProvider={existingLlmProvider}
fetchedModels={fetchedModels}
setFetchedModels={setFetchedModels}
modelConfigurations={modelConfigurations}
isTesting={isTesting}
onClose={onClose}
isOnboarding={isOnboarding}
/>
)}
</Formik>
<LiteLLMProxyModalInternals
existingLlmProvider={existingLlmProvider}
isOnboarding={isOnboarding}
/>
</ModalWrapper>
);
}

View File

@@ -1,47 +1,44 @@
"use client";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import * as Yup from "yup";
import { Dispatch, SetStateAction, useMemo, useState } from "react";
import { useSWRConfig } from "swr";
import { Formik, FormikProps } from "formik";
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
import { useFormikContext } from "formik";
import * as InputLayouts from "@/layouts/input-layouts";
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
import {
LLMProviderFormProps,
LLMProviderName,
LLMProviderView,
ModelConfiguration,
} from "@/interfaces/llm";
import * as Yup from "yup";
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
import {
buildDefaultInitialValues,
buildDefaultValidationSchema,
buildAvailableModelConfigurations,
buildOnboardingInitialValues,
useInitialValues,
buildValidationSchema,
BaseLLMFormValues,
} from "@/sections/modals/llmConfig/utils";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
import {
submitLLMProvider,
submitOnboardingProvider,
} from "@/sections/modals/llmConfig/svc";
import {
ModelsField,
ModelSelectionField,
DisplayNameField,
ModelsAccessField,
FieldSeparator,
SingleDefaultModelField,
LLMConfigurationModalWrapper,
ModelAccessField,
ModalWrapper,
} from "@/sections/modals/llmConfig/shared";
import { fetchOllamaModels } from "@/app/admin/configuration/llm/utils";
import debounce from "lodash/debounce";
import Tabs from "@/refresh-components/Tabs";
import { Card } from "@opal/components";
import { toast } from "@/hooks/useToast";
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
import useOnMount from "@/hooks/useOnMount";
const OLLAMA_PROVIDER_NAME = "ollama_chat";
const DEFAULT_API_BASE = "http://127.0.0.1:11434";
const TAB_SELF_HOSTED = "self-hosted";
const TAB_CLOUD = "cloud";
const CLOUD_API_BASE = "https://ollama.com";
enum Tab {
TAB_SELF_HOSTED = "self-hosted",
TAB_CLOUD = "cloud",
}
interface OllamaModalValues extends BaseLLMFormValues {
api_base: string;
@@ -51,104 +48,65 @@ interface OllamaModalValues extends BaseLLMFormValues {
}
interface OllamaModalInternalsProps {
formikProps: FormikProps<OllamaModalValues>;
existingLlmProvider: LLMProviderView | undefined;
fetchedModels: ModelConfiguration[];
setFetchedModels: (models: ModelConfiguration[]) => void;
isTesting: boolean;
onClose: () => void;
isOnboarding: boolean;
tab: Tab;
setTab: Dispatch<SetStateAction<Tab>>;
}
function OllamaModalInternals({
formikProps,
existingLlmProvider,
fetchedModels,
setFetchedModels,
isTesting,
onClose,
isOnboarding,
tab,
setTab,
}: OllamaModalInternalsProps) {
const isInitialMount = useRef(true);
const formikProps = useFormikContext<OllamaModalValues>();
const doFetchModels = useCallback(
(apiBase: string, signal: AbortSignal) => {
fetchOllamaModels({
api_base: apiBase,
provider_name: existingLlmProvider?.name,
signal,
}).then((data) => {
if (signal.aborted) return;
if (data.error) {
toast.error(data.error);
setFetchedModels([]);
return;
}
setFetchedModels(data.models);
const isFetchDisabled = useMemo(
() =>
tab === Tab.TAB_SELF_HOSTED
? !formikProps.values.api_base
: !formikProps.values.custom_config.OLLAMA_API_KEY,
[tab, formikProps]
);
const handleFetchModels = async () => {
// Only Ollama cloud accepts API key
const apiBase = formikProps.values.custom_config?.OLLAMA_API_KEY
? CLOUD_API_BASE
: formikProps.values.api_base;
const { models, error } = await fetchOllamaModels({
api_base: apiBase,
provider_name: existingLlmProvider?.name,
});
if (error) {
throw new Error(error);
}
formikProps.setFieldValue("model_configurations", models);
};
// Auto-fetch models on initial load when editing an existing provider
useOnMount(() => {
if (existingLlmProvider) {
handleFetchModels().catch((err) => {
toast.error(
err instanceof Error ? err.message : "Failed to fetch models"
);
});
},
[existingLlmProvider?.name, setFetchedModels]
);
const debouncedFetchModels = useMemo(
() => debounce(doFetchModels, 500),
[doFetchModels]
);
// Skip the initial fetch for new providers — api_base starts with a default
// value, which would otherwise trigger a fetch before the user has done
// anything. Existing providers should still auto-fetch on mount.
useEffect(() => {
if (isInitialMount.current) {
isInitialMount.current = false;
if (!existingLlmProvider) return;
}
if (formikProps.values.api_base) {
const controller = new AbortController();
debouncedFetchModels(formikProps.values.api_base, controller.signal);
return () => {
debouncedFetchModels.cancel();
controller.abort();
};
} else {
setFetchedModels([]);
}
}, [
formikProps.values.api_base,
debouncedFetchModels,
setFetchedModels,
existingLlmProvider,
]);
const currentModels =
fetchedModels.length > 0
? fetchedModels
: existingLlmProvider?.model_configurations || [];
const hasApiKey = !!formikProps.values.custom_config?.OLLAMA_API_KEY;
const defaultTab =
existingLlmProvider && hasApiKey ? TAB_CLOUD : TAB_SELF_HOSTED;
});
return (
<LLMConfigurationModalWrapper
providerEndpoint={OLLAMA_PROVIDER_NAME}
existingProviderName={existingLlmProvider?.name}
onClose={onClose}
isFormValid={formikProps.isValid}
isDirty={formikProps.dirty}
isTesting={isTesting}
isSubmitting={formikProps.isSubmitting}
>
<>
<Card backgroundVariant="light" borderVariant="none" sizeVariant="lg">
<Tabs defaultValue={defaultTab}>
<Tabs value={tab} onValueChange={(value) => setTab(value as Tab)}>
<Tabs.List>
<Tabs.Trigger value={TAB_SELF_HOSTED}>
<Tabs.Trigger value={Tab.TAB_SELF_HOSTED}>
Self-hosted Ollama
</Tabs.Trigger>
<Tabs.Trigger value={TAB_CLOUD}>Ollama Cloud</Tabs.Trigger>
<Tabs.Trigger value={Tab.TAB_CLOUD}>Ollama Cloud</Tabs.Trigger>
</Tabs.List>
<Tabs.Content value={TAB_SELF_HOSTED}>
<Tabs.Content value={Tab.TAB_SELF_HOSTED} padding={0}>
<InputLayouts.Vertical
name="api_base"
title="API Base URL"
@@ -161,7 +119,7 @@ function OllamaModalInternals({
</InputLayouts.Vertical>
</Tabs.Content>
<Tabs.Content value={TAB_CLOUD}>
<Tabs.Content value={Tab.TAB_CLOUD}>
<InputLayouts.Vertical
name="custom_config.OLLAMA_API_KEY"
title="API Key"
@@ -178,31 +136,24 @@ function OllamaModalInternals({
{!isOnboarding && (
<>
<FieldSeparator />
<InputLayouts.FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
)}
<FieldSeparator />
{isOnboarding ? (
<SingleDefaultModelField placeholder="E.g. llama3.1" />
) : (
<ModelsField
modelConfigurations={currentModels}
formikProps={formikProps}
recommendedDefaultModel={null}
shouldShowAutoUpdateToggle={false}
/>
)}
<InputLayouts.FieldSeparator />
<ModelSelectionField
shouldShowAutoUpdateToggle={false}
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
/>
{!isOnboarding && (
<>
<FieldSeparator />
<ModelsAccessField formikProps={formikProps} />
<InputLayouts.FieldSeparator />
<ModelAccessField />
</>
)}
</LLMConfigurationModalWrapper>
</>
);
}
@@ -212,67 +163,55 @@ export default function OllamaModal({
shouldMarkAsDefault,
open,
onOpenChange,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
onSuccess,
}: LLMProviderFormProps) {
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
const [isTesting, setIsTesting] = useState(false);
const isOnboarding = variant === "onboarding";
const { mutate } = useSWRConfig();
const { wellKnownLLMProvider } =
useWellKnownLLMProvider(OLLAMA_PROVIDER_NAME);
const apiKey = existingLlmProvider?.custom_config?.OLLAMA_API_KEY;
const defaultTab =
existingLlmProvider && !!apiKey ? Tab.TAB_CLOUD : Tab.TAB_SELF_HOSTED;
const [tab, setTab] = useState<Tab>(defaultTab);
if (open === false) return null;
const initialValues: OllamaModalValues = {
...useInitialValues(
isOnboarding,
LLMProviderName.OLLAMA_CHAT,
existingLlmProvider
),
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
custom_config: {
OLLAMA_API_KEY: apiKey,
},
} as OllamaModalValues;
const validationSchema = useMemo(
() =>
buildValidationSchema(isOnboarding, {
apiBase: tab === Tab.TAB_SELF_HOSTED,
extra:
tab === Tab.TAB_CLOUD
? {
custom_config: Yup.object({
OLLAMA_API_KEY: Yup.string().required("API Key is required"),
}),
}
: undefined,
}),
[tab, isOnboarding]
);
const onClose = () => onOpenChange?.(false);
const modelConfigurations = buildAvailableModelConfigurations(
existingLlmProvider,
wellKnownLLMProvider ?? llmDescriptor
);
const initialValues: OllamaModalValues = isOnboarding
? ({
...buildOnboardingInitialValues(),
name: OLLAMA_PROVIDER_NAME,
provider: OLLAMA_PROVIDER_NAME,
api_base: DEFAULT_API_BASE,
default_model_name: "",
custom_config: {
OLLAMA_API_KEY: "",
},
} as OllamaModalValues)
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
custom_config: {
OLLAMA_API_KEY:
(existingLlmProvider?.custom_config?.OLLAMA_API_KEY as string) ??
"",
},
};
const validationSchema = isOnboarding
? Yup.object().shape({
api_base: Yup.string().required("API Base URL is required"),
default_model_name: Yup.string().required("Model name is required"),
})
: buildDefaultValidationSchema().shape({
api_base: Yup.string().required("API Base URL is required"),
});
if (open === false) return null;
return (
<Formik
<ModalWrapper
providerName={LLMProviderName.OLLAMA_CHAT}
llmProvider={existingLlmProvider}
onClose={onClose}
initialValues={initialValues}
validationSchema={validationSchema}
validateOnMount={true}
onSubmit={async (values, { setSubmitting }) => {
onSubmit={async (values, { setSubmitting, setStatus }) => {
const filteredCustomConfig = Object.fromEntries(
Object.entries(values.custom_config || {}).filter(([, v]) => v !== "")
);
@@ -285,50 +224,39 @@ export default function OllamaModal({
: undefined,
};
if (isOnboarding && onboardingState && onboardingActions) {
const modelConfigsToUse =
fetchedModels.length > 0 ? fetchedModels : [];
await submitOnboardingProvider({
providerName: OLLAMA_PROVIDER_NAME,
payload: {
...submitValues,
model_configurations: modelConfigsToUse,
},
onboardingState,
onboardingActions,
isCustomProvider: false,
onClose,
setIsSubmitting: setSubmitting,
});
} else {
await submitLLMProvider({
providerName: OLLAMA_PROVIDER_NAME,
values: submitValues,
initialValues,
modelConfigurations:
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
existingLlmProvider,
shouldMarkAsDefault,
setIsTesting,
mutate,
onClose,
setSubmitting,
});
}
await submitProvider({
analyticsSource: isOnboarding
? LLMProviderConfiguredSource.CHAT_ONBOARDING
: LLMProviderConfiguredSource.ADMIN_PAGE,
providerName: LLMProviderName.OLLAMA_CHAT,
values: submitValues,
initialValues,
existingLlmProvider,
shouldMarkAsDefault,
setStatus,
setSubmitting,
onClose,
onSuccess: async () => {
if (onSuccess) {
await onSuccess();
} else {
await refreshLlmProviderCaches(mutate);
toast.success(
existingLlmProvider
? "Provider updated successfully!"
: "Provider enabled successfully!"
);
}
},
});
}}
>
{(formikProps) => (
<OllamaModalInternals
formikProps={formikProps}
existingLlmProvider={existingLlmProvider}
fetchedModels={fetchedModels}
setFetchedModels={setFetchedModels}
isTesting={isTesting}
onClose={onClose}
isOnboarding={isOnboarding}
/>
)}
</Formik>
<OllamaModalInternals
existingLlmProvider={existingLlmProvider}
isOnboarding={isOnboarding}
tab={tab}
setTab={setTab}
/>
</ModalWrapper>
);
}

View File

@@ -0,0 +1,177 @@
"use client";
import { useEffect } from "react";
import { markdown } from "@opal/utils";
import { useSWRConfig } from "swr";
import { useFormikContext } from "formik";
import * as InputLayouts from "@/layouts/input-layouts";
import {
LLMProviderFormProps,
LLMProviderName,
LLMProviderView,
} from "@/interfaces/llm";
import { fetchOpenAICompatibleModels } from "@/app/admin/configuration/llm/utils";
import {
useInitialValues,
buildValidationSchema,
BaseLLMFormValues,
} from "@/sections/modals/llmConfig/utils";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
import {
APIBaseField,
APIKeyField,
ModelSelectionField,
DisplayNameField,
ModelAccessField,
ModalWrapper,
} from "@/sections/modals/llmConfig/shared";
import { toast } from "@/hooks/useToast";
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
interface OpenAICompatibleModalValues extends BaseLLMFormValues {
api_key: string;
api_base: string;
}
interface OpenAICompatibleModalInternalsProps {
existingLlmProvider: LLMProviderView | undefined;
isOnboarding: boolean;
}
function OpenAICompatibleModalInternals({
existingLlmProvider,
isOnboarding,
}: OpenAICompatibleModalInternalsProps) {
const formikProps = useFormikContext<OpenAICompatibleModalValues>();
const isFetchDisabled = !formikProps.values.api_base;
const handleFetchModels = async () => {
const { models, error } = await fetchOpenAICompatibleModels({
api_base: formikProps.values.api_base,
api_key: formikProps.values.api_key || undefined,
provider_name: existingLlmProvider?.name,
});
if (error) {
throw new Error(error);
}
formikProps.setFieldValue("model_configurations", models);
};
// Auto-fetch models on initial load when editing an existing provider
useEffect(() => {
if (existingLlmProvider && !isFetchDisabled) {
handleFetchModels().catch((err) => {
toast.error(
err instanceof Error ? err.message : "Failed to fetch models"
);
});
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
return (
<>
<APIBaseField
subDescription="The base URL of your OpenAI-compatible server."
placeholder="http://localhost:8000/v1"
/>
<APIKeyField
optional
subDescription={markdown(
"Provide an API key if your server requires authentication."
)}
/>
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
)}
<InputLayouts.FieldSeparator />
<ModelSelectionField
shouldShowAutoUpdateToggle={false}
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
/>
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<ModelAccessField />
</>
)}
</>
);
}
export default function OpenAICompatibleModal({
variant = "llm-configuration",
existingLlmProvider,
shouldMarkAsDefault,
open,
onOpenChange,
onSuccess,
}: LLMProviderFormProps) {
const isOnboarding = variant === "onboarding";
const { mutate } = useSWRConfig();
const initialValues = useInitialValues(
isOnboarding,
LLMProviderName.OPENAI_COMPATIBLE,
existingLlmProvider
) as OpenAICompatibleModalValues;
const validationSchema = buildValidationSchema(isOnboarding, {
apiBase: true,
});
const onClose = () => onOpenChange?.(false);
if (open === false) return null;
return (
<ModalWrapper
providerName={LLMProviderName.OPENAI_COMPATIBLE}
llmProvider={existingLlmProvider}
onClose={onClose}
initialValues={initialValues}
validationSchema={validationSchema}
onSubmit={async (values, { setSubmitting, setStatus }) => {
await submitProvider({
analyticsSource: isOnboarding
? LLMProviderConfiguredSource.CHAT_ONBOARDING
: LLMProviderConfiguredSource.ADMIN_PAGE,
providerName: LLMProviderName.OPENAI_COMPATIBLE,
values,
initialValues,
existingLlmProvider,
shouldMarkAsDefault,
setStatus,
setSubmitting,
onClose,
onSuccess: async () => {
if (onSuccess) {
await onSuccess();
} else {
await refreshLlmProviderCaches(mutate);
toast.success(
existingLlmProvider
? "Provider updated successfully!"
: "Provider enabled successfully!"
);
}
},
});
}}
>
<OpenAICompatibleModalInternals
existingLlmProvider={existingLlmProvider}
isOnboarding={isOnboarding}
/>
</ModalWrapper>
);
}

View File

@@ -1,33 +1,23 @@
"use client";
import { useState } from "react";
import { useSWRConfig } from "swr";
import { Formik } from "formik";
import { LLMProviderFormProps } from "@/interfaces/llm";
import * as Yup from "yup";
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
import { LLMProviderFormProps, LLMProviderName } from "@/interfaces/llm";
import {
buildDefaultInitialValues,
buildDefaultValidationSchema,
buildAvailableModelConfigurations,
buildOnboardingInitialValues,
useInitialValues,
buildValidationSchema,
} from "@/sections/modals/llmConfig/utils";
import {
submitLLMProvider,
submitOnboardingProvider,
} from "@/sections/modals/llmConfig/svc";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
import {
APIKeyField,
ModelsField,
ModelSelectionField,
DisplayNameField,
FieldSeparator,
ModelsAccessField,
SingleDefaultModelField,
LLMConfigurationModalWrapper,
ModelAccessField,
ModalWrapper,
} from "@/sections/modals/llmConfig/shared";
const OPENAI_PROVIDER_NAME = "openai";
const DEFAULT_DEFAULT_MODEL_NAME = "gpt-5.2";
import * as InputLayouts from "@/layouts/input-layouts";
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
import { toast } from "@/hooks/useToast";
export default function OpenAIModal({
variant = "llm-configuration",
@@ -35,141 +25,78 @@ export default function OpenAIModal({
shouldMarkAsDefault,
open,
onOpenChange,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
onSuccess,
}: LLMProviderFormProps) {
const isOnboarding = variant === "onboarding";
const [isTesting, setIsTesting] = useState(false);
const { mutate } = useSWRConfig();
const { wellKnownLLMProvider } =
useWellKnownLLMProvider(OPENAI_PROVIDER_NAME);
if (open === false) return null;
const initialValues = useInitialValues(
isOnboarding,
LLMProviderName.OPENAI,
existingLlmProvider
);
const validationSchema = buildValidationSchema(isOnboarding, {
apiKey: true,
});
const onClose = () => onOpenChange?.(false);
const modelConfigurations = buildAvailableModelConfigurations(
existingLlmProvider,
wellKnownLLMProvider ?? llmDescriptor
);
const initialValues = isOnboarding
? {
...buildOnboardingInitialValues(),
name: OPENAI_PROVIDER_NAME,
provider: OPENAI_PROVIDER_NAME,
api_key: "",
default_model_name: DEFAULT_DEFAULT_MODEL_NAME,
}
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
api_key: existingLlmProvider?.api_key ?? "",
default_model_name:
(defaultModelName &&
modelConfigurations.some((m) => m.name === defaultModelName)
? defaultModelName
: undefined) ??
wellKnownLLMProvider?.recommended_default_model?.name ??
DEFAULT_DEFAULT_MODEL_NAME,
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,
};
const validationSchema = isOnboarding
? Yup.object().shape({
api_key: Yup.string().required("API Key is required"),
default_model_name: Yup.string().required("Model name is required"),
})
: buildDefaultValidationSchema().shape({
api_key: Yup.string().required("API Key is required"),
});
if (open === false) return null;
return (
<Formik
<ModalWrapper
providerName={LLMProviderName.OPENAI}
llmProvider={existingLlmProvider}
onClose={onClose}
initialValues={initialValues}
validationSchema={validationSchema}
validateOnMount={true}
onSubmit={async (values, { setSubmitting }) => {
if (isOnboarding && onboardingState && onboardingActions) {
const modelConfigsToUse =
(wellKnownLLMProvider ?? llmDescriptor)?.known_models ?? [];
await submitOnboardingProvider({
providerName: OPENAI_PROVIDER_NAME,
payload: {
...values,
model_configurations: modelConfigsToUse,
is_auto_mode:
values.default_model_name === DEFAULT_DEFAULT_MODEL_NAME,
},
onboardingState,
onboardingActions,
isCustomProvider: false,
onClose,
setIsSubmitting: setSubmitting,
});
} else {
await submitLLMProvider({
providerName: OPENAI_PROVIDER_NAME,
values,
initialValues,
modelConfigurations,
existingLlmProvider,
shouldMarkAsDefault,
setIsTesting,
mutate,
onClose,
setSubmitting,
});
}
onSubmit={async (values, { setSubmitting, setStatus }) => {
await submitProvider({
analyticsSource: isOnboarding
? LLMProviderConfiguredSource.CHAT_ONBOARDING
: LLMProviderConfiguredSource.ADMIN_PAGE,
providerName: LLMProviderName.OPENAI,
values,
initialValues,
existingLlmProvider,
shouldMarkAsDefault,
setStatus,
setSubmitting,
onClose,
onSuccess: async () => {
if (onSuccess) {
await onSuccess();
} else {
await refreshLlmProviderCaches(mutate);
toast.success(
existingLlmProvider
? "Provider updated successfully!"
: "Provider enabled successfully!"
);
}
},
});
}}
>
{(formikProps) => (
<LLMConfigurationModalWrapper
providerEndpoint={OPENAI_PROVIDER_NAME}
existingProviderName={existingLlmProvider?.name}
onClose={onClose}
isFormValid={formikProps.isValid}
isDirty={formikProps.dirty}
isTesting={isTesting}
isSubmitting={formikProps.isSubmitting}
>
<APIKeyField providerName="OpenAI" />
<APIKeyField providerName="OpenAI" />
{!isOnboarding && (
<>
<FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
)}
<FieldSeparator />
{isOnboarding ? (
<SingleDefaultModelField placeholder="E.g. gpt-5.2" />
) : (
<ModelsField
modelConfigurations={modelConfigurations}
formikProps={formikProps}
recommendedDefaultModel={
wellKnownLLMProvider?.recommended_default_model ?? null
}
shouldShowAutoUpdateToggle={true}
/>
)}
{!isOnboarding && (
<>
<FieldSeparator />
<ModelsAccessField formikProps={formikProps} />
</>
)}
</LLMConfigurationModalWrapper>
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
)}
</Formik>
<InputLayouts.FieldSeparator />
<ModelSelectionField shouldShowAutoUpdateToggle={true} />
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<ModelAccessField />
</>
)}
</ModalWrapper>
);
}

View File

@@ -1,73 +1,50 @@
"use client";
import { useState, useEffect } from "react";
import { useEffect } from "react";
import { useSWRConfig } from "swr";
import { Formik, FormikProps } from "formik";
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
import { useFormikContext } from "formik";
import * as InputLayouts from "@/layouts/input-layouts";
import {
LLMProviderFormProps,
LLMProviderName,
LLMProviderView,
ModelConfiguration,
} from "@/interfaces/llm";
import { fetchOpenRouterModels } from "@/app/admin/configuration/llm/utils";
import * as Yup from "yup";
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
import {
buildDefaultInitialValues,
buildDefaultValidationSchema,
buildAvailableModelConfigurations,
buildOnboardingInitialValues,
useInitialValues,
buildValidationSchema,
BaseLLMFormValues,
} from "@/sections/modals/llmConfig/utils";
import {
submitLLMProvider,
submitOnboardingProvider,
} from "@/sections/modals/llmConfig/svc";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
import {
APIKeyField,
ModelsField,
APIBaseField,
ModelSelectionField,
DisplayNameField,
ModelsAccessField,
FieldSeparator,
FieldWrapper,
SingleDefaultModelField,
LLMConfigurationModalWrapper,
ModelAccessField,
ModalWrapper,
} from "@/sections/modals/llmConfig/shared";
import { toast } from "@/hooks/useToast";
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
const OPENROUTER_PROVIDER_NAME = "openrouter";
const DEFAULT_API_BASE = "https://openrouter.ai/api/v1";
interface OpenRouterModalValues extends BaseLLMFormValues {
api_key: string;
api_base: string;
}
interface OpenRouterModalInternalsProps {
formikProps: FormikProps<OpenRouterModalValues>;
existingLlmProvider: LLMProviderView | undefined;
fetchedModels: ModelConfiguration[];
setFetchedModels: (models: ModelConfiguration[]) => void;
modelConfigurations: ModelConfiguration[];
isTesting: boolean;
onClose: () => void;
isOnboarding: boolean;
}
function OpenRouterModalInternals({
formikProps,
existingLlmProvider,
fetchedModels,
setFetchedModels,
modelConfigurations,
isTesting,
onClose,
isOnboarding,
}: OpenRouterModalInternalsProps) {
const currentModels =
fetchedModels.length > 0
? fetchedModels
: existingLlmProvider?.model_configurations || modelConfigurations;
const formikProps = useFormikContext<OpenRouterModalValues>();
const isFetchDisabled =
!formikProps.values.api_base || !formikProps.values.api_key;
@@ -76,12 +53,12 @@ function OpenRouterModalInternals({
const { models, error } = await fetchOpenRouterModels({
api_base: formikProps.values.api_base,
api_key: formikProps.values.api_key,
provider_name: existingLlmProvider?.name,
provider_name: LLMProviderName.OPENROUTER,
});
if (error) {
throw new Error(error);
}
setFetchedModels(models);
formikProps.setFieldValue("model_configurations", models);
};
// Auto-fetch models on initial load when editing an existing provider
@@ -97,58 +74,34 @@ function OpenRouterModalInternals({
}, []);
return (
<LLMConfigurationModalWrapper
providerEndpoint={OPENROUTER_PROVIDER_NAME}
existingProviderName={existingLlmProvider?.name}
onClose={onClose}
isFormValid={formikProps.isValid}
isDirty={formikProps.dirty}
isTesting={isTesting}
isSubmitting={formikProps.isSubmitting}
>
<FieldWrapper>
<InputLayouts.Vertical
name="api_base"
title="API Base URL"
subDescription="Paste your OpenRouter-compatible endpoint URL or use OpenRouter API directly."
>
<InputTypeInField
name="api_base"
placeholder="Your OpenRouter base URL"
/>
</InputLayouts.Vertical>
</FieldWrapper>
<>
<APIBaseField
subDescription="Paste your OpenRouter-compatible endpoint URL or use OpenRouter API directly."
placeholder="Your OpenRouter base URL"
/>
<APIKeyField providerName="OpenRouter" />
{!isOnboarding && (
<>
<FieldSeparator />
<InputLayouts.FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
)}
<FieldSeparator />
{isOnboarding ? (
<SingleDefaultModelField placeholder="E.g. openai/gpt-4o" />
) : (
<ModelsField
modelConfigurations={currentModels}
formikProps={formikProps}
recommendedDefaultModel={null}
shouldShowAutoUpdateToggle={false}
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
/>
)}
<InputLayouts.FieldSeparator />
<ModelSelectionField
shouldShowAutoUpdateToggle={false}
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
/>
{!isOnboarding && (
<>
<FieldSeparator />
<ModelsAccessField formikProps={formikProps} />
<InputLayouts.FieldSeparator />
<ModelAccessField />
</>
)}
</LLMConfigurationModalWrapper>
</>
);
}
@@ -158,109 +111,68 @@ export default function OpenRouterModal({
shouldMarkAsDefault,
open,
onOpenChange,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
onSuccess,
}: LLMProviderFormProps) {
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
const [isTesting, setIsTesting] = useState(false);
const isOnboarding = variant === "onboarding";
const { mutate } = useSWRConfig();
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
OPENROUTER_PROVIDER_NAME
);
if (open === false) return null;
const initialValues: OpenRouterModalValues = {
...useInitialValues(
isOnboarding,
LLMProviderName.OPENROUTER,
existingLlmProvider
),
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
} as OpenRouterModalValues;
const validationSchema = buildValidationSchema(isOnboarding, {
apiKey: true,
apiBase: true,
});
const onClose = () => onOpenChange?.(false);
const modelConfigurations = buildAvailableModelConfigurations(
existingLlmProvider,
wellKnownLLMProvider ?? llmDescriptor
);
const initialValues: OpenRouterModalValues = isOnboarding
? ({
...buildOnboardingInitialValues(),
name: OPENROUTER_PROVIDER_NAME,
provider: OPENROUTER_PROVIDER_NAME,
api_key: "",
api_base: DEFAULT_API_BASE,
default_model_name: "",
} as OpenRouterModalValues)
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
api_key: existingLlmProvider?.api_key ?? "",
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
};
const validationSchema = isOnboarding
? Yup.object().shape({
api_key: Yup.string().required("API Key is required"),
api_base: Yup.string().required("API Base URL is required"),
default_model_name: Yup.string().required("Model name is required"),
})
: buildDefaultValidationSchema().shape({
api_key: Yup.string().required("API Key is required"),
api_base: Yup.string().required("API Base URL is required"),
});
if (open === false) return null;
return (
<Formik
<ModalWrapper
providerName={LLMProviderName.OPENROUTER}
llmProvider={existingLlmProvider}
onClose={onClose}
initialValues={initialValues}
validationSchema={validationSchema}
validateOnMount={true}
onSubmit={async (values, { setSubmitting }) => {
if (isOnboarding && onboardingState && onboardingActions) {
const modelConfigsToUse =
fetchedModels.length > 0 ? fetchedModels : [];
await submitOnboardingProvider({
providerName: OPENROUTER_PROVIDER_NAME,
payload: {
...values,
model_configurations: modelConfigsToUse,
},
onboardingState,
onboardingActions,
isCustomProvider: false,
onClose,
setIsSubmitting: setSubmitting,
});
} else {
await submitLLMProvider({
providerName: OPENROUTER_PROVIDER_NAME,
values,
initialValues,
modelConfigurations:
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
existingLlmProvider,
shouldMarkAsDefault,
setIsTesting,
mutate,
onClose,
setSubmitting,
});
}
onSubmit={async (values, { setSubmitting, setStatus }) => {
await submitProvider({
analyticsSource: isOnboarding
? LLMProviderConfiguredSource.CHAT_ONBOARDING
: LLMProviderConfiguredSource.ADMIN_PAGE,
providerName: LLMProviderName.OPENROUTER,
values,
initialValues,
existingLlmProvider,
shouldMarkAsDefault,
setStatus,
setSubmitting,
onClose,
onSuccess: async () => {
if (onSuccess) {
await onSuccess();
} else {
await refreshLlmProviderCaches(mutate);
toast.success(
existingLlmProvider
? "Provider updated successfully!"
: "Provider enabled successfully!"
);
}
},
});
}}
>
{(formikProps) => (
<OpenRouterModalInternals
formikProps={formikProps}
existingLlmProvider={existingLlmProvider}
fetchedModels={fetchedModels}
setFetchedModels={setFetchedModels}
modelConfigurations={modelConfigurations}
isTesting={isTesting}
onClose={onClose}
isOnboarding={isOnboarding}
/>
)}
</Formik>
<OpenRouterModalInternals
existingLlmProvider={existingLlmProvider}
isOnboarding={isOnboarding}
/>
</ModalWrapper>
);
}

View File

@@ -1,38 +1,27 @@
"use client";
import { useState } from "react";
import { useSWRConfig } from "swr";
import { Formik } from "formik";
import { FileUploadFormField } from "@/components/Field";
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
import * as InputLayouts from "@/layouts/input-layouts";
import { LLMProviderFormProps } from "@/interfaces/llm";
import { LLMProviderFormProps, LLMProviderName } from "@/interfaces/llm";
import * as Yup from "yup";
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
import {
buildDefaultInitialValues,
buildDefaultValidationSchema,
buildAvailableModelConfigurations,
buildOnboardingInitialValues,
useInitialValues,
buildValidationSchema,
BaseLLMFormValues,
} from "@/sections/modals/llmConfig/utils";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
import {
submitLLMProvider,
submitOnboardingProvider,
} from "@/sections/modals/llmConfig/svc";
import {
ModelsField,
ModelSelectionField,
DisplayNameField,
FieldSeparator,
FieldWrapper,
ModelsAccessField,
SingleDefaultModelField,
LLMConfigurationModalWrapper,
ModelAccessField,
ModalWrapper,
} from "@/sections/modals/llmConfig/shared";
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
import { toast } from "@/hooks/useToast";
const VERTEXAI_PROVIDER_NAME = "vertex_ai";
const VERTEXAI_DISPLAY_NAME = "Google Cloud Vertex AI";
const VERTEXAI_DEFAULT_MODEL = "gemini-2.5-pro";
const VERTEXAI_DEFAULT_LOCATION = "global";
interface VertexAIModalValues extends BaseLLMFormValues {
@@ -48,87 +37,50 @@ export default function VertexAIModal({
shouldMarkAsDefault,
open,
onOpenChange,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
onSuccess,
}: LLMProviderFormProps) {
const isOnboarding = variant === "onboarding";
const [isTesting, setIsTesting] = useState(false);
const { mutate } = useSWRConfig();
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
VERTEXAI_PROVIDER_NAME
);
if (open === false) return null;
const initialValues: VertexAIModalValues = {
...useInitialValues(
isOnboarding,
LLMProviderName.VERTEX_AI,
existingLlmProvider
),
custom_config: {
vertex_credentials:
(existingLlmProvider?.custom_config?.vertex_credentials as string) ??
"",
vertex_location:
(existingLlmProvider?.custom_config?.vertex_location as string) ??
VERTEXAI_DEFAULT_LOCATION,
},
} as VertexAIModalValues;
const validationSchema = buildValidationSchema(isOnboarding, {
extra: {
custom_config: Yup.object({
vertex_credentials: Yup.string().required(
"Credentials file is required"
),
vertex_location: Yup.string(),
}),
},
});
const onClose = () => onOpenChange?.(false);
const modelConfigurations = buildAvailableModelConfigurations(
existingLlmProvider,
wellKnownLLMProvider ?? llmDescriptor
);
const initialValues: VertexAIModalValues = isOnboarding
? ({
...buildOnboardingInitialValues(),
name: VERTEXAI_PROVIDER_NAME,
provider: VERTEXAI_PROVIDER_NAME,
default_model_name: VERTEXAI_DEFAULT_MODEL,
custom_config: {
vertex_credentials: "",
vertex_location: VERTEXAI_DEFAULT_LOCATION,
},
} as VertexAIModalValues)
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
default_model_name:
(defaultModelName &&
modelConfigurations.some((m) => m.name === defaultModelName)
? defaultModelName
: undefined) ??
wellKnownLLMProvider?.recommended_default_model?.name ??
VERTEXAI_DEFAULT_MODEL,
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,
custom_config: {
vertex_credentials:
(existingLlmProvider?.custom_config
?.vertex_credentials as string) ?? "",
vertex_location:
(existingLlmProvider?.custom_config?.vertex_location as string) ??
VERTEXAI_DEFAULT_LOCATION,
},
};
const validationSchema = isOnboarding
? Yup.object().shape({
default_model_name: Yup.string().required("Model name is required"),
custom_config: Yup.object({
vertex_credentials: Yup.string().required(
"Credentials file is required"
),
vertex_location: Yup.string(),
}),
})
: buildDefaultValidationSchema().shape({
custom_config: Yup.object({
vertex_credentials: Yup.string().required(
"Credentials file is required"
),
vertex_location: Yup.string(),
}),
});
if (open === false) return null;
return (
<Formik
<ModalWrapper
providerName={LLMProviderName.VERTEX_AI}
llmProvider={existingLlmProvider}
onClose={onClose}
initialValues={initialValues}
validationSchema={validationSchema}
validateOnMount={true}
onSubmit={async (values, { setSubmitting }) => {
onSubmit={async (values, { setSubmitting, setStatus }) => {
const filteredCustomConfig = Object.fromEntries(
Object.entries(values.custom_config || {}).filter(
([key, v]) => key === "vertex_credentials" || v !== ""
@@ -143,101 +95,75 @@ export default function VertexAIModal({
: undefined,
};
if (isOnboarding && onboardingState && onboardingActions) {
const modelConfigsToUse =
(wellKnownLLMProvider ?? llmDescriptor)?.known_models ?? [];
await submitOnboardingProvider({
providerName: VERTEXAI_PROVIDER_NAME,
payload: {
...submitValues,
model_configurations: modelConfigsToUse,
is_auto_mode:
values.default_model_name === VERTEXAI_DEFAULT_MODEL,
},
onboardingState,
onboardingActions,
isCustomProvider: false,
onClose,
setIsSubmitting: setSubmitting,
});
} else {
await submitLLMProvider({
providerName: VERTEXAI_PROVIDER_NAME,
values: submitValues,
initialValues,
modelConfigurations,
existingLlmProvider,
shouldMarkAsDefault,
setIsTesting,
mutate,
onClose,
setSubmitting,
});
}
await submitProvider({
analyticsSource: isOnboarding
? LLMProviderConfiguredSource.CHAT_ONBOARDING
: LLMProviderConfiguredSource.ADMIN_PAGE,
providerName: LLMProviderName.VERTEX_AI,
values: submitValues,
initialValues,
existingLlmProvider,
shouldMarkAsDefault,
setStatus,
setSubmitting,
onClose,
onSuccess: async () => {
if (onSuccess) {
await onSuccess();
} else {
await refreshLlmProviderCaches(mutate);
toast.success(
existingLlmProvider
? "Provider updated successfully!"
: "Provider enabled successfully!"
);
}
},
});
}}
>
{(formikProps) => (
<LLMConfigurationModalWrapper
providerEndpoint={VERTEXAI_PROVIDER_NAME}
providerName={VERTEXAI_DISPLAY_NAME}
existingProviderName={existingLlmProvider?.name}
onClose={onClose}
isFormValid={formikProps.isValid}
isDirty={formikProps.dirty}
isTesting={isTesting}
isSubmitting={formikProps.isSubmitting}
<InputLayouts.FieldPadder>
<InputLayouts.Vertical
name="custom_config.vertex_location"
title="Google Cloud Region Name"
subDescription="Region where your Google Vertex AI models are hosted. See full list of regions supported at Google Cloud."
>
<FieldWrapper>
<InputLayouts.Vertical
name="custom_config.vertex_location"
title="Google Cloud Region Name"
subDescription="Region where your Google Vertex AI models are hosted. See full list of regions supported at Google Cloud."
>
<InputTypeInField
name="custom_config.vertex_location"
placeholder={VERTEXAI_DEFAULT_LOCATION}
/>
</InputLayouts.Vertical>
</FieldWrapper>
<InputTypeInField
name="custom_config.vertex_location"
placeholder={VERTEXAI_DEFAULT_LOCATION}
/>
</InputLayouts.Vertical>
</InputLayouts.FieldPadder>
<FieldWrapper>
<InputLayouts.Vertical
name="custom_config.vertex_credentials"
title="API Key"
subDescription="Attach your API key JSON from Google Cloud to access your models."
>
<FileUploadFormField
name="custom_config.vertex_credentials"
label=""
/>
</InputLayouts.Vertical>
</FieldWrapper>
<InputLayouts.FieldPadder>
<InputLayouts.Vertical
name="custom_config.vertex_credentials"
title="API Key"
subDescription="Attach your API key JSON from Google Cloud to access your models."
>
<FileUploadFormField
name="custom_config.vertex_credentials"
label=""
/>
</InputLayouts.Vertical>
</InputLayouts.FieldPadder>
<FieldSeparator />
{!isOnboarding && (
<DisplayNameField disabled={!!existingLlmProvider} />
)}
<FieldSeparator />
{isOnboarding ? (
<SingleDefaultModelField placeholder="E.g. gemini-2.5-pro" />
) : (
<ModelsField
modelConfigurations={modelConfigurations}
formikProps={formikProps}
recommendedDefaultModel={
wellKnownLLMProvider?.recommended_default_model ?? null
}
shouldShowAutoUpdateToggle={true}
/>
)}
{!isOnboarding && <ModelsAccessField formikProps={formikProps} />}
</LLMConfigurationModalWrapper>
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
)}
</Formik>
<InputLayouts.FieldSeparator />
<ModelSelectionField shouldShowAutoUpdateToggle={true} />
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<ModelAccessField />
</>
)}
</ModalWrapper>
);
}

View File

@@ -7,9 +7,10 @@ import VertexAIModal from "@/sections/modals/llmConfig/VertexAIModal";
import OpenRouterModal from "@/sections/modals/llmConfig/OpenRouterModal";
import CustomModal from "@/sections/modals/llmConfig/CustomModal";
import BedrockModal from "@/sections/modals/llmConfig/BedrockModal";
import LMStudioForm from "@/sections/modals/llmConfig/LMStudioForm";
import LMStudioModal from "@/sections/modals/llmConfig/LMStudioModal";
import LiteLLMProxyModal from "@/sections/modals/llmConfig/LiteLLMProxyModal";
import BifrostModal from "@/sections/modals/llmConfig/BifrostModal";
import OpenAICompatibleModal from "@/sections/modals/llmConfig/OpenAICompatibleModal";
function detectIfRealOpenAIProvider(provider: LLMProviderView) {
return (
@@ -54,11 +55,13 @@ export function getModalForExistingProvider(
case LLMProviderName.OPENROUTER:
return <OpenRouterModal {...props} />;
case LLMProviderName.LM_STUDIO:
return <LMStudioForm {...props} />;
return <LMStudioModal {...props} />;
case LLMProviderName.LITELLM_PROXY:
return <LiteLLMProxyModal {...props} />;
case LLMProviderName.BIFROST:
return <BifrostModal {...props} />;
case LLMProviderName.OPENAI_COMPATIBLE:
return <OpenAICompatibleModal {...props} />;
default:
return <CustomModal {...props} />;
}

View File

@@ -1,30 +1,31 @@
"use client";
import { ReactNode } from "react";
import { Form, FormikProps } from "formik";
import React, { useEffect, useRef, useState } from "react";
import { Formik, Form, useFormikContext } from "formik";
import type { FormikConfig } from "formik";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import { useAgents } from "@/hooks/useAgents";
import { useUserGroups } from "@/lib/hooks";
import { ModelConfiguration, SimpleKnownModel } from "@/interfaces/llm";
import { LLMProviderView, ModelConfiguration } from "@/interfaces/llm";
import * as InputLayouts from "@/layouts/input-layouts";
import Checkbox from "@/refresh-components/inputs/Checkbox";
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
import InputComboBox from "@/refresh-components/inputs/InputComboBox";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
import Switch from "@/refresh-components/inputs/Switch";
import Text from "@/refresh-components/texts/Text";
import { Button, LineItemButton, Tag } from "@opal/components";
import { Button, LineItemButton } from "@opal/components";
import { BaseLLMFormValues } from "@/sections/modals/llmConfig/utils";
import { WithoutStyles } from "@opal/types";
import Separator from "@/refresh-components/Separator";
import type { RichStr } from "@opal/types";
import { Section } from "@/layouts/general-layouts";
import { Disabled, Hoverable } from "@opal/core";
import { Content } from "@opal/layouts";
import {
SvgArrowExchange,
SvgOnyxOctagon,
SvgOrganization,
SvgPlusCircle,
SvgRefreshCw,
SvgSparkle,
SvgUserManage,
@@ -46,27 +47,14 @@ import {
getProviderProductName,
} from "@/lib/llmConfig/providers";
export function FieldSeparator() {
return <Separator noPadding className="px-2" />;
}
export type FieldWrapperProps = WithoutStyles<
React.HTMLAttributes<HTMLDivElement>
>;
export function FieldWrapper(props: FieldWrapperProps) {
return <div {...props} className="p-2 w-full" />;
}
// ─── DisplayNameField ────────────────────────────────────────────────────────
export interface DisplayNameFieldProps {
disabled?: boolean;
}
export function DisplayNameField({ disabled = false }: DisplayNameFieldProps) {
return (
<FieldWrapper>
<InputLayouts.FieldPadder>
<InputLayouts.Vertical
name="name"
title="Display Name"
@@ -78,56 +66,68 @@ export function DisplayNameField({ disabled = false }: DisplayNameFieldProps) {
variant={disabled ? "disabled" : undefined}
/>
</InputLayouts.Vertical>
</FieldWrapper>
</InputLayouts.FieldPadder>
);
}
// ─── APIKeyField ─────────────────────────────────────────────────────────────
export interface APIKeyFieldProps {
/** Formik field name. @default "api_key" */
name?: string;
optional?: boolean;
providerName?: string;
subDescription?: string | RichStr;
}
export function APIKeyField({
name = "api_key",
optional = false,
providerName,
subDescription,
}: APIKeyFieldProps) {
return (
<FieldWrapper>
<InputLayouts.FieldPadder>
<InputLayouts.Vertical
name="api_key"
name={name}
title="API Key"
subDescription={
providerName
? `Paste your API key from ${providerName} to access your models.`
: "Paste your API key to access your models."
subDescription
? subDescription
: providerName
? `Paste your API key from ${providerName} to access your models.`
: "Paste your API key to access your models."
}
optional={optional}
>
<PasswordInputTypeInField name="api_key" placeholder="API Key" />
<PasswordInputTypeInField name={name} />
</InputLayouts.Vertical>
</FieldWrapper>
</InputLayouts.FieldPadder>
);
}
// ─── SingleDefaultModelField ─────────────────────────────────────────────────
// ─── APIBaseField ───────────────────────────────────────────────────────────
export interface SingleDefaultModelFieldProps {
export interface APIBaseFieldProps {
optional?: boolean;
subDescription?: string | RichStr;
placeholder?: string;
}
export function SingleDefaultModelField({
placeholder = "E.g. gpt-4o",
}: SingleDefaultModelFieldProps) {
export function APIBaseField({
optional = false,
subDescription,
placeholder = "https://",
}: APIBaseFieldProps) {
return (
<InputLayouts.Vertical
name="default_model_name"
title="Default Model"
description="The model to use by default for this provider unless otherwise specified."
>
<InputTypeInField name="default_model_name" placeholder={placeholder} />
</InputLayouts.Vertical>
<InputLayouts.FieldPadder>
<InputLayouts.Vertical
name="api_base"
title="API Base URL"
subDescription={subDescription}
optional={optional}
>
<InputTypeInField name="api_base" placeholder={placeholder} />
</InputLayouts.Vertical>
</InputLayouts.FieldPadder>
);
}
@@ -137,13 +137,8 @@ export function SingleDefaultModelField({
const GROUP_PREFIX = "group:";
const AGENT_PREFIX = "agent:";
interface ModelsAccessFieldProps<T> {
formikProps: FormikProps<T>;
}
export function ModelsAccessField<T extends BaseLLMFormValues>({
formikProps,
}: ModelsAccessFieldProps<T>) {
export function ModelAccessField() {
const formikProps = useFormikContext<BaseLLMFormValues>();
const { agents } = useAgents();
const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups();
const { data: usersData } = useUsers({ includeApiKeys: false });
@@ -229,7 +224,7 @@ export function ModelsAccessField<T extends BaseLLMFormValues>({
return (
<div className="flex flex-col w-full">
<FieldWrapper>
<InputLayouts.FieldPadder>
<InputLayouts.Horizontal
name="is_public"
title="Models Access"
@@ -250,7 +245,7 @@ export function ModelsAccessField<T extends BaseLLMFormValues>({
</InputSelect.Content>
</InputSelect>
</InputLayouts.Horizontal>
</FieldWrapper>
</InputLayouts.FieldPadder>
{!isPublic && (
<Card backgroundVariant="light" borderVariant="none" sizeVariant="lg">
@@ -316,7 +311,7 @@ export function ModelsAccessField<T extends BaseLLMFormValues>({
</div>
)}
<FieldSeparator />
<InputLayouts.FieldSeparator />
{selectedAgentIds.length > 0 ? (
<div className="grid grid-cols-2 gap-1 w-full">
@@ -371,84 +366,73 @@ export function ModelsAccessField<T extends BaseLLMFormValues>({
// ─── ModelsField ─────────────────────────────────────────────────────
export interface ModelsFieldProps<T> {
formikProps: FormikProps<T>;
modelConfigurations: ModelConfiguration[];
recommendedDefaultModel: SimpleKnownModel | null;
export interface ModelSelectionFieldProps {
shouldShowAutoUpdateToggle: boolean;
/** Called when the user clicks the refresh button to re-fetch models. */
onRefetch?: () => Promise<void> | void;
/** Called when the user adds a custom model name (e.g. for Azure). */
onAddModel?: (modelName: string) => void;
}
export function ModelsField<T extends BaseLLMFormValues>({
formikProps,
modelConfigurations,
recommendedDefaultModel,
export function ModelSelectionField({
shouldShowAutoUpdateToggle,
onRefetch,
}: ModelsFieldProps<T>) {
const isAutoMode = formikProps.values.is_auto_mode;
const selectedModels = formikProps.values.selected_model_names ?? [];
const defaultModel = formikProps.values.default_model_name;
onAddModel,
}: ModelSelectionFieldProps) {
const formikProps = useFormikContext<BaseLLMFormValues>();
const [newModelName, setNewModelName] = useState("");
// When the auto-update toggle is hidden, auto mode should have no effect —
// otherwise models can't be deselected and "Select All" stays disabled.
const isAutoMode =
shouldShowAutoUpdateToggle && formikProps.values.is_auto_mode;
const models = formikProps.values.model_configurations;
function handleCheckboxChange(modelName: string, checked: boolean) {
// Read current values inside the handler to avoid stale closure issues
const currentSelected = formikProps.values.selected_model_names ?? [];
const currentDefault = formikProps.values.default_model_name;
if (checked) {
const newSelected = [...currentSelected, modelName];
formikProps.setFieldValue("selected_model_names", newSelected);
// If this is the first model, set it as default
if (currentSelected.length === 0) {
formikProps.setFieldValue("default_model_name", modelName);
}
} else {
const newSelected = currentSelected.filter((name) => name !== modelName);
formikProps.setFieldValue("selected_model_names", newSelected);
// If removing the default, set the first remaining model as default
if (currentDefault === modelName && newSelected.length > 0) {
formikProps.setFieldValue("default_model_name", newSelected[0]);
} else if (newSelected.length === 0) {
formikProps.setFieldValue("default_model_name", undefined);
}
// Snapshot the original model visibility so we can restore it when
// toggling auto mode back on.
const originalModelsRef = useRef(models);
useEffect(() => {
if (originalModelsRef.current.length === 0 && models.length > 0) {
originalModelsRef.current = models;
}
}
}, [models]);
function handleSetDefault(modelName: string) {
formikProps.setFieldValue("default_model_name", modelName);
// Automatically derive test_model_name from model_configurations.
// Any change to visibility or the model list syncs this automatically.
useEffect(() => {
const firstVisible = models.find((m) => m.is_visible)?.name;
if (firstVisible !== formikProps.values.test_model_name) {
formikProps.setFieldValue("test_model_name", firstVisible);
}
}, [models]); // eslint-disable-line react-hooks/exhaustive-deps
function setVisibility(modelName: string, visible: boolean) {
const updated = models.map((m) =>
m.name === modelName ? { ...m, is_visible: visible } : m
);
formikProps.setFieldValue("model_configurations", updated);
}
function handleToggleAutoMode(nextIsAutoMode: boolean) {
formikProps.setFieldValue("is_auto_mode", nextIsAutoMode);
formikProps.setFieldValue(
"selected_model_names",
modelConfigurations.filter((m) => m.is_visible).map((m) => m.name)
);
formikProps.setFieldValue(
"default_model_name",
recommendedDefaultModel?.name ?? undefined
);
}
const allSelected =
modelConfigurations.length > 0 &&
modelConfigurations.every((m) => selectedModels.includes(m.name));
function handleToggleSelectAll() {
if (allSelected) {
formikProps.setFieldValue("selected_model_names", []);
formikProps.setFieldValue("default_model_name", undefined);
} else {
const allNames = modelConfigurations.map((m) => m.name);
formikProps.setFieldValue("selected_model_names", allNames);
if (!formikProps.values.default_model_name && allNames.length > 0) {
formikProps.setFieldValue("default_model_name", allNames[0]);
}
if (nextIsAutoMode) {
formikProps.setFieldValue(
"model_configurations",
originalModelsRef.current
);
}
}
const visibleModels = modelConfigurations.filter((m) => m.is_visible);
const allSelected = models.length > 0 && models.every((m) => m.is_visible);
function handleToggleSelectAll() {
const nextVisible = !allSelected;
const updated = models.map((m) => ({
...m,
is_visible: nextVisible,
}));
formikProps.setFieldValue("model_configurations", updated);
}
const visibleModels = models.filter((m) => m.is_visible);
return (
<Card backgroundVariant="light" borderVariant="none" sizeVariant="lg">
@@ -460,15 +444,14 @@ export function ModelsField<T extends BaseLLMFormValues>({
center
>
<Section flexDirection="row" gap={0}>
<Disabled disabled={isAutoMode || modelConfigurations.length === 0}>
<Button
prominence="tertiary"
size="md"
onClick={handleToggleSelectAll}
>
{allSelected ? "Unselect All" : "Select All"}
</Button>
</Disabled>
<Button
disabled={isAutoMode || models.length === 0}
prominence="tertiary"
size="md"
onClick={handleToggleSelectAll}
>
{allSelected ? "Unselect All" : "Select All"}
</Button>
{onRefetch && (
<Button
prominence="tertiary"
@@ -489,91 +472,75 @@ export function ModelsField<T extends BaseLLMFormValues>({
</Section>
</InputLayouts.Horizontal>
{modelConfigurations.length === 0 ? (
{models.length === 0 ? (
<EmptyMessageCard title="No models available." />
) : (
<Section gap={0.25}>
{isAutoMode
? // Auto mode: read-only display
visibleModels.map((model) => (
<Hoverable.Root
? visibleModels.map((model) => (
<LineItemButton
key={model.name}
group="LLMConfigurationButton"
widthVariant="full"
>
<LineItemButton
variant="section"
sizePreset="main-ui"
selectVariant="select-heavy"
state="selected"
icon={() => <Checkbox checked />}
title={model.display_name || model.name}
rightChildren={
model.name === defaultModel ? (
<Section>
<Tag title="Default Model" color="blue" />
</Section>
) : undefined
}
/>
</Hoverable.Root>
variant="section"
sizePreset="main-ui"
selectVariant="select-heavy"
state="selected"
icon={() => <Checkbox checked />}
title={model.display_name || model.name}
/>
))
: // Manual mode: checkbox selection
modelConfigurations.map((modelConfiguration) => {
const isSelected = selectedModels.includes(
modelConfiguration.name
);
const isDefault = defaultModel === modelConfiguration.name;
: models.map((model) => (
<LineItemButton
key={model.name}
variant="section"
sizePreset="main-ui"
selectVariant="select-heavy"
state={model.is_visible ? "selected" : "empty"}
icon={() => <Checkbox checked={model.is_visible} />}
title={model.name}
onClick={() => setVisibility(model.name, !model.is_visible)}
/>
))}
</Section>
)}
return (
<Hoverable.Root
key={modelConfiguration.name}
group="LLMConfigurationButton"
widthVariant="full"
>
<LineItemButton
variant="section"
sizePreset="main-ui"
selectVariant="select-heavy"
state={isSelected ? "selected" : "empty"}
icon={() => <Checkbox checked={isSelected} />}
title={modelConfiguration.name}
onClick={() =>
handleCheckboxChange(
modelConfiguration.name,
!isSelected
)
}
rightChildren={
isSelected ? (
isDefault ? (
<Section>
<Tag color="blue" title="Default Model" />
</Section>
) : (
<Hoverable.Item
group="LLMConfigurationButton"
variant="opacity-on-hover"
>
<Button
size="sm"
prominence="internal"
onClick={(e) => {
e.stopPropagation();
handleSetDefault(modelConfiguration.name);
}}
type="button"
>
Set as default
</Button>
</Hoverable.Item>
)
) : undefined
}
/>
</Hoverable.Root>
);
})}
{onAddModel && !isAutoMode && (
<Section flexDirection="row" gap={0.5}>
<div className="flex-1">
<InputTypeIn
placeholder="Enter model name"
value={newModelName}
onChange={(e) => setNewModelName(e.target.value)}
onKeyDown={(e) => {
if (e.key === "Enter" && newModelName.trim()) {
e.preventDefault();
const trimmed = newModelName.trim();
if (!models.some((m) => m.name === trimmed)) {
onAddModel(trimmed);
setNewModelName("");
}
}
}}
showClearButton={false}
/>
</div>
<Button
prominence="secondary"
icon={SvgPlusCircle}
type="button"
disabled={
!newModelName.trim() ||
models.some((m) => m.name === newModelName.trim())
}
onClick={() => {
const trimmed = newModelName.trim();
if (trimmed && !models.some((m) => m.name === trimmed)) {
onAddModel(trimmed);
setNewModelName("");
}
}}
>
Add Model
</Button>
</Section>
)}
@@ -593,41 +560,96 @@ export function ModelsField<T extends BaseLLMFormValues>({
);
}
// ============================================================================
// LLMConfigurationModalWrapper
// ============================================================================
// ─── ModalWrapper ─────────────────────────────────────────────────────
interface LLMConfigurationModalWrapperProps {
providerEndpoint: string;
providerName?: string;
existingProviderName?: string;
export interface ModalWrapperProps<
T extends BaseLLMFormValues = BaseLLMFormValues,
> {
providerName: string;
llmProvider?: LLMProviderView;
onClose: () => void;
isFormValid: boolean;
isDirty?: boolean;
isTesting?: boolean;
isSubmitting?: boolean;
children: ReactNode;
initialValues: T;
validationSchema: FormikConfig<T>["validationSchema"];
onSubmit: FormikConfig<T>["onSubmit"];
children: React.ReactNode;
}
export function ModalWrapper<T extends BaseLLMFormValues = BaseLLMFormValues>({
providerName,
llmProvider,
onClose,
initialValues,
validationSchema,
onSubmit,
children,
}: ModalWrapperProps<T>) {
return (
<Formik
initialValues={initialValues}
validationSchema={validationSchema}
validateOnMount
onSubmit={onSubmit}
>
{() => (
<ModalWrapperInner
providerName={providerName}
llmProvider={llmProvider}
onClose={onClose}
modelConfigurations={initialValues.model_configurations}
>
{children}
</ModalWrapperInner>
)}
</Formik>
);
}
export function LLMConfigurationModalWrapper({
providerEndpoint,
interface ModalWrapperInnerProps {
providerName: string;
llmProvider?: LLMProviderView;
onClose: () => void;
modelConfigurations?: ModelConfiguration[];
children: React.ReactNode;
}
function ModalWrapperInner({
providerName,
existingProviderName,
llmProvider,
onClose,
isFormValid,
isDirty,
isTesting,
isSubmitting,
modelConfigurations,
children,
}: LLMConfigurationModalWrapperProps) {
const busy = isTesting || isSubmitting;
const providerIcon = getProviderIcon(providerEndpoint);
const providerDisplayName =
providerName ?? getProviderDisplayName(providerEndpoint);
const providerProductName = getProviderProductName(providerEndpoint);
}: ModalWrapperInnerProps) {
const { isValid, dirty, isSubmitting, status, setFieldValue, values } =
useFormikContext<BaseLLMFormValues>();
const title = existingProviderName
? `Configure "${existingProviderName}"`
// When SWR resolves after mount, populate model_configurations if still
// empty. test_model_name is then derived automatically by
// ModelSelectionField's useEffect.
useEffect(() => {
if (
modelConfigurations &&
modelConfigurations.length > 0 &&
values.model_configurations.length === 0
) {
setFieldValue("model_configurations", modelConfigurations);
}
}, [modelConfigurations]); // eslint-disable-line react-hooks/exhaustive-deps
const isTesting = status?.isTesting === true;
const busy = isTesting || isSubmitting;
const disabledTooltip = busy
? undefined
: !isValid
? "Please fill in all required fields."
: !dirty
? "No changes to save."
: undefined;
const providerIcon = getProviderIcon(providerName);
const providerDisplayName = getProviderDisplayName(providerName);
const providerProductName = getProviderProductName(providerName);
const title = llmProvider
? `Configure "${llmProvider.name}"`
: `Set up ${providerProductName}`;
const description = `Connect to ${providerDisplayName} and set up your ${providerProductName} models.`;
@@ -650,21 +672,20 @@ export function LLMConfigurationModalWrapper({
<Button prominence="secondary" onClick={onClose} type="button">
Cancel
</Button>
<Disabled
disabled={
!isFormValid || busy || (!!existingProviderName && !isDirty)
}
<Button
disabled={!isValid || !dirty || busy}
type="submit"
icon={busy ? SimpleLoader : undefined}
tooltip={disabledTooltip}
>
<Button type="submit" icon={busy ? SimpleLoader : undefined}>
{existingProviderName
? busy
? "Updating"
: "Update"
: busy
? "Connecting"
: "Connect"}
</Button>
</Disabled>
{llmProvider?.name
? busy
? "Updating"
: "Update"
: busy
? "Connecting"
: "Connect"}
</Button>
</Modal.Footer>
</Form>
</Modal.Content>

View File

@@ -1,13 +1,8 @@
import {
LLMProviderName,
LLMProviderView,
ModelConfiguration,
} from "@/interfaces/llm";
import { LLMProviderName, LLMProviderView } from "@/interfaces/llm";
import {
LLM_ADMIN_URL,
LLM_PROVIDERS_ADMIN_URL,
} from "@/lib/llmConfig/constants";
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
import { toast } from "@/hooks/useToast";
import isEqual from "lodash/isEqual";
import { parseAzureTargetUri } from "@/lib/azureTargetUri";
@@ -18,13 +13,11 @@ import {
} from "@/lib/analytics";
import {
BaseLLMFormValues,
SubmitLLMProviderParams,
SubmitOnboardingProviderParams,
TestApiKeyResult,
filterModelConfigurations,
getAutoModeModelConfigurations,
} from "@/sections/modals/llmConfig/utils";
// ─── Test helpers ─────────────────────────────────────────────────────────
const submitLlmTestRequest = async (
payload: Record<string, unknown>,
fallbackErrorMessage: string
@@ -50,161 +43,6 @@ const submitLlmTestRequest = async (
}
};
export const submitLLMProvider = async <T extends BaseLLMFormValues>({
providerName,
values,
initialValues,
modelConfigurations,
existingLlmProvider,
shouldMarkAsDefault,
hideSuccess,
setIsTesting,
mutate,
onClose,
setSubmitting,
}: SubmitLLMProviderParams<T>): Promise<void> => {
setSubmitting(true);
const { selected_model_names: visibleModels, api_key, ...rest } = values;
// In auto mode, use recommended models from descriptor
// In manual mode, use user's selection
let filteredModelConfigurations: ModelConfiguration[];
let finalDefaultModelName = rest.default_model_name;
if (values.is_auto_mode) {
filteredModelConfigurations =
getAutoModeModelConfigurations(modelConfigurations);
// In auto mode, use the first recommended model as default if current default isn't in the list
const visibleModelNames = new Set(
filteredModelConfigurations.map((m) => m.name)
);
if (
finalDefaultModelName &&
!visibleModelNames.has(finalDefaultModelName)
) {
finalDefaultModelName = filteredModelConfigurations[0]?.name ?? "";
}
} else {
filteredModelConfigurations = filterModelConfigurations(
modelConfigurations,
visibleModels,
rest.default_model_name as string | undefined
);
}
const customConfigChanged = !isEqual(
values.custom_config,
initialValues.custom_config
);
const normalizedApiBase =
typeof rest.api_base === "string" && rest.api_base.trim() === ""
? undefined
: rest.api_base;
const finalValues = {
...rest,
api_base: normalizedApiBase,
default_model_name: finalDefaultModelName,
api_key,
api_key_changed: api_key !== (initialValues.api_key as string | undefined),
custom_config_changed: customConfigChanged,
model_configurations: filteredModelConfigurations,
};
// Test the configuration
if (!isEqual(finalValues, initialValues)) {
setIsTesting(true);
const response = await fetch("/api/admin/llm/test", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
provider: providerName,
...finalValues,
model: finalDefaultModelName,
id: existingLlmProvider?.id,
}),
});
setIsTesting(false);
if (!response.ok) {
const errorMsg = (await response.json()).detail;
toast.error(errorMsg);
setSubmitting(false);
return;
}
}
const response = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}${
existingLlmProvider ? "" : "?is_creation=true"
}`,
{
method: "PUT",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
provider: providerName,
...finalValues,
id: existingLlmProvider?.id,
}),
}
);
if (!response.ok) {
const errorMsg = (await response.json()).detail;
const fullErrorMsg = existingLlmProvider
? `Failed to update provider: ${errorMsg}`
: `Failed to enable provider: ${errorMsg}`;
toast.error(fullErrorMsg);
return;
}
if (shouldMarkAsDefault) {
const newLlmProvider = (await response.json()) as LLMProviderView;
const setDefaultResponse = await fetch(`${LLM_ADMIN_URL}/default`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
provider_id: newLlmProvider.id,
model_name: finalDefaultModelName,
}),
});
if (!setDefaultResponse.ok) {
const errorMsg = (await setDefaultResponse.json()).detail;
toast.error(`Failed to set provider as default: ${errorMsg}`);
return;
}
}
await refreshLlmProviderCaches(mutate);
onClose();
if (!hideSuccess) {
const successMsg = existingLlmProvider
? "Provider updated successfully!"
: "Provider enabled successfully!";
toast.success(successMsg);
}
const knownProviders = new Set<string>(Object.values(LLMProviderName));
track(AnalyticsEvent.CONFIGURED_LLM_PROVIDER, {
provider: knownProviders.has(providerName) ? providerName : "custom",
is_creation: !existingLlmProvider,
source: LLMProviderConfiguredSource.ADMIN_PAGE,
});
setSubmitting(false);
};
export const testApiKeyHelper = async (
providerName: string,
formValues: Record<string, unknown>,
@@ -241,7 +79,7 @@ export const testApiKeyHelper = async (
...((formValues?.custom_config as Record<string, unknown>) ?? {}),
...(customConfigOverride ?? {}),
},
model: modelName ?? (formValues?.default_model_name as string) ?? "",
model: modelName ?? (formValues?.test_model_name as string) ?? "",
};
return await submitLlmTestRequest(
@@ -259,96 +97,148 @@ export const testCustomProvider = async (
);
};
export const submitOnboardingProvider = async ({
// ─── Submit provider ──────────────────────────────────────────────────────
export interface SubmitProviderParams<
T extends BaseLLMFormValues = BaseLLMFormValues,
> {
providerName: string;
values: T;
initialValues: T;
existingLlmProvider?: LLMProviderView;
shouldMarkAsDefault?: boolean;
isCustomProvider?: boolean;
setStatus: (status: Record<string, unknown>) => void;
setSubmitting: (submitting: boolean) => void;
onClose: () => void;
/** Called after successful create/update + set-default. Use for cache refresh, state updates, toasts, etc. */
onSuccess?: () => void | Promise<void>;
/** Analytics source for tracking. @default LLMProviderConfiguredSource.ADMIN_PAGE */
analyticsSource?: LLMProviderConfiguredSource;
}
export async function submitProvider<T extends BaseLLMFormValues>({
providerName,
payload,
onboardingState,
onboardingActions,
values,
initialValues,
existingLlmProvider,
shouldMarkAsDefault,
isCustomProvider,
setStatus,
setSubmitting,
onClose,
setIsSubmitting,
}: SubmitOnboardingProviderParams): Promise<void> => {
setIsSubmitting(true);
onSuccess,
analyticsSource = LLMProviderConfiguredSource.ADMIN_PAGE,
}: SubmitProviderParams<T>): Promise<void> {
setSubmitting(true);
// Test credentials
let result: TestApiKeyResult;
if (isCustomProvider) {
result = await testCustomProvider(payload);
} else {
result = await testApiKeyHelper(providerName, payload);
const { test_model_name, api_key, ...rest } = values;
const testModelName =
test_model_name ||
values.model_configurations.find((m) => m.is_visible)?.name ||
"";
// ── Test credentials ────────────────────────────────────────────────
const customConfigChanged = !isEqual(
values.custom_config,
initialValues.custom_config
);
const normalizedApiBase =
typeof rest.api_base === "string" && rest.api_base.trim() === ""
? undefined
: rest.api_base;
const finalValues = {
...rest,
api_base: normalizedApiBase,
api_key,
api_key_changed: api_key !== (initialValues.api_key as string | undefined),
custom_config_changed: customConfigChanged,
};
if (!isEqual(finalValues, initialValues)) {
setStatus({ isTesting: true });
const testResult = await submitLlmTestRequest(
{
provider: providerName,
...finalValues,
model: testModelName,
id: existingLlmProvider?.id,
},
"An error occurred while testing the provider."
);
setStatus({ isTesting: false });
if (!testResult.ok) {
toast.error(testResult.errorMessage);
setSubmitting(false);
return;
}
}
if (!result.ok) {
toast.error(result.errorMessage);
setIsSubmitting(false);
return;
}
// Create provider
const response = await fetch(`${LLM_PROVIDERS_ADMIN_URL}?is_creation=true`, {
method: "PUT",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(payload),
});
// ── Create/update provider ──────────────────────────────────────────
const response = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}${
existingLlmProvider ? "" : "?is_creation=true"
}`,
{
method: "PUT",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
provider: providerName,
...finalValues,
id: existingLlmProvider?.id,
}),
}
);
if (!response.ok) {
const errorMsg = (await response.json()).detail;
toast.error(errorMsg);
setIsSubmitting(false);
const fullErrorMsg = existingLlmProvider
? `Failed to update provider: ${errorMsg}`
: `Failed to enable provider: ${errorMsg}`;
toast.error(fullErrorMsg);
setSubmitting(false);
return;
}
// Set as default if first provider
if (
onboardingState?.data?.llmProviders == null ||
onboardingState.data.llmProviders.length === 0
) {
// ── Set as default ──────────────────────────────────────────────────
if (shouldMarkAsDefault && testModelName) {
try {
const newLlmProvider = await response.json();
if (newLlmProvider?.id != null) {
const defaultModelName =
(payload as Record<string, string>).default_model_name ??
(payload as Record<string, ModelConfiguration[]>)
.model_configurations?.[0]?.name ??
"";
if (defaultModelName) {
const setDefaultResponse = await fetch(`${LLM_ADMIN_URL}/default`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
provider_id: newLlmProvider.id,
model_name: defaultModelName,
}),
});
if (!setDefaultResponse.ok) {
const err = await setDefaultResponse.json().catch(() => ({}));
toast.error(err?.detail ?? "Failed to set provider as default");
setIsSubmitting(false);
return;
}
const setDefaultResponse = await fetch(`${LLM_ADMIN_URL}/default`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
provider_id: newLlmProvider.id,
model_name: testModelName,
}),
});
if (!setDefaultResponse.ok) {
const err = await setDefaultResponse.json().catch(() => ({}));
toast.error(err?.detail ?? "Failed to set provider as default");
setSubmitting(false);
return;
}
}
} catch (_e) {
} catch {
toast.error("Failed to set new provider as default");
}
}
// ── Post-success ────────────────────────────────────────────────────
const knownProviders = new Set<string>(Object.values(LLMProviderName));
track(AnalyticsEvent.CONFIGURED_LLM_PROVIDER, {
provider: isCustomProvider ? "custom" : providerName,
is_creation: true,
source: LLMProviderConfiguredSource.CHAT_ONBOARDING,
provider: knownProviders.has(providerName) ? providerName : "custom",
is_creation: !existingLlmProvider,
source: analyticsSource,
});
// Update onboarding state
onboardingActions.updateData({
llmProviders: [
...(onboardingState?.data.llmProviders ?? []),
isCustomProvider ? "custom" : providerName,
],
});
onboardingActions.setButtonActive(true);
if (onSuccess) await onSuccess();
setIsSubmitting(false);
setSubmitting(false);
onClose();
};
}

View File

@@ -1,197 +1,130 @@
import {
LLMProviderName,
LLMProviderView,
ModelConfiguration,
WellKnownLLMProviderDescriptor,
} from "@/interfaces/llm";
import * as Yup from "yup";
import { ScopedMutator } from "swr";
import { OnboardingActions, OnboardingState } from "@/interfaces/onboarding";
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
// Common class names for the Form component across all LLM provider forms
export const LLM_FORM_CLASS_NAME = "flex flex-col gap-y-4 items-stretch mt-6";
// ─── useInitialValues ─────────────────────────────────────────────────────
export const buildDefaultInitialValues = (
/** Builds the merged model list from existing + well-known, deduped by name. */
function buildModelConfigurations(
existingLlmProvider?: LLMProviderView,
modelConfigurations?: ModelConfiguration[],
currentDefaultModelName?: string
) => {
const defaultModelName =
(currentDefaultModelName &&
existingLlmProvider?.model_configurations?.some(
(m) => m.name === currentDefaultModelName
)
? currentDefaultModelName
: undefined) ??
existingLlmProvider?.model_configurations?.[0]?.name ??
modelConfigurations?.[0]?.name ??
"";
wellKnownLLMProvider?: WellKnownLLMProviderDescriptor
): ModelConfiguration[] {
const existingModels = existingLlmProvider?.model_configurations ?? [];
const wellKnownModels = wellKnownLLMProvider?.known_models ?? [];
// Auto mode must be explicitly enabled by the user
// Default to false for new providers, preserve existing value when editing
const isAutoMode = existingLlmProvider?.is_auto_mode ?? false;
const modelMap = new Map<string, ModelConfiguration>();
wellKnownModels.forEach((m) => modelMap.set(m.name, m));
existingModels.forEach((m) => modelMap.set(m.name, m));
return Array.from(modelMap.values());
}
/** Shared initial values for all LLM provider forms (both onboarding and admin). */
export function useInitialValues(
isOnboarding: boolean,
providerName: LLMProviderName,
existingLlmProvider?: LLMProviderView
) {
const { wellKnownLLMProvider } = useWellKnownLLMProvider(providerName);
const modelConfigurations = buildModelConfigurations(
existingLlmProvider,
wellKnownLLMProvider ?? undefined
);
const testModelName =
modelConfigurations.find((m) => m.is_visible)?.name ??
wellKnownLLMProvider?.recommended_default_model?.name;
return {
name: existingLlmProvider?.name || "",
default_model_name: defaultModelName,
provider: existingLlmProvider?.provider ?? providerName,
name: isOnboarding ? providerName : existingLlmProvider?.name ?? "",
api_key: existingLlmProvider?.api_key ?? undefined,
api_base: existingLlmProvider?.api_base ?? undefined,
is_public: existingLlmProvider?.is_public ?? true,
is_auto_mode: isAutoMode,
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,
groups: existingLlmProvider?.groups ?? [],
personas: existingLlmProvider?.personas ?? [],
selected_model_names: existingLlmProvider
? existingLlmProvider.model_configurations
.filter((modelConfiguration) => modelConfiguration.is_visible)
.map((modelConfiguration) => modelConfiguration.name)
: modelConfigurations
?.filter((modelConfiguration) => modelConfiguration.is_visible)
.map((modelConfiguration) => modelConfiguration.name) ?? [],
model_configurations: modelConfigurations,
test_model_name: testModelName,
};
};
}
// ─── buildValidationSchema ────────────────────────────────────────────────
interface ValidationSchemaOptions {
apiKey?: boolean;
apiBase?: boolean;
extra?: Yup.ObjectShape;
}
/**
* Builds the validation schema for a modal.
*
* @param isOnboarding — controls the base schema:
* - `true`: minimal (only `test_model_name`).
* - `false`: full admin schema (display name, access, models, etc.).
* @param options.apiKey — require `api_key`.
* @param options.apiBase — require `api_base`.
* @param options.extra — arbitrary Yup fields for provider-specific validation.
*/
export function buildValidationSchema(
isOnboarding: boolean,
{ apiKey, apiBase, extra }: ValidationSchemaOptions = {}
) {
const providerFields: Yup.ObjectShape = {
...(apiKey && {
api_key: Yup.string().required("API Key is required"),
}),
...(apiBase && {
api_base: Yup.string().required("API Base URL is required"),
}),
...extra,
};
if (isOnboarding) {
return Yup.object().shape({
test_model_name: Yup.string().required("Model name is required"),
...providerFields,
});
}
export const buildDefaultValidationSchema = () => {
return Yup.object({
name: Yup.string().required("Display Name is required"),
default_model_name: Yup.string().required("Model name is required"),
is_public: Yup.boolean().required(),
is_auto_mode: Yup.boolean().required(),
groups: Yup.array().of(Yup.number()),
personas: Yup.array().of(Yup.number()),
selected_model_names: Yup.array().of(Yup.string()),
test_model_name: Yup.string().required("Model name is required"),
...providerFields,
});
};
}
export const buildAvailableModelConfigurations = (
existingLlmProvider?: LLMProviderView,
wellKnownLLMProvider?: WellKnownLLMProviderDescriptor
): ModelConfiguration[] => {
const existingModels = existingLlmProvider?.model_configurations ?? [];
const wellKnownModels = wellKnownLLMProvider?.known_models ?? [];
// ─── Form value types ─────────────────────────────────────────────────────
// Create a map to deduplicate by model name, preferring existing models
const modelMap = new Map<string, ModelConfiguration>();
// Add well-known models first
wellKnownModels.forEach((model) => {
modelMap.set(model.name, model);
});
// Override with existing models (they take precedence)
existingModels.forEach((model) => {
modelMap.set(model.name, model);
});
return Array.from(modelMap.values());
};
// Base form values that all provider forms share
/** Base form values that all provider forms share. */
export interface BaseLLMFormValues {
name: string;
api_key?: string;
api_base?: string;
default_model_name?: string;
/** Model name used for the test request — automatically derived. */
test_model_name?: string;
is_public: boolean;
is_auto_mode: boolean;
groups: number[];
personas: number[];
selected_model_names: string[];
/** The full model list with is_visible set directly by user interaction. */
model_configurations: ModelConfiguration[];
custom_config?: Record<string, string>;
}
export interface SubmitLLMProviderParams<
T extends BaseLLMFormValues = BaseLLMFormValues,
> {
providerName: string;
values: T;
initialValues: T;
modelConfigurations: ModelConfiguration[];
existingLlmProvider?: LLMProviderView;
shouldMarkAsDefault?: boolean;
hideSuccess?: boolean;
setIsTesting: (testing: boolean) => void;
mutate: ScopedMutator;
onClose: () => void;
setSubmitting: (submitting: boolean) => void;
}
export const filterModelConfigurations = (
currentModelConfigurations: ModelConfiguration[],
visibleModels: string[],
defaultModelName?: string
): ModelConfiguration[] => {
return currentModelConfigurations
.map(
(modelConfiguration): ModelConfiguration => ({
name: modelConfiguration.name,
is_visible: visibleModels.includes(modelConfiguration.name),
max_input_tokens: modelConfiguration.max_input_tokens ?? null,
supports_image_input: modelConfiguration.supports_image_input,
supports_reasoning: modelConfiguration.supports_reasoning,
display_name: modelConfiguration.display_name,
})
)
.filter(
(modelConfiguration) =>
modelConfiguration.name === defaultModelName ||
modelConfiguration.is_visible
);
};
// Helper to get model configurations for auto mode
// In auto mode, we include ALL models but preserve their visibility status
// Models in the auto config are visible, others are created but not visible
export const getAutoModeModelConfigurations = (
modelConfigurations: ModelConfiguration[]
): ModelConfiguration[] => {
return modelConfigurations.map(
(modelConfiguration): ModelConfiguration => ({
name: modelConfiguration.name,
is_visible: modelConfiguration.is_visible,
max_input_tokens: modelConfiguration.max_input_tokens ?? null,
supports_image_input: modelConfiguration.supports_image_input,
supports_reasoning: modelConfiguration.supports_reasoning,
display_name: modelConfiguration.display_name,
})
);
};
// ─── Misc ─────────────────────────────────────────────────────────────────
export type TestApiKeyResult =
| { ok: true }
| { ok: false; errorMessage: string };
export const getModelOptions = (
fetchedModelConfigurations: Array<{ name: string }>
) => {
return fetchedModelConfigurations.map((model) => ({
label: model.name,
value: model.name,
}));
};
/** Initial values used by onboarding forms (flat shape, always creating new). */
export const buildOnboardingInitialValues = () => ({
name: "",
provider: "",
api_key: "",
api_base: "",
api_version: "",
default_model_name: "",
model_configurations: [] as ModelConfiguration[],
custom_config: {} as Record<string, string>,
api_key_changed: true,
groups: [] as number[],
is_public: true,
is_auto_mode: false,
personas: [] as number[],
selected_model_names: [] as string[],
deployment_name: "",
target_uri: "",
});
export interface SubmitOnboardingProviderParams {
providerName: string;
payload: Record<string, unknown>;
onboardingState: OnboardingState;
onboardingActions: OnboardingActions;
isCustomProvider: boolean;
onClose: () => void;
setIsSubmitting: (submitting: boolean) => void;
}

View File

@@ -2,6 +2,7 @@ import React from "react";
import {
WellKnownLLMProviderDescriptor,
LLMProviderName,
LLMProviderFormProps,
} from "@/interfaces/llm";
import { OnboardingActions, OnboardingState } from "@/interfaces/onboarding";
import OpenAIModal from "@/sections/modals/llmConfig/OpenAIModal";
@@ -12,8 +13,9 @@ import BedrockModal from "@/sections/modals/llmConfig/BedrockModal";
import VertexAIModal from "@/sections/modals/llmConfig/VertexAIModal";
import OpenRouterModal from "@/sections/modals/llmConfig/OpenRouterModal";
import CustomModal from "@/sections/modals/llmConfig/CustomModal";
import LMStudioForm from "@/sections/modals/llmConfig/LMStudioForm";
import LMStudioModal from "@/sections/modals/llmConfig/LMStudioModal";
import LiteLLMProxyModal from "@/sections/modals/llmConfig/LiteLLMProxyModal";
import OpenAICompatibleModal from "@/sections/modals/llmConfig/OpenAICompatibleModal";
// Display info for LLM provider cards - title is the product name, displayName is the company/platform
const PROVIDER_DISPLAY_INFO: Record<
@@ -47,6 +49,10 @@ const PROVIDER_DISPLAY_INFO: Record<
title: "LiteLLM Proxy",
displayName: "LiteLLM Proxy",
},
[LLMProviderName.OPENAI_COMPATIBLE]: {
title: "OpenAI Compatible",
displayName: "OpenAI Compatible",
},
};
export function getProviderDisplayInfo(providerName: string): {
@@ -78,12 +84,26 @@ export function getOnboardingForm({
open,
onOpenChange,
}: OnboardingFormProps): React.ReactNode {
const sharedProps = {
const providerName = isCustomProvider
? "custom"
: llmDescriptor?.name ?? "custom";
const sharedProps: LLMProviderFormProps = {
variant: "onboarding" as const,
onboardingState,
shouldMarkAsDefault:
(onboardingState?.data.llmProviders ?? []).length === 0,
onboardingActions,
open,
onOpenChange,
onSuccess: () => {
onboardingActions.updateData({
llmProviders: [
...(onboardingState?.data.llmProviders ?? []),
providerName,
],
});
onboardingActions.setButtonActive(true);
},
};
// Handle custom provider
@@ -91,38 +111,36 @@ export function getOnboardingForm({
return <CustomModal {...sharedProps} />;
}
const providerProps = {
...sharedProps,
llmDescriptor,
};
switch (llmDescriptor.name) {
case LLMProviderName.OPENAI:
return <OpenAIModal {...providerProps} />;
return <OpenAIModal {...sharedProps} />;
case LLMProviderName.ANTHROPIC:
return <AnthropicModal {...providerProps} />;
return <AnthropicModal {...sharedProps} />;
case LLMProviderName.OLLAMA_CHAT:
return <OllamaModal {...providerProps} />;
return <OllamaModal {...sharedProps} />;
case LLMProviderName.AZURE:
return <AzureModal {...providerProps} />;
return <AzureModal {...sharedProps} />;
case LLMProviderName.BEDROCK:
return <BedrockModal {...providerProps} />;
return <BedrockModal {...sharedProps} />;
case LLMProviderName.VERTEX_AI:
return <VertexAIModal {...providerProps} />;
return <VertexAIModal {...sharedProps} />;
case LLMProviderName.OPENROUTER:
return <OpenRouterModal {...providerProps} />;
return <OpenRouterModal {...sharedProps} />;
case LLMProviderName.LM_STUDIO:
return <LMStudioForm {...providerProps} />;
return <LMStudioModal {...sharedProps} />;
case LLMProviderName.LITELLM_PROXY:
return <LiteLLMProxyModal {...providerProps} />;
return <LiteLLMProxyModal {...sharedProps} />;
case LLMProviderName.OPENAI_COMPATIBLE:
return <OpenAICompatibleModal {...sharedProps} />;
default:
return <CustomModal {...sharedProps} />;