mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-21 17:36:44 +00:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4fa3deedb9 | ||
|
|
593ccbcc66 | ||
|
|
9910487f37 | ||
|
|
d158639844 | ||
|
|
6d2bd97412 | ||
|
|
3d48b6a63e | ||
|
|
2a7b7c9187 | ||
|
|
c348d1855d |
@@ -11,6 +11,8 @@ require a valid SCIM bearer token.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import struct
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
@@ -22,6 +24,7 @@ from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -59,9 +62,25 @@ from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Group names reserved for system default groups (seeded by migration).
|
||||
_RESERVED_GROUP_NAMES = frozenset({"Admin", "Basic"})
|
||||
|
||||
# Namespace prefix for the seat-allocation advisory lock. Hashed together
|
||||
# with the tenant ID so the lock is scoped per-tenant (unrelated tenants
|
||||
# never block each other) and cannot collide with unrelated advisory locks.
|
||||
_SEAT_LOCK_NAMESPACE = "onyx_scim_seat_lock"
|
||||
|
||||
|
||||
def _seat_lock_id_for_tenant(tenant_id: str) -> int:
|
||||
"""Derive a stable 64-bit signed int lock id for this tenant's seat lock."""
|
||||
digest = hashlib.sha256(f"{_SEAT_LOCK_NAMESPACE}:{tenant_id}".encode()).digest()
|
||||
# pg_advisory_xact_lock takes a signed 8-byte int; unpack as such.
|
||||
return struct.unpack("q", digest[:8])[0]
|
||||
|
||||
|
||||
class ScimJSONResponse(JSONResponse):
|
||||
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
|
||||
@@ -200,12 +219,37 @@ def _apply_exclusions(
|
||||
|
||||
|
||||
def _check_seat_availability(dal: ScimDAL) -> str | None:
|
||||
"""Return an error message if seat limit is reached, else None."""
|
||||
"""Return an error message if seat limit is reached, else None.
|
||||
|
||||
Acquires a transaction-scoped advisory lock so that concurrent
|
||||
SCIM requests are serialized. IdPs like Okta send provisioning
|
||||
requests in parallel batches — without serialization the check is
|
||||
vulnerable to a TOCTOU race where N concurrent requests each see
|
||||
"seats available", all insert, and the tenant ends up over its
|
||||
seat limit.
|
||||
|
||||
The lock is held until the caller's next COMMIT or ROLLBACK, which
|
||||
means the seat count cannot change between the check here and the
|
||||
subsequent INSERT/UPDATE. Each call site in this module follows
|
||||
the pattern: _check_seat_availability → write → dal.commit()
|
||||
(which releases the lock for the next waiting request).
|
||||
"""
|
||||
check_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license", "check_seat_availability", None
|
||||
)
|
||||
if check_fn is None:
|
||||
return None
|
||||
|
||||
# Transaction-scoped advisory lock — released on dal.commit() / dal.rollback().
|
||||
# The lock id is derived from the tenant so unrelated tenants never block
|
||||
# each other, and from a namespace string so it cannot collide with
|
||||
# unrelated advisory locks elsewhere in the codebase.
|
||||
lock_id = _seat_lock_id_for_tenant(get_current_tenant_id())
|
||||
dal.session.execute(
|
||||
text("SELECT pg_advisory_xact_lock(:lock_id)"),
|
||||
{"lock_id": lock_id},
|
||||
)
|
||||
|
||||
result = check_fn(dal.session, seats_needed=1)
|
||||
if not result.available:
|
||||
return result.error_message or "Seat limit reached"
|
||||
|
||||
@@ -816,6 +816,29 @@ MAX_FILE_SIZE_BYTES = int(
|
||||
os.environ.get("MAX_FILE_SIZE_BYTES") or 2 * 1024 * 1024 * 1024
|
||||
) # 2GB in bytes
|
||||
|
||||
# Maximum embedded images allowed in a single file. PDFs (and other formats)
|
||||
# with thousands of embedded images can OOM the user-file-processing worker
|
||||
# because every image is decoded with PIL and then sent to the vision LLM.
|
||||
# Enforced both at upload time (rejects the file) and during extraction
|
||||
# (defense-in-depth: caps the number of images materialized).
|
||||
#
|
||||
# Clamped to >= 0; a negative env value would turn upload validation into
|
||||
# always-fail and extraction into always-stop, which is never desired. 0
|
||||
# disables image extraction entirely, which is a valid (if aggressive) setting.
|
||||
MAX_EMBEDDED_IMAGES_PER_FILE = max(
|
||||
0, int(os.environ.get("MAX_EMBEDDED_IMAGES_PER_FILE") or 500)
|
||||
)
|
||||
|
||||
# Maximum embedded images allowed across all files in a single upload batch.
|
||||
# Protects against the scenario where a user uploads many files that each
|
||||
# fall under MAX_EMBEDDED_IMAGES_PER_FILE but aggregate to enough work
|
||||
# (serial-ish celery fan-out plus per-image vision-LLM calls) to OOM the
|
||||
# worker under concurrency or run up surprise latency/cost. Also clamped
|
||||
# to >= 0.
|
||||
MAX_EMBEDDED_IMAGES_PER_UPLOAD = max(
|
||||
0, int(os.environ.get("MAX_EMBEDDED_IMAGES_PER_UPLOAD") or 1000)
|
||||
)
|
||||
|
||||
# Use document summary for contextual rag
|
||||
USE_DOCUMENT_SUMMARY = os.environ.get("USE_DOCUMENT_SUMMARY", "true").lower() == "true"
|
||||
# Use chunk summary for contextual rag
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import TypedDict
|
||||
|
||||
@@ -6,6 +7,14 @@ from pydantic import BaseModel
|
||||
from onyx.onyxbot.slack.models import ChannelType
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DirectThreadFetch:
|
||||
"""Request to fetch a Slack thread directly by channel and timestamp."""
|
||||
|
||||
channel_id: str
|
||||
thread_ts: str
|
||||
|
||||
|
||||
class ChannelMetadata(TypedDict):
|
||||
"""Type definition for cached channel metadata."""
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from onyx.configs.chat_configs import DOC_TIME_DECAY
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.context.search.federated.models import ChannelMetadata
|
||||
from onyx.context.search.federated.models import DirectThreadFetch
|
||||
from onyx.context.search.federated.models import SlackMessage
|
||||
from onyx.context.search.federated.slack_search_utils import ALL_CHANNEL_TYPES
|
||||
from onyx.context.search.federated.slack_search_utils import build_channel_query_filter
|
||||
@@ -49,7 +50,6 @@ from onyx.server.federated.models import FederatedConnectorDetail
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -58,7 +58,6 @@ HIGHLIGHT_END_CHAR = "\ue001"
|
||||
|
||||
CHANNEL_METADATA_CACHE_TTL = 60 * 60 * 24 # 24 hours
|
||||
USER_PROFILE_CACHE_TTL = 60 * 60 * 24 # 24 hours
|
||||
SLACK_THREAD_CONTEXT_WINDOW = 3 # Number of messages before matched message to include
|
||||
CHANNEL_METADATA_MAX_RETRIES = 3 # Maximum retry attempts for channel metadata fetching
|
||||
CHANNEL_METADATA_RETRY_DELAY = 1 # Initial retry delay in seconds (exponential backoff)
|
||||
|
||||
@@ -421,6 +420,94 @@ class SlackQueryResult(BaseModel):
|
||||
filtered_channels: list[str] # Channels filtered out during this query
|
||||
|
||||
|
||||
def _fetch_thread_from_url(
|
||||
thread_fetch: DirectThreadFetch,
|
||||
access_token: str,
|
||||
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
|
||||
) -> SlackQueryResult:
|
||||
"""Fetch a thread directly from a Slack URL via conversations.replies."""
|
||||
channel_id = thread_fetch.channel_id
|
||||
thread_ts = thread_fetch.thread_ts
|
||||
|
||||
slack_client = WebClient(token=access_token)
|
||||
try:
|
||||
response = slack_client.conversations_replies(
|
||||
channel=channel_id,
|
||||
ts=thread_ts,
|
||||
)
|
||||
response.validate()
|
||||
messages: list[dict[str, Any]] = response.get("messages", [])
|
||||
except SlackApiError as e:
|
||||
logger.warning(
|
||||
f"Failed to fetch thread from URL (channel={channel_id}, ts={thread_ts}): {e}"
|
||||
)
|
||||
return SlackQueryResult(messages=[], filtered_channels=[])
|
||||
|
||||
if not messages:
|
||||
logger.warning(
|
||||
f"No messages found for URL override (channel={channel_id}, ts={thread_ts})"
|
||||
)
|
||||
return SlackQueryResult(messages=[], filtered_channels=[])
|
||||
|
||||
# Build thread text from all messages
|
||||
thread_text = _build_thread_text(messages, access_token, None, slack_client)
|
||||
|
||||
# Get channel name from metadata cache or API
|
||||
channel_name = "unknown"
|
||||
if channel_metadata_dict and channel_id in channel_metadata_dict:
|
||||
channel_name = channel_metadata_dict[channel_id].get("name", "unknown")
|
||||
else:
|
||||
try:
|
||||
ch_response = slack_client.conversations_info(channel=channel_id)
|
||||
ch_response.validate()
|
||||
channel_info: dict[str, Any] = ch_response.get("channel", {})
|
||||
channel_name = channel_info.get("name", "unknown")
|
||||
except SlackApiError:
|
||||
pass
|
||||
|
||||
# Build the SlackMessage
|
||||
parent_msg = messages[0]
|
||||
message_ts = parent_msg.get("ts", thread_ts)
|
||||
username = parent_msg.get("user", "unknown_user")
|
||||
parent_text = parent_msg.get("text", "")
|
||||
snippet = (
|
||||
parent_text[:50].rstrip() + "..." if len(parent_text) > 50 else parent_text
|
||||
).replace("\n", " ")
|
||||
|
||||
doc_time = datetime.fromtimestamp(float(message_ts))
|
||||
decay_factor = DOC_TIME_DECAY
|
||||
doc_age_years = (datetime.now() - doc_time).total_seconds() / (365 * 24 * 60 * 60)
|
||||
recency_bias = max(1 / (1 + decay_factor * doc_age_years), 0.75)
|
||||
|
||||
permalink = (
|
||||
f"https://slack.com/archives/{channel_id}/p{message_ts.replace('.', '')}"
|
||||
)
|
||||
|
||||
slack_message = SlackMessage(
|
||||
document_id=f"{channel_id}_{message_ts}",
|
||||
channel_id=channel_id,
|
||||
message_id=message_ts,
|
||||
thread_id=None, # Prevent double-enrichment in thread context fetch
|
||||
link=permalink,
|
||||
metadata={
|
||||
"channel": channel_name,
|
||||
"time": doc_time.isoformat(),
|
||||
},
|
||||
timestamp=doc_time,
|
||||
recency_bias=recency_bias,
|
||||
semantic_identifier=f"{username} in #{channel_name}: {snippet}",
|
||||
text=thread_text,
|
||||
highlighted_texts=set(),
|
||||
slack_score=100000.0, # High priority — user explicitly asked for this thread
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"URL override: fetched thread from channel={channel_id}, ts={thread_ts}, {len(messages)} messages"
|
||||
)
|
||||
|
||||
return SlackQueryResult(messages=[slack_message], filtered_channels=[])
|
||||
|
||||
|
||||
def query_slack(
|
||||
query_string: str,
|
||||
access_token: str,
|
||||
@@ -432,7 +519,6 @@ def query_slack(
|
||||
available_channels: list[str] | None = None,
|
||||
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
|
||||
) -> SlackQueryResult:
|
||||
|
||||
# Check if query has channel override (user specified channels in query)
|
||||
has_channel_override = query_string.startswith("__CHANNEL_OVERRIDE__")
|
||||
|
||||
@@ -662,7 +748,6 @@ def _fetch_thread_context(
|
||||
"""
|
||||
channel_id = message.channel_id
|
||||
thread_id = message.thread_id
|
||||
message_id = message.message_id
|
||||
|
||||
# If not a thread, return original text as success
|
||||
if thread_id is None:
|
||||
@@ -695,62 +780,37 @@ def _fetch_thread_context(
|
||||
if len(messages) <= 1:
|
||||
return ThreadContextResult.success(message.text)
|
||||
|
||||
# Build thread text from thread starter + context window around matched message
|
||||
thread_text = _build_thread_text(
|
||||
messages, message_id, thread_id, access_token, team_id, slack_client
|
||||
)
|
||||
# Build thread text from thread starter + all replies
|
||||
thread_text = _build_thread_text(messages, access_token, team_id, slack_client)
|
||||
return ThreadContextResult.success(thread_text)
|
||||
|
||||
|
||||
def _build_thread_text(
|
||||
messages: list[dict[str, Any]],
|
||||
message_id: str,
|
||||
thread_id: str,
|
||||
access_token: str,
|
||||
team_id: str | None,
|
||||
slack_client: WebClient,
|
||||
) -> str:
|
||||
"""Build the thread text from messages."""
|
||||
"""Build thread text including all replies.
|
||||
|
||||
Includes the thread parent message followed by all replies in order.
|
||||
"""
|
||||
msg_text = messages[0].get("text", "")
|
||||
msg_sender = messages[0].get("user", "")
|
||||
thread_text = f"<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# All messages after index 0 are replies
|
||||
replies = messages[1:]
|
||||
if not replies:
|
||||
return thread_text
|
||||
|
||||
logger.debug(f"Thread {messages[0].get('ts')}: {len(replies)} replies included")
|
||||
thread_text += "\n\nReplies:"
|
||||
if thread_id == message_id:
|
||||
message_id_idx = 0
|
||||
else:
|
||||
message_id_idx = next(
|
||||
(i for i, msg in enumerate(messages) if msg.get("ts") == message_id), 0
|
||||
)
|
||||
if not message_id_idx:
|
||||
return thread_text
|
||||
|
||||
start_idx = max(1, message_id_idx - SLACK_THREAD_CONTEXT_WINDOW)
|
||||
|
||||
if start_idx > 1:
|
||||
thread_text += "\n..."
|
||||
|
||||
for i in range(start_idx, message_id_idx):
|
||||
msg_text = messages[i].get("text", "")
|
||||
msg_sender = messages[i].get("user", "")
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
|
||||
msg_text = messages[message_id_idx].get("text", "")
|
||||
msg_sender = messages[message_id_idx].get("user", "")
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# Add following replies
|
||||
len_replies = 0
|
||||
for msg in messages[message_id_idx + 1 :]:
|
||||
for msg in replies:
|
||||
msg_text = msg.get("text", "")
|
||||
msg_sender = msg.get("user", "")
|
||||
reply = f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
thread_text += reply
|
||||
|
||||
len_replies += len(reply)
|
||||
if len_replies >= DOC_EMBEDDING_CONTEXT_SIZE * 4:
|
||||
thread_text += "\n..."
|
||||
break
|
||||
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
|
||||
|
||||
# Replace user IDs with names using cached lookups
|
||||
userids: set[str] = set(re.findall(r"<@([A-Z0-9]+)>", thread_text))
|
||||
@@ -976,7 +1036,16 @@ def slack_retrieval(
|
||||
|
||||
# Query slack with entity filtering
|
||||
llm = get_default_llm()
|
||||
query_strings = build_slack_queries(query, llm, entities, available_channels)
|
||||
query_items = build_slack_queries(query, llm, entities, available_channels)
|
||||
|
||||
# Partition into direct thread fetches and search query strings
|
||||
direct_fetches: list[DirectThreadFetch] = []
|
||||
query_strings: list[str] = []
|
||||
for item in query_items:
|
||||
if isinstance(item, DirectThreadFetch):
|
||||
direct_fetches.append(item)
|
||||
else:
|
||||
query_strings.append(item)
|
||||
|
||||
# Determine filtering based on entities OR context (bot)
|
||||
include_dm = False
|
||||
@@ -993,8 +1062,16 @@ def slack_retrieval(
|
||||
f"Private channel context: will only allow messages from {allowed_private_channel} + public channels"
|
||||
)
|
||||
|
||||
# Build search tasks
|
||||
search_tasks = [
|
||||
# Build search tasks — direct thread fetches + keyword searches
|
||||
search_tasks: list[tuple] = [
|
||||
(
|
||||
_fetch_thread_from_url,
|
||||
(fetch, access_token, channel_metadata_dict),
|
||||
)
|
||||
for fetch in direct_fetches
|
||||
]
|
||||
|
||||
search_tasks.extend(
|
||||
(
|
||||
query_slack,
|
||||
(
|
||||
@@ -1010,7 +1087,7 @@ def slack_retrieval(
|
||||
),
|
||||
)
|
||||
for query_string in query_strings
|
||||
]
|
||||
)
|
||||
|
||||
# If include_dm is True AND we're not already searching all channels,
|
||||
# add additional searches without channel filters.
|
||||
|
||||
@@ -10,6 +10,7 @@ from pydantic import ValidationError
|
||||
|
||||
from onyx.configs.app_configs import MAX_SLACK_QUERY_EXPANSIONS
|
||||
from onyx.context.search.federated.models import ChannelMetadata
|
||||
from onyx.context.search.federated.models import DirectThreadFetch
|
||||
from onyx.context.search.models import ChunkIndexRequest
|
||||
from onyx.federated_connectors.slack.models import SlackEntities
|
||||
from onyx.llm.interfaces import LLM
|
||||
@@ -638,12 +639,38 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
return [query_text]
|
||||
|
||||
|
||||
SLACK_URL_PATTERN = re.compile(
|
||||
r"https?://[a-z0-9-]+\.slack\.com/archives/([A-Z0-9]+)/p(\d{16})"
|
||||
)
|
||||
|
||||
|
||||
def extract_slack_message_urls(
|
||||
query_text: str,
|
||||
) -> list[tuple[str, str]]:
|
||||
"""Extract Slack message URLs from query text.
|
||||
|
||||
Parses URLs like:
|
||||
https://onyx-company.slack.com/archives/C097NBWMY8Y/p1775491616524769
|
||||
|
||||
Returns list of (channel_id, thread_ts) tuples.
|
||||
The 16-digit timestamp is converted to Slack ts format (with dot).
|
||||
"""
|
||||
results = []
|
||||
for match in SLACK_URL_PATTERN.finditer(query_text):
|
||||
channel_id = match.group(1)
|
||||
raw_ts = match.group(2)
|
||||
# Convert p1775491616524769 -> 1775491616.524769
|
||||
thread_ts = f"{raw_ts[:10]}.{raw_ts[10:]}"
|
||||
results.append((channel_id, thread_ts))
|
||||
return results
|
||||
|
||||
|
||||
def build_slack_queries(
|
||||
query: ChunkIndexRequest,
|
||||
llm: LLM,
|
||||
entities: dict[str, Any] | None = None,
|
||||
available_channels: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
) -> list[str | DirectThreadFetch]:
|
||||
"""Build Slack query strings with date filtering and query expansion."""
|
||||
default_search_days = 30
|
||||
if entities:
|
||||
@@ -668,6 +695,15 @@ def build_slack_queries(
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_back)
|
||||
time_filter = f" after:{cutoff_date.strftime('%Y-%m-%d')}"
|
||||
|
||||
# Check for Slack message URLs — if found, add direct fetch requests
|
||||
url_fetches: list[DirectThreadFetch] = []
|
||||
slack_urls = extract_slack_message_urls(query.query)
|
||||
for channel_id, thread_ts in slack_urls:
|
||||
url_fetches.append(
|
||||
DirectThreadFetch(channel_id=channel_id, thread_ts=thread_ts)
|
||||
)
|
||||
logger.info(f"Detected Slack URL: channel={channel_id}, ts={thread_ts}")
|
||||
|
||||
# ALWAYS extract channel references from the query (not just for recency queries)
|
||||
channel_references = extract_channel_references_from_query(query.query)
|
||||
|
||||
@@ -684,7 +720,9 @@ def build_slack_queries(
|
||||
|
||||
# If valid channels detected, use ONLY those channels with NO keywords
|
||||
# Return query with ONLY time filter + channel filter (no keywords)
|
||||
return [build_channel_override_query(channel_references, time_filter)]
|
||||
return url_fetches + [
|
||||
build_channel_override_query(channel_references, time_filter)
|
||||
]
|
||||
except ValueError as e:
|
||||
# If validation fails, log the error and continue with normal flow
|
||||
logger.warning(f"Channel reference validation failed: {e}")
|
||||
@@ -702,7 +740,8 @@ def build_slack_queries(
|
||||
rephrased_queries = expand_query_with_llm(query.query, llm)
|
||||
|
||||
# Build final query strings with time filters
|
||||
return [
|
||||
search_queries = [
|
||||
rephrased_query.strip() + time_filter
|
||||
for rephrased_query in rephrased_queries[:MAX_SLACK_QUERY_EXPANSIONS]
|
||||
]
|
||||
return url_fetches + search_queries
|
||||
|
||||
@@ -21,6 +21,7 @@ import chardet
|
||||
import openpyxl
|
||||
from PIL import Image
|
||||
|
||||
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
from onyx.configs.constants import ONYX_METADATA_FILENAME
|
||||
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
@@ -188,6 +189,56 @@ def read_text_file(
|
||||
return file_content_raw, metadata
|
||||
|
||||
|
||||
def count_pdf_embedded_images(file: IO[Any], cap: int) -> int:
|
||||
"""Return the number of embedded images in a PDF, short-circuiting at cap+1.
|
||||
|
||||
Used to reject PDFs whose image count would OOM the user-file-processing
|
||||
worker during indexing. Returns a value > cap as a sentinel once the count
|
||||
exceeds the cap, so callers do not iterate thousands of image objects just
|
||||
to report a number. Returns 0 if the PDF cannot be parsed.
|
||||
|
||||
Owner-password-only PDFs (permission restrictions but no open password) are
|
||||
counted normally — they decrypt with an empty string. Truly password-locked
|
||||
PDFs are skipped (return 0) since we can't inspect them; the caller should
|
||||
ensure the password-protected check runs first.
|
||||
|
||||
Always restores the file pointer to its original position before returning.
|
||||
"""
|
||||
from pypdf import PdfReader
|
||||
|
||||
try:
|
||||
start_pos = file.tell()
|
||||
except Exception:
|
||||
start_pos = None
|
||||
try:
|
||||
if start_pos is not None:
|
||||
file.seek(0)
|
||||
reader = PdfReader(file)
|
||||
if reader.is_encrypted:
|
||||
# Try empty password first (owner-password-only PDFs); give up if that fails.
|
||||
try:
|
||||
if reader.decrypt("") == 0:
|
||||
return 0
|
||||
except Exception:
|
||||
return 0
|
||||
count = 0
|
||||
for page in reader.pages:
|
||||
for _ in page.images:
|
||||
count += 1
|
||||
if count > cap:
|
||||
return count
|
||||
return count
|
||||
except Exception:
|
||||
logger.warning("Failed to count embedded images in PDF", exc_info=True)
|
||||
return 0
|
||||
finally:
|
||||
if start_pos is not None:
|
||||
try:
|
||||
file.seek(start_pos)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def pdf_to_text(file: IO[Any], pdf_pass: str | None = None) -> str:
|
||||
"""
|
||||
Extract text from a PDF. For embedded images, a more complex approach is needed.
|
||||
@@ -251,8 +302,27 @@ def read_pdf_file(
|
||||
)
|
||||
|
||||
if extract_images:
|
||||
image_cap = MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
images_processed = 0
|
||||
cap_reached = False
|
||||
for page_num, page in enumerate(pdf_reader.pages):
|
||||
if cap_reached:
|
||||
break
|
||||
for image_file_object in page.images:
|
||||
if images_processed >= image_cap:
|
||||
# Defense-in-depth backstop. Upload-time validation
|
||||
# should have rejected files exceeding the cap, but
|
||||
# we also break here so a single oversized file can
|
||||
# never pin a worker.
|
||||
logger.warning(
|
||||
"PDF embedded image cap reached (%d). "
|
||||
"Skipping remaining images on page %d and beyond.",
|
||||
image_cap,
|
||||
page_num + 1,
|
||||
)
|
||||
cap_reached = True
|
||||
break
|
||||
|
||||
image = Image.open(io.BytesIO(image_file_object.data))
|
||||
img_byte_arr = io.BytesIO()
|
||||
image.save(img_byte_arr, format=image.format)
|
||||
@@ -265,6 +335,7 @@ def read_pdf_file(
|
||||
image_callback(img_bytes, image_name)
|
||||
else:
|
||||
extracted_images.append((img_bytes, image_name))
|
||||
images_processed += 1
|
||||
|
||||
return text, metadata, extracted_images
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ class LlmProviderNames(str, Enum):
|
||||
MISTRAL = "mistral"
|
||||
LITELLM_PROXY = "litellm_proxy"
|
||||
BIFROST = "bifrost"
|
||||
OPENAI_COMPATIBLE = "openai_compatible"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Needed so things like:
|
||||
@@ -46,6 +47,7 @@ WELL_KNOWN_PROVIDER_NAMES = [
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.LITELLM_PROXY,
|
||||
LlmProviderNames.BIFROST,
|
||||
LlmProviderNames.OPENAI_COMPATIBLE,
|
||||
]
|
||||
|
||||
|
||||
@@ -64,6 +66,7 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
|
||||
LlmProviderNames.LM_STUDIO: "LM Studio",
|
||||
LlmProviderNames.LITELLM_PROXY: "LiteLLM Proxy",
|
||||
LlmProviderNames.BIFROST: "Bifrost",
|
||||
LlmProviderNames.OPENAI_COMPATIBLE: "OpenAI Compatible",
|
||||
"groq": "Groq",
|
||||
"anyscale": "Anyscale",
|
||||
"deepseek": "DeepSeek",
|
||||
@@ -116,6 +119,7 @@ AGGREGATOR_PROVIDERS: set[str] = {
|
||||
LlmProviderNames.AZURE,
|
||||
LlmProviderNames.LITELLM_PROXY,
|
||||
LlmProviderNames.BIFROST,
|
||||
LlmProviderNames.OPENAI_COMPATIBLE,
|
||||
}
|
||||
|
||||
# Model family name mappings for display name generation
|
||||
|
||||
@@ -175,6 +175,28 @@ def _strip_tool_content_from_messages(
|
||||
return result
|
||||
|
||||
|
||||
def _fix_tool_user_message_ordering(
|
||||
messages: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Insert a synthetic assistant message between tool and user messages.
|
||||
|
||||
Some models (e.g. Mistral on Azure) require strict message ordering where
|
||||
a user message cannot immediately follow a tool message. This function
|
||||
inserts a minimal assistant message to bridge the gap.
|
||||
"""
|
||||
if len(messages) < 2:
|
||||
return messages
|
||||
|
||||
result: list[dict[str, Any]] = [messages[0]]
|
||||
for msg in messages[1:]:
|
||||
prev_role = result[-1].get("role")
|
||||
curr_role = msg.get("role")
|
||||
if prev_role == "tool" and curr_role == "user":
|
||||
result.append({"role": "assistant", "content": "Noted. Continuing."})
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
|
||||
def _messages_contain_tool_content(messages: list[dict[str, Any]]) -> bool:
|
||||
"""Check if any messages contain tool-related content blocks."""
|
||||
for msg in messages:
|
||||
@@ -305,12 +327,19 @@ class LitellmLLM(LLM):
|
||||
):
|
||||
model_kwargs[VERTEX_LOCATION_KWARG] = "global"
|
||||
|
||||
# Bifrost: OpenAI-compatible proxy that expects model names in
|
||||
# provider/model format (e.g. "anthropic/claude-sonnet-4-6").
|
||||
# We route through LiteLLM's openai provider with the Bifrost base URL,
|
||||
# and ensure /v1 is appended.
|
||||
if model_provider == LlmProviderNames.BIFROST:
|
||||
# Bifrost and OpenAI-compatible: OpenAI-compatible proxies that send
|
||||
# model names directly to the endpoint. We route through LiteLLM's
|
||||
# openai provider with the server's base URL, and ensure /v1 is appended.
|
||||
if model_provider in (
|
||||
LlmProviderNames.BIFROST,
|
||||
LlmProviderNames.OPENAI_COMPATIBLE,
|
||||
):
|
||||
self._custom_llm_provider = "openai"
|
||||
# LiteLLM's OpenAI client requires an api_key to be set.
|
||||
# Many OpenAI-compatible servers don't need auth, so supply a
|
||||
# placeholder to prevent LiteLLM from raising AuthenticationError.
|
||||
if not self._api_key:
|
||||
model_kwargs.setdefault("api_key", "not-needed")
|
||||
if self._api_base is not None:
|
||||
base = self._api_base.rstrip("/")
|
||||
self._api_base = base if base.endswith("/v1") else f"{base}/v1"
|
||||
@@ -427,17 +456,20 @@ class LitellmLLM(LLM):
|
||||
optional_kwargs: dict[str, Any] = {}
|
||||
|
||||
# Model name
|
||||
is_bifrost = self._model_provider == LlmProviderNames.BIFROST
|
||||
is_openai_compatible_proxy = self._model_provider in (
|
||||
LlmProviderNames.BIFROST,
|
||||
LlmProviderNames.OPENAI_COMPATIBLE,
|
||||
)
|
||||
model_provider = (
|
||||
f"{self.config.model_provider}/responses"
|
||||
if is_openai_model # Uses litellm's completions -> responses bridge
|
||||
else self.config.model_provider
|
||||
)
|
||||
if is_bifrost:
|
||||
# Bifrost expects model names in provider/model format
|
||||
# (e.g. "anthropic/claude-sonnet-4-6") sent directly to its
|
||||
# OpenAI-compatible endpoint. We use custom_llm_provider="openai"
|
||||
# so LiteLLM doesn't try to route based on the provider prefix.
|
||||
if is_openai_compatible_proxy:
|
||||
# OpenAI-compatible proxies (Bifrost, generic OpenAI-compatible
|
||||
# servers) expect model names sent directly to their endpoint.
|
||||
# We use custom_llm_provider="openai" so LiteLLM doesn't try
|
||||
# to route based on the provider prefix.
|
||||
model = self.config.deployment_name or self.config.model_name
|
||||
else:
|
||||
model = f"{model_provider}/{self.config.deployment_name or self.config.model_name}"
|
||||
@@ -528,7 +560,10 @@ class LitellmLLM(LLM):
|
||||
if structured_response_format:
|
||||
optional_kwargs["response_format"] = structured_response_format
|
||||
|
||||
if not (is_claude_model or is_ollama or is_mistral) or is_bifrost:
|
||||
if (
|
||||
not (is_claude_model or is_ollama or is_mistral)
|
||||
or is_openai_compatible_proxy
|
||||
):
|
||||
# Litellm bug: tool_choice is dropped silently if not specified here for OpenAI
|
||||
# However, this param breaks Anthropic and Mistral models,
|
||||
# so it must be conditionally included unless the request is
|
||||
@@ -576,6 +611,18 @@ class LitellmLLM(LLM):
|
||||
):
|
||||
messages = _strip_tool_content_from_messages(messages)
|
||||
|
||||
# Some models (e.g. Mistral) reject a user message
|
||||
# immediately after a tool message. Insert a synthetic
|
||||
# assistant bridge message to satisfy the ordering
|
||||
# constraint. Check both the provider and the deployment/
|
||||
# model name to catch Mistral hosted on Azure.
|
||||
model_or_deployment = (
|
||||
self._deployment_name or self._model_version or ""
|
||||
).lower()
|
||||
is_mistral_model = is_mistral or "mistral" in model_or_deployment
|
||||
if is_mistral_model:
|
||||
messages = _fix_tool_user_message_ordering(messages)
|
||||
|
||||
# Only pass tool_choice when tools are present — some providers (e.g. Fireworks)
|
||||
# reject requests where tool_choice is explicitly null.
|
||||
if tools and tool_choice is not None:
|
||||
|
||||
@@ -15,6 +15,8 @@ LITELLM_PROXY_PROVIDER_NAME = "litellm_proxy"
|
||||
|
||||
BIFROST_PROVIDER_NAME = "bifrost"
|
||||
|
||||
OPENAI_COMPATIBLE_PROVIDER_NAME = "openai_compatible"
|
||||
|
||||
# Providers that use optional Bearer auth from custom_config
|
||||
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING: dict[str, str] = {
|
||||
LlmProviderNames.OLLAMA_CHAT: OLLAMA_API_KEY_CONFIG_KEY,
|
||||
|
||||
@@ -19,6 +19,7 @@ from onyx.llm.well_known_providers.constants import BIFROST_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import LITELLM_PROXY_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import LM_STUDIO_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OLLAMA_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENAI_COMPATIBLE_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENROUTER_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import VERTEXAI_PROVIDER_NAME
|
||||
@@ -51,6 +52,7 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
|
||||
OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API
|
||||
LITELLM_PROXY_PROVIDER_NAME: [], # Dynamic - fetched from LiteLLM proxy API
|
||||
BIFROST_PROVIDER_NAME: [], # Dynamic - fetched from Bifrost API
|
||||
OPENAI_COMPATIBLE_PROVIDER_NAME: [], # Dynamic - fetched from OpenAI-compatible API
|
||||
}
|
||||
|
||||
|
||||
@@ -336,6 +338,7 @@ def get_provider_display_name(provider_name: str) -> str:
|
||||
VERTEXAI_PROVIDER_NAME: "Google Vertex AI",
|
||||
OPENROUTER_PROVIDER_NAME: "OpenRouter",
|
||||
LITELLM_PROXY_PROVIDER_NAME: "LiteLLM Proxy",
|
||||
OPENAI_COMPATIBLE_PROVIDER_NAME: "OpenAI Compatible",
|
||||
}
|
||||
|
||||
if provider_name in _ONYX_PROVIDER_DISPLAY_NAMES:
|
||||
|
||||
@@ -40,6 +40,8 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.background.celery.versioned_apps.client import app as celery_app
|
||||
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_UPLOAD
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -50,6 +52,9 @@ from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.interfaces import DocumentMetadata
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_processing.extract_file_text import count_pdf_embedded_images
|
||||
from onyx.server.features.build.configs import USER_LIBRARY_MAX_FILE_SIZE_BYTES
|
||||
from onyx.server.features.build.configs import USER_LIBRARY_MAX_FILES_PER_UPLOAD
|
||||
from onyx.server.features.build.configs import USER_LIBRARY_MAX_TOTAL_SIZE_BYTES
|
||||
@@ -127,6 +132,49 @@ class DeleteFileResponse(BaseModel):
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _looks_like_pdf(filename: str, content_type: str | None) -> bool:
|
||||
"""True if either the filename or the content-type indicates a PDF.
|
||||
|
||||
Client-supplied ``content_type`` can be spoofed (e.g. a PDF uploaded with
|
||||
``Content-Type: application/octet-stream``), so we also fall back to
|
||||
extension-based detection via ``mimetypes.guess_type`` on the filename.
|
||||
"""
|
||||
if content_type == "application/pdf":
|
||||
return True
|
||||
guessed, _ = mimetypes.guess_type(filename)
|
||||
return guessed == "application/pdf"
|
||||
|
||||
|
||||
def _check_pdf_image_caps(
|
||||
filename: str, content: bytes, content_type: str | None, batch_total: int
|
||||
) -> int:
|
||||
"""Enforce per-file and per-batch embedded-image caps for PDFs.
|
||||
|
||||
Returns the number of embedded images in this file (0 for non-PDFs) so
|
||||
callers can update their running batch total. Raises OnyxError(INVALID_INPUT)
|
||||
if either cap is exceeded.
|
||||
"""
|
||||
if not _looks_like_pdf(filename, content_type):
|
||||
return 0
|
||||
file_cap = MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
batch_cap = MAX_EMBEDDED_IMAGES_PER_UPLOAD
|
||||
# Short-circuit at the larger cap so we get a useful count for both checks.
|
||||
count = count_pdf_embedded_images(BytesIO(content), max(file_cap, batch_cap))
|
||||
if count > file_cap:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
f"PDF '{filename}' contains too many embedded images "
|
||||
f"(more than {file_cap}). Try splitting the document into smaller files.",
|
||||
)
|
||||
if batch_total + count > batch_cap:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
f"Upload would exceed the {batch_cap}-image limit across all "
|
||||
f"files in this batch. Try uploading fewer image-heavy files at once.",
|
||||
)
|
||||
return count
|
||||
|
||||
|
||||
def _sanitize_path(path: str) -> str:
|
||||
"""Sanitize a file path, removing traversal attempts and normalizing.
|
||||
|
||||
@@ -355,6 +403,7 @@ async def upload_files(
|
||||
|
||||
uploaded_entries: list[LibraryEntryResponse] = []
|
||||
total_size = 0
|
||||
batch_image_total = 0
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Sanitize the base path
|
||||
@@ -374,6 +423,14 @@ async def upload_files(
|
||||
detail=f"File '{file.filename}' exceeds maximum size of {USER_LIBRARY_MAX_FILE_SIZE_BYTES // (1024 * 1024)}MB",
|
||||
)
|
||||
|
||||
# Reject PDFs with an unreasonable per-file or per-batch image count
|
||||
batch_image_total += _check_pdf_image_caps(
|
||||
filename=file.filename or "unnamed",
|
||||
content=content,
|
||||
content_type=file.content_type,
|
||||
batch_total=batch_image_total,
|
||||
)
|
||||
|
||||
# Validate cumulative storage (existing + this upload batch)
|
||||
total_size += file_size
|
||||
if existing_usage + total_size > USER_LIBRARY_MAX_TOTAL_SIZE_BYTES:
|
||||
@@ -472,6 +529,7 @@ async def upload_zip(
|
||||
|
||||
uploaded_entries: list[LibraryEntryResponse] = []
|
||||
total_size = 0
|
||||
batch_image_total = 0
|
||||
|
||||
# Extract zip contents into a subfolder named after the zip file
|
||||
zip_name = api_sanitize_filename(file.filename or "upload")
|
||||
@@ -510,6 +568,36 @@ async def upload_zip(
|
||||
logger.warning(f"Skipping '{zip_info.filename}' - exceeds max size")
|
||||
continue
|
||||
|
||||
# Skip PDFs that would trip the per-file or per-batch image
|
||||
# cap (would OOM the user-file-processing worker). Matches
|
||||
# /upload behavior but uses skip-and-warn to stay consistent
|
||||
# with the zip path's handling of oversized files.
|
||||
zip_file_name = zip_info.filename.split("/")[-1]
|
||||
zip_content_type, _ = mimetypes.guess_type(zip_file_name)
|
||||
if zip_content_type == "application/pdf":
|
||||
image_count = count_pdf_embedded_images(
|
||||
BytesIO(file_content),
|
||||
max(
|
||||
MAX_EMBEDDED_IMAGES_PER_FILE,
|
||||
MAX_EMBEDDED_IMAGES_PER_UPLOAD,
|
||||
),
|
||||
)
|
||||
if image_count > MAX_EMBEDDED_IMAGES_PER_FILE:
|
||||
logger.warning(
|
||||
"Skipping '%s' - exceeds %d per-file embedded-image cap",
|
||||
zip_info.filename,
|
||||
MAX_EMBEDDED_IMAGES_PER_FILE,
|
||||
)
|
||||
continue
|
||||
if batch_image_total + image_count > MAX_EMBEDDED_IMAGES_PER_UPLOAD:
|
||||
logger.warning(
|
||||
"Skipping '%s' - would exceed %d per-batch embedded-image cap",
|
||||
zip_info.filename,
|
||||
MAX_EMBEDDED_IMAGES_PER_UPLOAD,
|
||||
)
|
||||
continue
|
||||
batch_image_total += image_count
|
||||
|
||||
total_size += file_size
|
||||
|
||||
# Validate cumulative storage
|
||||
|
||||
@@ -10,9 +10,12 @@ from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
|
||||
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
from onyx.configs.app_configs import MAX_EMBEDDED_IMAGES_PER_UPLOAD
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_BYTES
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.file_processing.extract_file_text import count_pdf_embedded_images
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
@@ -192,6 +195,11 @@ def categorize_uploaded_files(
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Failed to get current tenant ID: {str(e)}")
|
||||
|
||||
# Running total of embedded images across PDFs in this batch. Once the
|
||||
# aggregate cap is reached, subsequent PDFs in the same upload are
|
||||
# rejected even if they'd individually fit under MAX_EMBEDDED_IMAGES_PER_FILE.
|
||||
batch_image_total = 0
|
||||
|
||||
for upload in files:
|
||||
try:
|
||||
filename = get_safe_filename(upload)
|
||||
@@ -253,6 +261,47 @@ def categorize_uploaded_files(
|
||||
)
|
||||
continue
|
||||
|
||||
# Reject PDFs with an unreasonable number of embedded images
|
||||
# (either per-file or accumulated across this upload batch).
|
||||
# A PDF with thousands of embedded images can OOM the
|
||||
# user-file-processing celery worker because every image is
|
||||
# decoded with PIL and then sent to the vision LLM.
|
||||
if extension == ".pdf":
|
||||
file_cap = MAX_EMBEDDED_IMAGES_PER_FILE
|
||||
batch_cap = MAX_EMBEDDED_IMAGES_PER_UPLOAD
|
||||
# Use the larger of the two caps as the short-circuit
|
||||
# threshold so we get a useful count for both checks.
|
||||
# count_pdf_embedded_images restores the stream position.
|
||||
count = count_pdf_embedded_images(
|
||||
upload.file, max(file_cap, batch_cap)
|
||||
)
|
||||
if count > file_cap:
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=(
|
||||
f"PDF contains too many embedded images "
|
||||
f"(more than {file_cap}). Try splitting "
|
||||
f"the document into smaller files."
|
||||
),
|
||||
)
|
||||
)
|
||||
continue
|
||||
if batch_image_total + count > batch_cap:
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=(
|
||||
f"Upload would exceed the "
|
||||
f"{batch_cap}-image limit across all "
|
||||
f"files in this batch. Try uploading "
|
||||
f"fewer image-heavy files at once."
|
||||
),
|
||||
)
|
||||
)
|
||||
continue
|
||||
batch_image_total += count
|
||||
|
||||
text_content = extract_file_text(
|
||||
file=upload.file,
|
||||
file_name=filename,
|
||||
|
||||
@@ -74,6 +74,8 @@ from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OllamaModelDetails
|
||||
from onyx.server.manage.llm.models import OllamaModelsRequest
|
||||
from onyx.server.manage.llm.models import OpenAICompatibleFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OpenAICompatibleModelsRequest
|
||||
from onyx.server.manage.llm.models import OpenRouterFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OpenRouterModelDetails
|
||||
from onyx.server.manage.llm.models import OpenRouterModelsRequest
|
||||
@@ -1575,3 +1577,95 @@ def _get_bifrost_models_response(api_base: str, api_key: str | None = None) -> d
|
||||
source_name="Bifrost",
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
|
||||
@admin_router.post("/openai-compatible/available-models")
|
||||
def get_openai_compatible_server_available_models(
|
||||
request: OpenAICompatibleModelsRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[OpenAICompatibleFinalModelResponse]:
|
||||
"""Fetch available models from a generic OpenAI-compatible /v1/models endpoint."""
|
||||
response_json = _get_openai_compatible_server_response(
|
||||
api_base=request.api_base, api_key=request.api_key
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
if not isinstance(models, list) or len(models) == 0:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No models found from your OpenAI-compatible endpoint",
|
||||
)
|
||||
|
||||
results: list[OpenAICompatibleFinalModelResponse] = []
|
||||
for model in models:
|
||||
try:
|
||||
model_id = model.get("id", "")
|
||||
model_name = model.get("name", model_id)
|
||||
|
||||
if not model_id:
|
||||
continue
|
||||
|
||||
# Skip embedding models
|
||||
if is_embedding_model(model_id):
|
||||
continue
|
||||
|
||||
results.append(
|
||||
OpenAICompatibleFinalModelResponse(
|
||||
name=model_id,
|
||||
display_name=model_name,
|
||||
max_input_tokens=model.get("context_length"),
|
||||
supports_image_input=infer_vision_support(model_id),
|
||||
supports_reasoning=is_reasoning_model(model_id, model_name),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to parse OpenAI-compatible model entry",
|
||||
extra={"error": str(e), "item": str(model)[:1000]},
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No compatible models found from OpenAI-compatible endpoint",
|
||||
)
|
||||
|
||||
sorted_results = sorted(results, key=lambda m: m.name.lower())
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="OpenAI Compatible",
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
|
||||
def _get_openai_compatible_server_response(
|
||||
api_base: str, api_key: str | None = None
|
||||
) -> dict:
|
||||
"""Perform GET to an OpenAI-compatible /v1/models and return parsed JSON."""
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
# Ensure we hit /v1/models
|
||||
if cleaned_api_base.endswith("/v1"):
|
||||
url = f"{cleaned_api_base}/models"
|
||||
else:
|
||||
url = f"{cleaned_api_base}/v1/models"
|
||||
|
||||
return _get_openai_compatible_models_response(
|
||||
url=url,
|
||||
source_name="OpenAI Compatible",
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
@@ -464,3 +464,18 @@ class BifrostFinalModelResponse(BaseModel):
|
||||
max_input_tokens: int | None
|
||||
supports_image_input: bool
|
||||
supports_reasoning: bool
|
||||
|
||||
|
||||
# OpenAI Compatible dynamic models fetch
|
||||
class OpenAICompatibleModelsRequest(BaseModel):
|
||||
api_base: str
|
||||
api_key: str | None = None
|
||||
provider_name: str | None = None # Optional: to save models to existing provider
|
||||
|
||||
|
||||
class OpenAICompatibleFinalModelResponse(BaseModel):
|
||||
name: str # Model ID (e.g. "meta-llama/Llama-3-8B-Instruct")
|
||||
display_name: str # Human-readable name from API
|
||||
max_input_tokens: int | None
|
||||
supports_image_input: bool
|
||||
supports_reasoning: bool
|
||||
|
||||
@@ -26,6 +26,7 @@ DYNAMIC_LLM_PROVIDERS = frozenset(
|
||||
LlmProviderNames.OLLAMA_CHAT,
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.BIFROST,
|
||||
LlmProviderNames.OPENAI_COMPATIBLE,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
"""Tests for _build_thread_text function."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.context.search.federated.slack_search import _build_thread_text
|
||||
|
||||
|
||||
def _make_msg(user: str, text: str, ts: str) -> dict[str, str]:
|
||||
return {"user": user, "text": text, "ts": ts}
|
||||
|
||||
|
||||
class TestBuildThreadText:
|
||||
"""Verify _build_thread_text includes full thread replies up to cap."""
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
|
||||
def test_includes_all_replies(self, mock_profiles: MagicMock) -> None:
|
||||
"""All replies within cap are included in output."""
|
||||
mock_profiles.return_value = {}
|
||||
messages = [
|
||||
_make_msg("U1", "parent msg", "1000.0"),
|
||||
_make_msg("U2", "reply 1", "1001.0"),
|
||||
_make_msg("U3", "reply 2", "1002.0"),
|
||||
_make_msg("U4", "reply 3", "1003.0"),
|
||||
]
|
||||
result = _build_thread_text(messages, "token", "T123", MagicMock())
|
||||
assert "parent msg" in result
|
||||
assert "reply 1" in result
|
||||
assert "reply 2" in result
|
||||
assert "reply 3" in result
|
||||
assert "..." not in result
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
|
||||
def test_non_thread_returns_parent_only(self, mock_profiles: MagicMock) -> None:
|
||||
"""Single message (no replies) returns just the parent text."""
|
||||
mock_profiles.return_value = {}
|
||||
messages = [_make_msg("U1", "just a message", "1000.0")]
|
||||
result = _build_thread_text(messages, "token", "T123", MagicMock())
|
||||
assert "just a message" in result
|
||||
assert "Replies:" not in result
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
|
||||
def test_parent_always_first(self, mock_profiles: MagicMock) -> None:
|
||||
"""Thread parent message is always the first line of output."""
|
||||
mock_profiles.return_value = {}
|
||||
messages = [
|
||||
_make_msg("U1", "I am the parent", "1000.0"),
|
||||
_make_msg("U2", "I am a reply", "1001.0"),
|
||||
]
|
||||
result = _build_thread_text(messages, "token", "T123", MagicMock())
|
||||
parent_pos = result.index("I am the parent")
|
||||
reply_pos = result.index("I am a reply")
|
||||
assert parent_pos < reply_pos
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.batch_get_user_profiles")
|
||||
def test_user_profiles_resolved(self, mock_profiles: MagicMock) -> None:
|
||||
"""User IDs in thread text are replaced with display names."""
|
||||
mock_profiles.return_value = {"U1": "Alice", "U2": "Bob"}
|
||||
messages = [
|
||||
_make_msg("U1", "hello", "1000.0"),
|
||||
_make_msg("U2", "world", "1001.0"),
|
||||
]
|
||||
result = _build_thread_text(messages, "token", "T123", MagicMock())
|
||||
assert "Alice" in result
|
||||
assert "Bob" in result
|
||||
assert "<@U1>" not in result
|
||||
assert "<@U2>" not in result
|
||||
@@ -0,0 +1,108 @@
|
||||
"""Tests for Slack URL parsing and direct thread fetch via URL override."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.context.search.federated.models import DirectThreadFetch
|
||||
from onyx.context.search.federated.slack_search import _fetch_thread_from_url
|
||||
from onyx.context.search.federated.slack_search_utils import extract_slack_message_urls
|
||||
|
||||
|
||||
class TestExtractSlackMessageUrls:
|
||||
"""Verify URL parsing extracts channel_id and timestamp correctly."""
|
||||
|
||||
def test_standard_url(self) -> None:
|
||||
query = "summarize https://mycompany.slack.com/archives/C097NBWMY8Y/p1775491616524769"
|
||||
results = extract_slack_message_urls(query)
|
||||
assert len(results) == 1
|
||||
assert results[0] == ("C097NBWMY8Y", "1775491616.524769")
|
||||
|
||||
def test_multiple_urls(self) -> None:
|
||||
query = (
|
||||
"compare https://co.slack.com/archives/C111/p1234567890123456 "
|
||||
"and https://co.slack.com/archives/C222/p9876543210987654"
|
||||
)
|
||||
results = extract_slack_message_urls(query)
|
||||
assert len(results) == 2
|
||||
assert results[0] == ("C111", "1234567890.123456")
|
||||
assert results[1] == ("C222", "9876543210.987654")
|
||||
|
||||
def test_no_urls(self) -> None:
|
||||
query = "what happened in #general last week?"
|
||||
results = extract_slack_message_urls(query)
|
||||
assert len(results) == 0
|
||||
|
||||
def test_non_slack_url_ignored(self) -> None:
|
||||
query = "check https://google.com/archives/C111/p1234567890123456"
|
||||
results = extract_slack_message_urls(query)
|
||||
assert len(results) == 0
|
||||
|
||||
def test_timestamp_conversion(self) -> None:
|
||||
"""p prefix removed, dot inserted after 10th digit."""
|
||||
query = "https://x.slack.com/archives/CABC123/p1775491616524769"
|
||||
results = extract_slack_message_urls(query)
|
||||
channel_id, ts = results[0]
|
||||
assert channel_id == "CABC123"
|
||||
assert ts == "1775491616.524769"
|
||||
assert not ts.startswith("p")
|
||||
assert "." in ts
|
||||
|
||||
|
||||
class TestFetchThreadFromUrl:
|
||||
"""Verify _fetch_thread_from_url calls conversations.replies and returns SlackMessage."""
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search._build_thread_text")
|
||||
@patch("onyx.context.search.federated.slack_search.WebClient")
|
||||
def test_successful_fetch(
|
||||
self, mock_webclient_cls: MagicMock, mock_build_thread: MagicMock
|
||||
) -> None:
|
||||
mock_client = MagicMock()
|
||||
mock_webclient_cls.return_value = mock_client
|
||||
|
||||
# Mock conversations_replies
|
||||
mock_response = MagicMock()
|
||||
mock_response.get.return_value = [
|
||||
{"user": "U1", "text": "parent", "ts": "1775491616.524769"},
|
||||
{"user": "U2", "text": "reply 1", "ts": "1775491617.000000"},
|
||||
{"user": "U3", "text": "reply 2", "ts": "1775491618.000000"},
|
||||
]
|
||||
mock_client.conversations_replies.return_value = mock_response
|
||||
|
||||
# Mock channel info
|
||||
mock_ch_response = MagicMock()
|
||||
mock_ch_response.get.return_value = {"name": "general"}
|
||||
mock_client.conversations_info.return_value = mock_ch_response
|
||||
|
||||
mock_build_thread.return_value = (
|
||||
"U1: parent\n\nReplies:\n\nU2: reply 1\n\nU3: reply 2"
|
||||
)
|
||||
|
||||
fetch = DirectThreadFetch(
|
||||
channel_id="C097NBWMY8Y", thread_ts="1775491616.524769"
|
||||
)
|
||||
result = _fetch_thread_from_url(fetch, "xoxp-token")
|
||||
|
||||
assert len(result.messages) == 1
|
||||
msg = result.messages[0]
|
||||
assert msg.channel_id == "C097NBWMY8Y"
|
||||
assert msg.thread_id is None # Prevents double-enrichment
|
||||
assert msg.slack_score == 100000.0
|
||||
assert "parent" in msg.text
|
||||
mock_client.conversations_replies.assert_called_once_with(
|
||||
channel="C097NBWMY8Y", ts="1775491616.524769"
|
||||
)
|
||||
|
||||
@patch("onyx.context.search.federated.slack_search.WebClient")
|
||||
def test_api_error_returns_empty(self, mock_webclient_cls: MagicMock) -> None:
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_webclient_cls.return_value = mock_client
|
||||
mock_client.conversations_replies.side_effect = SlackApiError(
|
||||
message="channel_not_found",
|
||||
response=MagicMock(status_code=404),
|
||||
)
|
||||
|
||||
fetch = DirectThreadFetch(channel_id="CBAD", thread_ts="1234567890.123456")
|
||||
result = _fetch_thread_from_url(fetch, "xoxp-token")
|
||||
assert len(result.messages) == 0
|
||||
@@ -12,6 +12,10 @@ dependency on pypdf internals (pypdf.generic).
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.file_processing import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import count_pdf_embedded_images
|
||||
from onyx.file_processing.extract_file_text import pdf_to_text
|
||||
from onyx.file_processing.extract_file_text import read_pdf_file
|
||||
from onyx.file_processing.password_validation import is_pdf_protected
|
||||
@@ -96,6 +100,80 @@ class TestReadPdfFile:
|
||||
# Returned list is empty when callback is used
|
||||
assert images == []
|
||||
|
||||
def test_image_cap_skips_images_above_limit(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""When the embedded-image cap is exceeded, remaining images are skipped.
|
||||
|
||||
The cap protects the user-file-processing worker from OOMing on PDFs
|
||||
with thousands of embedded images. Setting the cap to 0 should yield
|
||||
zero extracted images even though the fixture has one.
|
||||
"""
|
||||
monkeypatch.setattr(extract_file_text, "MAX_EMBEDDED_IMAGES_PER_FILE", 0)
|
||||
_, _, images = read_pdf_file(_load("with_image.pdf"), extract_images=True)
|
||||
assert images == []
|
||||
|
||||
def test_image_cap_at_limit_extracts_up_to_cap(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""A cap >= image count behaves identically to the uncapped path."""
|
||||
monkeypatch.setattr(extract_file_text, "MAX_EMBEDDED_IMAGES_PER_FILE", 100)
|
||||
_, _, images = read_pdf_file(_load("with_image.pdf"), extract_images=True)
|
||||
assert len(images) == 1
|
||||
|
||||
def test_image_cap_with_callback_stops_streaming_at_limit(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""The cap also short-circuits the streaming callback path."""
|
||||
monkeypatch.setattr(extract_file_text, "MAX_EMBEDDED_IMAGES_PER_FILE", 0)
|
||||
collected: list[tuple[bytes, str]] = []
|
||||
|
||||
def callback(data: bytes, name: str) -> None:
|
||||
collected.append((data, name))
|
||||
|
||||
read_pdf_file(
|
||||
_load("with_image.pdf"), extract_images=True, image_callback=callback
|
||||
)
|
||||
assert collected == []
|
||||
|
||||
|
||||
# ── count_pdf_embedded_images ────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCountPdfEmbeddedImages:
|
||||
def test_returns_count_for_normal_pdf(self) -> None:
|
||||
assert count_pdf_embedded_images(_load("with_image.pdf"), cap=10) == 1
|
||||
|
||||
def test_short_circuits_above_cap(self) -> None:
|
||||
# with_image.pdf has 1 image. cap=0 means "anything > 0 is over cap" —
|
||||
# function returns on first increment as the over-cap sentinel.
|
||||
assert count_pdf_embedded_images(_load("with_image.pdf"), cap=0) == 1
|
||||
|
||||
def test_returns_zero_for_pdf_without_images(self) -> None:
|
||||
assert count_pdf_embedded_images(_load("simple.pdf"), cap=10) == 0
|
||||
|
||||
def test_returns_zero_for_invalid_pdf(self) -> None:
|
||||
assert count_pdf_embedded_images(BytesIO(b"not a pdf"), cap=10) == 0
|
||||
|
||||
def test_returns_zero_for_password_locked_pdf(self) -> None:
|
||||
# encrypted.pdf has an open password; we can't inspect without it, so
|
||||
# the helper returns 0 — callers rely on the password-protected check
|
||||
# that runs earlier in the upload pipeline.
|
||||
assert count_pdf_embedded_images(_load("encrypted.pdf"), cap=10) == 0
|
||||
|
||||
def test_inspects_owner_password_only_pdf(self) -> None:
|
||||
# owner_protected.pdf is encrypted but has no open password. It should
|
||||
# decrypt with an empty string and count images normally. The fixture
|
||||
# has zero images, so 0 is a real count (not the "bail on encrypted"
|
||||
# path).
|
||||
assert count_pdf_embedded_images(_load("owner_protected.pdf"), cap=10) == 0
|
||||
|
||||
def test_preserves_file_position(self) -> None:
|
||||
pdf = _load("with_image.pdf")
|
||||
pdf.seek(42)
|
||||
count_pdf_embedded_images(pdf, cap=10)
|
||||
assert pdf.tell() == 42
|
||||
|
||||
|
||||
# ── pdf_to_text ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
@@ -9,7 +10,9 @@ from uuid import uuid4
|
||||
from fastapi import Response
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from ee.onyx.server.scim.api import _check_seat_availability
|
||||
from ee.onyx.server.scim.api import _scim_name_to_str
|
||||
from ee.onyx.server.scim.api import _seat_lock_id_for_tenant
|
||||
from ee.onyx.server.scim.api import create_user
|
||||
from ee.onyx.server.scim.api import delete_user
|
||||
from ee.onyx.server.scim.api import get_user
|
||||
@@ -741,3 +744,80 @@ class TestEmailCasePreservation:
|
||||
resource = parse_scim_user(result)
|
||||
assert resource.userName == "Alice@Example.COM"
|
||||
assert resource.emails[0].value == "Alice@Example.COM"
|
||||
|
||||
|
||||
class TestSeatLock:
|
||||
"""Tests for the advisory lock in _check_seat_availability."""
|
||||
|
||||
@patch("ee.onyx.server.scim.api.get_current_tenant_id", return_value="tenant_abc")
|
||||
def test_acquires_advisory_lock_before_checking(
|
||||
self,
|
||||
_mock_tenant: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
"""The advisory lock must be acquired before the seat check runs."""
|
||||
call_order: list[str] = []
|
||||
|
||||
def track_execute(stmt: Any, _params: Any = None) -> None:
|
||||
if "pg_advisory_xact_lock" in str(stmt):
|
||||
call_order.append("lock")
|
||||
|
||||
mock_dal.session.execute.side_effect = track_execute
|
||||
|
||||
with patch(
|
||||
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop"
|
||||
) as mock_fetch:
|
||||
mock_result = MagicMock()
|
||||
mock_result.available = True
|
||||
mock_fn = MagicMock(return_value=mock_result)
|
||||
mock_fetch.return_value = mock_fn
|
||||
|
||||
def track_check(*_args: Any, **_kwargs: Any) -> Any:
|
||||
call_order.append("check")
|
||||
return mock_result
|
||||
|
||||
mock_fn.side_effect = track_check
|
||||
|
||||
_check_seat_availability(mock_dal)
|
||||
|
||||
assert call_order == ["lock", "check"]
|
||||
|
||||
@patch("ee.onyx.server.scim.api.get_current_tenant_id", return_value="tenant_xyz")
|
||||
def test_lock_uses_tenant_scoped_key(
|
||||
self,
|
||||
_mock_tenant: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
"""The lock id must be derived from the tenant via _seat_lock_id_for_tenant."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.available = True
|
||||
mock_check = MagicMock(return_value=mock_result)
|
||||
|
||||
with patch(
|
||||
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop",
|
||||
return_value=mock_check,
|
||||
):
|
||||
_check_seat_availability(mock_dal)
|
||||
|
||||
mock_dal.session.execute.assert_called_once()
|
||||
params = mock_dal.session.execute.call_args[0][1]
|
||||
assert params["lock_id"] == _seat_lock_id_for_tenant("tenant_xyz")
|
||||
|
||||
def test_seat_lock_id_is_stable_and_tenant_scoped(self) -> None:
|
||||
"""Lock id must be deterministic and differ across tenants."""
|
||||
assert _seat_lock_id_for_tenant("t1") == _seat_lock_id_for_tenant("t1")
|
||||
assert _seat_lock_id_for_tenant("t1") != _seat_lock_id_for_tenant("t2")
|
||||
|
||||
def test_no_lock_when_ee_absent(
|
||||
self,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
"""No advisory lock should be acquired when the EE check is absent."""
|
||||
with patch(
|
||||
"ee.onyx.server.scim.api.fetch_ee_implementation_or_noop",
|
||||
return_value=None,
|
||||
):
|
||||
result = _check_seat_availability(mock_dal)
|
||||
|
||||
assert result is None
|
||||
mock_dal.session.execute.assert_not_called()
|
||||
|
||||
@@ -1,7 +1 @@
|
||||
"use client";
|
||||
|
||||
import LLMConfigurationPage from "@/refresh-pages/admin/LLMConfigurationPage";
|
||||
|
||||
export default function Page() {
|
||||
return <LLMConfigurationPage />;
|
||||
}
|
||||
export { default } from "@/refresh-pages/admin/LLMProviderConfigurationPage";
|
||||
|
||||
@@ -32,8 +32,10 @@ import {
|
||||
OpenRouterFetchParams,
|
||||
LiteLLMProxyFetchParams,
|
||||
BifrostFetchParams,
|
||||
OpenAICompatibleFetchParams,
|
||||
OpenAICompatibleModelResponse,
|
||||
} from "@/interfaces/llm";
|
||||
import { SvgAws, SvgBifrost, SvgOpenrouter } from "@opal/icons";
|
||||
import { SvgAws, SvgBifrost, SvgOpenrouter, SvgPlug } from "@opal/icons";
|
||||
|
||||
// Aggregator providers that host models from multiple vendors
|
||||
export const AGGREGATOR_PROVIDERS = new Set([
|
||||
@@ -44,6 +46,7 @@ export const AGGREGATOR_PROVIDERS = new Set([
|
||||
"lm_studio",
|
||||
"litellm_proxy",
|
||||
"bifrost",
|
||||
"openai_compatible",
|
||||
"vertex_ai",
|
||||
]);
|
||||
|
||||
@@ -82,6 +85,7 @@ export const getProviderIcon = (
|
||||
openrouter: SvgOpenrouter,
|
||||
litellm_proxy: LiteLLMIcon,
|
||||
bifrost: SvgBifrost,
|
||||
openai_compatible: SvgPlug,
|
||||
vertex_ai: GeminiIcon,
|
||||
};
|
||||
|
||||
@@ -411,6 +415,64 @@ export const fetchBifrostModels = async (
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches models from a generic OpenAI-compatible server.
|
||||
* Uses snake_case params to match API structure.
|
||||
*/
|
||||
export const fetchOpenAICompatibleModels = async (
|
||||
params: OpenAICompatibleFetchParams
|
||||
): Promise<{ models: ModelConfiguration[]; error?: string }> => {
|
||||
const apiBase = params.api_base;
|
||||
if (!apiBase) {
|
||||
return { models: [], error: "API Base is required" };
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(
|
||||
"/api/admin/llm/openai-compatible/available-models",
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
api_base: apiBase,
|
||||
api_key: params.api_key,
|
||||
provider_name: params.provider_name,
|
||||
}),
|
||||
signal: params.signal,
|
||||
}
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
let errorMessage = "Failed to fetch models";
|
||||
try {
|
||||
const errorData = await response.json();
|
||||
errorMessage = errorData.detail || errorData.message || errorMessage;
|
||||
} catch {
|
||||
// ignore JSON parsing errors
|
||||
}
|
||||
return { models: [], error: errorMessage };
|
||||
}
|
||||
|
||||
const data: OpenAICompatibleModelResponse[] = await response.json();
|
||||
const models: ModelConfiguration[] = data.map((modelData) => ({
|
||||
name: modelData.name,
|
||||
display_name: modelData.display_name,
|
||||
is_visible: true,
|
||||
max_input_tokens: modelData.max_input_tokens,
|
||||
supports_image_input: modelData.supports_image_input,
|
||||
supports_reasoning: modelData.supports_reasoning,
|
||||
}));
|
||||
|
||||
return { models };
|
||||
} catch (error) {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : "Unknown error";
|
||||
return { models: [], error: errorMessage };
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches LiteLLM Proxy models directly without any form state dependencies.
|
||||
* Uses snake_case params to match API structure.
|
||||
@@ -531,6 +593,13 @@ export const fetchModels = async (
|
||||
provider_name: formValues.name,
|
||||
signal,
|
||||
});
|
||||
case LLMProviderName.OPENAI_COMPATIBLE:
|
||||
return fetchOpenAICompatibleModels({
|
||||
api_base: formValues.api_base,
|
||||
api_key: formValues.api_key,
|
||||
provider_name: formValues.name,
|
||||
signal,
|
||||
});
|
||||
default:
|
||||
return { models: [], error: `Unknown provider: ${providerName}` };
|
||||
}
|
||||
@@ -545,6 +614,7 @@ export function canProviderFetchModels(providerName?: string) {
|
||||
case LLMProviderName.OPENROUTER:
|
||||
case LLMProviderName.LITELLM_PROXY:
|
||||
case LLMProviderName.BIFROST:
|
||||
case LLMProviderName.OPENAI_COMPATIBLE:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
||||
@@ -24,7 +24,6 @@ import {
|
||||
} from "@/app/craft/onboarding/constants";
|
||||
import { LLMProviderDescriptor } from "@/interfaces/llm";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "@/lib/llmConfig/constants";
|
||||
import { buildOnboardingInitialValues as buildInitialValues } from "@/sections/modals/llmConfig/utils";
|
||||
import { testApiKeyHelper } from "@/sections/modals/llmConfig/svc";
|
||||
import OnboardingInfoPages from "@/app/craft/onboarding/components/OnboardingInfoPages";
|
||||
import OnboardingUserInfo from "@/app/craft/onboarding/components/OnboardingUserInfo";
|
||||
@@ -221,10 +220,8 @@ export default function BuildOnboardingModal({
|
||||
setConnectionStatus("testing");
|
||||
setErrorMessage("");
|
||||
|
||||
const baseValues = buildInitialValues();
|
||||
const providerName = `build-mode-${currentProviderConfig.providerName}`;
|
||||
const payload = {
|
||||
...baseValues,
|
||||
name: providerName,
|
||||
provider: currentProviderConfig.providerName,
|
||||
api_key: apiKey,
|
||||
|
||||
@@ -272,6 +272,22 @@ export default function UserLibraryModal({
|
||||
</Disabled>
|
||||
</Section>
|
||||
|
||||
{/* The exact cap is controlled by the backend env var
|
||||
MAX_EMBEDDED_IMAGES_PER_FILE (default 500). This copy is
|
||||
deliberately vague so it doesn't drift if the limit is
|
||||
tuned per-deployment; the precise number is surfaced in
|
||||
the rejection error the server returns. */}
|
||||
<Section
|
||||
flexDirection="row"
|
||||
justifyContent="end"
|
||||
padding={0.5}
|
||||
height="fit"
|
||||
>
|
||||
<Text secondaryBody text03>
|
||||
PDFs with many embedded images may be rejected.
|
||||
</Text>
|
||||
</Section>
|
||||
|
||||
{isLoading ? (
|
||||
<Section padding={2} height="fit">
|
||||
<Text secondaryBody text03>
|
||||
|
||||
@@ -4,6 +4,7 @@ import useSWR from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import {
|
||||
LLMProviderDescriptor,
|
||||
LLMProviderName,
|
||||
LLMProviderResponse,
|
||||
LLMProviderView,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
@@ -136,14 +137,12 @@ export function useAdminLLMProviders() {
|
||||
* Used inside individual provider modals to pre-populate model lists
|
||||
* before the user has entered credentials.
|
||||
*
|
||||
* @param providerEndpoint - The provider's API endpoint name (e.g. "openai", "anthropic").
|
||||
* @param providerName - The provider's API endpoint name (e.g. "openai", "anthropic").
|
||||
* Pass `null` to suppress the request.
|
||||
*/
|
||||
export function useWellKnownLLMProvider(providerEndpoint: string | null) {
|
||||
export function useWellKnownLLMProvider(providerName: LLMProviderName) {
|
||||
const { data, error, isLoading } = useSWR<WellKnownLLMProviderDescriptor>(
|
||||
providerEndpoint
|
||||
? `/api/admin/llm/built-in/options/${providerEndpoint}`
|
||||
: null,
|
||||
providerName ? `/api/admin/llm/built-in/options/${providerName}` : null,
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
|
||||
@@ -14,6 +14,7 @@ export enum LLMProviderName {
|
||||
BEDROCK = "bedrock",
|
||||
LITELLM_PROXY = "litellm_proxy",
|
||||
BIFROST = "bifrost",
|
||||
OPENAI_COMPATIBLE = "openai_compatible",
|
||||
CUSTOM = "custom",
|
||||
}
|
||||
|
||||
@@ -123,14 +124,11 @@ export interface LLMProviderFormProps {
|
||||
shouldMarkAsDefault?: boolean;
|
||||
open?: boolean;
|
||||
onOpenChange?: (open: boolean) => void;
|
||||
|
||||
/** The current default model name for this provider (from the global default). */
|
||||
defaultModelName?: string;
|
||||
/** Called after successful provider creation/update. */
|
||||
onSuccess?: () => void | Promise<void>;
|
||||
|
||||
// Onboarding-specific (only when variant === "onboarding")
|
||||
onboardingState?: OnboardingState;
|
||||
onboardingActions?: OnboardingActions;
|
||||
llmDescriptor?: WellKnownLLMProviderDescriptor;
|
||||
}
|
||||
|
||||
// Param types for model fetching functions - use snake_case to match API structure
|
||||
@@ -181,6 +179,21 @@ export interface BifrostModelResponse {
|
||||
supports_reasoning: boolean;
|
||||
}
|
||||
|
||||
export interface OpenAICompatibleFetchParams {
|
||||
api_base?: string;
|
||||
api_key?: string;
|
||||
provider_name?: string;
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
|
||||
export interface OpenAICompatibleModelResponse {
|
||||
name: string;
|
||||
display_name: string;
|
||||
max_input_tokens: number | null;
|
||||
supports_image_input: boolean;
|
||||
supports_reasoning: boolean;
|
||||
}
|
||||
|
||||
export interface VertexAIFetchParams {
|
||||
model_configurations?: ModelConfiguration[];
|
||||
}
|
||||
@@ -199,5 +212,6 @@ export type FetchModelsParams =
|
||||
| OpenRouterFetchParams
|
||||
| LiteLLMProxyFetchParams
|
||||
| BifrostFetchParams
|
||||
| OpenAICompatibleFetchParams
|
||||
| VertexAIFetchParams
|
||||
| LMStudioFetchParams;
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"use client";
|
||||
|
||||
import type { RichStr } from "@opal/types";
|
||||
import type { RichStr, WithoutStyles } from "@opal/types";
|
||||
import { resolveStr } from "@opal/components/text/InlineMarkdown";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import { SvgXOctagon, SvgAlertCircle } from "@opal/icons";
|
||||
import { useField, useFormikContext } from "formik";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
@@ -234,9 +235,27 @@ function ErrorTextLayout({ children, type = "error" }: ErrorTextLayoutProps) {
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* FieldSeparator - A horizontal rule with inline padding, used to visually separate field groups.
|
||||
*/
|
||||
function FieldSeparator() {
|
||||
return <Separator noPadding className="p-2" />;
|
||||
}
|
||||
|
||||
/**
|
||||
* FieldPadder - Wraps a field in standard horizontal + vertical padding (`p-2 w-full`).
|
||||
*/
|
||||
type FieldPadderProps = WithoutStyles<React.HTMLAttributes<HTMLDivElement>>;
|
||||
function FieldPadder(props: FieldPadderProps) {
|
||||
return <div {...props} className="p-2 w-full" />;
|
||||
}
|
||||
|
||||
export {
|
||||
VerticalInputLayout as Vertical,
|
||||
HorizontalInputLayout as Horizontal,
|
||||
ErrorLayout as Error,
|
||||
ErrorTextLayout,
|
||||
FieldSeparator,
|
||||
FieldPadder,
|
||||
type FieldPadderProps,
|
||||
};
|
||||
|
||||
@@ -8,6 +8,7 @@ import {
|
||||
SvgCloud,
|
||||
SvgAws,
|
||||
SvgOpenrouter,
|
||||
SvgPlug,
|
||||
SvgServer,
|
||||
SvgAzure,
|
||||
SvgGemini,
|
||||
@@ -28,6 +29,7 @@ const PROVIDER_ICONS: Record<string, IconFunctionComponent> = {
|
||||
[LLMProviderName.OPENROUTER]: SvgOpenrouter,
|
||||
[LLMProviderName.LM_STUDIO]: SvgLmStudio,
|
||||
[LLMProviderName.BIFROST]: SvgBifrost,
|
||||
[LLMProviderName.OPENAI_COMPATIBLE]: SvgPlug,
|
||||
|
||||
// fallback
|
||||
[LLMProviderName.CUSTOM]: SvgServer,
|
||||
@@ -45,6 +47,7 @@ const PROVIDER_PRODUCT_NAMES: Record<string, string> = {
|
||||
[LLMProviderName.OPENROUTER]: "OpenRouter",
|
||||
[LLMProviderName.LM_STUDIO]: "LM Studio",
|
||||
[LLMProviderName.BIFROST]: "Bifrost",
|
||||
[LLMProviderName.OPENAI_COMPATIBLE]: "OpenAI Compatible",
|
||||
|
||||
// fallback
|
||||
[LLMProviderName.CUSTOM]: "Custom Models",
|
||||
@@ -62,6 +65,7 @@ const PROVIDER_DISPLAY_NAMES: Record<string, string> = {
|
||||
[LLMProviderName.OPENROUTER]: "OpenRouter",
|
||||
[LLMProviderName.LM_STUDIO]: "LM Studio",
|
||||
[LLMProviderName.BIFROST]: "Bifrost",
|
||||
[LLMProviderName.OPENAI_COMPATIBLE]: "OpenAI Compatible",
|
||||
|
||||
// fallback
|
||||
[LLMProviderName.CUSTOM]: "Other providers or self-hosted",
|
||||
|
||||
@@ -31,6 +31,7 @@ import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationMo
|
||||
import { useCreateModal } from "@/refresh-components/contexts/ModalContext";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import {
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "@/interfaces/llm";
|
||||
@@ -43,9 +44,10 @@ import BedrockModal from "@/sections/modals/llmConfig/BedrockModal";
|
||||
import VertexAIModal from "@/sections/modals/llmConfig/VertexAIModal";
|
||||
import OpenRouterModal from "@/sections/modals/llmConfig/OpenRouterModal";
|
||||
import CustomModal from "@/sections/modals/llmConfig/CustomModal";
|
||||
import LMStudioForm from "@/sections/modals/llmConfig/LMStudioForm";
|
||||
import LMStudioModal from "@/sections/modals/llmConfig/LMStudioModal";
|
||||
import LiteLLMProxyModal from "@/sections/modals/llmConfig/LiteLLMProxyModal";
|
||||
import BifrostModal from "@/sections/modals/llmConfig/BifrostModal";
|
||||
import OpenAICompatibleModal from "@/sections/modals/llmConfig/OpenAICompatibleModal";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
|
||||
const route = ADMIN_ROUTES.LLM_MODELS;
|
||||
@@ -57,16 +59,18 @@ const route = ADMIN_ROUTES.LLM_MODELS;
|
||||
// Client-side ordering for the "Add Provider" cards. The backend may return
|
||||
// wellKnownLLMProviders in an arbitrary order, so we sort explicitly here.
|
||||
const PROVIDER_DISPLAY_ORDER: string[] = [
|
||||
"openai",
|
||||
"anthropic",
|
||||
"vertex_ai",
|
||||
"bedrock",
|
||||
"azure",
|
||||
"litellm_proxy",
|
||||
"ollama_chat",
|
||||
"openrouter",
|
||||
"lm_studio",
|
||||
"bifrost",
|
||||
LLMProviderName.OPENAI,
|
||||
LLMProviderName.ANTHROPIC,
|
||||
LLMProviderName.VERTEX_AI,
|
||||
LLMProviderName.BEDROCK,
|
||||
LLMProviderName.AZURE,
|
||||
"litellm",
|
||||
LLMProviderName.LITELLM_PROXY,
|
||||
LLMProviderName.OLLAMA_CHAT,
|
||||
LLMProviderName.OPENROUTER,
|
||||
LLMProviderName.LM_STUDIO,
|
||||
LLMProviderName.BIFROST,
|
||||
LLMProviderName.OPENAI_COMPATIBLE,
|
||||
];
|
||||
|
||||
const PROVIDER_MODAL_MAP: Record<
|
||||
@@ -127,7 +131,7 @@ const PROVIDER_MODAL_MAP: Record<
|
||||
/>
|
||||
),
|
||||
lm_studio: (d, open, onOpenChange) => (
|
||||
<LMStudioForm
|
||||
<LMStudioModal
|
||||
shouldMarkAsDefault={d}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
@@ -147,6 +151,13 @@ const PROVIDER_MODAL_MAP: Record<
|
||||
onOpenChange={onOpenChange}
|
||||
/>
|
||||
),
|
||||
openai_compatible: (d, open, onOpenChange) => (
|
||||
<OpenAICompatibleModal
|
||||
shouldMarkAsDefault={d}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
/>
|
||||
),
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
@@ -341,7 +352,7 @@ function NewCustomProviderCard({
|
||||
// LLMConfigurationPage — main page component
|
||||
// ============================================================================
|
||||
|
||||
export default function LLMConfigurationPage() {
|
||||
export default function LLMProviderConfigurationPage() {
|
||||
const { mutate } = useSWRConfig();
|
||||
const { llmProviders: existingLlmProviders, defaultText } =
|
||||
useAdminLLMProviders();
|
||||
@@ -1,33 +1,23 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { Formik } from "formik";
|
||||
import { LLMProviderFormProps } from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import { LLMProviderFormProps, LLMProviderName } from "@/interfaces/llm";
|
||||
import {
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
APIKeyField,
|
||||
ModelsField,
|
||||
ModelSelectionField,
|
||||
DisplayNameField,
|
||||
ModelsAccessField,
|
||||
FieldSeparator,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
|
||||
const ANTHROPIC_PROVIDER_NAME = "anthropic";
|
||||
const DEFAULT_DEFAULT_MODEL_NAME = "claude-sonnet-4-5";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
|
||||
export default function AnthropicModal({
|
||||
variant = "llm-configuration",
|
||||
@@ -35,143 +25,78 @@ export default function AnthropicModal({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
onSuccess,
|
||||
}: LLMProviderFormProps) {
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
|
||||
ANTHROPIC_PROVIDER_NAME
|
||||
|
||||
const initialValues = useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.ANTHROPIC,
|
||||
existingLlmProvider
|
||||
);
|
||||
|
||||
if (open === false) return null;
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
apiKey: true,
|
||||
});
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const initialValues = isOnboarding
|
||||
? {
|
||||
...buildOnboardingInitialValues(),
|
||||
name: ANTHROPIC_PROVIDER_NAME,
|
||||
provider: ANTHROPIC_PROVIDER_NAME,
|
||||
api_key: "",
|
||||
default_model_name: DEFAULT_DEFAULT_MODEL_NAME,
|
||||
}
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
api_base: existingLlmProvider?.api_base ?? undefined,
|
||||
default_model_name:
|
||||
(defaultModelName &&
|
||||
modelConfigurations.some((m) => m.name === defaultModelName)
|
||||
? defaultModelName
|
||||
: undefined) ??
|
||||
wellKnownLLMProvider?.recommended_default_model?.name ??
|
||||
DEFAULT_DEFAULT_MODEL_NAME,
|
||||
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
});
|
||||
if (open === false) return null;
|
||||
|
||||
return (
|
||||
<Formik
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.ANTHROPIC}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
(wellKnownLLMProvider ?? llmDescriptor)?.known_models ?? [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: ANTHROPIC_PROVIDER_NAME,
|
||||
payload: {
|
||||
...values,
|
||||
model_configurations: modelConfigsToUse,
|
||||
is_auto_mode:
|
||||
values.default_model_name === DEFAULT_DEFAULT_MODEL_NAME,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: ANTHROPIC_PROVIDER_NAME,
|
||||
values,
|
||||
initialValues,
|
||||
modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.ANTHROPIC,
|
||||
values,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
{(formikProps) => (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={ANTHROPIC_PROVIDER_NAME}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
>
|
||||
<APIKeyField providerName="Anthropic" />
|
||||
<APIKeyField providerName="Anthropic" />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<FieldSeparator />
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. claude-sonnet-4-5" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={modelConfigurations}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={
|
||||
wellKnownLLMProvider?.recommended_default_model ?? null
|
||||
}
|
||||
shouldShowAutoUpdateToggle={true}
|
||||
/>
|
||||
)}
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
</>
|
||||
)}
|
||||
</LLMConfigurationModalWrapper>
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
</Formik>
|
||||
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField shouldShowAutoUpdateToggle={true} />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</ModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,41 +1,35 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { Formik } from "formik";
|
||||
import { useFormikContext } from "formik";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import { LLMProviderFormProps, LLMProviderView } from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import {
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
} from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import {
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
BaseLLMFormValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
APIKeyField,
|
||||
DisplayNameField,
|
||||
FieldSeparator,
|
||||
FieldWrapper,
|
||||
ModelsAccessField,
|
||||
ModelsField,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
ModelAccessField,
|
||||
ModelSelectionField,
|
||||
ModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import {
|
||||
isValidAzureTargetUri,
|
||||
parseAzureTargetUri,
|
||||
} from "@/lib/azureTargetUri";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
|
||||
const AZURE_PROVIDER_NAME = "azure";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
|
||||
interface AzureModalValues extends BaseLLMFormValues {
|
||||
api_key: string;
|
||||
@@ -45,6 +39,33 @@ interface AzureModalValues extends BaseLLMFormValues {
|
||||
deployment_name?: string;
|
||||
}
|
||||
|
||||
function AzureModelSelection() {
|
||||
const formikProps = useFormikContext<AzureModalValues>();
|
||||
return (
|
||||
<ModelSelectionField
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onAddModel={(modelName) => {
|
||||
const current = formikProps.values.model_configurations;
|
||||
if (current.some((m) => m.name === modelName)) return;
|
||||
const updated = [
|
||||
...current,
|
||||
{
|
||||
name: modelName,
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
];
|
||||
formikProps.setFieldValue("model_configurations", updated);
|
||||
if (!formikProps.values.test_model_name) {
|
||||
formikProps.setFieldValue("test_model_name", modelName);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
function buildTargetUri(existingLlmProvider?: LLMProviderView): string {
|
||||
if (!existingLlmProvider?.api_base || !existingLlmProvider?.api_version) {
|
||||
return "";
|
||||
@@ -81,160 +102,105 @@ export default function AzureModal({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
onSuccess,
|
||||
}: LLMProviderFormProps) {
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } = useWellKnownLLMProvider(AZURE_PROVIDER_NAME);
|
||||
|
||||
if (open === false) return null;
|
||||
const initialValues: AzureModalValues = {
|
||||
...useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.AZURE,
|
||||
existingLlmProvider
|
||||
),
|
||||
target_uri: buildTargetUri(existingLlmProvider),
|
||||
} as AzureModalValues;
|
||||
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
apiKey: true,
|
||||
extra: {
|
||||
target_uri: Yup.string()
|
||||
.required("Target URI is required")
|
||||
.test(
|
||||
"valid-target-uri",
|
||||
"Target URI must be a valid URL with api-version query parameter and either a deployment name in the path or /openai/responses",
|
||||
(value) => (value ? isValidAzureTargetUri(value) : false)
|
||||
),
|
||||
},
|
||||
});
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const initialValues: AzureModalValues = isOnboarding
|
||||
? ({
|
||||
...buildOnboardingInitialValues(),
|
||||
name: AZURE_PROVIDER_NAME,
|
||||
provider: AZURE_PROVIDER_NAME,
|
||||
api_key: "",
|
||||
target_uri: "",
|
||||
default_model_name: "",
|
||||
} as AzureModalValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
target_uri: buildTargetUri(existingLlmProvider),
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
target_uri: Yup.string()
|
||||
.required("Target URI is required")
|
||||
.test(
|
||||
"valid-target-uri",
|
||||
"Target URI must be a valid URL with api-version query parameter and either a deployment name in the path or /openai/responses",
|
||||
(value) => (value ? isValidAzureTargetUri(value) : false)
|
||||
),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
target_uri: Yup.string()
|
||||
.required("Target URI is required")
|
||||
.test(
|
||||
"valid-target-uri",
|
||||
"Target URI must be a valid URL with api-version query parameter and either a deployment name in the path or /openai/responses",
|
||||
(value) => (value ? isValidAzureTargetUri(value) : false)
|
||||
),
|
||||
});
|
||||
if (open === false) return null;
|
||||
|
||||
return (
|
||||
<Formik
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.AZURE}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
const processedValues = processValues(values);
|
||||
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
(wellKnownLLMProvider ?? llmDescriptor)?.known_models ?? [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: AZURE_PROVIDER_NAME,
|
||||
payload: {
|
||||
...processedValues,
|
||||
model_configurations: modelConfigsToUse,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: AZURE_PROVIDER_NAME,
|
||||
values: processedValues,
|
||||
initialValues,
|
||||
modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.AZURE,
|
||||
values: processedValues,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
{(formikProps) => (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={AZURE_PROVIDER_NAME}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
<InputLayouts.FieldPadder>
|
||||
<InputLayouts.Vertical
|
||||
name="target_uri"
|
||||
title="Target URI"
|
||||
subDescription="Paste your endpoint target URI from Azure OpenAI (including API endpoint base, deployment name, and API version)."
|
||||
>
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="target_uri"
|
||||
title="Target URI"
|
||||
subDescription="Paste your endpoint target URI from Azure OpenAI (including API endpoint base, deployment name, and API version)."
|
||||
>
|
||||
<InputTypeInField
|
||||
name="target_uri"
|
||||
placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview"
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
<InputTypeInField
|
||||
name="target_uri"
|
||||
placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview"
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</InputLayouts.FieldPadder>
|
||||
|
||||
<APIKeyField providerName="Azure" />
|
||||
<APIKeyField providerName="Azure" />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<FieldSeparator />
|
||||
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. gpt-4o" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={modelConfigurations}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={null}
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
/>
|
||||
)}
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
</>
|
||||
)}
|
||||
</LLMConfigurationModalWrapper>
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
</Formik>
|
||||
|
||||
<InputLayouts.FieldSeparator />
|
||||
<AzureModelSelection />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</ModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import { useEffect } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { Formik, FormikProps } from "formik";
|
||||
import { useFormikContext } from "formik";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import InputSelectField from "@/refresh-components/form/InputSelectField";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
@@ -10,30 +10,22 @@ import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
} from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import {
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
BaseLLMFormValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import {
|
||||
ModelsField,
|
||||
ModelSelectionField,
|
||||
DisplayNameField,
|
||||
FieldSeparator,
|
||||
FieldWrapper,
|
||||
ModelsAccessField,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import { fetchBedrockModels } from "@/app/admin/configuration/llm/utils";
|
||||
import { Card } from "@opal/components";
|
||||
@@ -41,9 +33,9 @@ import { Section } from "@/layouts/general-layouts";
|
||||
import { SvgAlertCircle } from "@opal/icons";
|
||||
import { Content } from "@opal/layouts";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
import useOnMount from "@/hooks/useOnMount";
|
||||
|
||||
const BEDROCK_PROVIDER_NAME = "bedrock";
|
||||
const AWS_REGION_OPTIONS = [
|
||||
{ name: "us-east-1", value: "us-east-1" },
|
||||
{ name: "us-east-2", value: "us-east-2" },
|
||||
@@ -79,26 +71,15 @@ interface BedrockModalValues extends BaseLLMFormValues {
|
||||
}
|
||||
|
||||
interface BedrockModalInternalsProps {
|
||||
formikProps: FormikProps<BedrockModalValues>;
|
||||
existingLlmProvider: LLMProviderView | undefined;
|
||||
fetchedModels: ModelConfiguration[];
|
||||
setFetchedModels: (models: ModelConfiguration[]) => void;
|
||||
modelConfigurations: ModelConfiguration[];
|
||||
isTesting: boolean;
|
||||
onClose: () => void;
|
||||
isOnboarding: boolean;
|
||||
}
|
||||
|
||||
function BedrockModalInternals({
|
||||
formikProps,
|
||||
existingLlmProvider,
|
||||
fetchedModels,
|
||||
setFetchedModels,
|
||||
modelConfigurations,
|
||||
isTesting,
|
||||
onClose,
|
||||
isOnboarding,
|
||||
}: BedrockModalInternalsProps) {
|
||||
const formikProps = useFormikContext<BedrockModalValues>();
|
||||
const authMethod = formikProps.values.custom_config?.BEDROCK_AUTH_METHOD;
|
||||
|
||||
useEffect(() => {
|
||||
@@ -115,11 +96,6 @@ function BedrockModalInternals({
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [authMethod]);
|
||||
|
||||
const currentModels =
|
||||
fetchedModels.length > 0
|
||||
? fetchedModels
|
||||
: existingLlmProvider?.model_configurations || modelConfigurations;
|
||||
|
||||
const isAuthComplete =
|
||||
authMethod === AUTH_METHOD_IAM ||
|
||||
(authMethod === AUTH_METHOD_ACCESS_KEY &&
|
||||
@@ -139,12 +115,12 @@ function BedrockModalInternals({
|
||||
formikProps.values.custom_config?.AWS_SECRET_ACCESS_KEY,
|
||||
aws_bearer_token_bedrock:
|
||||
formikProps.values.custom_config?.AWS_BEARER_TOKEN_BEDROCK,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
provider_name: LLMProviderName.BEDROCK,
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
setFetchedModels(models);
|
||||
formikProps.setFieldValue("model_configurations", models);
|
||||
};
|
||||
|
||||
// Auto-fetch models on initial load when editing an existing provider
|
||||
@@ -159,16 +135,8 @@ function BedrockModalInternals({
|
||||
});
|
||||
|
||||
return (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={BEDROCK_PROVIDER_NAME}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
>
|
||||
<FieldWrapper>
|
||||
<>
|
||||
<InputLayouts.FieldPadder>
|
||||
<Section gap={1}>
|
||||
<InputLayouts.Vertical
|
||||
name={FIELD_AWS_REGION_NAME}
|
||||
@@ -222,7 +190,7 @@ function BedrockModalInternals({
|
||||
</InputSelect>
|
||||
</InputLayouts.Vertical>
|
||||
</Section>
|
||||
</FieldWrapper>
|
||||
</InputLayouts.FieldPadder>
|
||||
|
||||
{authMethod === AUTH_METHOD_ACCESS_KEY && (
|
||||
<Card backgroundVariant="light" borderVariant="none" sizeVariant="lg">
|
||||
@@ -250,7 +218,7 @@ function BedrockModalInternals({
|
||||
)}
|
||||
|
||||
{authMethod === AUTH_METHOD_IAM && (
|
||||
<FieldWrapper>
|
||||
<InputLayouts.FieldPadder>
|
||||
<Card backgroundVariant="none" borderVariant="solid">
|
||||
<Content
|
||||
icon={SvgAlertCircle}
|
||||
@@ -259,7 +227,7 @@ function BedrockModalInternals({
|
||||
sizePreset="main-ui"
|
||||
/>
|
||||
</Card>
|
||||
</FieldWrapper>
|
||||
</InputLayouts.FieldPadder>
|
||||
)}
|
||||
|
||||
{authMethod === AUTH_METHOD_LONG_TERM_API_KEY && (
|
||||
@@ -280,32 +248,24 @@ function BedrockModalInternals({
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<FieldSeparator />
|
||||
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. us.anthropic.claude-sonnet-4-5-v1" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={currentModels}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={null}
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
)}
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</LLMConfigurationModalWrapper>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -315,86 +275,54 @@ export default function BedrockModal({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
onSuccess,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
|
||||
BEDROCK_PROVIDER_NAME
|
||||
);
|
||||
|
||||
if (open === false) return null;
|
||||
const initialValues: BedrockModalValues = {
|
||||
...useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.BEDROCK,
|
||||
existingLlmProvider
|
||||
),
|
||||
custom_config: {
|
||||
AWS_REGION_NAME:
|
||||
(existingLlmProvider?.custom_config?.AWS_REGION_NAME as string) ?? "",
|
||||
BEDROCK_AUTH_METHOD:
|
||||
(existingLlmProvider?.custom_config?.BEDROCK_AUTH_METHOD as string) ??
|
||||
"access_key",
|
||||
AWS_ACCESS_KEY_ID:
|
||||
(existingLlmProvider?.custom_config?.AWS_ACCESS_KEY_ID as string) ?? "",
|
||||
AWS_SECRET_ACCESS_KEY:
|
||||
(existingLlmProvider?.custom_config?.AWS_SECRET_ACCESS_KEY as string) ??
|
||||
"",
|
||||
AWS_BEARER_TOKEN_BEDROCK:
|
||||
(existingLlmProvider?.custom_config
|
||||
?.AWS_BEARER_TOKEN_BEDROCK as string) ?? "",
|
||||
},
|
||||
} as BedrockModalValues;
|
||||
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
extra: {
|
||||
custom_config: Yup.object({
|
||||
AWS_REGION_NAME: Yup.string().required("AWS Region is required"),
|
||||
}),
|
||||
},
|
||||
});
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const initialValues: BedrockModalValues = isOnboarding
|
||||
? ({
|
||||
...buildOnboardingInitialValues(),
|
||||
name: BEDROCK_PROVIDER_NAME,
|
||||
provider: BEDROCK_PROVIDER_NAME,
|
||||
default_model_name: "",
|
||||
custom_config: {
|
||||
AWS_REGION_NAME: "",
|
||||
BEDROCK_AUTH_METHOD: "access_key",
|
||||
AWS_ACCESS_KEY_ID: "",
|
||||
AWS_SECRET_ACCESS_KEY: "",
|
||||
AWS_BEARER_TOKEN_BEDROCK: "",
|
||||
},
|
||||
} as BedrockModalValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
custom_config: {
|
||||
AWS_REGION_NAME:
|
||||
(existingLlmProvider?.custom_config?.AWS_REGION_NAME as string) ??
|
||||
"",
|
||||
BEDROCK_AUTH_METHOD:
|
||||
(existingLlmProvider?.custom_config
|
||||
?.BEDROCK_AUTH_METHOD as string) ?? "access_key",
|
||||
AWS_ACCESS_KEY_ID:
|
||||
(existingLlmProvider?.custom_config?.AWS_ACCESS_KEY_ID as string) ??
|
||||
"",
|
||||
AWS_SECRET_ACCESS_KEY:
|
||||
(existingLlmProvider?.custom_config
|
||||
?.AWS_SECRET_ACCESS_KEY as string) ?? "",
|
||||
AWS_BEARER_TOKEN_BEDROCK:
|
||||
(existingLlmProvider?.custom_config
|
||||
?.AWS_BEARER_TOKEN_BEDROCK as string) ?? "",
|
||||
},
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
custom_config: Yup.object({
|
||||
AWS_REGION_NAME: Yup.string().required("AWS Region is required"),
|
||||
}),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
custom_config: Yup.object({
|
||||
AWS_REGION_NAME: Yup.string().required("AWS Region is required"),
|
||||
}),
|
||||
});
|
||||
if (open === false) return null;
|
||||
|
||||
return (
|
||||
<Formik
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.BEDROCK}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
const filteredCustomConfig = Object.fromEntries(
|
||||
Object.entries(values.custom_config || {}).filter(([, v]) => v !== "")
|
||||
);
|
||||
@@ -407,51 +335,37 @@ export default function BedrockModal({
|
||||
: undefined,
|
||||
};
|
||||
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
fetchedModels.length > 0 ? fetchedModels : [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: BEDROCK_PROVIDER_NAME,
|
||||
payload: {
|
||||
...submitValues,
|
||||
model_configurations: modelConfigsToUse,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: BEDROCK_PROVIDER_NAME,
|
||||
values: submitValues,
|
||||
initialValues,
|
||||
modelConfigurations:
|
||||
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.BEDROCK,
|
||||
values: submitValues,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
{(formikProps) => (
|
||||
<BedrockModalInternals
|
||||
formikProps={formikProps}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
fetchedModels={fetchedModels}
|
||||
setFetchedModels={setFetchedModels}
|
||||
modelConfigurations={modelConfigurations}
|
||||
isTesting={isTesting}
|
||||
onClose={onClose}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
)}
|
||||
</Formik>
|
||||
<BedrockModalInternals
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
</ModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,45 +1,33 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import { useEffect } from "react";
|
||||
import { markdown } from "@opal/utils";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { Formik, FormikProps } from "formik";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
|
||||
import { useFormikContext } from "formik";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
} from "@/interfaces/llm";
|
||||
import { fetchBifrostModels } from "@/app/admin/configuration/llm/utils";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import {
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
BaseLLMFormValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import {
|
||||
ModelsField,
|
||||
APIBaseField,
|
||||
APIKeyField,
|
||||
ModelSelectionField,
|
||||
DisplayNameField,
|
||||
ModelsAccessField,
|
||||
FieldSeparator,
|
||||
FieldWrapper,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
|
||||
const BIFROST_PROVIDER_NAME = LLMProviderName.BIFROST;
|
||||
const DEFAULT_API_BASE = "";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
|
||||
interface BifrostModalValues extends BaseLLMFormValues {
|
||||
api_key: string;
|
||||
@@ -47,30 +35,15 @@ interface BifrostModalValues extends BaseLLMFormValues {
|
||||
}
|
||||
|
||||
interface BifrostModalInternalsProps {
|
||||
formikProps: FormikProps<BifrostModalValues>;
|
||||
existingLlmProvider: LLMProviderView | undefined;
|
||||
fetchedModels: ModelConfiguration[];
|
||||
setFetchedModels: (models: ModelConfiguration[]) => void;
|
||||
modelConfigurations: ModelConfiguration[];
|
||||
isTesting: boolean;
|
||||
onClose: () => void;
|
||||
isOnboarding: boolean;
|
||||
}
|
||||
|
||||
function BifrostModalInternals({
|
||||
formikProps,
|
||||
existingLlmProvider,
|
||||
fetchedModels,
|
||||
setFetchedModels,
|
||||
modelConfigurations,
|
||||
isTesting,
|
||||
onClose,
|
||||
isOnboarding,
|
||||
}: BifrostModalInternalsProps) {
|
||||
const currentModels =
|
||||
fetchedModels.length > 0
|
||||
? fetchedModels
|
||||
: existingLlmProvider?.model_configurations || modelConfigurations;
|
||||
const formikProps = useFormikContext<BifrostModalValues>();
|
||||
|
||||
const isFetchDisabled = !formikProps.values.api_base;
|
||||
|
||||
@@ -78,12 +51,12 @@ function BifrostModalInternals({
|
||||
const { models, error } = await fetchBifrostModels({
|
||||
api_base: formikProps.values.api_base,
|
||||
api_key: formikProps.values.api_key || undefined,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
provider_name: LLMProviderName.BIFROST,
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
setFetchedModels(models);
|
||||
formikProps.setFieldValue("model_configurations", models);
|
||||
};
|
||||
|
||||
// Auto-fetch models on initial load when editing an existing provider
|
||||
@@ -100,69 +73,39 @@ function BifrostModalInternals({
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={LLMProviderName.BIFROST}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
>
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="api_base"
|
||||
title="API Base URL"
|
||||
subDescription="Paste your Bifrost gateway endpoint URL (including API version)."
|
||||
>
|
||||
<InputTypeInField
|
||||
name="api_base"
|
||||
placeholder="https://your-bifrost-gateway.com/v1"
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
<>
|
||||
<APIBaseField
|
||||
subDescription="Paste your Bifrost gateway endpoint URL (including API version)."
|
||||
placeholder="https://your-bifrost-gateway.com/v1"
|
||||
/>
|
||||
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="api_key"
|
||||
title="API Key"
|
||||
optional={true}
|
||||
subDescription={markdown(
|
||||
"Paste your API key from [Bifrost](https://docs.getbifrost.ai/overview) to access your models."
|
||||
)}
|
||||
>
|
||||
<PasswordInputTypeInField name="api_key" placeholder="API Key" />
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
<APIKeyField
|
||||
optional
|
||||
subDescription={markdown(
|
||||
"Paste your API key from [Bifrost](https://docs.getbifrost.ai/overview) to access your models."
|
||||
)}
|
||||
/>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<FieldSeparator />
|
||||
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. anthropic/claude-sonnet-4-6" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={currentModels}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={null}
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
)}
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</LLMConfigurationModalWrapper>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -172,107 +115,64 @@ export default function BifrostModal({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
onSuccess,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
|
||||
BIFROST_PROVIDER_NAME
|
||||
);
|
||||
|
||||
if (open === false) return null;
|
||||
const initialValues: BifrostModalValues = useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.BIFROST,
|
||||
existingLlmProvider
|
||||
) as BifrostModalValues;
|
||||
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
apiBase: true,
|
||||
});
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const initialValues: BifrostModalValues = isOnboarding
|
||||
? ({
|
||||
...buildOnboardingInitialValues(),
|
||||
name: BIFROST_PROVIDER_NAME,
|
||||
provider: BIFROST_PROVIDER_NAME,
|
||||
api_key: "",
|
||||
api_base: DEFAULT_API_BASE,
|
||||
default_model_name: "",
|
||||
} as BifrostModalValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
});
|
||||
if (open === false) return null;
|
||||
|
||||
return (
|
||||
<Formik
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.BIFROST}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
fetchedModels.length > 0 ? fetchedModels : [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: BIFROST_PROVIDER_NAME,
|
||||
payload: {
|
||||
...values,
|
||||
model_configurations: modelConfigsToUse,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: BIFROST_PROVIDER_NAME,
|
||||
values,
|
||||
initialValues,
|
||||
modelConfigurations:
|
||||
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.BIFROST,
|
||||
values,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
{(formikProps) => (
|
||||
<BifrostModalInternals
|
||||
formikProps={formikProps}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
fetchedModels={fetchedModels}
|
||||
setFetchedModels={setFetchedModels}
|
||||
modelConfigurations={modelConfigurations}
|
||||
isTesting={isTesting}
|
||||
onClose={onClose}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
)}
|
||||
</Formik>
|
||||
<BifrostModalInternals
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
</ModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -70,7 +70,9 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
}
|
||||
) {
|
||||
const nameInput = screen.getByPlaceholderText("Display Name");
|
||||
const providerInput = screen.getByPlaceholderText("Provider Name");
|
||||
const providerInput = screen.getByPlaceholderText(
|
||||
"Provider Name as shown on LiteLLM"
|
||||
);
|
||||
|
||||
await user.type(nameInput, options.name);
|
||||
await user.type(providerInput, options.provider);
|
||||
@@ -498,7 +500,9 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
const nameInput = screen.getByPlaceholderText("Display Name");
|
||||
await user.type(nameInput, "Cloudflare Provider");
|
||||
|
||||
const providerInput = screen.getByPlaceholderText("Provider Name");
|
||||
const providerInput = screen.getByPlaceholderText(
|
||||
"Provider Name as shown on LiteLLM"
|
||||
);
|
||||
await user.type(providerInput, "cloudflare");
|
||||
|
||||
// Click "Add Line" button for custom config (aria-label from KeyValueInput)
|
||||
|
||||
@@ -1,24 +1,23 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { markdown } from "@opal/utils";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { Formik, FormikProps } from "formik";
|
||||
import { LLMProviderFormProps, ModelConfiguration } from "@/interfaces/llm";
|
||||
import { useFormikContext } from "formik";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
ModelConfiguration,
|
||||
} from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import { useInitialValues } from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
buildDefaultInitialValues,
|
||||
buildOnboardingInitialValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import {
|
||||
APIKeyField,
|
||||
APIBaseField,
|
||||
DisplayNameField,
|
||||
FieldSeparator,
|
||||
ModelsAccessField,
|
||||
LLMConfigurationModalWrapper,
|
||||
FieldWrapper,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
@@ -32,6 +31,7 @@ import { Button, Card, EmptyMessageCard } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { SvgMinusCircle, SvgPlusCircle } from "@opal/icons";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
import { Content } from "@opal/layouts";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
|
||||
@@ -109,13 +109,10 @@ function ModelConfigurationItem({
|
||||
);
|
||||
}
|
||||
|
||||
interface ModelConfigurationListProps {
|
||||
formikProps: FormikProps<{
|
||||
function ModelConfigurationList() {
|
||||
const formikProps = useFormikContext<{
|
||||
model_configurations: CustomModelConfiguration[];
|
||||
}>;
|
||||
}
|
||||
|
||||
function ModelConfigurationList({ formikProps }: ModelConfigurationListProps) {
|
||||
}>();
|
||||
const models = formikProps.values.model_configurations;
|
||||
|
||||
function handleChange(index: number, next: CustomModelConfiguration) {
|
||||
@@ -181,6 +178,19 @@ function ModelConfigurationList({ formikProps }: ModelConfigurationListProps) {
|
||||
);
|
||||
}
|
||||
|
||||
function CustomConfigKeyValue() {
|
||||
const formikProps = useFormikContext<{ custom_config_list: KeyValue[] }>();
|
||||
return (
|
||||
<KeyValueInput
|
||||
items={formikProps.values.custom_config_list}
|
||||
onChange={(items) =>
|
||||
formikProps.setFieldValue("custom_config_list", items)
|
||||
}
|
||||
addButtonLabel="Add Line"
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Custom Config Processing ─────────────────────────────────────────────────
|
||||
|
||||
function customConfigProcessing(items: KeyValue[]) {
|
||||
@@ -197,39 +207,36 @@ export default function CustomModal({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
onSuccess,
|
||||
}: LLMProviderFormProps) {
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const { mutate } = useSWRConfig();
|
||||
|
||||
if (open === false) return null;
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const initialValues = {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
undefined,
|
||||
defaultModelName
|
||||
...useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.CUSTOM,
|
||||
existingLlmProvider
|
||||
),
|
||||
...(isOnboarding ? buildOnboardingInitialValues() : {}),
|
||||
provider: existingLlmProvider?.provider ?? "",
|
||||
api_version: existingLlmProvider?.api_version ?? "",
|
||||
model_configurations: existingLlmProvider?.model_configurations.map(
|
||||
(mc) => ({
|
||||
name: mc.name,
|
||||
display_name: mc.display_name ?? "",
|
||||
is_visible: mc.is_visible,
|
||||
max_input_tokens: mc.max_input_tokens ?? null,
|
||||
supports_image_input: mc.supports_image_input,
|
||||
supports_reasoning: mc.supports_reasoning,
|
||||
})
|
||||
) ?? [
|
||||
{
|
||||
name: "",
|
||||
display_name: "",
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
],
|
||||
custom_config_list: existingLlmProvider?.custom_config
|
||||
@@ -260,12 +267,18 @@ export default function CustomModal({
|
||||
model_configurations: Yup.array(modelConfigurationSchema),
|
||||
});
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
if (open === false) return null;
|
||||
|
||||
return (
|
||||
<Formik
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.CUSTOM}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
setSubmitting(true);
|
||||
|
||||
const modelConfigurations = values.model_configurations
|
||||
@@ -285,127 +298,123 @@ export default function CustomModal({
|
||||
return;
|
||||
}
|
||||
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
await submitOnboardingProvider({
|
||||
providerName: values.provider,
|
||||
payload: {
|
||||
...values,
|
||||
model_configurations: modelConfigurations,
|
||||
custom_config: customConfigProcessing(values.custom_config_list),
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: true,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
const selectedModelNames = modelConfigurations.map(
|
||||
(config) => config.name
|
||||
);
|
||||
// Always send custom_config as a dict (even empty) so the backend
|
||||
// preserves it as non-null — this is the signal that the provider was
|
||||
// created via CustomModal.
|
||||
const customConfig = customConfigProcessing(values.custom_config_list);
|
||||
|
||||
await submitLLMProvider({
|
||||
providerName: values.provider,
|
||||
values: {
|
||||
...values,
|
||||
selected_model_names: selectedModelNames,
|
||||
custom_config: customConfigProcessing(values.custom_config_list),
|
||||
},
|
||||
initialValues: {
|
||||
...initialValues,
|
||||
custom_config: customConfigProcessing(
|
||||
initialValues.custom_config_list
|
||||
),
|
||||
},
|
||||
modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: (values as Record<string, unknown>).provider as string,
|
||||
values: {
|
||||
...values,
|
||||
model_configurations: modelConfigurations,
|
||||
custom_config: customConfig,
|
||||
},
|
||||
initialValues: {
|
||||
...initialValues,
|
||||
custom_config: customConfigProcessing(
|
||||
initialValues.custom_config_list
|
||||
),
|
||||
},
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
isCustomProvider: true,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
{(formikProps) => (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint="custom"
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
<InputLayouts.FieldPadder>
|
||||
<InputLayouts.Vertical
|
||||
name="provider"
|
||||
title="Provider Name"
|
||||
subDescription={markdown(
|
||||
"Should be one of the providers listed at [LiteLLM](https://docs.litellm.ai/docs/providers)."
|
||||
)}
|
||||
>
|
||||
{!isOnboarding && (
|
||||
<Section gap={0}>
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
<InputTypeInField
|
||||
name="provider"
|
||||
placeholder="Provider Name as shown on LiteLLM"
|
||||
variant={existingLlmProvider ? "disabled" : undefined}
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</InputLayouts.FieldPadder>
|
||||
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="provider"
|
||||
title="Provider Name"
|
||||
subDescription="Should be one of the providers listed at https://docs.litellm.ai/docs/providers."
|
||||
>
|
||||
<InputTypeInField
|
||||
name="provider"
|
||||
placeholder="Provider Name"
|
||||
variant={existingLlmProvider ? "disabled" : undefined}
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
</Section>
|
||||
)}
|
||||
<APIBaseField optional />
|
||||
|
||||
<FieldSeparator />
|
||||
<InputLayouts.FieldPadder>
|
||||
<InputLayouts.Vertical name="api_version" title="API Version" optional>
|
||||
<InputTypeInField name="api_version" />
|
||||
</InputLayouts.Vertical>
|
||||
</InputLayouts.FieldPadder>
|
||||
|
||||
<FieldWrapper>
|
||||
<Section gap={0.75}>
|
||||
<Content
|
||||
title="Provider Configs"
|
||||
description="Add properties as needed by the model provider. This is passed to LiteLLM completion() call as arguments in the environment variable. See LiteLLM documentation for more instructions."
|
||||
widthVariant="full"
|
||||
variant="section"
|
||||
sizePreset="main-content"
|
||||
/>
|
||||
<APIKeyField
|
||||
optional
|
||||
subDescription="Paste your API key if your model provider requires authentication."
|
||||
/>
|
||||
|
||||
<KeyValueInput
|
||||
items={formikProps.values.custom_config_list}
|
||||
onChange={(items) =>
|
||||
formikProps.setFieldValue("custom_config_list", items)
|
||||
}
|
||||
addButtonLabel="Add Line"
|
||||
/>
|
||||
</Section>
|
||||
</FieldWrapper>
|
||||
<InputLayouts.FieldPadder>
|
||||
<Section gap={0.75}>
|
||||
<Content
|
||||
title="Additional Configs"
|
||||
description={markdown(
|
||||
"Add extra properties as needed by the model provider. These are passed to LiteLLM's `completion()` call as [environment variables](https://docs.litellm.ai/docs/set_keys#environment-variables). See [documentation](https://docs.onyx.app/admins/ai_models/custom_inference_provider) for more instructions."
|
||||
)}
|
||||
widthVariant="full"
|
||||
variant="section"
|
||||
sizePreset="main-content"
|
||||
/>
|
||||
|
||||
<FieldSeparator />
|
||||
<CustomConfigKeyValue />
|
||||
</Section>
|
||||
</InputLayouts.FieldPadder>
|
||||
|
||||
<Section gap={0.5}>
|
||||
<FieldWrapper>
|
||||
<Content
|
||||
title="Models"
|
||||
description="List LLM models you wish to use and their configurations for this provider. See full list of models at LiteLLM."
|
||||
variant="section"
|
||||
sizePreset="main-content"
|
||||
widthVariant="full"
|
||||
/>
|
||||
</FieldWrapper>
|
||||
|
||||
<Card>
|
||||
<ModelConfigurationList formikProps={formikProps as any} />
|
||||
</Card>
|
||||
</Section>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
</>
|
||||
)}
|
||||
</LLMConfigurationModalWrapper>
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
</Formik>
|
||||
|
||||
<InputLayouts.FieldSeparator />
|
||||
<Section gap={0.5}>
|
||||
<InputLayouts.FieldPadder>
|
||||
<Content
|
||||
title="Models"
|
||||
description="List LLM models you wish to use and their configurations for this provider. See full list of models at LiteLLM."
|
||||
variant="section"
|
||||
sizePreset="main-content"
|
||||
widthVariant="full"
|
||||
/>
|
||||
</InputLayouts.FieldPadder>
|
||||
|
||||
<Card sizeVariant="lg">
|
||||
<ModelConfigurationList />
|
||||
</Card>
|
||||
</Section>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</ModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,315 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useCallback, useEffect, useMemo, useState } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { Formik, FormikProps } from "formik";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
} from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import {
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
BaseLLMFormValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import {
|
||||
ModelsField,
|
||||
DisplayNameField,
|
||||
ModelsAccessField,
|
||||
FieldSeparator,
|
||||
FieldWrapper,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import { fetchModels } from "@/app/admin/configuration/llm/utils";
|
||||
import debounce from "lodash/debounce";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
|
||||
const DEFAULT_API_BASE = "http://localhost:1234";
|
||||
|
||||
interface LMStudioFormValues extends BaseLLMFormValues {
|
||||
api_base: string;
|
||||
custom_config: {
|
||||
LM_STUDIO_API_KEY?: string;
|
||||
};
|
||||
}
|
||||
|
||||
interface LMStudioFormInternalsProps {
|
||||
formikProps: FormikProps<LMStudioFormValues>;
|
||||
existingLlmProvider: LLMProviderView | undefined;
|
||||
fetchedModels: ModelConfiguration[];
|
||||
setFetchedModels: (models: ModelConfiguration[]) => void;
|
||||
isTesting: boolean;
|
||||
onClose: () => void;
|
||||
isOnboarding: boolean;
|
||||
}
|
||||
|
||||
function LMStudioFormInternals({
|
||||
formikProps,
|
||||
existingLlmProvider,
|
||||
fetchedModels,
|
||||
setFetchedModels,
|
||||
isTesting,
|
||||
onClose,
|
||||
isOnboarding,
|
||||
}: LMStudioFormInternalsProps) {
|
||||
const initialApiKey =
|
||||
(existingLlmProvider?.custom_config?.LM_STUDIO_API_KEY as string) ?? "";
|
||||
|
||||
const doFetchModels = useCallback(
|
||||
(apiBase: string, apiKey: string | undefined, signal: AbortSignal) => {
|
||||
fetchModels(
|
||||
LLMProviderName.LM_STUDIO,
|
||||
{
|
||||
api_base: apiBase,
|
||||
custom_config: apiKey ? { LM_STUDIO_API_KEY: apiKey } : {},
|
||||
api_key_changed: apiKey !== initialApiKey,
|
||||
name: existingLlmProvider?.name,
|
||||
},
|
||||
signal
|
||||
).then((data) => {
|
||||
if (signal.aborted) return;
|
||||
if (data.error) {
|
||||
toast.error(data.error);
|
||||
setFetchedModels([]);
|
||||
return;
|
||||
}
|
||||
setFetchedModels(data.models);
|
||||
});
|
||||
},
|
||||
[existingLlmProvider?.name, initialApiKey, setFetchedModels]
|
||||
);
|
||||
|
||||
const debouncedFetchModels = useMemo(
|
||||
() => debounce(doFetchModels, 500),
|
||||
[doFetchModels]
|
||||
);
|
||||
|
||||
const apiBase = formikProps.values.api_base;
|
||||
const apiKey = formikProps.values.custom_config?.LM_STUDIO_API_KEY;
|
||||
|
||||
useEffect(() => {
|
||||
if (apiBase) {
|
||||
const controller = new AbortController();
|
||||
debouncedFetchModels(apiBase, apiKey, controller.signal);
|
||||
return () => {
|
||||
debouncedFetchModels.cancel();
|
||||
controller.abort();
|
||||
};
|
||||
} else {
|
||||
setFetchedModels([]);
|
||||
}
|
||||
}, [apiBase, apiKey, debouncedFetchModels, setFetchedModels]);
|
||||
|
||||
const currentModels =
|
||||
fetchedModels.length > 0
|
||||
? fetchedModels
|
||||
: existingLlmProvider?.model_configurations || [];
|
||||
|
||||
return (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={LLMProviderName.LM_STUDIO}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
>
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="api_base"
|
||||
title="API Base URL"
|
||||
subDescription="The base URL for your LM Studio server."
|
||||
>
|
||||
<InputTypeInField
|
||||
name="api_base"
|
||||
placeholder="Your LM Studio API base URL"
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="custom_config.LM_STUDIO_API_KEY"
|
||||
title="API Key"
|
||||
subDescription="Optional API key if your LM Studio server requires authentication."
|
||||
optional
|
||||
>
|
||||
<PasswordInputTypeInField
|
||||
name="custom_config.LM_STUDIO_API_KEY"
|
||||
placeholder="API Key"
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<FieldSeparator />
|
||||
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. llama3.1" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={currentModels}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={null}
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
/>
|
||||
)}
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
</>
|
||||
)}
|
||||
</LLMConfigurationModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
export default function LMStudioForm({
|
||||
variant = "llm-configuration",
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
|
||||
LLMProviderName.LM_STUDIO
|
||||
);
|
||||
|
||||
if (open === false) return null;
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const initialValues: LMStudioFormValues = isOnboarding
|
||||
? ({
|
||||
...buildOnboardingInitialValues(),
|
||||
name: LLMProviderName.LM_STUDIO,
|
||||
provider: LLMProviderName.LM_STUDIO,
|
||||
api_base: DEFAULT_API_BASE,
|
||||
default_model_name: "",
|
||||
custom_config: {
|
||||
LM_STUDIO_API_KEY: "",
|
||||
},
|
||||
} as LMStudioFormValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
custom_config: {
|
||||
LM_STUDIO_API_KEY:
|
||||
(existingLlmProvider?.custom_config?.LM_STUDIO_API_KEY as string) ??
|
||||
"",
|
||||
},
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
});
|
||||
|
||||
return (
|
||||
<Formik
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
const filteredCustomConfig = Object.fromEntries(
|
||||
Object.entries(values.custom_config || {}).filter(([, v]) => v !== "")
|
||||
);
|
||||
|
||||
const submitValues = {
|
||||
...values,
|
||||
custom_config:
|
||||
Object.keys(filteredCustomConfig).length > 0
|
||||
? filteredCustomConfig
|
||||
: undefined,
|
||||
};
|
||||
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
fetchedModels.length > 0 ? fetchedModels : [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: LLMProviderName.LM_STUDIO,
|
||||
payload: {
|
||||
...submitValues,
|
||||
model_configurations: modelConfigsToUse,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: LLMProviderName.LM_STUDIO,
|
||||
values: submitValues,
|
||||
initialValues,
|
||||
modelConfigurations:
|
||||
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
{(formikProps) => (
|
||||
<LMStudioFormInternals
|
||||
formikProps={formikProps}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
fetchedModels={fetchedModels}
|
||||
setFetchedModels={setFetchedModels}
|
||||
isTesting={isTesting}
|
||||
onClose={onClose}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
)}
|
||||
</Formik>
|
||||
);
|
||||
}
|
||||
217
web/src/sections/modals/llmConfig/LMStudioModal.tsx
Normal file
217
web/src/sections/modals/llmConfig/LMStudioModal.tsx
Normal file
@@ -0,0 +1,217 @@
|
||||
"use client";
|
||||
|
||||
import { useCallback, useEffect, useMemo } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { useFormikContext } from "formik";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
} from "@/interfaces/llm";
|
||||
import {
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
BaseLLMFormValues as BaseLLMModalValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
APIKeyField,
|
||||
APIBaseField,
|
||||
ModelSelectionField,
|
||||
DisplayNameField,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import { fetchModels } from "@/app/admin/configuration/llm/utils";
|
||||
import debounce from "lodash/debounce";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
|
||||
const DEFAULT_API_BASE = "http://localhost:1234";
|
||||
|
||||
interface LMStudioModalValues extends BaseLLMModalValues {
|
||||
api_base: string;
|
||||
custom_config: {
|
||||
LM_STUDIO_API_KEY?: string;
|
||||
};
|
||||
}
|
||||
|
||||
interface LMStudioModalInternalsProps {
|
||||
existingLlmProvider: LLMProviderView | undefined;
|
||||
isOnboarding: boolean;
|
||||
}
|
||||
|
||||
function LMStudioModalInternals({
|
||||
existingLlmProvider,
|
||||
isOnboarding,
|
||||
}: LMStudioModalInternalsProps) {
|
||||
const formikProps = useFormikContext<LMStudioModalValues>();
|
||||
const initialApiKey = existingLlmProvider?.custom_config?.LM_STUDIO_API_KEY;
|
||||
|
||||
const doFetchModels = useCallback(
|
||||
(apiBase: string, apiKey: string | undefined, signal: AbortSignal) => {
|
||||
fetchModels(
|
||||
LLMProviderName.LM_STUDIO,
|
||||
{
|
||||
api_base: apiBase,
|
||||
custom_config: apiKey ? { LM_STUDIO_API_KEY: apiKey } : {},
|
||||
api_key_changed: apiKey !== initialApiKey,
|
||||
name: existingLlmProvider?.name,
|
||||
},
|
||||
signal
|
||||
).then((data) => {
|
||||
if (signal.aborted) return;
|
||||
if (data.error) {
|
||||
toast.error(data.error);
|
||||
formikProps.setFieldValue("model_configurations", []);
|
||||
return;
|
||||
}
|
||||
formikProps.setFieldValue("model_configurations", data.models);
|
||||
});
|
||||
},
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
[existingLlmProvider?.name, initialApiKey]
|
||||
);
|
||||
|
||||
const debouncedFetchModels = useMemo(
|
||||
() => debounce(doFetchModels, 500),
|
||||
[doFetchModels]
|
||||
);
|
||||
|
||||
const apiBase = formikProps.values.api_base;
|
||||
const apiKey = formikProps.values.custom_config?.LM_STUDIO_API_KEY;
|
||||
|
||||
useEffect(() => {
|
||||
if (apiBase) {
|
||||
const controller = new AbortController();
|
||||
debouncedFetchModels(apiBase, apiKey, controller.signal);
|
||||
return () => {
|
||||
debouncedFetchModels.cancel();
|
||||
controller.abort();
|
||||
};
|
||||
} else {
|
||||
formikProps.setFieldValue("model_configurations", []);
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [apiBase, apiKey, debouncedFetchModels]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<APIBaseField
|
||||
subDescription="The base URL for your LM Studio server."
|
||||
placeholder="Your LM Studio API base URL"
|
||||
/>
|
||||
|
||||
<APIKeyField
|
||||
name="custom_config.LM_STUDIO_API_KEY"
|
||||
optional
|
||||
subDescription="Optional API key if your LM Studio server requires authentication."
|
||||
/>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField shouldShowAutoUpdateToggle={false} />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export default function LMStudioModal({
|
||||
variant = "llm-configuration",
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
onSuccess,
|
||||
}: LLMProviderFormProps) {
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const { mutate } = useSWRConfig();
|
||||
|
||||
const initialValues: LMStudioModalValues = {
|
||||
...useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.LM_STUDIO,
|
||||
existingLlmProvider
|
||||
),
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
custom_config: {
|
||||
LM_STUDIO_API_KEY: existingLlmProvider?.custom_config?.LM_STUDIO_API_KEY,
|
||||
},
|
||||
} as LMStudioModalValues;
|
||||
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
apiBase: true,
|
||||
});
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
if (open === false) return null;
|
||||
|
||||
return (
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.LM_STUDIO}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
const filteredCustomConfig = Object.fromEntries(
|
||||
Object.entries(values.custom_config || {}).filter(([, v]) => v !== "")
|
||||
);
|
||||
|
||||
const submitValues = {
|
||||
...values,
|
||||
custom_config:
|
||||
Object.keys(filteredCustomConfig).length > 0
|
||||
? filteredCustomConfig
|
||||
: undefined,
|
||||
};
|
||||
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.LM_STUDIO,
|
||||
values: submitValues,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
<LMStudioModalInternals
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
</ModalWrapper>
|
||||
);
|
||||
}
|
||||
@@ -1,41 +1,32 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import { useEffect } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { Formik, FormikProps } from "formik";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import { useFormikContext } from "formik";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
} from "@/interfaces/llm";
|
||||
import { fetchLiteLLMProxyModels } from "@/app/admin/configuration/llm/utils";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import {
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
BaseLLMFormValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
APIKeyField,
|
||||
ModelsField,
|
||||
APIBaseField,
|
||||
ModelSelectionField,
|
||||
DisplayNameField,
|
||||
ModelsAccessField,
|
||||
FieldSeparator,
|
||||
FieldWrapper,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
|
||||
const DEFAULT_API_BASE = "http://localhost:4000";
|
||||
|
||||
@@ -45,30 +36,15 @@ interface LiteLLMProxyModalValues extends BaseLLMFormValues {
|
||||
}
|
||||
|
||||
interface LiteLLMProxyModalInternalsProps {
|
||||
formikProps: FormikProps<LiteLLMProxyModalValues>;
|
||||
existingLlmProvider: LLMProviderView | undefined;
|
||||
fetchedModels: ModelConfiguration[];
|
||||
setFetchedModels: (models: ModelConfiguration[]) => void;
|
||||
modelConfigurations: ModelConfiguration[];
|
||||
isTesting: boolean;
|
||||
onClose: () => void;
|
||||
isOnboarding: boolean;
|
||||
}
|
||||
|
||||
function LiteLLMProxyModalInternals({
|
||||
formikProps,
|
||||
existingLlmProvider,
|
||||
fetchedModels,
|
||||
setFetchedModels,
|
||||
modelConfigurations,
|
||||
isTesting,
|
||||
onClose,
|
||||
isOnboarding,
|
||||
}: LiteLLMProxyModalInternalsProps) {
|
||||
const currentModels =
|
||||
fetchedModels.length > 0
|
||||
? fetchedModels
|
||||
: existingLlmProvider?.model_configurations || modelConfigurations;
|
||||
const formikProps = useFormikContext<LiteLLMProxyModalValues>();
|
||||
|
||||
const isFetchDisabled =
|
||||
!formikProps.values.api_base || !formikProps.values.api_key;
|
||||
@@ -77,12 +53,12 @@ function LiteLLMProxyModalInternals({
|
||||
const { models, error } = await fetchLiteLLMProxyModels({
|
||||
api_base: formikProps.values.api_base,
|
||||
api_key: formikProps.values.api_key,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
provider_name: LLMProviderName.LITELLM_PROXY,
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
setFetchedModels(models);
|
||||
formikProps.setFieldValue("model_configurations", models);
|
||||
};
|
||||
|
||||
// Auto-fetch models on initial load when editing an existing provider
|
||||
@@ -98,58 +74,34 @@ function LiteLLMProxyModalInternals({
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={LLMProviderName.LITELLM_PROXY}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
>
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="api_base"
|
||||
title="API Base URL"
|
||||
subDescription="The base URL for your LiteLLM Proxy server."
|
||||
>
|
||||
<InputTypeInField
|
||||
name="api_base"
|
||||
placeholder="https://your-litellm-proxy.com"
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
<>
|
||||
<APIBaseField
|
||||
subDescription="The base URL for your LiteLLM Proxy server."
|
||||
placeholder="https://your-litellm-proxy.com"
|
||||
/>
|
||||
|
||||
<APIKeyField providerName="LiteLLM Proxy" />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<FieldSeparator />
|
||||
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. gpt-4o" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={currentModels}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={null}
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
)}
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</LLMConfigurationModalWrapper>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -159,109 +111,68 @@ export default function LiteLLMProxyModal({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
onSuccess,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
|
||||
LLMProviderName.LITELLM_PROXY
|
||||
);
|
||||
|
||||
if (open === false) return null;
|
||||
const initialValues: LiteLLMProxyModalValues = {
|
||||
...useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.LITELLM_PROXY,
|
||||
existingLlmProvider
|
||||
),
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
} as LiteLLMProxyModalValues;
|
||||
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
apiKey: true,
|
||||
apiBase: true,
|
||||
});
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const initialValues: LiteLLMProxyModalValues = isOnboarding
|
||||
? ({
|
||||
...buildOnboardingInitialValues(),
|
||||
name: LLMProviderName.LITELLM_PROXY,
|
||||
provider: LLMProviderName.LITELLM_PROXY,
|
||||
api_key: "",
|
||||
api_base: DEFAULT_API_BASE,
|
||||
default_model_name: "",
|
||||
} as LiteLLMProxyModalValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
});
|
||||
if (open === false) return null;
|
||||
|
||||
return (
|
||||
<Formik
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.LITELLM_PROXY}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
fetchedModels.length > 0 ? fetchedModels : [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: LLMProviderName.LITELLM_PROXY,
|
||||
payload: {
|
||||
...values,
|
||||
model_configurations: modelConfigsToUse,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: LLMProviderName.LITELLM_PROXY,
|
||||
values,
|
||||
initialValues,
|
||||
modelConfigurations:
|
||||
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.LITELLM_PROXY,
|
||||
values,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
{(formikProps) => (
|
||||
<LiteLLMProxyModalInternals
|
||||
formikProps={formikProps}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
fetchedModels={fetchedModels}
|
||||
setFetchedModels={setFetchedModels}
|
||||
modelConfigurations={modelConfigurations}
|
||||
isTesting={isTesting}
|
||||
onClose={onClose}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
)}
|
||||
</Formik>
|
||||
<LiteLLMProxyModalInternals
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
</ModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,47 +1,44 @@
|
||||
"use client";
|
||||
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import * as Yup from "yup";
|
||||
import { Dispatch, SetStateAction, useMemo, useState } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { Formik, FormikProps } from "formik";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import { useFormikContext } from "formik";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
} from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import {
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
BaseLLMFormValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import {
|
||||
ModelsField,
|
||||
ModelSelectionField,
|
||||
DisplayNameField,
|
||||
ModelsAccessField,
|
||||
FieldSeparator,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import { fetchOllamaModels } from "@/app/admin/configuration/llm/utils";
|
||||
import debounce from "lodash/debounce";
|
||||
import Tabs from "@/refresh-components/Tabs";
|
||||
import { Card } from "@opal/components";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import useOnMount from "@/hooks/useOnMount";
|
||||
|
||||
const OLLAMA_PROVIDER_NAME = "ollama_chat";
|
||||
const DEFAULT_API_BASE = "http://127.0.0.1:11434";
|
||||
const TAB_SELF_HOSTED = "self-hosted";
|
||||
const TAB_CLOUD = "cloud";
|
||||
const CLOUD_API_BASE = "https://ollama.com";
|
||||
|
||||
enum Tab {
|
||||
TAB_SELF_HOSTED = "self-hosted",
|
||||
TAB_CLOUD = "cloud",
|
||||
}
|
||||
|
||||
interface OllamaModalValues extends BaseLLMFormValues {
|
||||
api_base: string;
|
||||
@@ -51,104 +48,65 @@ interface OllamaModalValues extends BaseLLMFormValues {
|
||||
}
|
||||
|
||||
interface OllamaModalInternalsProps {
|
||||
formikProps: FormikProps<OllamaModalValues>;
|
||||
existingLlmProvider: LLMProviderView | undefined;
|
||||
fetchedModels: ModelConfiguration[];
|
||||
setFetchedModels: (models: ModelConfiguration[]) => void;
|
||||
isTesting: boolean;
|
||||
onClose: () => void;
|
||||
isOnboarding: boolean;
|
||||
tab: Tab;
|
||||
setTab: Dispatch<SetStateAction<Tab>>;
|
||||
}
|
||||
|
||||
function OllamaModalInternals({
|
||||
formikProps,
|
||||
existingLlmProvider,
|
||||
fetchedModels,
|
||||
setFetchedModels,
|
||||
isTesting,
|
||||
onClose,
|
||||
isOnboarding,
|
||||
tab,
|
||||
setTab,
|
||||
}: OllamaModalInternalsProps) {
|
||||
const isInitialMount = useRef(true);
|
||||
const formikProps = useFormikContext<OllamaModalValues>();
|
||||
|
||||
const doFetchModels = useCallback(
|
||||
(apiBase: string, signal: AbortSignal) => {
|
||||
fetchOllamaModels({
|
||||
api_base: apiBase,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
signal,
|
||||
}).then((data) => {
|
||||
if (signal.aborted) return;
|
||||
if (data.error) {
|
||||
toast.error(data.error);
|
||||
setFetchedModels([]);
|
||||
return;
|
||||
}
|
||||
setFetchedModels(data.models);
|
||||
const isFetchDisabled = useMemo(
|
||||
() =>
|
||||
tab === Tab.TAB_SELF_HOSTED
|
||||
? !formikProps.values.api_base
|
||||
: !formikProps.values.custom_config.OLLAMA_API_KEY,
|
||||
[tab, formikProps]
|
||||
);
|
||||
|
||||
const handleFetchModels = async () => {
|
||||
// Only Ollama cloud accepts API key
|
||||
const apiBase = formikProps.values.custom_config?.OLLAMA_API_KEY
|
||||
? CLOUD_API_BASE
|
||||
: formikProps.values.api_base;
|
||||
const { models, error } = await fetchOllamaModels({
|
||||
api_base: apiBase,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
formikProps.setFieldValue("model_configurations", models);
|
||||
};
|
||||
|
||||
// Auto-fetch models on initial load when editing an existing provider
|
||||
useOnMount(() => {
|
||||
if (existingLlmProvider) {
|
||||
handleFetchModels().catch((err) => {
|
||||
toast.error(
|
||||
err instanceof Error ? err.message : "Failed to fetch models"
|
||||
);
|
||||
});
|
||||
},
|
||||
[existingLlmProvider?.name, setFetchedModels]
|
||||
);
|
||||
|
||||
const debouncedFetchModels = useMemo(
|
||||
() => debounce(doFetchModels, 500),
|
||||
[doFetchModels]
|
||||
);
|
||||
|
||||
// Skip the initial fetch for new providers — api_base starts with a default
|
||||
// value, which would otherwise trigger a fetch before the user has done
|
||||
// anything. Existing providers should still auto-fetch on mount.
|
||||
useEffect(() => {
|
||||
if (isInitialMount.current) {
|
||||
isInitialMount.current = false;
|
||||
if (!existingLlmProvider) return;
|
||||
}
|
||||
|
||||
if (formikProps.values.api_base) {
|
||||
const controller = new AbortController();
|
||||
debouncedFetchModels(formikProps.values.api_base, controller.signal);
|
||||
return () => {
|
||||
debouncedFetchModels.cancel();
|
||||
controller.abort();
|
||||
};
|
||||
} else {
|
||||
setFetchedModels([]);
|
||||
}
|
||||
}, [
|
||||
formikProps.values.api_base,
|
||||
debouncedFetchModels,
|
||||
setFetchedModels,
|
||||
existingLlmProvider,
|
||||
]);
|
||||
|
||||
const currentModels =
|
||||
fetchedModels.length > 0
|
||||
? fetchedModels
|
||||
: existingLlmProvider?.model_configurations || [];
|
||||
|
||||
const hasApiKey = !!formikProps.values.custom_config?.OLLAMA_API_KEY;
|
||||
const defaultTab =
|
||||
existingLlmProvider && hasApiKey ? TAB_CLOUD : TAB_SELF_HOSTED;
|
||||
});
|
||||
|
||||
return (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={OLLAMA_PROVIDER_NAME}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
>
|
||||
<>
|
||||
<Card backgroundVariant="light" borderVariant="none" sizeVariant="lg">
|
||||
<Tabs defaultValue={defaultTab}>
|
||||
<Tabs value={tab} onValueChange={(value) => setTab(value as Tab)}>
|
||||
<Tabs.List>
|
||||
<Tabs.Trigger value={TAB_SELF_HOSTED}>
|
||||
<Tabs.Trigger value={Tab.TAB_SELF_HOSTED}>
|
||||
Self-hosted Ollama
|
||||
</Tabs.Trigger>
|
||||
<Tabs.Trigger value={TAB_CLOUD}>Ollama Cloud</Tabs.Trigger>
|
||||
<Tabs.Trigger value={Tab.TAB_CLOUD}>Ollama Cloud</Tabs.Trigger>
|
||||
</Tabs.List>
|
||||
<Tabs.Content value={TAB_SELF_HOSTED}>
|
||||
<Tabs.Content value={Tab.TAB_SELF_HOSTED} padding={0}>
|
||||
<InputLayouts.Vertical
|
||||
name="api_base"
|
||||
title="API Base URL"
|
||||
@@ -161,7 +119,7 @@ function OllamaModalInternals({
|
||||
</InputLayouts.Vertical>
|
||||
</Tabs.Content>
|
||||
|
||||
<Tabs.Content value={TAB_CLOUD}>
|
||||
<Tabs.Content value={Tab.TAB_CLOUD}>
|
||||
<InputLayouts.Vertical
|
||||
name="custom_config.OLLAMA_API_KEY"
|
||||
title="API Key"
|
||||
@@ -178,31 +136,24 @@ function OllamaModalInternals({
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<FieldSeparator />
|
||||
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. llama3.1" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={currentModels}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={null}
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
/>
|
||||
)}
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</LLMConfigurationModalWrapper>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -212,67 +163,55 @@ export default function OllamaModal({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
onSuccess,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } =
|
||||
useWellKnownLLMProvider(OLLAMA_PROVIDER_NAME);
|
||||
const apiKey = existingLlmProvider?.custom_config?.OLLAMA_API_KEY;
|
||||
const defaultTab =
|
||||
existingLlmProvider && !!apiKey ? Tab.TAB_CLOUD : Tab.TAB_SELF_HOSTED;
|
||||
const [tab, setTab] = useState<Tab>(defaultTab);
|
||||
|
||||
if (open === false) return null;
|
||||
const initialValues: OllamaModalValues = {
|
||||
...useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.OLLAMA_CHAT,
|
||||
existingLlmProvider
|
||||
),
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
custom_config: {
|
||||
OLLAMA_API_KEY: apiKey,
|
||||
},
|
||||
} as OllamaModalValues;
|
||||
|
||||
const validationSchema = useMemo(
|
||||
() =>
|
||||
buildValidationSchema(isOnboarding, {
|
||||
apiBase: tab === Tab.TAB_SELF_HOSTED,
|
||||
extra:
|
||||
tab === Tab.TAB_CLOUD
|
||||
? {
|
||||
custom_config: Yup.object({
|
||||
OLLAMA_API_KEY: Yup.string().required("API Key is required"),
|
||||
}),
|
||||
}
|
||||
: undefined,
|
||||
}),
|
||||
[tab, isOnboarding]
|
||||
);
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const initialValues: OllamaModalValues = isOnboarding
|
||||
? ({
|
||||
...buildOnboardingInitialValues(),
|
||||
name: OLLAMA_PROVIDER_NAME,
|
||||
provider: OLLAMA_PROVIDER_NAME,
|
||||
api_base: DEFAULT_API_BASE,
|
||||
default_model_name: "",
|
||||
custom_config: {
|
||||
OLLAMA_API_KEY: "",
|
||||
},
|
||||
} as OllamaModalValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
custom_config: {
|
||||
OLLAMA_API_KEY:
|
||||
(existingLlmProvider?.custom_config?.OLLAMA_API_KEY as string) ??
|
||||
"",
|
||||
},
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
});
|
||||
if (open === false) return null;
|
||||
|
||||
return (
|
||||
<Formik
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.OLLAMA_CHAT}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
const filteredCustomConfig = Object.fromEntries(
|
||||
Object.entries(values.custom_config || {}).filter(([, v]) => v !== "")
|
||||
);
|
||||
@@ -285,50 +224,39 @@ export default function OllamaModal({
|
||||
: undefined,
|
||||
};
|
||||
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
fetchedModels.length > 0 ? fetchedModels : [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: OLLAMA_PROVIDER_NAME,
|
||||
payload: {
|
||||
...submitValues,
|
||||
model_configurations: modelConfigsToUse,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: OLLAMA_PROVIDER_NAME,
|
||||
values: submitValues,
|
||||
initialValues,
|
||||
modelConfigurations:
|
||||
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.OLLAMA_CHAT,
|
||||
values: submitValues,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
{(formikProps) => (
|
||||
<OllamaModalInternals
|
||||
formikProps={formikProps}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
fetchedModels={fetchedModels}
|
||||
setFetchedModels={setFetchedModels}
|
||||
isTesting={isTesting}
|
||||
onClose={onClose}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
)}
|
||||
</Formik>
|
||||
<OllamaModalInternals
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
isOnboarding={isOnboarding}
|
||||
tab={tab}
|
||||
setTab={setTab}
|
||||
/>
|
||||
</ModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
177
web/src/sections/modals/llmConfig/OpenAICompatibleModal.tsx
Normal file
177
web/src/sections/modals/llmConfig/OpenAICompatibleModal.tsx
Normal file
@@ -0,0 +1,177 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect } from "react";
|
||||
import { markdown } from "@opal/utils";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { useFormikContext } from "formik";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
} from "@/interfaces/llm";
|
||||
import { fetchOpenAICompatibleModels } from "@/app/admin/configuration/llm/utils";
|
||||
import {
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
BaseLLMFormValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
APIBaseField,
|
||||
APIKeyField,
|
||||
ModelSelectionField,
|
||||
DisplayNameField,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
|
||||
interface OpenAICompatibleModalValues extends BaseLLMFormValues {
|
||||
api_key: string;
|
||||
api_base: string;
|
||||
}
|
||||
|
||||
interface OpenAICompatibleModalInternalsProps {
|
||||
existingLlmProvider: LLMProviderView | undefined;
|
||||
isOnboarding: boolean;
|
||||
}
|
||||
|
||||
function OpenAICompatibleModalInternals({
|
||||
existingLlmProvider,
|
||||
isOnboarding,
|
||||
}: OpenAICompatibleModalInternalsProps) {
|
||||
const formikProps = useFormikContext<OpenAICompatibleModalValues>();
|
||||
|
||||
const isFetchDisabled = !formikProps.values.api_base;
|
||||
|
||||
const handleFetchModels = async () => {
|
||||
const { models, error } = await fetchOpenAICompatibleModels({
|
||||
api_base: formikProps.values.api_base,
|
||||
api_key: formikProps.values.api_key || undefined,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
formikProps.setFieldValue("model_configurations", models);
|
||||
};
|
||||
|
||||
// Auto-fetch models on initial load when editing an existing provider
|
||||
useEffect(() => {
|
||||
if (existingLlmProvider && !isFetchDisabled) {
|
||||
handleFetchModels().catch((err) => {
|
||||
toast.error(
|
||||
err instanceof Error ? err.message : "Failed to fetch models"
|
||||
);
|
||||
});
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<>
|
||||
<APIBaseField
|
||||
subDescription="The base URL of your OpenAI-compatible server."
|
||||
placeholder="http://localhost:8000/v1"
|
||||
/>
|
||||
|
||||
<APIKeyField
|
||||
optional
|
||||
subDescription={markdown(
|
||||
"Provide an API key if your server requires authentication."
|
||||
)}
|
||||
/>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export default function OpenAICompatibleModal({
|
||||
variant = "llm-configuration",
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
onSuccess,
|
||||
}: LLMProviderFormProps) {
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const { mutate } = useSWRConfig();
|
||||
|
||||
const initialValues = useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.OPENAI_COMPATIBLE,
|
||||
existingLlmProvider
|
||||
) as OpenAICompatibleModalValues;
|
||||
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
apiBase: true,
|
||||
});
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
if (open === false) return null;
|
||||
|
||||
return (
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.OPENAI_COMPATIBLE}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.OPENAI_COMPATIBLE,
|
||||
values,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
<OpenAICompatibleModalInternals
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
</ModalWrapper>
|
||||
);
|
||||
}
|
||||
@@ -1,33 +1,23 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { Formik } from "formik";
|
||||
import { LLMProviderFormProps } from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import { LLMProviderFormProps, LLMProviderName } from "@/interfaces/llm";
|
||||
import {
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
APIKeyField,
|
||||
ModelsField,
|
||||
ModelSelectionField,
|
||||
DisplayNameField,
|
||||
FieldSeparator,
|
||||
ModelsAccessField,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
|
||||
const OPENAI_PROVIDER_NAME = "openai";
|
||||
const DEFAULT_DEFAULT_MODEL_NAME = "gpt-5.2";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
|
||||
export default function OpenAIModal({
|
||||
variant = "llm-configuration",
|
||||
@@ -35,141 +25,78 @@ export default function OpenAIModal({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
onSuccess,
|
||||
}: LLMProviderFormProps) {
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } =
|
||||
useWellKnownLLMProvider(OPENAI_PROVIDER_NAME);
|
||||
|
||||
if (open === false) return null;
|
||||
const initialValues = useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.OPENAI,
|
||||
existingLlmProvider
|
||||
);
|
||||
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
apiKey: true,
|
||||
});
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const initialValues = isOnboarding
|
||||
? {
|
||||
...buildOnboardingInitialValues(),
|
||||
name: OPENAI_PROVIDER_NAME,
|
||||
provider: OPENAI_PROVIDER_NAME,
|
||||
api_key: "",
|
||||
default_model_name: DEFAULT_DEFAULT_MODEL_NAME,
|
||||
}
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
default_model_name:
|
||||
(defaultModelName &&
|
||||
modelConfigurations.some((m) => m.name === defaultModelName)
|
||||
? defaultModelName
|
||||
: undefined) ??
|
||||
wellKnownLLMProvider?.recommended_default_model?.name ??
|
||||
DEFAULT_DEFAULT_MODEL_NAME,
|
||||
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
});
|
||||
if (open === false) return null;
|
||||
|
||||
return (
|
||||
<Formik
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.OPENAI}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
(wellKnownLLMProvider ?? llmDescriptor)?.known_models ?? [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: OPENAI_PROVIDER_NAME,
|
||||
payload: {
|
||||
...values,
|
||||
model_configurations: modelConfigsToUse,
|
||||
is_auto_mode:
|
||||
values.default_model_name === DEFAULT_DEFAULT_MODEL_NAME,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: OPENAI_PROVIDER_NAME,
|
||||
values,
|
||||
initialValues,
|
||||
modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.OPENAI,
|
||||
values,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
{(formikProps) => (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={OPENAI_PROVIDER_NAME}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
>
|
||||
<APIKeyField providerName="OpenAI" />
|
||||
<APIKeyField providerName="OpenAI" />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<FieldSeparator />
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. gpt-5.2" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={modelConfigurations}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={
|
||||
wellKnownLLMProvider?.recommended_default_model ?? null
|
||||
}
|
||||
shouldShowAutoUpdateToggle={true}
|
||||
/>
|
||||
)}
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
</>
|
||||
)}
|
||||
</LLMConfigurationModalWrapper>
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
</Formik>
|
||||
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField shouldShowAutoUpdateToggle={true} />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</ModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,73 +1,50 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import { useEffect } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { Formik, FormikProps } from "formik";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import { useFormikContext } from "formik";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
} from "@/interfaces/llm";
|
||||
import { fetchOpenRouterModels } from "@/app/admin/configuration/llm/utils";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import {
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
BaseLLMFormValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
APIKeyField,
|
||||
ModelsField,
|
||||
APIBaseField,
|
||||
ModelSelectionField,
|
||||
DisplayNameField,
|
||||
ModelsAccessField,
|
||||
FieldSeparator,
|
||||
FieldWrapper,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
|
||||
const OPENROUTER_PROVIDER_NAME = "openrouter";
|
||||
const DEFAULT_API_BASE = "https://openrouter.ai/api/v1";
|
||||
|
||||
interface OpenRouterModalValues extends BaseLLMFormValues {
|
||||
api_key: string;
|
||||
api_base: string;
|
||||
}
|
||||
|
||||
interface OpenRouterModalInternalsProps {
|
||||
formikProps: FormikProps<OpenRouterModalValues>;
|
||||
existingLlmProvider: LLMProviderView | undefined;
|
||||
fetchedModels: ModelConfiguration[];
|
||||
setFetchedModels: (models: ModelConfiguration[]) => void;
|
||||
modelConfigurations: ModelConfiguration[];
|
||||
isTesting: boolean;
|
||||
onClose: () => void;
|
||||
isOnboarding: boolean;
|
||||
}
|
||||
|
||||
function OpenRouterModalInternals({
|
||||
formikProps,
|
||||
existingLlmProvider,
|
||||
fetchedModels,
|
||||
setFetchedModels,
|
||||
modelConfigurations,
|
||||
isTesting,
|
||||
onClose,
|
||||
isOnboarding,
|
||||
}: OpenRouterModalInternalsProps) {
|
||||
const currentModels =
|
||||
fetchedModels.length > 0
|
||||
? fetchedModels
|
||||
: existingLlmProvider?.model_configurations || modelConfigurations;
|
||||
const formikProps = useFormikContext<OpenRouterModalValues>();
|
||||
|
||||
const isFetchDisabled =
|
||||
!formikProps.values.api_base || !formikProps.values.api_key;
|
||||
@@ -76,12 +53,12 @@ function OpenRouterModalInternals({
|
||||
const { models, error } = await fetchOpenRouterModels({
|
||||
api_base: formikProps.values.api_base,
|
||||
api_key: formikProps.values.api_key,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
provider_name: LLMProviderName.OPENROUTER,
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
setFetchedModels(models);
|
||||
formikProps.setFieldValue("model_configurations", models);
|
||||
};
|
||||
|
||||
// Auto-fetch models on initial load when editing an existing provider
|
||||
@@ -97,58 +74,34 @@ function OpenRouterModalInternals({
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={OPENROUTER_PROVIDER_NAME}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
>
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="api_base"
|
||||
title="API Base URL"
|
||||
subDescription="Paste your OpenRouter-compatible endpoint URL or use OpenRouter API directly."
|
||||
>
|
||||
<InputTypeInField
|
||||
name="api_base"
|
||||
placeholder="Your OpenRouter base URL"
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
<>
|
||||
<APIBaseField
|
||||
subDescription="Paste your OpenRouter-compatible endpoint URL or use OpenRouter API directly."
|
||||
placeholder="Your OpenRouter base URL"
|
||||
/>
|
||||
|
||||
<APIKeyField providerName="OpenRouter" />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<FieldSeparator />
|
||||
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. openai/gpt-4o" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={currentModels}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={null}
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
)}
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</LLMConfigurationModalWrapper>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -158,109 +111,68 @@ export default function OpenRouterModal({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
onSuccess,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
|
||||
OPENROUTER_PROVIDER_NAME
|
||||
);
|
||||
|
||||
if (open === false) return null;
|
||||
const initialValues: OpenRouterModalValues = {
|
||||
...useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.OPENROUTER,
|
||||
existingLlmProvider
|
||||
),
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
} as OpenRouterModalValues;
|
||||
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
apiKey: true,
|
||||
apiBase: true,
|
||||
});
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const initialValues: OpenRouterModalValues = isOnboarding
|
||||
? ({
|
||||
...buildOnboardingInitialValues(),
|
||||
name: OPENROUTER_PROVIDER_NAME,
|
||||
provider: OPENROUTER_PROVIDER_NAME,
|
||||
api_key: "",
|
||||
api_base: DEFAULT_API_BASE,
|
||||
default_model_name: "",
|
||||
} as OpenRouterModalValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
});
|
||||
if (open === false) return null;
|
||||
|
||||
return (
|
||||
<Formik
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.OPENROUTER}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
fetchedModels.length > 0 ? fetchedModels : [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: OPENROUTER_PROVIDER_NAME,
|
||||
payload: {
|
||||
...values,
|
||||
model_configurations: modelConfigsToUse,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: OPENROUTER_PROVIDER_NAME,
|
||||
values,
|
||||
initialValues,
|
||||
modelConfigurations:
|
||||
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.OPENROUTER,
|
||||
values,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
{(formikProps) => (
|
||||
<OpenRouterModalInternals
|
||||
formikProps={formikProps}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
fetchedModels={fetchedModels}
|
||||
setFetchedModels={setFetchedModels}
|
||||
modelConfigurations={modelConfigurations}
|
||||
isTesting={isTesting}
|
||||
onClose={onClose}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
)}
|
||||
</Formik>
|
||||
<OpenRouterModalInternals
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
</ModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,38 +1,27 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { Formik } from "formik";
|
||||
import { FileUploadFormField } from "@/components/Field";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import { LLMProviderFormProps } from "@/interfaces/llm";
|
||||
import { LLMProviderFormProps, LLMProviderName } from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import {
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
BaseLLMFormValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import {
|
||||
ModelsField,
|
||||
ModelSelectionField,
|
||||
DisplayNameField,
|
||||
FieldSeparator,
|
||||
FieldWrapper,
|
||||
ModelsAccessField,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
|
||||
const VERTEXAI_PROVIDER_NAME = "vertex_ai";
|
||||
const VERTEXAI_DISPLAY_NAME = "Google Cloud Vertex AI";
|
||||
const VERTEXAI_DEFAULT_MODEL = "gemini-2.5-pro";
|
||||
const VERTEXAI_DEFAULT_LOCATION = "global";
|
||||
|
||||
interface VertexAIModalValues extends BaseLLMFormValues {
|
||||
@@ -48,87 +37,50 @@ export default function VertexAIModal({
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
onSuccess,
|
||||
}: LLMProviderFormProps) {
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
|
||||
VERTEXAI_PROVIDER_NAME
|
||||
);
|
||||
|
||||
if (open === false) return null;
|
||||
const initialValues: VertexAIModalValues = {
|
||||
...useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.VERTEX_AI,
|
||||
existingLlmProvider
|
||||
),
|
||||
custom_config: {
|
||||
vertex_credentials:
|
||||
(existingLlmProvider?.custom_config?.vertex_credentials as string) ??
|
||||
"",
|
||||
vertex_location:
|
||||
(existingLlmProvider?.custom_config?.vertex_location as string) ??
|
||||
VERTEXAI_DEFAULT_LOCATION,
|
||||
},
|
||||
} as VertexAIModalValues;
|
||||
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
extra: {
|
||||
custom_config: Yup.object({
|
||||
vertex_credentials: Yup.string().required(
|
||||
"Credentials file is required"
|
||||
),
|
||||
vertex_location: Yup.string(),
|
||||
}),
|
||||
},
|
||||
});
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const initialValues: VertexAIModalValues = isOnboarding
|
||||
? ({
|
||||
...buildOnboardingInitialValues(),
|
||||
name: VERTEXAI_PROVIDER_NAME,
|
||||
provider: VERTEXAI_PROVIDER_NAME,
|
||||
default_model_name: VERTEXAI_DEFAULT_MODEL,
|
||||
custom_config: {
|
||||
vertex_credentials: "",
|
||||
vertex_location: VERTEXAI_DEFAULT_LOCATION,
|
||||
},
|
||||
} as VertexAIModalValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
default_model_name:
|
||||
(defaultModelName &&
|
||||
modelConfigurations.some((m) => m.name === defaultModelName)
|
||||
? defaultModelName
|
||||
: undefined) ??
|
||||
wellKnownLLMProvider?.recommended_default_model?.name ??
|
||||
VERTEXAI_DEFAULT_MODEL,
|
||||
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,
|
||||
custom_config: {
|
||||
vertex_credentials:
|
||||
(existingLlmProvider?.custom_config
|
||||
?.vertex_credentials as string) ?? "",
|
||||
vertex_location:
|
||||
(existingLlmProvider?.custom_config?.vertex_location as string) ??
|
||||
VERTEXAI_DEFAULT_LOCATION,
|
||||
},
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
custom_config: Yup.object({
|
||||
vertex_credentials: Yup.string().required(
|
||||
"Credentials file is required"
|
||||
),
|
||||
vertex_location: Yup.string(),
|
||||
}),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
custom_config: Yup.object({
|
||||
vertex_credentials: Yup.string().required(
|
||||
"Credentials file is required"
|
||||
),
|
||||
vertex_location: Yup.string(),
|
||||
}),
|
||||
});
|
||||
if (open === false) return null;
|
||||
|
||||
return (
|
||||
<Formik
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.VERTEX_AI}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
const filteredCustomConfig = Object.fromEntries(
|
||||
Object.entries(values.custom_config || {}).filter(
|
||||
([key, v]) => key === "vertex_credentials" || v !== ""
|
||||
@@ -143,101 +95,75 @@ export default function VertexAIModal({
|
||||
: undefined,
|
||||
};
|
||||
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
(wellKnownLLMProvider ?? llmDescriptor)?.known_models ?? [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: VERTEXAI_PROVIDER_NAME,
|
||||
payload: {
|
||||
...submitValues,
|
||||
model_configurations: modelConfigsToUse,
|
||||
is_auto_mode:
|
||||
values.default_model_name === VERTEXAI_DEFAULT_MODEL,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: VERTEXAI_PROVIDER_NAME,
|
||||
values: submitValues,
|
||||
initialValues,
|
||||
modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.VERTEX_AI,
|
||||
values: submitValues,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
{(formikProps) => (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={VERTEXAI_PROVIDER_NAME}
|
||||
providerName={VERTEXAI_DISPLAY_NAME}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
<InputLayouts.FieldPadder>
|
||||
<InputLayouts.Vertical
|
||||
name="custom_config.vertex_location"
|
||||
title="Google Cloud Region Name"
|
||||
subDescription="Region where your Google Vertex AI models are hosted. See full list of regions supported at Google Cloud."
|
||||
>
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="custom_config.vertex_location"
|
||||
title="Google Cloud Region Name"
|
||||
subDescription="Region where your Google Vertex AI models are hosted. See full list of regions supported at Google Cloud."
|
||||
>
|
||||
<InputTypeInField
|
||||
name="custom_config.vertex_location"
|
||||
placeholder={VERTEXAI_DEFAULT_LOCATION}
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
<InputTypeInField
|
||||
name="custom_config.vertex_location"
|
||||
placeholder={VERTEXAI_DEFAULT_LOCATION}
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</InputLayouts.FieldPadder>
|
||||
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="custom_config.vertex_credentials"
|
||||
title="API Key"
|
||||
subDescription="Attach your API key JSON from Google Cloud to access your models."
|
||||
>
|
||||
<FileUploadFormField
|
||||
name="custom_config.vertex_credentials"
|
||||
label=""
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
<InputLayouts.FieldPadder>
|
||||
<InputLayouts.Vertical
|
||||
name="custom_config.vertex_credentials"
|
||||
title="API Key"
|
||||
subDescription="Attach your API key JSON from Google Cloud to access your models."
|
||||
>
|
||||
<FileUploadFormField
|
||||
name="custom_config.vertex_credentials"
|
||||
label=""
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</InputLayouts.FieldPadder>
|
||||
|
||||
<FieldSeparator />
|
||||
|
||||
{!isOnboarding && (
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
)}
|
||||
|
||||
<FieldSeparator />
|
||||
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. gemini-2.5-pro" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={modelConfigurations}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={
|
||||
wellKnownLLMProvider?.recommended_default_model ?? null
|
||||
}
|
||||
shouldShowAutoUpdateToggle={true}
|
||||
/>
|
||||
)}
|
||||
|
||||
{!isOnboarding && <ModelsAccessField formikProps={formikProps} />}
|
||||
</LLMConfigurationModalWrapper>
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
</Formik>
|
||||
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField shouldShowAutoUpdateToggle={true} />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</ModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -7,9 +7,10 @@ import VertexAIModal from "@/sections/modals/llmConfig/VertexAIModal";
|
||||
import OpenRouterModal from "@/sections/modals/llmConfig/OpenRouterModal";
|
||||
import CustomModal from "@/sections/modals/llmConfig/CustomModal";
|
||||
import BedrockModal from "@/sections/modals/llmConfig/BedrockModal";
|
||||
import LMStudioForm from "@/sections/modals/llmConfig/LMStudioForm";
|
||||
import LMStudioModal from "@/sections/modals/llmConfig/LMStudioModal";
|
||||
import LiteLLMProxyModal from "@/sections/modals/llmConfig/LiteLLMProxyModal";
|
||||
import BifrostModal from "@/sections/modals/llmConfig/BifrostModal";
|
||||
import OpenAICompatibleModal from "@/sections/modals/llmConfig/OpenAICompatibleModal";
|
||||
|
||||
function detectIfRealOpenAIProvider(provider: LLMProviderView) {
|
||||
return (
|
||||
@@ -54,11 +55,13 @@ export function getModalForExistingProvider(
|
||||
case LLMProviderName.OPENROUTER:
|
||||
return <OpenRouterModal {...props} />;
|
||||
case LLMProviderName.LM_STUDIO:
|
||||
return <LMStudioForm {...props} />;
|
||||
return <LMStudioModal {...props} />;
|
||||
case LLMProviderName.LITELLM_PROXY:
|
||||
return <LiteLLMProxyModal {...props} />;
|
||||
case LLMProviderName.BIFROST:
|
||||
return <BifrostModal {...props} />;
|
||||
case LLMProviderName.OPENAI_COMPATIBLE:
|
||||
return <OpenAICompatibleModal {...props} />;
|
||||
default:
|
||||
return <CustomModal {...props} />;
|
||||
}
|
||||
|
||||
@@ -1,30 +1,31 @@
|
||||
"use client";
|
||||
|
||||
import { ReactNode } from "react";
|
||||
import { Form, FormikProps } from "formik";
|
||||
import React, { useEffect, useRef, useState } from "react";
|
||||
import { Formik, Form, useFormikContext } from "formik";
|
||||
import type { FormikConfig } from "formik";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import { useAgents } from "@/hooks/useAgents";
|
||||
import { useUserGroups } from "@/lib/hooks";
|
||||
import { ModelConfiguration, SimpleKnownModel } from "@/interfaces/llm";
|
||||
import { LLMProviderView, ModelConfiguration } from "@/interfaces/llm";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import Checkbox from "@/refresh-components/inputs/Checkbox";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import InputComboBox from "@/refresh-components/inputs/InputComboBox";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
|
||||
import Switch from "@/refresh-components/inputs/Switch";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Button, LineItemButton, Tag } from "@opal/components";
|
||||
import { Button, LineItemButton } from "@opal/components";
|
||||
import { BaseLLMFormValues } from "@/sections/modals/llmConfig/utils";
|
||||
import { WithoutStyles } from "@opal/types";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import type { RichStr } from "@opal/types";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { Disabled, Hoverable } from "@opal/core";
|
||||
import { Content } from "@opal/layouts";
|
||||
import {
|
||||
SvgArrowExchange,
|
||||
SvgOnyxOctagon,
|
||||
SvgOrganization,
|
||||
SvgPlusCircle,
|
||||
SvgRefreshCw,
|
||||
SvgSparkle,
|
||||
SvgUserManage,
|
||||
@@ -46,27 +47,14 @@ import {
|
||||
getProviderProductName,
|
||||
} from "@/lib/llmConfig/providers";
|
||||
|
||||
export function FieldSeparator() {
|
||||
return <Separator noPadding className="px-2" />;
|
||||
}
|
||||
|
||||
export type FieldWrapperProps = WithoutStyles<
|
||||
React.HTMLAttributes<HTMLDivElement>
|
||||
>;
|
||||
|
||||
export function FieldWrapper(props: FieldWrapperProps) {
|
||||
return <div {...props} className="p-2 w-full" />;
|
||||
}
|
||||
|
||||
// ─── DisplayNameField ────────────────────────────────────────────────────────
|
||||
|
||||
export interface DisplayNameFieldProps {
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
export function DisplayNameField({ disabled = false }: DisplayNameFieldProps) {
|
||||
return (
|
||||
<FieldWrapper>
|
||||
<InputLayouts.FieldPadder>
|
||||
<InputLayouts.Vertical
|
||||
name="name"
|
||||
title="Display Name"
|
||||
@@ -78,56 +66,68 @@ export function DisplayNameField({ disabled = false }: DisplayNameFieldProps) {
|
||||
variant={disabled ? "disabled" : undefined}
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
</InputLayouts.FieldPadder>
|
||||
);
|
||||
}
|
||||
|
||||
// ─── APIKeyField ─────────────────────────────────────────────────────────────
|
||||
|
||||
export interface APIKeyFieldProps {
|
||||
/** Formik field name. @default "api_key" */
|
||||
name?: string;
|
||||
optional?: boolean;
|
||||
providerName?: string;
|
||||
subDescription?: string | RichStr;
|
||||
}
|
||||
|
||||
export function APIKeyField({
|
||||
name = "api_key",
|
||||
optional = false,
|
||||
providerName,
|
||||
subDescription,
|
||||
}: APIKeyFieldProps) {
|
||||
return (
|
||||
<FieldWrapper>
|
||||
<InputLayouts.FieldPadder>
|
||||
<InputLayouts.Vertical
|
||||
name="api_key"
|
||||
name={name}
|
||||
title="API Key"
|
||||
subDescription={
|
||||
providerName
|
||||
? `Paste your API key from ${providerName} to access your models.`
|
||||
: "Paste your API key to access your models."
|
||||
subDescription
|
||||
? subDescription
|
||||
: providerName
|
||||
? `Paste your API key from ${providerName} to access your models.`
|
||||
: "Paste your API key to access your models."
|
||||
}
|
||||
optional={optional}
|
||||
>
|
||||
<PasswordInputTypeInField name="api_key" placeholder="API Key" />
|
||||
<PasswordInputTypeInField name={name} />
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
</InputLayouts.FieldPadder>
|
||||
);
|
||||
}
|
||||
|
||||
// ─── SingleDefaultModelField ─────────────────────────────────────────────────
|
||||
// ─── APIBaseField ───────────────────────────────────────────────────────────
|
||||
|
||||
export interface SingleDefaultModelFieldProps {
|
||||
export interface APIBaseFieldProps {
|
||||
optional?: boolean;
|
||||
subDescription?: string | RichStr;
|
||||
placeholder?: string;
|
||||
}
|
||||
|
||||
export function SingleDefaultModelField({
|
||||
placeholder = "E.g. gpt-4o",
|
||||
}: SingleDefaultModelFieldProps) {
|
||||
export function APIBaseField({
|
||||
optional = false,
|
||||
subDescription,
|
||||
placeholder = "https://",
|
||||
}: APIBaseFieldProps) {
|
||||
return (
|
||||
<InputLayouts.Vertical
|
||||
name="default_model_name"
|
||||
title="Default Model"
|
||||
description="The model to use by default for this provider unless otherwise specified."
|
||||
>
|
||||
<InputTypeInField name="default_model_name" placeholder={placeholder} />
|
||||
</InputLayouts.Vertical>
|
||||
<InputLayouts.FieldPadder>
|
||||
<InputLayouts.Vertical
|
||||
name="api_base"
|
||||
title="API Base URL"
|
||||
subDescription={subDescription}
|
||||
optional={optional}
|
||||
>
|
||||
<InputTypeInField name="api_base" placeholder={placeholder} />
|
||||
</InputLayouts.Vertical>
|
||||
</InputLayouts.FieldPadder>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -137,13 +137,8 @@ export function SingleDefaultModelField({
|
||||
const GROUP_PREFIX = "group:";
|
||||
const AGENT_PREFIX = "agent:";
|
||||
|
||||
interface ModelsAccessFieldProps<T> {
|
||||
formikProps: FormikProps<T>;
|
||||
}
|
||||
|
||||
export function ModelsAccessField<T extends BaseLLMFormValues>({
|
||||
formikProps,
|
||||
}: ModelsAccessFieldProps<T>) {
|
||||
export function ModelAccessField() {
|
||||
const formikProps = useFormikContext<BaseLLMFormValues>();
|
||||
const { agents } = useAgents();
|
||||
const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups();
|
||||
const { data: usersData } = useUsers({ includeApiKeys: false });
|
||||
@@ -229,7 +224,7 @@ export function ModelsAccessField<T extends BaseLLMFormValues>({
|
||||
|
||||
return (
|
||||
<div className="flex flex-col w-full">
|
||||
<FieldWrapper>
|
||||
<InputLayouts.FieldPadder>
|
||||
<InputLayouts.Horizontal
|
||||
name="is_public"
|
||||
title="Models Access"
|
||||
@@ -250,7 +245,7 @@ export function ModelsAccessField<T extends BaseLLMFormValues>({
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</InputLayouts.Horizontal>
|
||||
</FieldWrapper>
|
||||
</InputLayouts.FieldPadder>
|
||||
|
||||
{!isPublic && (
|
||||
<Card backgroundVariant="light" borderVariant="none" sizeVariant="lg">
|
||||
@@ -316,7 +311,7 @@ export function ModelsAccessField<T extends BaseLLMFormValues>({
|
||||
</div>
|
||||
)}
|
||||
|
||||
<FieldSeparator />
|
||||
<InputLayouts.FieldSeparator />
|
||||
|
||||
{selectedAgentIds.length > 0 ? (
|
||||
<div className="grid grid-cols-2 gap-1 w-full">
|
||||
@@ -371,84 +366,73 @@ export function ModelsAccessField<T extends BaseLLMFormValues>({
|
||||
|
||||
// ─── ModelsField ─────────────────────────────────────────────────────
|
||||
|
||||
export interface ModelsFieldProps<T> {
|
||||
formikProps: FormikProps<T>;
|
||||
modelConfigurations: ModelConfiguration[];
|
||||
recommendedDefaultModel: SimpleKnownModel | null;
|
||||
export interface ModelSelectionFieldProps {
|
||||
shouldShowAutoUpdateToggle: boolean;
|
||||
/** Called when the user clicks the refresh button to re-fetch models. */
|
||||
onRefetch?: () => Promise<void> | void;
|
||||
/** Called when the user adds a custom model name (e.g. for Azure). */
|
||||
onAddModel?: (modelName: string) => void;
|
||||
}
|
||||
|
||||
export function ModelsField<T extends BaseLLMFormValues>({
|
||||
formikProps,
|
||||
modelConfigurations,
|
||||
recommendedDefaultModel,
|
||||
export function ModelSelectionField({
|
||||
shouldShowAutoUpdateToggle,
|
||||
onRefetch,
|
||||
}: ModelsFieldProps<T>) {
|
||||
const isAutoMode = formikProps.values.is_auto_mode;
|
||||
const selectedModels = formikProps.values.selected_model_names ?? [];
|
||||
const defaultModel = formikProps.values.default_model_name;
|
||||
onAddModel,
|
||||
}: ModelSelectionFieldProps) {
|
||||
const formikProps = useFormikContext<BaseLLMFormValues>();
|
||||
const [newModelName, setNewModelName] = useState("");
|
||||
// When the auto-update toggle is hidden, auto mode should have no effect —
|
||||
// otherwise models can't be deselected and "Select All" stays disabled.
|
||||
const isAutoMode =
|
||||
shouldShowAutoUpdateToggle && formikProps.values.is_auto_mode;
|
||||
const models = formikProps.values.model_configurations;
|
||||
|
||||
function handleCheckboxChange(modelName: string, checked: boolean) {
|
||||
// Read current values inside the handler to avoid stale closure issues
|
||||
const currentSelected = formikProps.values.selected_model_names ?? [];
|
||||
const currentDefault = formikProps.values.default_model_name;
|
||||
|
||||
if (checked) {
|
||||
const newSelected = [...currentSelected, modelName];
|
||||
formikProps.setFieldValue("selected_model_names", newSelected);
|
||||
// If this is the first model, set it as default
|
||||
if (currentSelected.length === 0) {
|
||||
formikProps.setFieldValue("default_model_name", modelName);
|
||||
}
|
||||
} else {
|
||||
const newSelected = currentSelected.filter((name) => name !== modelName);
|
||||
formikProps.setFieldValue("selected_model_names", newSelected);
|
||||
// If removing the default, set the first remaining model as default
|
||||
if (currentDefault === modelName && newSelected.length > 0) {
|
||||
formikProps.setFieldValue("default_model_name", newSelected[0]);
|
||||
} else if (newSelected.length === 0) {
|
||||
formikProps.setFieldValue("default_model_name", undefined);
|
||||
}
|
||||
// Snapshot the original model visibility so we can restore it when
|
||||
// toggling auto mode back on.
|
||||
const originalModelsRef = useRef(models);
|
||||
useEffect(() => {
|
||||
if (originalModelsRef.current.length === 0 && models.length > 0) {
|
||||
originalModelsRef.current = models;
|
||||
}
|
||||
}
|
||||
}, [models]);
|
||||
|
||||
function handleSetDefault(modelName: string) {
|
||||
formikProps.setFieldValue("default_model_name", modelName);
|
||||
// Automatically derive test_model_name from model_configurations.
|
||||
// Any change to visibility or the model list syncs this automatically.
|
||||
useEffect(() => {
|
||||
const firstVisible = models.find((m) => m.is_visible)?.name;
|
||||
if (firstVisible !== formikProps.values.test_model_name) {
|
||||
formikProps.setFieldValue("test_model_name", firstVisible);
|
||||
}
|
||||
}, [models]); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
function setVisibility(modelName: string, visible: boolean) {
|
||||
const updated = models.map((m) =>
|
||||
m.name === modelName ? { ...m, is_visible: visible } : m
|
||||
);
|
||||
formikProps.setFieldValue("model_configurations", updated);
|
||||
}
|
||||
|
||||
function handleToggleAutoMode(nextIsAutoMode: boolean) {
|
||||
formikProps.setFieldValue("is_auto_mode", nextIsAutoMode);
|
||||
formikProps.setFieldValue(
|
||||
"selected_model_names",
|
||||
modelConfigurations.filter((m) => m.is_visible).map((m) => m.name)
|
||||
);
|
||||
formikProps.setFieldValue(
|
||||
"default_model_name",
|
||||
recommendedDefaultModel?.name ?? undefined
|
||||
);
|
||||
}
|
||||
|
||||
const allSelected =
|
||||
modelConfigurations.length > 0 &&
|
||||
modelConfigurations.every((m) => selectedModels.includes(m.name));
|
||||
|
||||
function handleToggleSelectAll() {
|
||||
if (allSelected) {
|
||||
formikProps.setFieldValue("selected_model_names", []);
|
||||
formikProps.setFieldValue("default_model_name", undefined);
|
||||
} else {
|
||||
const allNames = modelConfigurations.map((m) => m.name);
|
||||
formikProps.setFieldValue("selected_model_names", allNames);
|
||||
if (!formikProps.values.default_model_name && allNames.length > 0) {
|
||||
formikProps.setFieldValue("default_model_name", allNames[0]);
|
||||
}
|
||||
if (nextIsAutoMode) {
|
||||
formikProps.setFieldValue(
|
||||
"model_configurations",
|
||||
originalModelsRef.current
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const visibleModels = modelConfigurations.filter((m) => m.is_visible);
|
||||
const allSelected = models.length > 0 && models.every((m) => m.is_visible);
|
||||
|
||||
function handleToggleSelectAll() {
|
||||
const nextVisible = !allSelected;
|
||||
const updated = models.map((m) => ({
|
||||
...m,
|
||||
is_visible: nextVisible,
|
||||
}));
|
||||
formikProps.setFieldValue("model_configurations", updated);
|
||||
}
|
||||
|
||||
const visibleModels = models.filter((m) => m.is_visible);
|
||||
|
||||
return (
|
||||
<Card backgroundVariant="light" borderVariant="none" sizeVariant="lg">
|
||||
@@ -460,15 +444,14 @@ export function ModelsField<T extends BaseLLMFormValues>({
|
||||
center
|
||||
>
|
||||
<Section flexDirection="row" gap={0}>
|
||||
<Disabled disabled={isAutoMode || modelConfigurations.length === 0}>
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
size="md"
|
||||
onClick={handleToggleSelectAll}
|
||||
>
|
||||
{allSelected ? "Unselect All" : "Select All"}
|
||||
</Button>
|
||||
</Disabled>
|
||||
<Button
|
||||
disabled={isAutoMode || models.length === 0}
|
||||
prominence="tertiary"
|
||||
size="md"
|
||||
onClick={handleToggleSelectAll}
|
||||
>
|
||||
{allSelected ? "Unselect All" : "Select All"}
|
||||
</Button>
|
||||
{onRefetch && (
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
@@ -489,91 +472,75 @@ export function ModelsField<T extends BaseLLMFormValues>({
|
||||
</Section>
|
||||
</InputLayouts.Horizontal>
|
||||
|
||||
{modelConfigurations.length === 0 ? (
|
||||
{models.length === 0 ? (
|
||||
<EmptyMessageCard title="No models available." />
|
||||
) : (
|
||||
<Section gap={0.25}>
|
||||
{isAutoMode
|
||||
? // Auto mode: read-only display
|
||||
visibleModels.map((model) => (
|
||||
<Hoverable.Root
|
||||
? visibleModels.map((model) => (
|
||||
<LineItemButton
|
||||
key={model.name}
|
||||
group="LLMConfigurationButton"
|
||||
widthVariant="full"
|
||||
>
|
||||
<LineItemButton
|
||||
variant="section"
|
||||
sizePreset="main-ui"
|
||||
selectVariant="select-heavy"
|
||||
state="selected"
|
||||
icon={() => <Checkbox checked />}
|
||||
title={model.display_name || model.name}
|
||||
rightChildren={
|
||||
model.name === defaultModel ? (
|
||||
<Section>
|
||||
<Tag title="Default Model" color="blue" />
|
||||
</Section>
|
||||
) : undefined
|
||||
}
|
||||
/>
|
||||
</Hoverable.Root>
|
||||
variant="section"
|
||||
sizePreset="main-ui"
|
||||
selectVariant="select-heavy"
|
||||
state="selected"
|
||||
icon={() => <Checkbox checked />}
|
||||
title={model.display_name || model.name}
|
||||
/>
|
||||
))
|
||||
: // Manual mode: checkbox selection
|
||||
modelConfigurations.map((modelConfiguration) => {
|
||||
const isSelected = selectedModels.includes(
|
||||
modelConfiguration.name
|
||||
);
|
||||
const isDefault = defaultModel === modelConfiguration.name;
|
||||
: models.map((model) => (
|
||||
<LineItemButton
|
||||
key={model.name}
|
||||
variant="section"
|
||||
sizePreset="main-ui"
|
||||
selectVariant="select-heavy"
|
||||
state={model.is_visible ? "selected" : "empty"}
|
||||
icon={() => <Checkbox checked={model.is_visible} />}
|
||||
title={model.name}
|
||||
onClick={() => setVisibility(model.name, !model.is_visible)}
|
||||
/>
|
||||
))}
|
||||
</Section>
|
||||
)}
|
||||
|
||||
return (
|
||||
<Hoverable.Root
|
||||
key={modelConfiguration.name}
|
||||
group="LLMConfigurationButton"
|
||||
widthVariant="full"
|
||||
>
|
||||
<LineItemButton
|
||||
variant="section"
|
||||
sizePreset="main-ui"
|
||||
selectVariant="select-heavy"
|
||||
state={isSelected ? "selected" : "empty"}
|
||||
icon={() => <Checkbox checked={isSelected} />}
|
||||
title={modelConfiguration.name}
|
||||
onClick={() =>
|
||||
handleCheckboxChange(
|
||||
modelConfiguration.name,
|
||||
!isSelected
|
||||
)
|
||||
}
|
||||
rightChildren={
|
||||
isSelected ? (
|
||||
isDefault ? (
|
||||
<Section>
|
||||
<Tag color="blue" title="Default Model" />
|
||||
</Section>
|
||||
) : (
|
||||
<Hoverable.Item
|
||||
group="LLMConfigurationButton"
|
||||
variant="opacity-on-hover"
|
||||
>
|
||||
<Button
|
||||
size="sm"
|
||||
prominence="internal"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
handleSetDefault(modelConfiguration.name);
|
||||
}}
|
||||
type="button"
|
||||
>
|
||||
Set as default
|
||||
</Button>
|
||||
</Hoverable.Item>
|
||||
)
|
||||
) : undefined
|
||||
}
|
||||
/>
|
||||
</Hoverable.Root>
|
||||
);
|
||||
})}
|
||||
{onAddModel && !isAutoMode && (
|
||||
<Section flexDirection="row" gap={0.5}>
|
||||
<div className="flex-1">
|
||||
<InputTypeIn
|
||||
placeholder="Enter model name"
|
||||
value={newModelName}
|
||||
onChange={(e) => setNewModelName(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter" && newModelName.trim()) {
|
||||
e.preventDefault();
|
||||
const trimmed = newModelName.trim();
|
||||
if (!models.some((m) => m.name === trimmed)) {
|
||||
onAddModel(trimmed);
|
||||
setNewModelName("");
|
||||
}
|
||||
}
|
||||
}}
|
||||
showClearButton={false}
|
||||
/>
|
||||
</div>
|
||||
<Button
|
||||
prominence="secondary"
|
||||
icon={SvgPlusCircle}
|
||||
type="button"
|
||||
disabled={
|
||||
!newModelName.trim() ||
|
||||
models.some((m) => m.name === newModelName.trim())
|
||||
}
|
||||
onClick={() => {
|
||||
const trimmed = newModelName.trim();
|
||||
if (trimmed && !models.some((m) => m.name === trimmed)) {
|
||||
onAddModel(trimmed);
|
||||
setNewModelName("");
|
||||
}
|
||||
}}
|
||||
>
|
||||
Add Model
|
||||
</Button>
|
||||
</Section>
|
||||
)}
|
||||
|
||||
@@ -593,41 +560,96 @@ export function ModelsField<T extends BaseLLMFormValues>({
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// LLMConfigurationModalWrapper
|
||||
// ============================================================================
|
||||
// ─── ModalWrapper ─────────────────────────────────────────────────────
|
||||
|
||||
interface LLMConfigurationModalWrapperProps {
|
||||
providerEndpoint: string;
|
||||
providerName?: string;
|
||||
existingProviderName?: string;
|
||||
export interface ModalWrapperProps<
|
||||
T extends BaseLLMFormValues = BaseLLMFormValues,
|
||||
> {
|
||||
providerName: string;
|
||||
llmProvider?: LLMProviderView;
|
||||
onClose: () => void;
|
||||
isFormValid: boolean;
|
||||
isDirty?: boolean;
|
||||
isTesting?: boolean;
|
||||
isSubmitting?: boolean;
|
||||
children: ReactNode;
|
||||
initialValues: T;
|
||||
validationSchema: FormikConfig<T>["validationSchema"];
|
||||
onSubmit: FormikConfig<T>["onSubmit"];
|
||||
children: React.ReactNode;
|
||||
}
|
||||
export function ModalWrapper<T extends BaseLLMFormValues = BaseLLMFormValues>({
|
||||
providerName,
|
||||
llmProvider,
|
||||
onClose,
|
||||
initialValues,
|
||||
validationSchema,
|
||||
onSubmit,
|
||||
children,
|
||||
}: ModalWrapperProps<T>) {
|
||||
return (
|
||||
<Formik
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
validateOnMount
|
||||
onSubmit={onSubmit}
|
||||
>
|
||||
{() => (
|
||||
<ModalWrapperInner
|
||||
providerName={providerName}
|
||||
llmProvider={llmProvider}
|
||||
onClose={onClose}
|
||||
modelConfigurations={initialValues.model_configurations}
|
||||
>
|
||||
{children}
|
||||
</ModalWrapperInner>
|
||||
)}
|
||||
</Formik>
|
||||
);
|
||||
}
|
||||
|
||||
export function LLMConfigurationModalWrapper({
|
||||
providerEndpoint,
|
||||
interface ModalWrapperInnerProps {
|
||||
providerName: string;
|
||||
llmProvider?: LLMProviderView;
|
||||
onClose: () => void;
|
||||
modelConfigurations?: ModelConfiguration[];
|
||||
children: React.ReactNode;
|
||||
}
|
||||
function ModalWrapperInner({
|
||||
providerName,
|
||||
existingProviderName,
|
||||
llmProvider,
|
||||
onClose,
|
||||
isFormValid,
|
||||
isDirty,
|
||||
isTesting,
|
||||
isSubmitting,
|
||||
modelConfigurations,
|
||||
children,
|
||||
}: LLMConfigurationModalWrapperProps) {
|
||||
const busy = isTesting || isSubmitting;
|
||||
const providerIcon = getProviderIcon(providerEndpoint);
|
||||
const providerDisplayName =
|
||||
providerName ?? getProviderDisplayName(providerEndpoint);
|
||||
const providerProductName = getProviderProductName(providerEndpoint);
|
||||
}: ModalWrapperInnerProps) {
|
||||
const { isValid, dirty, isSubmitting, status, setFieldValue, values } =
|
||||
useFormikContext<BaseLLMFormValues>();
|
||||
|
||||
const title = existingProviderName
|
||||
? `Configure "${existingProviderName}"`
|
||||
// When SWR resolves after mount, populate model_configurations if still
|
||||
// empty. test_model_name is then derived automatically by
|
||||
// ModelSelectionField's useEffect.
|
||||
useEffect(() => {
|
||||
if (
|
||||
modelConfigurations &&
|
||||
modelConfigurations.length > 0 &&
|
||||
values.model_configurations.length === 0
|
||||
) {
|
||||
setFieldValue("model_configurations", modelConfigurations);
|
||||
}
|
||||
}, [modelConfigurations]); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
const isTesting = status?.isTesting === true;
|
||||
const busy = isTesting || isSubmitting;
|
||||
|
||||
const disabledTooltip = busy
|
||||
? undefined
|
||||
: !isValid
|
||||
? "Please fill in all required fields."
|
||||
: !dirty
|
||||
? "No changes to save."
|
||||
: undefined;
|
||||
|
||||
const providerIcon = getProviderIcon(providerName);
|
||||
const providerDisplayName = getProviderDisplayName(providerName);
|
||||
const providerProductName = getProviderProductName(providerName);
|
||||
|
||||
const title = llmProvider
|
||||
? `Configure "${llmProvider.name}"`
|
||||
: `Set up ${providerProductName}`;
|
||||
const description = `Connect to ${providerDisplayName} and set up your ${providerProductName} models.`;
|
||||
|
||||
@@ -650,21 +672,20 @@ export function LLMConfigurationModalWrapper({
|
||||
<Button prominence="secondary" onClick={onClose} type="button">
|
||||
Cancel
|
||||
</Button>
|
||||
<Disabled
|
||||
disabled={
|
||||
!isFormValid || busy || (!!existingProviderName && !isDirty)
|
||||
}
|
||||
<Button
|
||||
disabled={!isValid || !dirty || busy}
|
||||
type="submit"
|
||||
icon={busy ? SimpleLoader : undefined}
|
||||
tooltip={disabledTooltip}
|
||||
>
|
||||
<Button type="submit" icon={busy ? SimpleLoader : undefined}>
|
||||
{existingProviderName
|
||||
? busy
|
||||
? "Updating"
|
||||
: "Update"
|
||||
: busy
|
||||
? "Connecting"
|
||||
: "Connect"}
|
||||
</Button>
|
||||
</Disabled>
|
||||
{llmProvider?.name
|
||||
? busy
|
||||
? "Updating"
|
||||
: "Update"
|
||||
: busy
|
||||
? "Connecting"
|
||||
: "Connect"}
|
||||
</Button>
|
||||
</Modal.Footer>
|
||||
</Form>
|
||||
</Modal.Content>
|
||||
|
||||
@@ -1,13 +1,8 @@
|
||||
import {
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
} from "@/interfaces/llm";
|
||||
import { LLMProviderName, LLMProviderView } from "@/interfaces/llm";
|
||||
import {
|
||||
LLM_ADMIN_URL,
|
||||
LLM_PROVIDERS_ADMIN_URL,
|
||||
} from "@/lib/llmConfig/constants";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import isEqual from "lodash/isEqual";
|
||||
import { parseAzureTargetUri } from "@/lib/azureTargetUri";
|
||||
@@ -18,13 +13,11 @@ import {
|
||||
} from "@/lib/analytics";
|
||||
import {
|
||||
BaseLLMFormValues,
|
||||
SubmitLLMProviderParams,
|
||||
SubmitOnboardingProviderParams,
|
||||
TestApiKeyResult,
|
||||
filterModelConfigurations,
|
||||
getAutoModeModelConfigurations,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
|
||||
// ─── Test helpers ─────────────────────────────────────────────────────────
|
||||
|
||||
const submitLlmTestRequest = async (
|
||||
payload: Record<string, unknown>,
|
||||
fallbackErrorMessage: string
|
||||
@@ -50,161 +43,6 @@ const submitLlmTestRequest = async (
|
||||
}
|
||||
};
|
||||
|
||||
export const submitLLMProvider = async <T extends BaseLLMFormValues>({
|
||||
providerName,
|
||||
values,
|
||||
initialValues,
|
||||
modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
hideSuccess,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
}: SubmitLLMProviderParams<T>): Promise<void> => {
|
||||
setSubmitting(true);
|
||||
|
||||
const { selected_model_names: visibleModels, api_key, ...rest } = values;
|
||||
|
||||
// In auto mode, use recommended models from descriptor
|
||||
// In manual mode, use user's selection
|
||||
let filteredModelConfigurations: ModelConfiguration[];
|
||||
let finalDefaultModelName = rest.default_model_name;
|
||||
|
||||
if (values.is_auto_mode) {
|
||||
filteredModelConfigurations =
|
||||
getAutoModeModelConfigurations(modelConfigurations);
|
||||
|
||||
// In auto mode, use the first recommended model as default if current default isn't in the list
|
||||
const visibleModelNames = new Set(
|
||||
filteredModelConfigurations.map((m) => m.name)
|
||||
);
|
||||
if (
|
||||
finalDefaultModelName &&
|
||||
!visibleModelNames.has(finalDefaultModelName)
|
||||
) {
|
||||
finalDefaultModelName = filteredModelConfigurations[0]?.name ?? "";
|
||||
}
|
||||
} else {
|
||||
filteredModelConfigurations = filterModelConfigurations(
|
||||
modelConfigurations,
|
||||
visibleModels,
|
||||
rest.default_model_name as string | undefined
|
||||
);
|
||||
}
|
||||
|
||||
const customConfigChanged = !isEqual(
|
||||
values.custom_config,
|
||||
initialValues.custom_config
|
||||
);
|
||||
|
||||
const normalizedApiBase =
|
||||
typeof rest.api_base === "string" && rest.api_base.trim() === ""
|
||||
? undefined
|
||||
: rest.api_base;
|
||||
|
||||
const finalValues = {
|
||||
...rest,
|
||||
api_base: normalizedApiBase,
|
||||
default_model_name: finalDefaultModelName,
|
||||
api_key,
|
||||
api_key_changed: api_key !== (initialValues.api_key as string | undefined),
|
||||
custom_config_changed: customConfigChanged,
|
||||
model_configurations: filteredModelConfigurations,
|
||||
};
|
||||
|
||||
// Test the configuration
|
||||
if (!isEqual(finalValues, initialValues)) {
|
||||
setIsTesting(true);
|
||||
|
||||
const response = await fetch("/api/admin/llm/test", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
provider: providerName,
|
||||
...finalValues,
|
||||
model: finalDefaultModelName,
|
||||
id: existingLlmProvider?.id,
|
||||
}),
|
||||
});
|
||||
setIsTesting(false);
|
||||
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
toast.error(errorMsg);
|
||||
setSubmitting(false);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const response = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}${
|
||||
existingLlmProvider ? "" : "?is_creation=true"
|
||||
}`,
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
provider: providerName,
|
||||
...finalValues,
|
||||
id: existingLlmProvider?.id,
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
const fullErrorMsg = existingLlmProvider
|
||||
? `Failed to update provider: ${errorMsg}`
|
||||
: `Failed to enable provider: ${errorMsg}`;
|
||||
toast.error(fullErrorMsg);
|
||||
return;
|
||||
}
|
||||
|
||||
if (shouldMarkAsDefault) {
|
||||
const newLlmProvider = (await response.json()) as LLMProviderView;
|
||||
const setDefaultResponse = await fetch(`${LLM_ADMIN_URL}/default`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
provider_id: newLlmProvider.id,
|
||||
model_name: finalDefaultModelName,
|
||||
}),
|
||||
});
|
||||
if (!setDefaultResponse.ok) {
|
||||
const errorMsg = (await setDefaultResponse.json()).detail;
|
||||
toast.error(`Failed to set provider as default: ${errorMsg}`);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
onClose();
|
||||
|
||||
if (!hideSuccess) {
|
||||
const successMsg = existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!";
|
||||
toast.success(successMsg);
|
||||
}
|
||||
|
||||
const knownProviders = new Set<string>(Object.values(LLMProviderName));
|
||||
track(AnalyticsEvent.CONFIGURED_LLM_PROVIDER, {
|
||||
provider: knownProviders.has(providerName) ? providerName : "custom",
|
||||
is_creation: !existingLlmProvider,
|
||||
source: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
});
|
||||
|
||||
setSubmitting(false);
|
||||
};
|
||||
|
||||
export const testApiKeyHelper = async (
|
||||
providerName: string,
|
||||
formValues: Record<string, unknown>,
|
||||
@@ -241,7 +79,7 @@ export const testApiKeyHelper = async (
|
||||
...((formValues?.custom_config as Record<string, unknown>) ?? {}),
|
||||
...(customConfigOverride ?? {}),
|
||||
},
|
||||
model: modelName ?? (formValues?.default_model_name as string) ?? "",
|
||||
model: modelName ?? (formValues?.test_model_name as string) ?? "",
|
||||
};
|
||||
|
||||
return await submitLlmTestRequest(
|
||||
@@ -259,96 +97,148 @@ export const testCustomProvider = async (
|
||||
);
|
||||
};
|
||||
|
||||
export const submitOnboardingProvider = async ({
|
||||
// ─── Submit provider ──────────────────────────────────────────────────────
|
||||
|
||||
export interface SubmitProviderParams<
|
||||
T extends BaseLLMFormValues = BaseLLMFormValues,
|
||||
> {
|
||||
providerName: string;
|
||||
values: T;
|
||||
initialValues: T;
|
||||
existingLlmProvider?: LLMProviderView;
|
||||
shouldMarkAsDefault?: boolean;
|
||||
isCustomProvider?: boolean;
|
||||
setStatus: (status: Record<string, unknown>) => void;
|
||||
setSubmitting: (submitting: boolean) => void;
|
||||
onClose: () => void;
|
||||
/** Called after successful create/update + set-default. Use for cache refresh, state updates, toasts, etc. */
|
||||
onSuccess?: () => void | Promise<void>;
|
||||
/** Analytics source for tracking. @default LLMProviderConfiguredSource.ADMIN_PAGE */
|
||||
analyticsSource?: LLMProviderConfiguredSource;
|
||||
}
|
||||
|
||||
export async function submitProvider<T extends BaseLLMFormValues>({
|
||||
providerName,
|
||||
payload,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
values,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
isCustomProvider,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
setIsSubmitting,
|
||||
}: SubmitOnboardingProviderParams): Promise<void> => {
|
||||
setIsSubmitting(true);
|
||||
onSuccess,
|
||||
analyticsSource = LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
}: SubmitProviderParams<T>): Promise<void> {
|
||||
setSubmitting(true);
|
||||
|
||||
// Test credentials
|
||||
let result: TestApiKeyResult;
|
||||
if (isCustomProvider) {
|
||||
result = await testCustomProvider(payload);
|
||||
} else {
|
||||
result = await testApiKeyHelper(providerName, payload);
|
||||
const { test_model_name, api_key, ...rest } = values;
|
||||
const testModelName =
|
||||
test_model_name ||
|
||||
values.model_configurations.find((m) => m.is_visible)?.name ||
|
||||
"";
|
||||
|
||||
// ── Test credentials ────────────────────────────────────────────────
|
||||
const customConfigChanged = !isEqual(
|
||||
values.custom_config,
|
||||
initialValues.custom_config
|
||||
);
|
||||
|
||||
const normalizedApiBase =
|
||||
typeof rest.api_base === "string" && rest.api_base.trim() === ""
|
||||
? undefined
|
||||
: rest.api_base;
|
||||
|
||||
const finalValues = {
|
||||
...rest,
|
||||
api_base: normalizedApiBase,
|
||||
api_key,
|
||||
api_key_changed: api_key !== (initialValues.api_key as string | undefined),
|
||||
custom_config_changed: customConfigChanged,
|
||||
};
|
||||
|
||||
if (!isEqual(finalValues, initialValues)) {
|
||||
setStatus({ isTesting: true });
|
||||
|
||||
const testResult = await submitLlmTestRequest(
|
||||
{
|
||||
provider: providerName,
|
||||
...finalValues,
|
||||
model: testModelName,
|
||||
id: existingLlmProvider?.id,
|
||||
},
|
||||
"An error occurred while testing the provider."
|
||||
);
|
||||
setStatus({ isTesting: false });
|
||||
|
||||
if (!testResult.ok) {
|
||||
toast.error(testResult.errorMessage);
|
||||
setSubmitting(false);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (!result.ok) {
|
||||
toast.error(result.errorMessage);
|
||||
setIsSubmitting(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Create provider
|
||||
const response = await fetch(`${LLM_PROVIDERS_ADMIN_URL}?is_creation=true`, {
|
||||
method: "PUT",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
// ── Create/update provider ──────────────────────────────────────────
|
||||
const response = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}${
|
||||
existingLlmProvider ? "" : "?is_creation=true"
|
||||
}`,
|
||||
{
|
||||
method: "PUT",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
provider: providerName,
|
||||
...finalValues,
|
||||
id: existingLlmProvider?.id,
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
toast.error(errorMsg);
|
||||
setIsSubmitting(false);
|
||||
const fullErrorMsg = existingLlmProvider
|
||||
? `Failed to update provider: ${errorMsg}`
|
||||
: `Failed to enable provider: ${errorMsg}`;
|
||||
toast.error(fullErrorMsg);
|
||||
setSubmitting(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Set as default if first provider
|
||||
if (
|
||||
onboardingState?.data?.llmProviders == null ||
|
||||
onboardingState.data.llmProviders.length === 0
|
||||
) {
|
||||
// ── Set as default ──────────────────────────────────────────────────
|
||||
if (shouldMarkAsDefault && testModelName) {
|
||||
try {
|
||||
const newLlmProvider = await response.json();
|
||||
if (newLlmProvider?.id != null) {
|
||||
const defaultModelName =
|
||||
(payload as Record<string, string>).default_model_name ??
|
||||
(payload as Record<string, ModelConfiguration[]>)
|
||||
.model_configurations?.[0]?.name ??
|
||||
"";
|
||||
|
||||
if (defaultModelName) {
|
||||
const setDefaultResponse = await fetch(`${LLM_ADMIN_URL}/default`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
provider_id: newLlmProvider.id,
|
||||
model_name: defaultModelName,
|
||||
}),
|
||||
});
|
||||
if (!setDefaultResponse.ok) {
|
||||
const err = await setDefaultResponse.json().catch(() => ({}));
|
||||
toast.error(err?.detail ?? "Failed to set provider as default");
|
||||
setIsSubmitting(false);
|
||||
return;
|
||||
}
|
||||
const setDefaultResponse = await fetch(`${LLM_ADMIN_URL}/default`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
provider_id: newLlmProvider.id,
|
||||
model_name: testModelName,
|
||||
}),
|
||||
});
|
||||
if (!setDefaultResponse.ok) {
|
||||
const err = await setDefaultResponse.json().catch(() => ({}));
|
||||
toast.error(err?.detail ?? "Failed to set provider as default");
|
||||
setSubmitting(false);
|
||||
return;
|
||||
}
|
||||
}
|
||||
} catch (_e) {
|
||||
} catch {
|
||||
toast.error("Failed to set new provider as default");
|
||||
}
|
||||
}
|
||||
|
||||
// ── Post-success ────────────────────────────────────────────────────
|
||||
const knownProviders = new Set<string>(Object.values(LLMProviderName));
|
||||
track(AnalyticsEvent.CONFIGURED_LLM_PROVIDER, {
|
||||
provider: isCustomProvider ? "custom" : providerName,
|
||||
is_creation: true,
|
||||
source: LLMProviderConfiguredSource.CHAT_ONBOARDING,
|
||||
provider: knownProviders.has(providerName) ? providerName : "custom",
|
||||
is_creation: !existingLlmProvider,
|
||||
source: analyticsSource,
|
||||
});
|
||||
|
||||
// Update onboarding state
|
||||
onboardingActions.updateData({
|
||||
llmProviders: [
|
||||
...(onboardingState?.data.llmProviders ?? []),
|
||||
isCustomProvider ? "custom" : providerName,
|
||||
],
|
||||
});
|
||||
onboardingActions.setButtonActive(true);
|
||||
if (onSuccess) await onSuccess();
|
||||
|
||||
setIsSubmitting(false);
|
||||
setSubmitting(false);
|
||||
onClose();
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,197 +1,130 @@
|
||||
import {
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import { ScopedMutator } from "swr";
|
||||
import { OnboardingActions, OnboardingState } from "@/interfaces/onboarding";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
|
||||
// Common class names for the Form component across all LLM provider forms
|
||||
export const LLM_FORM_CLASS_NAME = "flex flex-col gap-y-4 items-stretch mt-6";
|
||||
// ─── useInitialValues ─────────────────────────────────────────────────────
|
||||
|
||||
export const buildDefaultInitialValues = (
|
||||
/** Builds the merged model list from existing + well-known, deduped by name. */
|
||||
function buildModelConfigurations(
|
||||
existingLlmProvider?: LLMProviderView,
|
||||
modelConfigurations?: ModelConfiguration[],
|
||||
currentDefaultModelName?: string
|
||||
) => {
|
||||
const defaultModelName =
|
||||
(currentDefaultModelName &&
|
||||
existingLlmProvider?.model_configurations?.some(
|
||||
(m) => m.name === currentDefaultModelName
|
||||
)
|
||||
? currentDefaultModelName
|
||||
: undefined) ??
|
||||
existingLlmProvider?.model_configurations?.[0]?.name ??
|
||||
modelConfigurations?.[0]?.name ??
|
||||
"";
|
||||
wellKnownLLMProvider?: WellKnownLLMProviderDescriptor
|
||||
): ModelConfiguration[] {
|
||||
const existingModels = existingLlmProvider?.model_configurations ?? [];
|
||||
const wellKnownModels = wellKnownLLMProvider?.known_models ?? [];
|
||||
|
||||
// Auto mode must be explicitly enabled by the user
|
||||
// Default to false for new providers, preserve existing value when editing
|
||||
const isAutoMode = existingLlmProvider?.is_auto_mode ?? false;
|
||||
const modelMap = new Map<string, ModelConfiguration>();
|
||||
wellKnownModels.forEach((m) => modelMap.set(m.name, m));
|
||||
existingModels.forEach((m) => modelMap.set(m.name, m));
|
||||
|
||||
return Array.from(modelMap.values());
|
||||
}
|
||||
|
||||
/** Shared initial values for all LLM provider forms (both onboarding and admin). */
|
||||
export function useInitialValues(
|
||||
isOnboarding: boolean,
|
||||
providerName: LLMProviderName,
|
||||
existingLlmProvider?: LLMProviderView
|
||||
) {
|
||||
const { wellKnownLLMProvider } = useWellKnownLLMProvider(providerName);
|
||||
|
||||
const modelConfigurations = buildModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? undefined
|
||||
);
|
||||
|
||||
const testModelName =
|
||||
modelConfigurations.find((m) => m.is_visible)?.name ??
|
||||
wellKnownLLMProvider?.recommended_default_model?.name;
|
||||
|
||||
return {
|
||||
name: existingLlmProvider?.name || "",
|
||||
default_model_name: defaultModelName,
|
||||
provider: existingLlmProvider?.provider ?? providerName,
|
||||
name: isOnboarding ? providerName : existingLlmProvider?.name ?? "",
|
||||
api_key: existingLlmProvider?.api_key ?? undefined,
|
||||
api_base: existingLlmProvider?.api_base ?? undefined,
|
||||
is_public: existingLlmProvider?.is_public ?? true,
|
||||
is_auto_mode: isAutoMode,
|
||||
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,
|
||||
groups: existingLlmProvider?.groups ?? [],
|
||||
personas: existingLlmProvider?.personas ?? [],
|
||||
selected_model_names: existingLlmProvider
|
||||
? existingLlmProvider.model_configurations
|
||||
.filter((modelConfiguration) => modelConfiguration.is_visible)
|
||||
.map((modelConfiguration) => modelConfiguration.name)
|
||||
: modelConfigurations
|
||||
?.filter((modelConfiguration) => modelConfiguration.is_visible)
|
||||
.map((modelConfiguration) => modelConfiguration.name) ?? [],
|
||||
model_configurations: modelConfigurations,
|
||||
test_model_name: testModelName,
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
// ─── buildValidationSchema ────────────────────────────────────────────────
|
||||
|
||||
interface ValidationSchemaOptions {
|
||||
apiKey?: boolean;
|
||||
apiBase?: boolean;
|
||||
extra?: Yup.ObjectShape;
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds the validation schema for a modal.
|
||||
*
|
||||
* @param isOnboarding — controls the base schema:
|
||||
* - `true`: minimal (only `test_model_name`).
|
||||
* - `false`: full admin schema (display name, access, models, etc.).
|
||||
* @param options.apiKey — require `api_key`.
|
||||
* @param options.apiBase — require `api_base`.
|
||||
* @param options.extra — arbitrary Yup fields for provider-specific validation.
|
||||
*/
|
||||
export function buildValidationSchema(
|
||||
isOnboarding: boolean,
|
||||
{ apiKey, apiBase, extra }: ValidationSchemaOptions = {}
|
||||
) {
|
||||
const providerFields: Yup.ObjectShape = {
|
||||
...(apiKey && {
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
}),
|
||||
...(apiBase && {
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
}),
|
||||
...extra,
|
||||
};
|
||||
|
||||
if (isOnboarding) {
|
||||
return Yup.object().shape({
|
||||
test_model_name: Yup.string().required("Model name is required"),
|
||||
...providerFields,
|
||||
});
|
||||
}
|
||||
|
||||
export const buildDefaultValidationSchema = () => {
|
||||
return Yup.object({
|
||||
name: Yup.string().required("Display Name is required"),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
is_public: Yup.boolean().required(),
|
||||
is_auto_mode: Yup.boolean().required(),
|
||||
groups: Yup.array().of(Yup.number()),
|
||||
personas: Yup.array().of(Yup.number()),
|
||||
selected_model_names: Yup.array().of(Yup.string()),
|
||||
test_model_name: Yup.string().required("Model name is required"),
|
||||
...providerFields,
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
export const buildAvailableModelConfigurations = (
|
||||
existingLlmProvider?: LLMProviderView,
|
||||
wellKnownLLMProvider?: WellKnownLLMProviderDescriptor
|
||||
): ModelConfiguration[] => {
|
||||
const existingModels = existingLlmProvider?.model_configurations ?? [];
|
||||
const wellKnownModels = wellKnownLLMProvider?.known_models ?? [];
|
||||
// ─── Form value types ─────────────────────────────────────────────────────
|
||||
|
||||
// Create a map to deduplicate by model name, preferring existing models
|
||||
const modelMap = new Map<string, ModelConfiguration>();
|
||||
|
||||
// Add well-known models first
|
||||
wellKnownModels.forEach((model) => {
|
||||
modelMap.set(model.name, model);
|
||||
});
|
||||
|
||||
// Override with existing models (they take precedence)
|
||||
existingModels.forEach((model) => {
|
||||
modelMap.set(model.name, model);
|
||||
});
|
||||
|
||||
return Array.from(modelMap.values());
|
||||
};
|
||||
|
||||
// Base form values that all provider forms share
|
||||
/** Base form values that all provider forms share. */
|
||||
export interface BaseLLMFormValues {
|
||||
name: string;
|
||||
api_key?: string;
|
||||
api_base?: string;
|
||||
default_model_name?: string;
|
||||
/** Model name used for the test request — automatically derived. */
|
||||
test_model_name?: string;
|
||||
is_public: boolean;
|
||||
is_auto_mode: boolean;
|
||||
groups: number[];
|
||||
personas: number[];
|
||||
selected_model_names: string[];
|
||||
/** The full model list with is_visible set directly by user interaction. */
|
||||
model_configurations: ModelConfiguration[];
|
||||
custom_config?: Record<string, string>;
|
||||
}
|
||||
|
||||
export interface SubmitLLMProviderParams<
|
||||
T extends BaseLLMFormValues = BaseLLMFormValues,
|
||||
> {
|
||||
providerName: string;
|
||||
values: T;
|
||||
initialValues: T;
|
||||
modelConfigurations: ModelConfiguration[];
|
||||
existingLlmProvider?: LLMProviderView;
|
||||
shouldMarkAsDefault?: boolean;
|
||||
hideSuccess?: boolean;
|
||||
setIsTesting: (testing: boolean) => void;
|
||||
mutate: ScopedMutator;
|
||||
onClose: () => void;
|
||||
setSubmitting: (submitting: boolean) => void;
|
||||
}
|
||||
|
||||
export const filterModelConfigurations = (
|
||||
currentModelConfigurations: ModelConfiguration[],
|
||||
visibleModels: string[],
|
||||
defaultModelName?: string
|
||||
): ModelConfiguration[] => {
|
||||
return currentModelConfigurations
|
||||
.map(
|
||||
(modelConfiguration): ModelConfiguration => ({
|
||||
name: modelConfiguration.name,
|
||||
is_visible: visibleModels.includes(modelConfiguration.name),
|
||||
max_input_tokens: modelConfiguration.max_input_tokens ?? null,
|
||||
supports_image_input: modelConfiguration.supports_image_input,
|
||||
supports_reasoning: modelConfiguration.supports_reasoning,
|
||||
display_name: modelConfiguration.display_name,
|
||||
})
|
||||
)
|
||||
.filter(
|
||||
(modelConfiguration) =>
|
||||
modelConfiguration.name === defaultModelName ||
|
||||
modelConfiguration.is_visible
|
||||
);
|
||||
};
|
||||
|
||||
// Helper to get model configurations for auto mode
|
||||
// In auto mode, we include ALL models but preserve their visibility status
|
||||
// Models in the auto config are visible, others are created but not visible
|
||||
export const getAutoModeModelConfigurations = (
|
||||
modelConfigurations: ModelConfiguration[]
|
||||
): ModelConfiguration[] => {
|
||||
return modelConfigurations.map(
|
||||
(modelConfiguration): ModelConfiguration => ({
|
||||
name: modelConfiguration.name,
|
||||
is_visible: modelConfiguration.is_visible,
|
||||
max_input_tokens: modelConfiguration.max_input_tokens ?? null,
|
||||
supports_image_input: modelConfiguration.supports_image_input,
|
||||
supports_reasoning: modelConfiguration.supports_reasoning,
|
||||
display_name: modelConfiguration.display_name,
|
||||
})
|
||||
);
|
||||
};
|
||||
// ─── Misc ─────────────────────────────────────────────────────────────────
|
||||
|
||||
export type TestApiKeyResult =
|
||||
| { ok: true }
|
||||
| { ok: false; errorMessage: string };
|
||||
|
||||
export const getModelOptions = (
|
||||
fetchedModelConfigurations: Array<{ name: string }>
|
||||
) => {
|
||||
return fetchedModelConfigurations.map((model) => ({
|
||||
label: model.name,
|
||||
value: model.name,
|
||||
}));
|
||||
};
|
||||
|
||||
/** Initial values used by onboarding forms (flat shape, always creating new). */
|
||||
export const buildOnboardingInitialValues = () => ({
|
||||
name: "",
|
||||
provider: "",
|
||||
api_key: "",
|
||||
api_base: "",
|
||||
api_version: "",
|
||||
default_model_name: "",
|
||||
model_configurations: [] as ModelConfiguration[],
|
||||
custom_config: {} as Record<string, string>,
|
||||
api_key_changed: true,
|
||||
groups: [] as number[],
|
||||
is_public: true,
|
||||
is_auto_mode: false,
|
||||
personas: [] as number[],
|
||||
selected_model_names: [] as string[],
|
||||
deployment_name: "",
|
||||
target_uri: "",
|
||||
});
|
||||
|
||||
export interface SubmitOnboardingProviderParams {
|
||||
providerName: string;
|
||||
payload: Record<string, unknown>;
|
||||
onboardingState: OnboardingState;
|
||||
onboardingActions: OnboardingActions;
|
||||
isCustomProvider: boolean;
|
||||
onClose: () => void;
|
||||
setIsSubmitting: (submitting: boolean) => void;
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ import React from "react";
|
||||
import {
|
||||
WellKnownLLMProviderDescriptor,
|
||||
LLMProviderName,
|
||||
LLMProviderFormProps,
|
||||
} from "@/interfaces/llm";
|
||||
import { OnboardingActions, OnboardingState } from "@/interfaces/onboarding";
|
||||
import OpenAIModal from "@/sections/modals/llmConfig/OpenAIModal";
|
||||
@@ -12,8 +13,9 @@ import BedrockModal from "@/sections/modals/llmConfig/BedrockModal";
|
||||
import VertexAIModal from "@/sections/modals/llmConfig/VertexAIModal";
|
||||
import OpenRouterModal from "@/sections/modals/llmConfig/OpenRouterModal";
|
||||
import CustomModal from "@/sections/modals/llmConfig/CustomModal";
|
||||
import LMStudioForm from "@/sections/modals/llmConfig/LMStudioForm";
|
||||
import LMStudioModal from "@/sections/modals/llmConfig/LMStudioModal";
|
||||
import LiteLLMProxyModal from "@/sections/modals/llmConfig/LiteLLMProxyModal";
|
||||
import OpenAICompatibleModal from "@/sections/modals/llmConfig/OpenAICompatibleModal";
|
||||
|
||||
// Display info for LLM provider cards - title is the product name, displayName is the company/platform
|
||||
const PROVIDER_DISPLAY_INFO: Record<
|
||||
@@ -47,6 +49,10 @@ const PROVIDER_DISPLAY_INFO: Record<
|
||||
title: "LiteLLM Proxy",
|
||||
displayName: "LiteLLM Proxy",
|
||||
},
|
||||
[LLMProviderName.OPENAI_COMPATIBLE]: {
|
||||
title: "OpenAI Compatible",
|
||||
displayName: "OpenAI Compatible",
|
||||
},
|
||||
};
|
||||
|
||||
export function getProviderDisplayInfo(providerName: string): {
|
||||
@@ -78,12 +84,26 @@ export function getOnboardingForm({
|
||||
open,
|
||||
onOpenChange,
|
||||
}: OnboardingFormProps): React.ReactNode {
|
||||
const sharedProps = {
|
||||
const providerName = isCustomProvider
|
||||
? "custom"
|
||||
: llmDescriptor?.name ?? "custom";
|
||||
|
||||
const sharedProps: LLMProviderFormProps = {
|
||||
variant: "onboarding" as const,
|
||||
onboardingState,
|
||||
shouldMarkAsDefault:
|
||||
(onboardingState?.data.llmProviders ?? []).length === 0,
|
||||
onboardingActions,
|
||||
open,
|
||||
onOpenChange,
|
||||
onSuccess: () => {
|
||||
onboardingActions.updateData({
|
||||
llmProviders: [
|
||||
...(onboardingState?.data.llmProviders ?? []),
|
||||
providerName,
|
||||
],
|
||||
});
|
||||
onboardingActions.setButtonActive(true);
|
||||
},
|
||||
};
|
||||
|
||||
// Handle custom provider
|
||||
@@ -91,38 +111,36 @@ export function getOnboardingForm({
|
||||
return <CustomModal {...sharedProps} />;
|
||||
}
|
||||
|
||||
const providerProps = {
|
||||
...sharedProps,
|
||||
llmDescriptor,
|
||||
};
|
||||
|
||||
switch (llmDescriptor.name) {
|
||||
case LLMProviderName.OPENAI:
|
||||
return <OpenAIModal {...providerProps} />;
|
||||
return <OpenAIModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.ANTHROPIC:
|
||||
return <AnthropicModal {...providerProps} />;
|
||||
return <AnthropicModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.OLLAMA_CHAT:
|
||||
return <OllamaModal {...providerProps} />;
|
||||
return <OllamaModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.AZURE:
|
||||
return <AzureModal {...providerProps} />;
|
||||
return <AzureModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.BEDROCK:
|
||||
return <BedrockModal {...providerProps} />;
|
||||
return <BedrockModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.VERTEX_AI:
|
||||
return <VertexAIModal {...providerProps} />;
|
||||
return <VertexAIModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.OPENROUTER:
|
||||
return <OpenRouterModal {...providerProps} />;
|
||||
return <OpenRouterModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.LM_STUDIO:
|
||||
return <LMStudioForm {...providerProps} />;
|
||||
return <LMStudioModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.LITELLM_PROXY:
|
||||
return <LiteLLMProxyModal {...providerProps} />;
|
||||
return <LiteLLMProxyModal {...sharedProps} />;
|
||||
|
||||
case LLMProviderName.OPENAI_COMPATIBLE:
|
||||
return <OpenAICompatibleModal {...sharedProps} />;
|
||||
|
||||
default:
|
||||
return <CustomModal {...sharedProps} />;
|
||||
|
||||
Reference in New Issue
Block a user