Compare commits

..

3 Commits

Author SHA1 Message Date
Jamison Lahman
77ce667b21 nit 2026-04-05 18:45:22 -07:00
Jamison Lahman
1e0a8afc72 INTERNAL_URL 2026-04-05 18:05:02 -07:00
Jamison Lahman
85302a1cde feat(cli): --config-file and --server-url 2026-04-05 17:32:48 -07:00
106 changed files with 4173 additions and 4555 deletions

View File

@@ -228,7 +228,7 @@ jobs:
- name: Create GitHub Release
id: create-release
uses: softprops/action-gh-release@153bb8e04406b158c6c84fc1615b65b24149a1fe # ratchet:softprops/action-gh-release@v2
uses: softprops/action-gh-release@da05d552573ad5aba039eaac05058a918a7bf631 # ratchet:softprops/action-gh-release@v2
with:
tag_name: ${{ steps.release-tag.outputs.tag }}
name: ${{ steps.release-tag.outputs.tag }}

View File

@@ -21,7 +21,7 @@ jobs:
persist-credentials: false
- name: Install Helm CLI
uses: azure/setup-helm@dda3372f752e03dde6b3237bc9431cdc2f7a02a2 # ratchet:azure/setup-helm@v5.0.0
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4
with:
version: v3.12.1

View File

@@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # ratchet:actions/stale@v10
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # ratchet:actions/stale@v10
with:
stale-issue-message: 'This issue is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
stale-pr-message: 'This PR is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'

View File

@@ -36,7 +36,7 @@ jobs:
persist-credentials: false
- name: Set up Helm
uses: azure/setup-helm@dda3372f752e03dde6b3237bc9431cdc2f7a02a2 # ratchet:azure/setup-helm@v5.0.0
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4.3.1
with:
version: v3.19.0

3
.gitignore vendored
View File

@@ -59,6 +59,3 @@ node_modules
# plans
plans/
# Added context for LLMs
onyx-llm-context/

View File

@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Literal
from onyx.db.engine.iam_auth import get_iam_auth_token
from onyx.configs.app_configs import USE_IAM_AUTH
from onyx.configs.app_configs import POSTGRES_HOST
@@ -19,6 +19,7 @@ from logging.config import fileConfig
from alembic import context
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.sql.schema import SchemaItem
from onyx.configs.constants import SSL_CERT_FILE
from shared_configs.configs import (
MULTI_TENANT,
@@ -44,6 +45,8 @@ if config.config_file_name is not None and config.attributes.get(
target_metadata = [Base.metadata, ResultModelBase.metadata]
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
logger = logging.getLogger(__name__)
ssl_context: ssl.SSLContext | None = None
@@ -53,6 +56,25 @@ if USE_IAM_AUTH:
ssl_context = ssl.create_default_context(cafile=SSL_CERT_FILE)
def include_object(
object: SchemaItem, # noqa: ARG001
name: str | None,
type_: Literal[
"schema",
"table",
"column",
"index",
"unique_constraint",
"foreign_key_constraint",
],
reflected: bool, # noqa: ARG001
compare_to: SchemaItem | None, # noqa: ARG001
) -> bool:
if type_ == "table" and name in EXCLUDE_TABLES:
return False
return True
def filter_tenants_by_range(
tenant_ids: list[str], start_range: int | None = None, end_range: int | None = None
) -> list[str]:
@@ -209,6 +231,7 @@ def do_run_migrations(
context.configure(
connection=connection,
target_metadata=target_metadata, # type: ignore
include_object=include_object,
version_table_schema=schema_name,
include_schemas=True,
compare_type=True,
@@ -382,6 +405,7 @@ def run_migrations_offline() -> None:
url=url,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
include_object=include_object,
version_table_schema=schema,
include_schemas=True,
script_location=config.get_main_option("script_location"),
@@ -423,6 +447,7 @@ def run_migrations_offline() -> None:
url=url,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
include_object=include_object,
version_table_schema=schema,
include_schemas=True,
script_location=config.get_main_option("script_location"),
@@ -465,6 +490,7 @@ def run_migrations_online() -> None:
context.configure(
connection=connection,
target_metadata=target_metadata, # type: ignore
include_object=include_object,
version_table_schema=schema_name,
include_schemas=True,
compare_type=True,

View File

@@ -1,9 +1,11 @@
import asyncio
from logging.config import fileConfig
from typing import Literal
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.schema import SchemaItem
from alembic import context
from onyx.db.engine.sql_engine import build_connection_string
@@ -33,6 +35,27 @@ target_metadata = [PublicBase.metadata]
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
def include_object(
object: SchemaItem, # noqa: ARG001
name: str | None,
type_: Literal[
"schema",
"table",
"column",
"index",
"unique_constraint",
"foreign_key_constraint",
],
reflected: bool, # noqa: ARG001
compare_to: SchemaItem | None, # noqa: ARG001
) -> bool:
if type_ == "table" and name in EXCLUDE_TABLES:
return False
return True
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
@@ -62,6 +85,7 @@ def do_run_migrations(connection: Connection) -> None:
context.configure(
connection=connection,
target_metadata=target_metadata, # type: ignore[arg-type]
include_object=include_object,
)
with context.begin_transaction():

View File

@@ -5,7 +5,6 @@ from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis.lock import Lock as RedisLock
from ee.onyx.server.tenants.product_gating import get_gated_tenants
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
@@ -31,7 +30,6 @@ def cloud_beat_task_generator(
queue: str = OnyxCeleryTask.DEFAULT,
priority: int = OnyxCeleryPriority.MEDIUM,
expires: int = BEAT_EXPIRES_DEFAULT,
skip_gated: bool = True,
) -> bool | None:
"""a lightweight task used to kick off individual beat tasks per tenant."""
time_start = time.monotonic()
@@ -50,22 +48,20 @@ def cloud_beat_task_generator(
last_lock_time = time.monotonic()
tenant_ids: list[str] = []
num_processed_tenants = 0
num_skipped_gated = 0
try:
tenant_ids = get_all_tenant_ids()
# Per-task control over whether gated tenants are included. Most periodic tasks
# do no useful work on gated tenants and just waste DB connections fanning out
# to ~10k+ inactive tenants. A small number of cleanup tasks (connector deletion,
# checkpoint/index attempt cleanup) need to run on gated tenants and pass
# `skip_gated=False` from the beat schedule.
gated_tenants: set[str] = get_gated_tenants() if skip_gated else set()
# NOTE: for now, we are running tasks for gated tenants, since we want to allow
# connector deletion to run successfully. The new plan is to continously prune
# the gated tenants set, so we won't have a build up of old, unused gated tenants.
# Keeping this around in case we want to revert to the previous behavior.
# gated_tenants = get_gated_tenants()
for tenant_id in tenant_ids:
if tenant_id in gated_tenants:
num_skipped_gated += 1
continue
# Same comment here as the above NOTE
# if tenant_id in gated_tenants:
# continue
current_time = time.monotonic()
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
@@ -108,7 +104,6 @@ def cloud_beat_task_generator(
f"cloud_beat_task_generator finished: "
f"task={task_name} "
f"num_processed_tenants={num_processed_tenants} "
f"num_skipped_gated={num_skipped_gated} "
f"num_tenants={len(tenant_ids)} "
f"elapsed={time_elapsed:.2f}"
)

View File

@@ -1,7 +1,6 @@
# Overview of Onyx Background Jobs
The background jobs take care of:
1. Pulling/Indexing documents (from connectors)
2. Updating document metadata (from connectors)
3. Cleaning up checkpoints and logic around indexing work (indexing indexing checkpoints and index attempt metadata)
@@ -10,41 +9,37 @@ The background jobs take care of:
## Worker → Queue Mapping
| Worker | File | Queues |
| ------------------------- | ------------------------------ | -------------------------------------------------------------------------------------------------------------------- |
| Primary | `apps/primary.py` | `celery` |
| Light | `apps/light.py` | `vespa_metadata_sync`, `connector_deletion`, `doc_permissions_upsert`, `checkpoint_cleanup`, `index_attempt_cleanup` |
| Heavy | `apps/heavy.py` | `connector_pruning`, `connector_doc_permissions_sync`, `connector_external_group_sync`, `csv_generation`, `sandbox` |
| Docprocessing | `apps/docprocessing.py` | `docprocessing` |
| Docfetching | `apps/docfetching.py` | `connector_doc_fetching` |
| User File Processing | `apps/user_file_processing.py` | `user_file_processing`, `user_file_project_sync`, `user_file_delete` |
| Monitoring | `apps/monitoring.py` | `monitoring` |
| Background (consolidated) | `apps/background.py` | All queues above except `celery` |
| Worker | File | Queues |
|--------|------|--------|
| Primary | `apps/primary.py` | `celery` |
| Light | `apps/light.py` | `vespa_metadata_sync`, `connector_deletion`, `doc_permissions_upsert`, `checkpoint_cleanup`, `index_attempt_cleanup` |
| Heavy | `apps/heavy.py` | `connector_pruning`, `connector_doc_permissions_sync`, `connector_external_group_sync`, `csv_generation`, `sandbox` |
| Docprocessing | `apps/docprocessing.py` | `docprocessing` |
| Docfetching | `apps/docfetching.py` | `connector_doc_fetching` |
| User File Processing | `apps/user_file_processing.py` | `user_file_processing`, `user_file_project_sync`, `user_file_delete` |
| Monitoring | `apps/monitoring.py` | `monitoring` |
| Background (consolidated) | `apps/background.py` | All queues above except `celery` |
## Non-Worker Apps
| App | File | Purpose |
| ---------- | ----------- | ----------------------------------------------------------------------------------------------------- |
| **Beat** | `beat.py` | Celery beat scheduler with `DynamicTenantScheduler` that generates per-tenant periodic task schedules |
| **Client** | `client.py` | Minimal app for task submission from non-worker processes (e.g., API server) |
| App | File | Purpose |
|-----|------|---------|
| **Beat** | `beat.py` | Celery beat scheduler with `DynamicTenantScheduler` that generates per-tenant periodic task schedules |
| **Client** | `client.py` | Minimal app for task submission from non-worker processes (e.g., API server) |
### Shared Module
`app_base.py` provides:
- `TenantAwareTask` - Base task class that sets tenant context
- Signal handlers for logging, cleanup, and lifecycle events
- Readiness probes and health checks
## Worker Details
### Primary (Coordinator and task dispatcher)
It is the single worker which handles tasks from the default celery queue. It is a singleton worker ensured by the `PRIMARY_WORKER` Redis lock
which it touches every `CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8` seconds (using Celery Bootsteps)
On startup:
- waits for redis, postgres, document index to all be healthy
- acquires the singleton lock
- cleans all the redis states associated with background jobs
@@ -52,34 +47,34 @@ On startup:
Then it cycles through its tasks as scheduled by Celery Beat:
| Task | Frequency | Description |
| --------------------------------- | --------- | ------------------------------------------------------------------------------------------ |
| `check_for_indexing` | 15s | Scans for connectors needing indexing → dispatches to `DOCFETCHING` queue |
| `check_for_vespa_sync_task` | 20s | Finds stale documents/document sets → dispatches sync tasks to `VESPA_METADATA_SYNC` queue |
| `check_for_pruning` | 20s | Finds connectors due for pruning → dispatches to `CONNECTOR_PRUNING` queue |
| `check_for_connector_deletion` | 20s | Processes deletion requests → dispatches to `CONNECTOR_DELETION` queue |
| `check_for_user_file_processing` | 20s | Checks for user uploads → dispatches to `USER_FILE_PROCESSING` queue |
| `check_for_checkpoint_cleanup` | 1h | Cleans up old indexing checkpoints |
| `check_for_index_attempt_cleanup` | 30m | Cleans up old index attempts |
| `celery_beat_heartbeat` | 1m | Heartbeat for Beat watchdog |
| Task | Frequency | Description |
|------|-----------|-------------|
| `check_for_indexing` | 15s | Scans for connectors needing indexing → dispatches to `DOCFETCHING` queue |
| `check_for_vespa_sync_task` | 20s | Finds stale documents/document sets → dispatches sync tasks to `VESPA_METADATA_SYNC` queue |
| `check_for_pruning` | 20s | Finds connectors due for pruning → dispatches to `CONNECTOR_PRUNING` queue |
| `check_for_connector_deletion` | 20s | Processes deletion requests → dispatches to `CONNECTOR_DELETION` queue |
| `check_for_user_file_processing` | 20s | Checks for user uploads → dispatches to `USER_FILE_PROCESSING` queue |
| `check_for_checkpoint_cleanup` | 1h | Cleans up old indexing checkpoints |
| `check_for_index_attempt_cleanup` | 30m | Cleans up old index attempts |
| `kombu_message_cleanup_task` | periodic | Cleans orphaned Kombu messages from DB (Kombu being the messaging framework used by Celery) |
| `celery_beat_heartbeat` | 1m | Heartbeat for Beat watchdog |
Watchdog is a separate Python process managed by supervisord which runs alongside celery workers. It checks the ONYX_CELERY_BEAT_HEARTBEAT_KEY in
Redis to ensure Celery Beat is not dead. Beat schedules the celery_beat_heartbeat for Primary to touch the key and share that it's still alive.
See supervisord.conf for watchdog config.
### Light
### Light
Fast and short living tasks that are not resource intensive. High concurrency:
Can have 24 concurrent workers, each with a prefetch of 8 for a total of 192 tasks in flight at once.
Tasks it handles:
- Syncs access/permissions, document sets, boosts, hidden state
- Deletes documents that are marked for deletion in Postgres
- Cleanup of checkpoints and index attempts
### Heavy
### Heavy
Long running, resource intensive tasks, handles pruning and sandbox operations. Low concurrency - max concurrency of 4 with 1 prefetch.
Does not interact with the Document Index, it handles the syncs with external systems. Large volume API calls to handle pruning and fetching permissions, etc.
@@ -88,24 +83,16 @@ Generates CSV exports which may take a long time with significant data in Postgr
Sandbox (new feature) for running Next.js, Python virtual env, OpenCode AI Agent, and access to knowledge files
### Docprocessing, Docfetching, User File Processing
Docprocessing and Docfetching are for indexing documents:
- Docfetching runs connectors to pull documents from external APIs (Google Drive, Confluence, etc.), stores batches to file storage, and dispatches docprocessing tasks
- Docprocessing retrieves batches, runs the indexing pipeline (chunking, embedding), and indexes into the Document Index
- User Files come from uploads directly via the input bar
- Docprocessing retrieves batches, runs the indexing pipeline (chunking, embedding), and indexes into the Document Index
User Files come from uploads directly via the input bar
### Monitoring
Observability and metrics collections:
- Queue lengths, connector success/failure, connector latencies
- Queue lengths, connector success/failure, lconnector latencies
- Memory of supervisor managed processes (workers, beat, slack)
- Cloud and multitenant specific monitorings
## Prometheus Metrics
Workers can expose Prometheus metrics via a standalone HTTP server. Currently docfetching and docprocessing have push-based task lifecycle metrics; the monitoring worker runs pull-based collectors for queue depth and connector health.
For the full metric reference, integration guide, and PromQL examples, see [`docs/METRICS.md`](../../../docs/METRICS.md#celery-worker-metrics).

View File

@@ -317,6 +317,7 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.docprocessing",
"onyx.background.celery.tasks.evals",
"onyx.background.celery.tasks.hierarchyfetching",
"onyx.background.celery.tasks.periodic",
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",

View File

@@ -75,8 +75,6 @@ beat_task_templates: list[dict] = [
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
# Run on gated tenants too — they may still have stale checkpoints to clean.
"skip_gated": False,
},
},
{
@@ -86,8 +84,6 @@ beat_task_templates: list[dict] = [
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
# Run on gated tenants too — they may still have stale index attempts.
"skip_gated": False,
},
},
{
@@ -97,8 +93,6 @@ beat_task_templates: list[dict] = [
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
# Gated tenants may still have connectors awaiting deletion.
"skip_gated": False,
},
},
{
@@ -272,7 +266,7 @@ def make_cloud_generator_task(task: dict[str, Any]) -> dict[str, Any]:
cloud_task["kwargs"] = {}
cloud_task["kwargs"]["task_name"] = task["task"]
optional_fields = ["queue", "priority", "expires", "skip_gated"]
optional_fields = ["queue", "priority", "expires"]
for field in optional_fields:
if field in task["options"]:
cloud_task["kwargs"][field] = task["options"][field]
@@ -365,13 +359,7 @@ if not MULTI_TENANT:
]
)
# `skip_gated` is a cloud-only hint consumed by `cloud_beat_task_generator`. Strip
# it before extending the self-hosted schedule so it doesn't leak into apply_async
# as an unrecognised option on every fired task message.
for _template in beat_task_templates:
_self_hosted_template = copy.deepcopy(_template)
_self_hosted_template["options"].pop("skip_gated", None)
tasks_to_schedule.append(_self_hosted_template)
tasks_to_schedule.extend(beat_task_templates)
def generate_cloud_tasks(

View File

@@ -0,0 +1,138 @@
#####
# Periodic Tasks
#####
import json
from typing import Any
from celery import shared_task
from celery.contrib.abortable import AbortableTask # type: ignore
from celery.exceptions import TaskRevokedError
from sqlalchemy import inspect
from sqlalchemy import text
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import PostgresAdvisoryLocks
from onyx.db.engine.sql_engine import get_session_with_current_tenant
@shared_task(
name=OnyxCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
soft_time_limit=JOB_TIMEOUT,
bind=True,
base=AbortableTask,
)
def kombu_message_cleanup_task(self: Any, tenant_id: str) -> int: # noqa: ARG001
"""Runs periodically to clean up the kombu_message table"""
# we will select messages older than this amount to clean up
KOMBU_MESSAGE_CLEANUP_AGE = 7 # days
KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT = 1000
ctx = {}
ctx["last_processed_id"] = 0
ctx["deleted"] = 0
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
with get_session_with_current_tenant() as db_session:
# Exit the task if we can't take the advisory lock
result = db_session.execute(
text("SELECT pg_try_advisory_lock(:id)"),
{"id": PostgresAdvisoryLocks.KOMBU_MESSAGE_CLEANUP_LOCK_ID.value},
).scalar()
if not result:
return 0
while True:
if self.is_aborted():
raise TaskRevokedError("kombu_message_cleanup_task was aborted.")
b = kombu_message_cleanup_task_helper(ctx, db_session)
if not b:
break
db_session.commit()
if ctx["deleted"] > 0:
task_logger.info(
f"Deleted {ctx['deleted']} orphaned messages from kombu_message."
)
return ctx["deleted"]
def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool:
"""
Helper function to clean up old messages from the `kombu_message` table that are no longer relevant.
This function retrieves messages from the `kombu_message` table that are no longer visible and
older than a specified interval. It checks if the corresponding task_id exists in the
`celery_taskmeta` table. If the task_id does not exist, the message is deleted.
Args:
ctx (dict): A context dictionary containing configuration parameters such as:
- 'cleanup_age' (int): The age in days after which messages are considered old.
- 'page_limit' (int): The maximum number of messages to process in one batch.
- 'last_processed_id' (int): The ID of the last processed message to handle pagination.
- 'deleted' (int): A counter to track the number of deleted messages.
db_session (Session): The SQLAlchemy database session for executing queries.
Returns:
bool: Returns True if there are more rows to process, False if not.
"""
inspector = inspect(db_session.bind)
if not inspector:
return False
# With the move to redis as celery's broker and backend, kombu tables may not even exist.
# We can fail silently.
if not inspector.has_table("kombu_message"):
return False
query = text(
"""
SELECT id, timestamp, payload
FROM kombu_message WHERE visible = 'false'
AND timestamp < CURRENT_TIMESTAMP - INTERVAL :interval_days
AND id > :last_processed_id
ORDER BY id
LIMIT :page_limit
"""
)
kombu_messages = db_session.execute(
query,
{
"interval_days": f"{ctx['cleanup_age']} days",
"page_limit": ctx["page_limit"],
"last_processed_id": ctx["last_processed_id"],
},
).fetchall()
if len(kombu_messages) == 0:
return False
for msg in kombu_messages:
payload = json.loads(msg[2])
task_id = payload["headers"]["id"]
# Check if task_id exists in celery_taskmeta
task_exists = db_session.execute(
text("SELECT 1 FROM celery_taskmeta WHERE task_id = :task_id"),
{"task_id": task_id},
).fetchone()
# If task_id does not exist, delete the message
if not task_exists:
result = db_session.execute(
text("DELETE FROM kombu_message WHERE id = :message_id"),
{"message_id": msg[0]},
)
if result.rowcount > 0: # type: ignore
ctx["deleted"] += 1
ctx["last_processed_id"] = msg[0]
return True

View File

@@ -379,14 +379,6 @@ POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "127.0.0.1"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
AWS_REGION_NAME = os.environ.get("AWS_REGION_NAME") or "us-east-2"
# Comma-separated replica / multi-host list. If unset, defaults to POSTGRES_HOST
# only.
_POSTGRES_HOSTS_STR = os.environ.get("POSTGRES_HOSTS", "").strip()
POSTGRES_HOSTS: list[str] = (
[h.strip() for h in _POSTGRES_HOSTS_STR.split(",") if h.strip()]
if _POSTGRES_HOSTS_STR
else [POSTGRES_HOST]
)
POSTGRES_API_SERVER_POOL_SIZE = int(
os.environ.get("POSTGRES_API_SERVER_POOL_SIZE") or 40

View File

@@ -12,11 +12,6 @@ SLACK_USER_TOKEN_PREFIX = "xoxp-"
SLACK_BOT_TOKEN_PREFIX = "xoxb-"
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
# The mask_string() function in encryption.py uses "•" (U+2022 BULLET) to mask secrets.
MASK_CREDENTIAL_CHAR = "\u2022"
# Pattern produced by mask_string for strings >= 14 chars: "abcd...wxyz" (exactly 11 chars)
MASK_CREDENTIAL_LONG_RE = re.compile(r"^.{4}\.{3}.{4}$")
SOURCE_TYPE = "source_type"
# stored in the `metadata` of a chunk. Used to signify that this chunk should
# not be used for QA. For example, Google Drive file types which can't be parsed
@@ -396,6 +391,10 @@ class MilestoneRecordType(str, Enum):
REQUESTED_CONNECTOR = "requested_connector"
class PostgresAdvisoryLocks(Enum):
KOMBU_MESSAGE_CLEANUP_LOCK_ID = auto()
class OnyxCeleryQueues:
# "celery" is the default queue defined by celery and also the queue
# we are running in the primary worker to run system tasks
@@ -578,6 +577,7 @@ class OnyxCeleryTask:
MONITOR_PROCESS_MEMORY = "monitor_process_memory"
CELERY_BEAT_HEARTBEAT = "celery_beat_heartbeat"
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
"connector_permission_sync_generator_task"
)

View File

@@ -8,8 +8,6 @@ from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.configs.constants import FederatedConnectorSource
from onyx.configs.constants import MASK_CREDENTIAL_CHAR
from onyx.configs.constants import MASK_CREDENTIAL_LONG_RE
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import DocumentSet
from onyx.db.models import FederatedConnector
@@ -47,23 +45,6 @@ def fetch_all_federated_connectors_parallel() -> list[FederatedConnector]:
return fetch_all_federated_connectors(db_session)
def _reject_masked_credentials(credentials: dict[str, Any]) -> None:
"""Raise if any credential string value contains mask placeholder characters.
mask_string() has two output formats:
- Short strings (< 14 chars): "••••••••••••" (U+2022 BULLET)
- Long strings (>= 14 chars): "abcd...wxyz" (first4 + "..." + last4)
Both must be rejected.
"""
for key, val in credentials.items():
if isinstance(val, str) and (
MASK_CREDENTIAL_CHAR in val or MASK_CREDENTIAL_LONG_RE.match(val)
):
raise ValueError(
f"Credential field '{key}' contains masked placeholder characters. Please provide the actual credential value."
)
def validate_federated_connector_credentials(
source: FederatedConnectorSource,
credentials: dict[str, Any],
@@ -85,8 +66,6 @@ def create_federated_connector(
config: dict[str, Any] | None = None,
) -> FederatedConnector:
"""Create a new federated connector with credential and config validation."""
_reject_masked_credentials(credentials)
# Validate credentials before creating
if not validate_federated_connector_credentials(source, credentials):
raise ValueError(
@@ -298,8 +277,6 @@ def update_federated_connector(
)
if credentials is not None:
_reject_masked_credentials(credentials)
# Validate credentials before updating
if not validate_federated_connector_credentials(
federated_connector.source, credentials

View File

@@ -236,15 +236,14 @@ def upsert_llm_provider(
db_session.add(existing_llm_provider)
# Filter out empty strings and None values from custom_config to allow
# providers like Bedrock to fall back to IAM roles when credentials are not provided.
# NOTE: An empty dict ({}) is preserved as-is — it signals that the provider was
# created via the custom modal and must be reopened with CustomModal, not a
# provider-specific modal. Only None means "no custom config at all".
# providers like Bedrock to fall back to IAM roles when credentials are not provided
custom_config = llm_provider_upsert_request.custom_config
if custom_config:
custom_config = {
k: v for k, v in custom_config.items() if v is not None and v.strip() != ""
}
# Set to None if the dict is empty after filtering
custom_config = custom_config or None
api_base = llm_provider_upsert_request.api_base or None
existing_llm_provider.provider = llm_provider_upsert_request.provider
@@ -304,7 +303,16 @@ def upsert_llm_provider(
).delete(synchronize_session="fetch")
db_session.flush()
# Import here to avoid circular imports
from onyx.llm.utils import get_max_input_tokens
for model_config in llm_provider_upsert_request.model_configurations:
max_input_tokens = model_config.max_input_tokens
if max_input_tokens is None:
max_input_tokens = get_max_input_tokens(
model_name=model_config.name,
model_provider=llm_provider_upsert_request.provider,
)
supported_flows = [LLMModelFlowType.CHAT]
if model_config.supports_image_input:
@@ -317,7 +325,7 @@ def upsert_llm_provider(
model_configuration_id=existing.id,
supported_flows=supported_flows,
is_visible=model_config.is_visible,
max_input_tokens=model_config.max_input_tokens,
max_input_tokens=max_input_tokens,
display_name=model_config.display_name,
)
else:
@@ -327,7 +335,7 @@ def upsert_llm_provider(
model_name=model_config.name,
supported_flows=supported_flows,
is_visible=model_config.is_visible,
max_input_tokens=model_config.max_input_tokens,
max_input_tokens=max_input_tokens,
display_name=model_config.display_name,
)

View File

@@ -52,21 +52,9 @@ KNOWN_OPENPYXL_BUGS = [
def get_markitdown_converter() -> "MarkItDown":
global _MARKITDOWN_CONVERTER
from markitdown import MarkItDown
if _MARKITDOWN_CONVERTER is None:
from markitdown import MarkItDown
# Patch this function to effectively no-op because we were seeing this
# module take an inordinate amount of time to convert charts to markdown,
# making some powerpoint files with many or complicated charts nearly
# unindexable.
from markitdown.converters._pptx_converter import PptxConverter
setattr(
PptxConverter,
"_convert_chart_to_markdown",
lambda self, chart: "\n\n[chart omitted]\n\n", # noqa: ARG005
)
_MARKITDOWN_CONVERTER = MarkItDown(enable_plugins=False)
return _MARKITDOWN_CONVERTER
@@ -217,26 +205,18 @@ def read_pdf_file(
try:
pdf_reader = PdfReader(file)
if pdf_reader.is_encrypted:
# Try the explicit password first, then fall back to an empty
# string. Owner-password-only PDFs (permission restrictions but
# no open password) decrypt successfully with "".
# See https://github.com/onyx-dot-app/onyx/issues/9754
passwords = [p for p in [pdf_pass, ""] if p is not None]
if pdf_reader.is_encrypted and pdf_pass is not None:
decrypt_success = False
for pw in passwords:
try:
if pdf_reader.decrypt(pw) != 0:
decrypt_success = True
break
except Exception:
pass
try:
decrypt_success = pdf_reader.decrypt(pdf_pass) != 0
except Exception:
logger.error("Unable to decrypt pdf")
if not decrypt_success:
logger.error(
"Encrypted PDF could not be decrypted, returning empty text."
)
return "", metadata, []
elif pdf_reader.is_encrypted:
logger.warning("No Password for an encrypted PDF, returning empty text.")
return "", metadata, []
# Basic PDF metadata
if pdf_reader.metadata is not None:

View File

@@ -33,20 +33,8 @@ def is_pdf_protected(file: IO[Any]) -> bool:
with preserve_position(file):
reader = PdfReader(file)
if not reader.is_encrypted:
return False
# PDFs with only an owner password (permission restrictions like
# print/copy disabled) use an empty user password — any viewer can open
# them without prompting. decrypt("") returns 0 only when a real user
# password is required. See https://github.com/onyx-dot-app/onyx/issues/9754
try:
return reader.decrypt("") == 0
except Exception:
logger.exception(
"Failed to evaluate PDF encryption; treating as password protected"
)
return True
return bool(reader.is_encrypted)
def is_docx_protected(file: IO[Any]) -> bool:

View File

@@ -26,7 +26,6 @@ class LlmProviderNames(str, Enum):
MISTRAL = "mistral"
LITELLM_PROXY = "litellm_proxy"
BIFROST = "bifrost"
OPENAI_COMPATIBLE = "openai_compatible"
def __str__(self) -> str:
"""Needed so things like:
@@ -47,7 +46,6 @@ WELL_KNOWN_PROVIDER_NAMES = [
LlmProviderNames.LM_STUDIO,
LlmProviderNames.LITELLM_PROXY,
LlmProviderNames.BIFROST,
LlmProviderNames.OPENAI_COMPATIBLE,
]
@@ -66,7 +64,6 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
LlmProviderNames.LM_STUDIO: "LM Studio",
LlmProviderNames.LITELLM_PROXY: "LiteLLM Proxy",
LlmProviderNames.BIFROST: "Bifrost",
LlmProviderNames.OPENAI_COMPATIBLE: "OpenAI Compatible",
"groq": "Groq",
"anyscale": "Anyscale",
"deepseek": "DeepSeek",
@@ -119,7 +116,6 @@ AGGREGATOR_PROVIDERS: set[str] = {
LlmProviderNames.AZURE,
LlmProviderNames.LITELLM_PROXY,
LlmProviderNames.BIFROST,
LlmProviderNames.OPENAI_COMPATIBLE,
}
# Model family name mappings for display name generation

View File

@@ -327,19 +327,12 @@ class LitellmLLM(LLM):
):
model_kwargs[VERTEX_LOCATION_KWARG] = "global"
# Bifrost and OpenAI-compatible: OpenAI-compatible proxies that send
# model names directly to the endpoint. We route through LiteLLM's
# openai provider with the server's base URL, and ensure /v1 is appended.
if model_provider in (
LlmProviderNames.BIFROST,
LlmProviderNames.OPENAI_COMPATIBLE,
):
# Bifrost: OpenAI-compatible proxy that expects model names in
# provider/model format (e.g. "anthropic/claude-sonnet-4-6").
# We route through LiteLLM's openai provider with the Bifrost base URL,
# and ensure /v1 is appended.
if model_provider == LlmProviderNames.BIFROST:
self._custom_llm_provider = "openai"
# LiteLLM's OpenAI client requires an api_key to be set.
# Many OpenAI-compatible servers don't need auth, so supply a
# placeholder to prevent LiteLLM from raising AuthenticationError.
if not self._api_key:
model_kwargs.setdefault("api_key", "not-needed")
if self._api_base is not None:
base = self._api_base.rstrip("/")
self._api_base = base if base.endswith("/v1") else f"{base}/v1"
@@ -456,20 +449,17 @@ class LitellmLLM(LLM):
optional_kwargs: dict[str, Any] = {}
# Model name
is_openai_compatible_proxy = self._model_provider in (
LlmProviderNames.BIFROST,
LlmProviderNames.OPENAI_COMPATIBLE,
)
is_bifrost = self._model_provider == LlmProviderNames.BIFROST
model_provider = (
f"{self.config.model_provider}/responses"
if is_openai_model # Uses litellm's completions -> responses bridge
else self.config.model_provider
)
if is_openai_compatible_proxy:
# OpenAI-compatible proxies (Bifrost, generic OpenAI-compatible
# servers) expect model names sent directly to their endpoint.
# We use custom_llm_provider="openai" so LiteLLM doesn't try
# to route based on the provider prefix.
if is_bifrost:
# Bifrost expects model names in provider/model format
# (e.g. "anthropic/claude-sonnet-4-6") sent directly to its
# OpenAI-compatible endpoint. We use custom_llm_provider="openai"
# so LiteLLM doesn't try to route based on the provider prefix.
model = self.config.deployment_name or self.config.model_name
else:
model = f"{model_provider}/{self.config.deployment_name or self.config.model_name}"
@@ -560,10 +550,7 @@ class LitellmLLM(LLM):
if structured_response_format:
optional_kwargs["response_format"] = structured_response_format
if (
not (is_claude_model or is_ollama or is_mistral)
or is_openai_compatible_proxy
):
if not (is_claude_model or is_ollama or is_mistral) or is_bifrost:
# Litellm bug: tool_choice is dropped silently if not specified here for OpenAI
# However, this param breaks Anthropic and Mistral models,
# so it must be conditionally included unless the request is

View File

@@ -15,8 +15,6 @@ LITELLM_PROXY_PROVIDER_NAME = "litellm_proxy"
BIFROST_PROVIDER_NAME = "bifrost"
OPENAI_COMPATIBLE_PROVIDER_NAME = "openai_compatible"
# Providers that use optional Bearer auth from custom_config
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING: dict[str, str] = {
LlmProviderNames.OLLAMA_CHAT: OLLAMA_API_KEY_CONFIG_KEY,

View File

@@ -19,7 +19,6 @@ from onyx.llm.well_known_providers.constants import BIFROST_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import LITELLM_PROXY_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import LM_STUDIO_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OLLAMA_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OPENAI_COMPATIBLE_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OPENROUTER_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import VERTEXAI_PROVIDER_NAME
@@ -52,7 +51,6 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API
LITELLM_PROXY_PROVIDER_NAME: [], # Dynamic - fetched from LiteLLM proxy API
BIFROST_PROVIDER_NAME: [], # Dynamic - fetched from Bifrost API
OPENAI_COMPATIBLE_PROVIDER_NAME: [], # Dynamic - fetched from OpenAI-compatible API
}
@@ -338,7 +336,6 @@ def get_provider_display_name(provider_name: str) -> str:
VERTEXAI_PROVIDER_NAME: "Google Vertex AI",
OPENROUTER_PROVIDER_NAME: "OpenRouter",
LITELLM_PROXY_PROVIDER_NAME: "LiteLLM Proxy",
OPENAI_COMPATIBLE_PROVIDER_NAME: "OpenAI Compatible",
}
if provider_name in _ONYX_PROVIDER_DISPLAY_NAMES:

View File

@@ -6,7 +6,6 @@ from onyx.configs.app_configs import MCP_SERVER_ENABLED
from onyx.configs.app_configs import MCP_SERVER_HOST
from onyx.configs.app_configs import MCP_SERVER_PORT
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
logger = setup_logger()
@@ -17,7 +16,6 @@ def main() -> None:
logger.info("MCP server is disabled (MCP_SERVER_ENABLED=false)")
return
set_is_ee_based_on_env_variable()
logger.info(f"Starting MCP server on {MCP_SERVER_HOST}:{MCP_SERVER_PORT}")
from onyx.mcp_server.api import mcp_app

View File

@@ -74,8 +74,6 @@ from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.server.manage.llm.models import OllamaFinalModelResponse
from onyx.server.manage.llm.models import OllamaModelDetails
from onyx.server.manage.llm.models import OllamaModelsRequest
from onyx.server.manage.llm.models import OpenAICompatibleFinalModelResponse
from onyx.server.manage.llm.models import OpenAICompatibleModelsRequest
from onyx.server.manage.llm.models import OpenRouterFinalModelResponse
from onyx.server.manage.llm.models import OpenRouterModelDetails
from onyx.server.manage.llm.models import OpenRouterModelsRequest
@@ -1577,95 +1575,3 @@ def _get_bifrost_models_response(api_base: str, api_key: str | None = None) -> d
source_name="Bifrost",
api_key=api_key,
)
@admin_router.post("/openai-compatible/available-models")
def get_openai_compatible_server_available_models(
request: OpenAICompatibleModelsRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[OpenAICompatibleFinalModelResponse]:
"""Fetch available models from a generic OpenAI-compatible /v1/models endpoint."""
response_json = _get_openai_compatible_server_response(
api_base=request.api_base, api_key=request.api_key
)
models = response_json.get("data", [])
if not isinstance(models, list) or len(models) == 0:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No models found from your OpenAI-compatible endpoint",
)
results: list[OpenAICompatibleFinalModelResponse] = []
for model in models:
try:
model_id = model.get("id", "")
model_name = model.get("name", model_id)
if not model_id:
continue
# Skip embedding models
if is_embedding_model(model_id):
continue
results.append(
OpenAICompatibleFinalModelResponse(
name=model_id,
display_name=model_name,
max_input_tokens=model.get("context_length"),
supports_image_input=infer_vision_support(model_id),
supports_reasoning=is_reasoning_model(model_id, model_name),
)
)
except Exception as e:
logger.warning(
"Failed to parse OpenAI-compatible model entry",
extra={"error": str(e), "item": str(model)[:1000]},
)
if not results:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No compatible models found from OpenAI-compatible endpoint",
)
sorted_results = sorted(results, key=lambda m: m.name.lower())
# Sync new models to DB if provider_name is specified
if request.provider_name:
_sync_fetched_models(
db_session=db_session,
provider_name=request.provider_name,
models=[
SyncModelEntry(
name=r.name,
display_name=r.display_name,
max_input_tokens=r.max_input_tokens,
supports_image_input=r.supports_image_input,
)
for r in sorted_results
],
source_label="OpenAI Compatible",
)
return sorted_results
def _get_openai_compatible_server_response(
api_base: str, api_key: str | None = None
) -> dict:
"""Perform GET to an OpenAI-compatible /v1/models and return parsed JSON."""
cleaned_api_base = api_base.strip().rstrip("/")
# Ensure we hit /v1/models
if cleaned_api_base.endswith("/v1"):
url = f"{cleaned_api_base}/models"
else:
url = f"{cleaned_api_base}/v1/models"
return _get_openai_compatible_models_response(
url=url,
source_name="OpenAI Compatible",
api_key=api_key,
)

View File

@@ -79,9 +79,7 @@ class LLMProviderDescriptor(BaseModel):
provider=provider,
provider_display_name=get_provider_display_name(provider),
model_configurations=filter_model_configurations(
llm_provider_model.model_configurations,
provider,
use_stored_display_name=llm_provider_model.custom_config is not None,
llm_provider_model.model_configurations, provider
),
)
@@ -158,9 +156,7 @@ class LLMProviderView(LLMProvider):
personas=personas,
deployment_name=llm_provider_model.deployment_name,
model_configurations=filter_model_configurations(
llm_provider_model.model_configurations,
provider,
use_stored_display_name=llm_provider_model.custom_config is not None,
llm_provider_model.model_configurations, provider
),
)
@@ -202,13 +198,13 @@ class ModelConfigurationView(BaseModel):
cls,
model_configuration_model: "ModelConfigurationModel",
provider_name: str,
use_stored_display_name: bool = False,
) -> "ModelConfigurationView":
# For dynamic providers (OpenRouter, Bedrock, Ollama) and custom-config
# providers, use the display_name stored in DB. Skip LiteLLM parsing.
# For dynamic providers (OpenRouter, Bedrock, Ollama), use the display_name
# stored in DB from the source API. Skip LiteLLM parsing entirely.
if (
provider_name in DYNAMIC_LLM_PROVIDERS or use_stored_display_name
) and model_configuration_model.display_name:
provider_name in DYNAMIC_LLM_PROVIDERS
and model_configuration_model.display_name
):
# Extract vendor from model name for grouping (e.g., "Anthropic", "OpenAI")
vendor = extract_vendor_from_model_name(
model_configuration_model.name, provider_name
@@ -468,18 +464,3 @@ class BifrostFinalModelResponse(BaseModel):
max_input_tokens: int | None
supports_image_input: bool
supports_reasoning: bool
# OpenAI Compatible dynamic models fetch
class OpenAICompatibleModelsRequest(BaseModel):
api_base: str
api_key: str | None = None
provider_name: str | None = None # Optional: to save models to existing provider
class OpenAICompatibleFinalModelResponse(BaseModel):
name: str # Model ID (e.g. "meta-llama/Llama-3-8B-Instruct")
display_name: str # Human-readable name from API
max_input_tokens: int | None
supports_image_input: bool
supports_reasoning: bool

View File

@@ -26,7 +26,6 @@ DYNAMIC_LLM_PROVIDERS = frozenset(
LlmProviderNames.OLLAMA_CHAT,
LlmProviderNames.LM_STUDIO,
LlmProviderNames.BIFROST,
LlmProviderNames.OPENAI_COMPATIBLE,
}
)
@@ -309,15 +308,12 @@ def should_filter_as_dated_duplicate(
def filter_model_configurations(
model_configurations: list,
provider: str,
use_stored_display_name: bool = False,
) -> list:
"""Filter out obsolete and dated duplicate models from configurations.
Args:
model_configurations: List of ModelConfiguration DB models
provider: The provider name (e.g., "openai", "anthropic")
use_stored_display_name: If True, prefer the display_name stored in the
DB over LiteLLM enrichments. Set for custom-config providers.
Returns:
List of ModelConfigurationView objects with obsolete/duplicate models removed
@@ -337,9 +333,7 @@ def filter_model_configurations(
if should_filter_as_dated_duplicate(model_configuration.name, all_model_names):
continue
filtered_configs.append(
ModelConfigurationView.from_model(
model_configuration, provider, use_stored_display_name
)
ModelConfigurationView.from_model(model_configuration, provider)
)
return filtered_configs

View File

@@ -186,7 +186,7 @@ class TestDocumentIndexNew:
)
document_index.index(chunks=[pre_chunk], indexing_metadata=pre_metadata)
time.sleep(2)
time.sleep(1)
# Now index a batch with the existing doc and a new doc.
chunks = [

View File

@@ -1,58 +0,0 @@
import pytest
from onyx.configs.constants import MASK_CREDENTIAL_CHAR
from onyx.db.federated import _reject_masked_credentials
class TestRejectMaskedCredentials:
"""Verify that masked credential values are never accepted for DB writes.
mask_string() has two output formats:
- Short strings (< 14 chars): "••••••••••••" (U+2022 BULLET)
- Long strings (>= 14 chars): "abcd...wxyz" (first4 + "..." + last4)
_reject_masked_credentials must catch both.
"""
def test_rejects_fully_masked_value(self) -> None:
masked = MASK_CREDENTIAL_CHAR * 12 # "••••••••••••"
with pytest.raises(ValueError, match="masked placeholder"):
_reject_masked_credentials({"client_id": masked})
def test_rejects_long_string_masked_value(self) -> None:
"""mask_string returns 'first4...last4' for long strings — the real
format used for OAuth credentials like client_id and client_secret."""
with pytest.raises(ValueError, match="masked placeholder"):
_reject_masked_credentials({"client_id": "1234...7890"})
def test_rejects_when_any_field_is_masked(self) -> None:
"""Even if client_id is real, a masked client_secret must be caught."""
with pytest.raises(ValueError, match="client_secret"):
_reject_masked_credentials(
{
"client_id": "1234567890.1234567890",
"client_secret": MASK_CREDENTIAL_CHAR * 12,
}
)
def test_accepts_real_credentials(self) -> None:
# Should not raise
_reject_masked_credentials(
{
"client_id": "1234567890.1234567890",
"client_secret": "test_client_secret_value",
}
)
def test_accepts_empty_dict(self) -> None:
# Should not raise — empty credentials are handled elsewhere
_reject_masked_credentials({})
def test_ignores_non_string_values(self) -> None:
# Non-string values (None, bool, int) should pass through
_reject_masked_credentials(
{
"client_id": "real_value",
"redirect_uri": None,
"some_flag": True,
}
)

View File

@@ -1,76 +0,0 @@
%PDF-1.3
%<25><><EFBFBD><EFBFBD>
1 0 obj
<<
/Producer <1083d595b1>
>>
endobj
2 0 obj
<<
/Type /Pages
/Count 1
/Kids [ 4 0 R ]
>>
endobj
3 0 obj
<<
/Type /Catalog
/Pages 2 0 R
>>
endobj
4 0 obj
<<
/Type /Page
/Resources <<
/Font <<
/F1 <<
/Type /Font
/Subtype /Type1
/BaseFont /Helvetica
>>
>>
>>
/MediaBox [ 0.0 0.0 200 200 ]
/Contents 5 0 R
/Parent 2 0 R
>>
endobj
5 0 obj
<<
/Length 42
>>
stream
,N<><6~<7E>)<29><><EFBFBD><EFBFBD><EFBFBD>u<EFBFBD> <0C><><EFBFBD>Zc'<27><>>8g<38><67><EFBFBD>n<EFBFBD><6E><EFBFBD><EFBFBD><EFBFBD>9"
endstream
endobj
6 0 obj
<<
/V 2
/R 3
/Length 128
/P 4294967292
/Filter /Standard
/O <6a340a292629053da84a6d8b19a5d505953b8b3fdac3d2d389fde0e354528d44>
/U <d6f0dc91c7b9de264a8d708515468e6528bf4e5e4e758a4164004e56fffa0108>
>>
endobj
xref
0 7
0000000000 65535 f
0000000015 00000 n
0000000059 00000 n
0000000118 00000 n
0000000167 00000 n
0000000348 00000 n
0000000440 00000 n
trailer
<<
/Size 7
/Root 3 0 R
/Info 1 0 R
/ID [ <6364336635356135633239323638353039306635656133623165313637366430> <6364336635356135633239323638353039306635656133623165313637366430> ]
/Encrypt 6 0 R
>>
startxref
655
%%EOF

View File

@@ -54,12 +54,6 @@ class TestReadPdfFile:
text, _, _ = read_pdf_file(_load("encrypted.pdf"), pdf_pass="wrong")
assert text == ""
def test_owner_password_only_pdf_extracts_text(self) -> None:
"""A PDF encrypted with only an owner password (no user password)
should still yield its text content. Regression for #9754."""
text, _, _ = read_pdf_file(_load("owner_protected.pdf"))
assert "Hello World" in text
def test_empty_pdf(self) -> None:
text, _, _ = read_pdf_file(_load("empty.pdf"))
assert text.strip() == ""
@@ -123,12 +117,6 @@ class TestIsPdfProtected:
def test_protected_pdf(self) -> None:
assert is_pdf_protected(_load("encrypted.pdf")) is True
def test_owner_password_only_is_not_protected(self) -> None:
"""A PDF with only an owner password (permission restrictions) but no
user password should NOT be considered protected — any viewer can open
it without prompting for a password."""
assert is_pdf_protected(_load("owner_protected.pdf")) is False
def test_preserves_file_position(self) -> None:
pdf = _load("simple.pdf")
pdf.seek(42)

View File

@@ -1,79 +0,0 @@
import io
from pptx import Presentation # type: ignore[import-untyped]
from pptx.chart.data import CategoryChartData # type: ignore[import-untyped]
from pptx.enum.chart import XL_CHART_TYPE # type: ignore[import-untyped]
from pptx.util import Inches # type: ignore[import-untyped]
from onyx.file_processing.extract_file_text import pptx_to_text
def _make_pptx_with_chart() -> io.BytesIO:
"""Create an in-memory pptx with one text slide and one chart slide."""
prs = Presentation()
# Slide 1: text only
slide1 = prs.slides.add_slide(prs.slide_layouts[1])
slide1.shapes.title.text = "Introduction"
slide1.placeholders[1].text = "This is the first slide."
# Slide 2: chart
slide2 = prs.slides.add_slide(prs.slide_layouts[5]) # Blank layout
chart_data = CategoryChartData()
chart_data.categories = ["Q1", "Q2", "Q3"]
chart_data.add_series("Revenue", (100, 200, 300))
slide2.shapes.add_chart(
XL_CHART_TYPE.COLUMN_CLUSTERED,
Inches(1),
Inches(1),
Inches(6),
Inches(4),
chart_data,
)
buf = io.BytesIO()
prs.save(buf)
buf.seek(0)
return buf
def _make_pptx_without_chart() -> io.BytesIO:
"""Create an in-memory pptx with a single text-only slide."""
prs = Presentation()
slide = prs.slides.add_slide(prs.slide_layouts[1])
slide.shapes.title.text = "Hello World"
slide.placeholders[1].text = "Some content here."
buf = io.BytesIO()
prs.save(buf)
buf.seek(0)
return buf
class TestPptxToText:
def test_chart_is_omitted(self) -> None:
# Precondition
pptx_file = _make_pptx_with_chart()
# Under test
result = pptx_to_text(pptx_file)
# Postcondition
assert "Introduction" in result
assert "first slide" in result
assert "[chart omitted]" in result
# The actual chart data should NOT appear in the output.
assert "Revenue" not in result
assert "Q1" not in result
def test_text_only_pptx(self) -> None:
# Precondition
pptx_file = _make_pptx_without_chart()
# Under test
result = pptx_to_text(pptx_file)
# Postcondition
assert "Hello World" in result
assert "Some content" in result
assert "[chart omitted]" not in result

View File

@@ -6,7 +6,6 @@ import (
"text/tabwriter"
"github.com/onyx-dot-app/onyx/cli/internal/api"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/exitcodes"
"github.com/spf13/cobra"
)
@@ -25,7 +24,7 @@ Use --json for machine-readable output.`,
onyx-cli agents --json
onyx-cli agents --json | jq '.[].name'`,
RunE: func(cmd *cobra.Command, args []string) error {
cfg := config.Load()
cfg := loadConfig(cmd)
if !cfg.IsConfigured() {
return exitcodes.New(exitcodes.NotConfigured, "onyx CLI is not configured\n Run: onyx-cli configure")
}

View File

@@ -11,7 +11,6 @@ import (
"syscall"
"github.com/onyx-dot-app/onyx/cli/internal/api"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/exitcodes"
"github.com/onyx-dot-app/onyx/cli/internal/models"
"github.com/onyx-dot-app/onyx/cli/internal/overflow"
@@ -49,7 +48,7 @@ to a temp file. Set --max-output 0 to disable truncation.`,
cat error.log | onyx-cli ask --prompt "Find the root cause"
echo "what is onyx?" | onyx-cli ask`,
RunE: func(cmd *cobra.Command, args []string) error {
cfg := config.Load()
cfg := loadConfig(cmd)
if !cfg.IsConfigured() {
return exitcodes.New(exitcodes.NotConfigured, "onyx CLI is not configured\n Run: onyx-cli configure")
}

View File

@@ -21,7 +21,7 @@ an interactive setup wizard will guide you through configuration.`,
Example: ` onyx-cli chat
onyx-cli`,
RunE: func(cmd *cobra.Command, args []string) error {
cfg := config.Load()
cfg := loadConfig(cmd)
// First-run: onboarding
if !config.ConfigExists() || !cfg.IsConfigured() {

View File

@@ -69,7 +69,7 @@ Use --dry-run to test the connection without saving the configuration.`,
return exitcodes.New(exitcodes.BadRequest, "both --server-url and --api-key are required for non-interactive setup\n Run 'onyx-cli configure' without flags for interactive setup")
}
cfg := config.Load()
cfg := loadConfig(cmd)
onboarding.Run(&cfg)
return nil
},

View File

@@ -12,7 +12,7 @@ func newExperimentsCmd() *cobra.Command {
Use: "experiments",
Short: "List experimental features and their status",
RunE: func(cmd *cobra.Command, args []string) error {
cfg := config.Load()
cfg := loadConfig(cmd)
_, _ = fmt.Fprintln(cmd.OutOrStdout(), config.ExperimentsText(cfg.Features))
return nil
},

View File

@@ -13,6 +13,20 @@ import (
"github.com/spf13/cobra"
)
// loadConfig loads the CLI config, using the --config-file persistent flag if set.
func loadConfig(cmd *cobra.Command) config.OnyxCliConfig {
cf, _ := cmd.Flags().GetString("config-file")
return config.Load(cf)
}
// effectiveConfigPath returns the config file path, respecting --config-file.
func effectiveConfigPath(cmd *cobra.Command) string {
if cf, _ := cmd.Flags().GetString("config-file"); cf != "" {
return cf
}
return config.ConfigFilePath()
}
// Version and Commit are set via ldflags at build time.
var (
Version string
@@ -29,7 +43,7 @@ func fullVersion() string {
func printVersion(cmd *cobra.Command) {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Client version: %s\n", fullVersion())
cfg := config.Load()
cfg := loadConfig(cmd)
if !cfg.IsConfigured() {
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Server version: unknown (not configured)\n")
return
@@ -84,6 +98,8 @@ func Execute() error {
}
rootCmd.PersistentFlags().BoolVar(&opts.Debug, "debug", false, "run in debug mode")
rootCmd.PersistentFlags().String("config-file", "",
"Path to config file (default: "+config.ConfigFilePath()+")")
// Custom --version flag instead of Cobra's built-in (which only shows one version string)
var showVersion bool

View File

@@ -50,16 +50,14 @@ func sessionEnv(s ssh.Session, key string) string {
return ""
}
func validateAPIKey(serverURL string, apiKey string) error {
func validateAPIKey(serverCfg config.OnyxCliConfig, apiKey string) error {
trimmedKey := strings.TrimSpace(apiKey)
if len(trimmedKey) > maxAPIKeyLength {
return fmt.Errorf("API key is too long (max %d characters)", maxAPIKeyLength)
}
cfg := config.OnyxCliConfig{
ServerURL: serverURL,
APIKey: trimmedKey,
}
cfg := serverCfg
cfg.APIKey = trimmedKey
client := api.NewClient(cfg)
ctx, cancel := context.WithTimeout(context.Background(), apiKeyValidationTimeout)
defer cancel()
@@ -83,7 +81,7 @@ type authValidatedMsg struct {
type authModel struct {
input textinput.Model
serverURL string
serverCfg config.OnyxCliConfig
state authState
apiKey string // set on successful validation
errMsg string
@@ -91,7 +89,7 @@ type authModel struct {
aborted bool
}
func newAuthModel(serverURL, initialErr string) authModel {
func newAuthModel(serverCfg config.OnyxCliConfig, initialErr string) authModel {
ti := textinput.New()
ti.Prompt = " API Key: "
ti.EchoMode = textinput.EchoPassword
@@ -102,7 +100,7 @@ func newAuthModel(serverURL, initialErr string) authModel {
return authModel{
input: ti,
serverURL: serverURL,
serverCfg: serverCfg,
errMsg: initialErr,
}
}
@@ -138,9 +136,9 @@ func (m authModel) Update(msg tea.Msg) (authModel, tea.Cmd) {
}
m.state = authValidating
m.errMsg = ""
serverURL := m.serverURL
serverCfg := m.serverCfg
return m, func() tea.Msg {
return authValidatedMsg{key: key, err: validateAPIKey(serverURL, key)}
return authValidatedMsg{key: key, err: validateAPIKey(serverCfg, key)}
}
}
@@ -171,12 +169,13 @@ func (m authModel) Update(msg tea.Msg) (authModel, tea.Cmd) {
}
func (m authModel) View() string {
settingsURL := strings.TrimRight(m.serverURL, "/") + "/app/settings/accounts-access"
serverURL := m.serverCfg.ServerURL
settingsURL := strings.TrimRight(serverURL, "/") + "/app/settings/accounts-access"
var b strings.Builder
b.WriteString("\n")
b.WriteString(" \x1b[1;35mOnyx CLI\x1b[0m\n")
b.WriteString(" \x1b[90m" + m.serverURL + "\x1b[0m\n")
b.WriteString(" \x1b[90m" + serverURL + "\x1b[0m\n")
b.WriteString("\n")
b.WriteString(" Generate an API key at:\n")
b.WriteString(" \x1b[4;34m" + settingsURL + "\x1b[0m\n")
@@ -215,7 +214,7 @@ type serveModel struct {
func newServeModel(serverCfg config.OnyxCliConfig, initialErr string) serveModel {
return serveModel{
auth: newAuthModel(serverCfg.ServerURL, initialErr),
auth: newAuthModel(serverCfg, initialErr),
serverCfg: serverCfg,
}
}
@@ -238,11 +237,8 @@ func (m serveModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, tea.Quit
}
if m.auth.apiKey != "" {
cfg := config.OnyxCliConfig{
ServerURL: m.serverCfg.ServerURL,
APIKey: m.auth.apiKey,
DefaultAgentID: m.serverCfg.DefaultAgentID,
}
cfg := m.serverCfg
cfg.APIKey = m.auth.apiKey
m.tui = tui.NewModel(cfg)
m.authed = true
w, h := m.width, m.height
@@ -280,6 +276,8 @@ func newServeCmd() *cobra.Command {
rateLimitPerMin int
rateLimitBurst int
rateLimitCache int
serverURL string
apiServerURL string
)
cmd := &cobra.Command{
@@ -300,9 +298,18 @@ environment variable (the --host-key flag takes precedence).`,
Example: ` onyx-cli serve --port 2222
ssh localhost -p 2222
onyx-cli serve --host 0.0.0.0 --port 2222
onyx-cli serve --idle-timeout 30m --max-session-timeout 2h`,
onyx-cli serve --idle-timeout 30m --max-session-timeout 2h
onyx-cli serve --server-url https://my-onyx.example.com
onyx-cli serve --api-server-url http://api_server:8080 # bypass nginx
onyx-cli serve --config-file /etc/onyx-cli/config.json # global flag`,
RunE: func(cmd *cobra.Command, args []string) error {
serverCfg := config.Load()
serverCfg := loadConfig(cmd)
if cmd.Flags().Changed("server-url") {
serverCfg.ServerURL = serverURL
}
if cmd.Flags().Changed("api-server-url") {
serverCfg.InternalURL = apiServerURL
}
if serverCfg.ServerURL == "" {
return exitcodes.New(exitcodes.NotConfigured, "server URL is not configured\n Run: onyx-cli configure")
}
@@ -333,7 +340,7 @@ environment variable (the --host-key flag takes precedence).`,
var envErr string
if apiKey != "" {
if err := validateAPIKey(serverCfg.ServerURL, apiKey); err != nil {
if err := validateAPIKey(serverCfg, apiKey); err != nil {
envErr = fmt.Sprintf("ONYX_API_KEY from SSH environment is invalid: %s", err.Error())
apiKey = ""
}
@@ -341,11 +348,8 @@ environment variable (the --host-key flag takes precedence).`,
if apiKey != "" {
// Env key is valid — go straight to the TUI.
cfg := config.OnyxCliConfig{
ServerURL: serverCfg.ServerURL,
APIKey: apiKey,
DefaultAgentID: serverCfg.DefaultAgentID,
}
cfg := serverCfg
cfg.APIKey = apiKey
return tui.NewModel(cfg), []tea.ProgramOption{
tea.WithAltScreen(),
tea.WithMouseCellMotion(),
@@ -446,6 +450,10 @@ environment variable (the --host-key flag takes precedence).`,
defaultServeRateLimitCacheSize,
"Maximum number of IP limiter entries tracked in memory",
)
cmd.Flags().StringVar(&serverURL, "server-url", "",
"Onyx server URL (overrides config file and $"+config.EnvServerURL+")")
cmd.Flags().StringVar(&apiServerURL, "api-server-url", "",
"API server URL for direct access, bypassing nginx (overrides $"+config.EnvAPIServerURL+")")
return cmd
}

View File

@@ -4,10 +4,10 @@ import (
"context"
"errors"
"fmt"
"os"
"time"
"github.com/onyx-dot-app/onyx/cli/internal/api"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/exitcodes"
"github.com/onyx-dot-app/onyx/cli/internal/version"
log "github.com/sirupsen/logrus"
@@ -23,19 +23,21 @@ is valid. Also reports the server version and warns if it is below the
minimum required.`,
Example: ` onyx-cli validate-config`,
RunE: func(cmd *cobra.Command, args []string) error {
cfgPath := effectiveConfigPath(cmd)
// Check config file
if !config.ConfigExists() {
return exitcodes.Newf(exitcodes.NotConfigured, "config file not found at %s\n Run: onyx-cli configure", config.ConfigFilePath())
if _, err := os.Stat(cfgPath); err != nil {
return exitcodes.Newf(exitcodes.NotConfigured, "config file not found at %s\n Run: onyx-cli configure", cfgPath)
}
cfg := config.Load()
cfg := loadConfig(cmd)
// Check API key
if !cfg.IsConfigured() {
return exitcodes.New(exitcodes.NotConfigured, "API key is missing\n Run: onyx-cli configure")
}
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Config: %s\n", config.ConfigFilePath())
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Config: %s\n", cfgPath)
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Server: %s\n", cfg.ServerURL)
// Test connection

View File

@@ -16,18 +16,30 @@ import (
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/models"
log "github.com/sirupsen/logrus"
)
// Client is the Onyx API client.
type Client struct {
baseURL string
serverURL string // root server URL (for reachability checks)
baseURL string // API base URL (includes /api when going through nginx)
apiKey string
httpClient *http.Client // default 30s timeout for quick requests
longHTTPClient *http.Client // 5min timeout for streaming/uploads
}
// NewClient creates a new API client from config.
//
// When InternalURL is set, requests go directly to the API server (no /api
// prefix needed — mirrors INTERNAL_URL in the web server). Otherwise,
// requests go through the nginx proxy at ServerURL which strips /api.
func NewClient(cfg config.OnyxCliConfig) *Client {
baseURL := apiBaseURL(cfg)
log.WithFields(log.Fields{
"server_url": cfg.ServerURL,
"internal_url": cfg.InternalURL,
"base_url": baseURL,
}).Debug("creating API client")
var transport *http.Transport
if t, ok := http.DefaultTransport.(*http.Transport); ok {
transport = t.Clone()
@@ -35,8 +47,9 @@ func NewClient(cfg config.OnyxCliConfig) *Client {
transport = &http.Transport{}
}
return &Client{
baseURL: strings.TrimRight(cfg.ServerURL, "/"),
apiKey: cfg.APIKey,
serverURL: strings.TrimRight(cfg.ServerURL, "/"),
baseURL: baseURL,
apiKey: cfg.APIKey,
httpClient: &http.Client{
Timeout: 30 * time.Second,
Transport: transport,
@@ -48,14 +61,27 @@ func NewClient(cfg config.OnyxCliConfig) *Client {
}
}
// apiBaseURL returns the base URL for API requests. When InternalURL is set,
// it points directly at the API server. Otherwise it goes through the nginx
// proxy at ServerURL/api.
func apiBaseURL(cfg config.OnyxCliConfig) string {
if cfg.InternalURL != "" {
return strings.TrimRight(cfg.InternalURL, "/")
}
return strings.TrimRight(cfg.ServerURL, "/") + "/api"
}
// UpdateConfig replaces the client's config.
func (c *Client) UpdateConfig(cfg config.OnyxCliConfig) {
c.baseURL = strings.TrimRight(cfg.ServerURL, "/")
c.serverURL = strings.TrimRight(cfg.ServerURL, "/")
c.baseURL = apiBaseURL(cfg)
c.apiKey = cfg.APIKey
}
func (c *Client) newRequest(ctx context.Context, method, path string, body io.Reader) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, method, c.baseURL+path, body)
url := c.baseURL + path
log.WithFields(log.Fields{"method": method, "url": url}).Debug("API request")
req, err := http.NewRequestWithContext(ctx, method, url, body)
if err != nil {
return nil, err
}
@@ -87,12 +113,16 @@ func (c *Client) doJSON(ctx context.Context, method, path string, reqBody any, r
resp, err := c.httpClient.Do(req)
if err != nil {
log.WithError(err).WithField("url", req.URL.String()).Debug("API request failed")
return err
}
defer func() { _ = resp.Body.Close() }()
log.WithFields(log.Fields{"url": req.URL.String(), "status": resp.StatusCode}).Debug("API response")
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
respBody, _ := io.ReadAll(resp.Body)
log.WithFields(log.Fields{"status": resp.StatusCode, "body": string(respBody)}).Debug("API error response")
return &OnyxAPIError{StatusCode: resp.StatusCode, Detail: string(respBody)}
}
@@ -105,16 +135,26 @@ func (c *Client) doJSON(ctx context.Context, method, path string, reqBody any, r
// TestConnection checks if the server is reachable and credentials are valid.
// Returns nil on success, or an error with a descriptive message on failure.
func (c *Client) TestConnection(ctx context.Context) error {
// Step 1: Basic reachability
req, err := c.newRequest(ctx, "GET", "/", nil)
// Step 1: Basic reachability (hit the server root, not the API prefix)
reachURL := c.serverURL
if reachURL == "" {
reachURL = c.baseURL
}
log.WithFields(log.Fields{
"reach_url": reachURL,
"base_url": c.baseURL,
}).Debug("testing connection — step 1: reachability")
req, err := http.NewRequestWithContext(ctx, "GET", reachURL+"/", nil)
if err != nil {
return fmt.Errorf("cannot connect to %s: %w", c.baseURL, err)
return fmt.Errorf("cannot connect to %s: %w", reachURL, err)
}
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("cannot connect to %s — is the server running?", c.baseURL)
log.WithError(err).Debug("reachability check failed")
return fmt.Errorf("cannot connect to %s — is the server running?", reachURL)
}
log.WithField("status", resp.StatusCode).Debug("reachability check response")
_ = resp.Body.Close()
serverHeader := strings.ToLower(resp.Header.Get("Server"))
@@ -127,7 +167,8 @@ func (c *Client) TestConnection(ctx context.Context) error {
}
// Step 2: Authenticated check
req2, err := c.newRequest(ctx, "GET", "/api/me", nil)
log.WithField("url", c.baseURL+"/me").Debug("testing connection — step 2: auth check")
req2, err := c.newRequest(ctx, "GET", "/me", nil)
if err != nil {
return fmt.Errorf("server reachable but API error: %w", err)
}
@@ -167,7 +208,7 @@ func (c *Client) TestConnection(ctx context.Context) error {
// ListAgents returns visible agents.
func (c *Client) ListAgents(ctx context.Context) ([]models.AgentSummary, error) {
var raw []models.AgentSummary
if err := c.doJSON(ctx, "GET", "/api/persona", nil, &raw); err != nil {
if err := c.doJSON(ctx, "GET", "/persona", nil, &raw); err != nil {
return nil, err
}
var result []models.AgentSummary
@@ -184,7 +225,7 @@ func (c *Client) ListChatSessions(ctx context.Context) ([]models.ChatSessionDeta
var resp struct {
Sessions []models.ChatSessionDetails `json:"sessions"`
}
if err := c.doJSON(ctx, "GET", "/api/chat/get-user-chat-sessions", nil, &resp); err != nil {
if err := c.doJSON(ctx, "GET", "/chat/get-user-chat-sessions", nil, &resp); err != nil {
return nil, err
}
return resp.Sessions, nil
@@ -193,7 +234,7 @@ func (c *Client) ListChatSessions(ctx context.Context) ([]models.ChatSessionDeta
// GetChatSession returns full details for a session.
func (c *Client) GetChatSession(ctx context.Context, sessionID string) (*models.ChatSessionDetailResponse, error) {
var resp models.ChatSessionDetailResponse
if err := c.doJSON(ctx, "GET", "/api/chat/get-chat-session/"+sessionID, nil, &resp); err != nil {
if err := c.doJSON(ctx, "GET", "/chat/get-chat-session/"+sessionID, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
@@ -210,7 +251,7 @@ func (c *Client) RenameChatSession(ctx context.Context, sessionID string, name *
var resp struct {
NewName string `json:"new_name"`
}
if err := c.doJSON(ctx, "PUT", "/api/chat/rename-chat-session", payload, &resp); err != nil {
if err := c.doJSON(ctx, "PUT", "/chat/rename-chat-session", payload, &resp); err != nil {
return "", err
}
return resp.NewName, nil
@@ -236,7 +277,7 @@ func (c *Client) UploadFile(ctx context.Context, filePath string) (*models.FileD
}
_ = writer.Close()
req, err := c.newRequest(ctx, "POST", "/api/user/projects/file/upload", &buf)
req, err := c.newRequest(ctx, "POST", "/user/projects/file/upload", &buf)
if err != nil {
return nil, err
}
@@ -275,7 +316,7 @@ func (c *Client) GetBackendVersion(ctx context.Context) (string, error) {
var resp struct {
BackendVersion string `json:"backend_version"`
}
if err := c.doJSON(ctx, "GET", "/api/version", nil, &resp); err != nil {
if err := c.doJSON(ctx, "GET", "/version", nil, &resp); err != nil {
return "", err
}
return resp.BackendVersion, nil
@@ -283,7 +324,7 @@ func (c *Client) GetBackendVersion(ctx context.Context) (string, error) {
// StopChatSession sends a stop signal for a streaming session (best-effort).
func (c *Client) StopChatSession(ctx context.Context, sessionID string) {
req, err := c.newRequest(ctx, "POST", "/api/chat/stop-chat-session/"+sessionID, nil)
req, err := c.newRequest(ctx, "POST", "/chat/stop-chat-session/"+sessionID, nil)
if err != nil {
return
}

View File

@@ -64,7 +64,7 @@ func (c *Client) SendMessageStream(
return
}
req, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+"/api/chat/send-chat-message", nil)
req, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+"/chat/send-chat-message", nil)
if err != nil {
ch <- models.ErrorEvent{Error: fmt.Sprintf("request error: %v", err), IsRetryable: false}
return

View File

@@ -10,6 +10,7 @@ import (
const (
EnvServerURL = "ONYX_SERVER_URL"
EnvAPIServerURL = "ONYX_API_SERVER_URL"
EnvAPIKey = "ONYX_API_KEY"
EnvAgentID = "ONYX_PERSONA_ID"
EnvSSHHostKey = "ONYX_SSH_HOST_KEY"
@@ -27,6 +28,7 @@ type Features struct {
// OnyxCliConfig holds the CLI configuration.
type OnyxCliConfig struct {
ServerURL string `json:"server_url"`
InternalURL string `json:"internal_url,omitempty"`
APIKey string `json:"api_key"`
DefaultAgentID int `json:"default_persona_id"`
Features Features `json:"features,omitempty"`
@@ -78,30 +80,47 @@ func ConfigExists() bool {
return err == nil
}
// LoadFromDisk reads config from the file only, without applying environment
// variable overrides. Use this when you need the persisted config values
// (e.g., to preserve them during a save operation).
func LoadFromDisk() OnyxCliConfig {
// LoadFromDisk reads config from the given file path without applying
// environment variable overrides. Use this when you need the persisted
// config values (e.g., to preserve them during a save operation).
// If no path is provided, the default config file path is used.
func LoadFromDisk(path ...string) OnyxCliConfig {
p := ConfigFilePath()
if len(path) > 0 && path[0] != "" {
p = path[0]
}
cfg := DefaultConfig()
data, err := os.ReadFile(ConfigFilePath())
data, err := os.ReadFile(p)
if err == nil {
if jsonErr := json.Unmarshal(data, &cfg); jsonErr != nil {
fmt.Fprintf(os.Stderr, "warning: config file %s is malformed: %v (using defaults)\n", ConfigFilePath(), jsonErr)
fmt.Fprintf(os.Stderr, "warning: config file %s is malformed: %v (using defaults)\n", p, jsonErr)
}
}
return cfg
}
// Load reads config from file and applies environment variable overrides.
func Load() OnyxCliConfig {
cfg := LoadFromDisk()
// Load reads config from the given file path and applies environment variable
// overrides. If no path is provided, the default config file path is used.
func Load(path ...string) OnyxCliConfig {
cfg := LoadFromDisk(path...)
applyEnvOverrides(&cfg)
return cfg
}
// Environment overrides
func applyEnvOverrides(cfg *OnyxCliConfig) {
if v := os.Getenv(EnvServerURL); v != "" {
cfg.ServerURL = v
}
// ONYX_API_SERVER_URL takes precedence; fall back to INTERNAL_URL
// (the env var used by the web server) for compatibility.
if v := os.Getenv(EnvAPIServerURL); v != "" {
cfg.InternalURL = v
} else if v := os.Getenv("INTERNAL_URL"); v != "" {
cfg.InternalURL = v
}
if v := os.Getenv(EnvAPIKey); v != "" {
cfg.APIKey = v
}
@@ -117,8 +136,6 @@ func Load() OnyxCliConfig {
fmt.Fprintf(os.Stderr, "warning: invalid value %q for %s (expected true/false), ignoring\n", v, EnvStreamMarkdown)
}
}
return cfg
}
// Save writes the config to disk, creating parent directories if needed.

View File

@@ -19,6 +19,6 @@ dependencies:
version: 5.4.0
- name: code-interpreter
repository: https://onyx-dot-app.github.io/python-sandbox/
version: 0.3.2
digest: sha256:74908ea45ace2b4be913ff762772e6d87e40bab64e92c6662aa51730eaeb9d87
generated: "2026-04-06T15:34:02.597166-07:00"
version: 0.3.1
digest: sha256:4965b6ea3674c37163832a2192cd3bc8004f2228729fca170af0b9f457e8f987
generated: "2026-03-02T15:29:39.632344-08:00"

View File

@@ -5,7 +5,7 @@ home: https://www.onyx.app/
sources:
- "https://github.com/onyx-dot-app/onyx"
type: application
version: 0.4.40
version: 0.4.39
appVersion: latest
annotations:
category: Productivity
@@ -45,6 +45,6 @@ dependencies:
repository: https://charts.min.io/
condition: minio.enabled
- name: code-interpreter
version: 0.3.2
version: 0.3.1
repository: https://onyx-dot-app.github.io/python-sandbox/
condition: codeInterpreter.enabled

View File

@@ -67,9 +67,6 @@ spec:
- "/bin/sh"
- "-c"
- |
{{- if .Values.api.runUpdateCaCertificates }}
update-ca-certificates &&
{{- end }}
alembic upgrade head &&
echo "Starting Onyx Api Server" &&
uvicorn onyx.main:app --host {{ .Values.global.host }} --port {{ .Values.api.containerPorts.server }}

View File

@@ -504,18 +504,6 @@ api:
tolerations: []
affinity: {}
# Run update-ca-certificates before starting the server.
# Useful when mounting custom CA certificates via volumes/volumeMounts.
# NOTE: Requires the container to run as root (runAsUser: 0).
# CA certificate files must be mounted under /usr/local/share/ca-certificates/
# with a .crt extension (e.g. /usr/local/share/ca-certificates/my-ca.crt).
# NOTE: Python HTTP clients (requests, httpx) use certifi's bundle by default
# and will not pick up the system CA store automatically. Set the following
# environment variables via configMap values (loaded through envFrom) to make them use the updated system bundle:
# REQUESTS_CA_BUNDLE: /etc/ssl/certs/ca-certificates.crt
# SSL_CERT_FILE: /etc/ssl/certs/ca-certificates.crt
runUpdateCaCertificates: false
######################################################################
#

View File

@@ -30,10 +30,7 @@ target "backend" {
context = "backend"
dockerfile = "Dockerfile"
cache-from = [
"type=registry,ref=${BACKEND_REPOSITORY}:latest",
"type=registry,ref=${BACKEND_REPOSITORY}:edge",
]
cache-from = ["type=registry,ref=${BACKEND_REPOSITORY}:latest"]
cache-to = ["type=inline"]
tags = ["${BACKEND_REPOSITORY}:${TAG}"]
@@ -43,10 +40,7 @@ target "web" {
context = "web"
dockerfile = "Dockerfile"
cache-from = [
"type=registry,ref=${WEB_SERVER_REPOSITORY}:latest",
"type=registry,ref=${WEB_SERVER_REPOSITORY}:edge",
]
cache-from = ["type=registry,ref=${WEB_SERVER_REPOSITORY}:latest"]
cache-to = ["type=inline"]
tags = ["${WEB_SERVER_REPOSITORY}:${TAG}"]
@@ -57,10 +51,7 @@ target "model-server" {
dockerfile = "Dockerfile.model_server"
cache-from = [
"type=registry,ref=${MODEL_SERVER_REPOSITORY}:latest",
"type=registry,ref=${MODEL_SERVER_REPOSITORY}:edge",
]
cache-from = ["type=registry,ref=${MODEL_SERVER_REPOSITORY}:latest"]
cache-to = ["type=inline"]
tags = ["${MODEL_SERVER_REPOSITORY}:${TAG}"]
@@ -82,10 +73,7 @@ target "cli" {
context = "cli"
dockerfile = "Dockerfile"
cache-from = [
"type=registry,ref=${CLI_REPOSITORY}:latest",
"type=registry,ref=${CLI_REPOSITORY}:edge",
]
cache-from = ["type=registry,ref=${CLI_REPOSITORY}:latest"]
cache-to = ["type=inline"]
tags = ["${CLI_REPOSITORY}:${TAG}"]

View File

@@ -6,11 +6,11 @@ All Prometheus metrics live in the `backend/onyx/server/metrics/` package. Follo
### 1. Choose the right file (or create a new one)
| File | Purpose |
| ------------------------------------- | -------------------------------------------- |
| `metrics/slow_requests.py` | Slow request counter + callback |
| `metrics/postgres_connection_pool.py` | SQLAlchemy connection pool metrics |
| `metrics/prometheus_setup.py` | FastAPI instrumentator config (orchestrator) |
| File | Purpose |
|------|---------|
| `metrics/slow_requests.py` | Slow request counter + callback |
| `metrics/postgres_connection_pool.py` | SQLAlchemy connection pool metrics |
| `metrics/prometheus_setup.py` | FastAPI instrumentator config (orchestrator) |
If your metric is a standalone concern (e.g. cache hit rates, queue depths), create a new file under `metrics/` and keep one metric concept per file.
@@ -30,7 +30,6 @@ _my_counter = Counter(
```
**Naming conventions:**
- Prefix all metric names with `onyx_`
- Counters: `_total` suffix (e.g. `onyx_api_slow_requests_total`)
- Histograms: `_seconds` or `_bytes` suffix for durations/sizes
@@ -108,26 +107,26 @@ These metrics are exposed at `GET /metrics` on the API server.
### Built-in (via `prometheus-fastapi-instrumentator`)
| Metric | Type | Labels | Description |
| ------------------------------------- | --------- | ----------------------------- | ------------------------------------------------- |
| `http_requests_total` | Counter | `method`, `status`, `handler` | Total request count |
| `http_request_duration_highr_seconds` | Histogram | _(none)_ | High-resolution latency (many buckets, no labels) |
| `http_request_duration_seconds` | Histogram | `method`, `handler` | Latency by handler (custom buckets for P95/P99) |
| `http_request_size_bytes` | Summary | `handler` | Incoming request content length |
| `http_response_size_bytes` | Summary | `handler` | Outgoing response content length |
| `http_requests_inprogress` | Gauge | `method`, `handler` | Currently in-flight requests |
| Metric | Type | Labels | Description |
|--------|------|--------|-------------|
| `http_requests_total` | Counter | `method`, `status`, `handler` | Total request count |
| `http_request_duration_highr_seconds` | Histogram | _(none)_ | High-resolution latency (many buckets, no labels) |
| `http_request_duration_seconds` | Histogram | `method`, `handler` | Latency by handler (custom buckets for P95/P99) |
| `http_request_size_bytes` | Summary | `handler` | Incoming request content length |
| `http_response_size_bytes` | Summary | `handler` | Outgoing response content length |
| `http_requests_inprogress` | Gauge | `method`, `handler` | Currently in-flight requests |
### Custom (via `onyx.server.metrics`)
| Metric | Type | Labels | Description |
| ------------------------------ | ------- | ----------------------------- | ---------------------------------------------------------------- |
| Metric | Type | Labels | Description |
|--------|------|--------|-------------|
| `onyx_api_slow_requests_total` | Counter | `method`, `handler`, `status` | Requests exceeding `SLOW_REQUEST_THRESHOLD_SECONDS` (default 1s) |
### Configuration
| Env Var | Default | Description |
| -------------------------------- | ------- | -------------------------------------------- |
| `SLOW_REQUEST_THRESHOLD_SECONDS` | `1.0` | Duration threshold for slow request counting |
| Env Var | Default | Description |
|---------|---------|-------------|
| `SLOW_REQUEST_THRESHOLD_SECONDS` | `1.0` | Duration threshold for slow request counting |
### Instrumentator Settings
@@ -142,188 +141,44 @@ These metrics provide visibility into SQLAlchemy connection pool state across al
### Pool State (via custom Prometheus collector — snapshot on each scrape)
| Metric | Type | Labels | Description |
| -------------------------- | ----- | -------- | ----------------------------------------------- |
| `onyx_db_pool_checked_out` | Gauge | `engine` | Currently checked-out connections |
| `onyx_db_pool_checked_in` | Gauge | `engine` | Idle connections available in the pool |
| `onyx_db_pool_overflow` | Gauge | `engine` | Current overflow connections beyond `pool_size` |
| `onyx_db_pool_size` | Gauge | `engine` | Configured pool size (constant) |
| Metric | Type | Labels | Description |
|--------|------|--------|-------------|
| `onyx_db_pool_checked_out` | Gauge | `engine` | Currently checked-out connections |
| `onyx_db_pool_checked_in` | Gauge | `engine` | Idle connections available in the pool |
| `onyx_db_pool_overflow` | Gauge | `engine` | Current overflow connections beyond `pool_size` |
| `onyx_db_pool_size` | Gauge | `engine` | Configured pool size (constant) |
### Pool Lifecycle (via SQLAlchemy pool event listeners)
| Metric | Type | Labels | Description |
| ---------------------------------------- | ------- | -------- | ---------------------------------------- |
| `onyx_db_pool_checkout_total` | Counter | `engine` | Total connection checkouts from the pool |
| `onyx_db_pool_checkin_total` | Counter | `engine` | Total connection checkins to the pool |
| `onyx_db_pool_connections_created_total` | Counter | `engine` | Total new database connections created |
| `onyx_db_pool_invalidations_total` | Counter | `engine` | Total connection invalidations |
| `onyx_db_pool_checkout_timeout_total` | Counter | `engine` | Total connection checkout timeouts |
| Metric | Type | Labels | Description |
|--------|------|--------|-------------|
| `onyx_db_pool_checkout_total` | Counter | `engine` | Total connection checkouts from the pool |
| `onyx_db_pool_checkin_total` | Counter | `engine` | Total connection checkins to the pool |
| `onyx_db_pool_connections_created_total` | Counter | `engine` | Total new database connections created |
| `onyx_db_pool_invalidations_total` | Counter | `engine` | Total connection invalidations |
| `onyx_db_pool_checkout_timeout_total` | Counter | `engine` | Total connection checkout timeouts |
### Per-Endpoint Attribution (via pool events + endpoint context middleware)
| Metric | Type | Labels | Description |
| -------------------------------------- | --------- | ------------------- | ----------------------------------------------- |
| `onyx_db_connections_held_by_endpoint` | Gauge | `handler`, `engine` | DB connections currently held, by endpoint |
| `onyx_db_connection_hold_seconds` | Histogram | `handler`, `engine` | Duration a DB connection is held by an endpoint |
| Metric | Type | Labels | Description |
|--------|------|--------|-------------|
| `onyx_db_connections_held_by_endpoint` | Gauge | `handler`, `engine` | DB connections currently held, by endpoint |
| `onyx_db_connection_hold_seconds` | Histogram | `handler`, `engine` | Duration a DB connection is held by an endpoint |
Engine label values: `sync` (main read-write), `async` (async sessions), `readonly` (read-only user).
Connections from background tasks (Celery) or boot-time warmup appear as `handler="unknown"`.
## Celery Worker Metrics
Celery workers expose metrics via a standalone Prometheus HTTP server (separate from the API server's `/metrics` endpoint). Each worker type runs its own server on a dedicated port.
### Metrics Server (`onyx.server.metrics.metrics_server`)
| Env Var | Default | Description |
| ---------------------------- | ------------------- | ----------------------------------------------------- |
| `PROMETHEUS_METRICS_PORT` | _(per worker type)_ | Override the default port for this worker |
| `PROMETHEUS_METRICS_ENABLED` | `true` | Set to `false` to disable the metrics server entirely |
Default ports:
| Worker | Port |
| --------------- | ---- |
| `docfetching` | 9092 |
| `docprocessing` | 9093 |
| `monitoring` | 9096 |
Workers without a default port and no `PROMETHEUS_METRICS_PORT` env var will skip starting the server.
### Generic Task Lifecycle Metrics (`onyx.server.metrics.celery_task_metrics`)
Push-based metrics that fire on Celery signals for all tasks on the worker.
| Metric | Type | Labels | Description |
| ----------------------------------- | --------- | ------------------------------- | ----------------------------------------------------------------------------- |
| `onyx_celery_task_started_total` | Counter | `task_name`, `queue` | Total tasks started |
| `onyx_celery_task_completed_total` | Counter | `task_name`, `queue`, `outcome` | Total tasks completed (`outcome`: `success` or `failure`) |
| `onyx_celery_task_duration_seconds` | Histogram | `task_name`, `queue` | Task execution duration. Buckets: 1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600 |
| `onyx_celery_tasks_active` | Gauge | `task_name`, `queue` | Currently executing tasks |
| `onyx_celery_task_retried_total` | Counter | `task_name`, `queue` | Total task retries |
| `onyx_celery_task_revoked_total` | Counter | `task_name` | Total tasks revoked (cancelled) |
| `onyx_celery_task_rejected_total` | Counter | `task_name` | Total tasks rejected by worker |
Stale start-time entries (tasks killed via SIGTERM/OOM where `task_postrun` never fires) are evicted after 1 hour.
### Per-Connector Indexing Metrics (`onyx.server.metrics.indexing_task_metrics`)
Enriches docfetching and docprocessing tasks with connector-level labels. Silently no-ops for all other tasks.
| Metric | Type | Labels | Description |
| ------------------------------------- | --------- | ----------------------------------------------------------- | ---------------------------------------- |
| `onyx_indexing_task_started_total` | Counter | `task_name`, `source`, `tenant_id`, `cc_pair_id` | Indexing tasks started per connector |
| `onyx_indexing_task_completed_total` | Counter | `task_name`, `source`, `tenant_id`, `cc_pair_id`, `outcome` | Indexing tasks completed per connector |
| `onyx_indexing_task_duration_seconds` | Histogram | `task_name`, `source`, `tenant_id` | Indexing task duration by connector type |
`connector_name` is intentionally excluded from these push-based counters to avoid unbounded cardinality (it's a free-form user string). The pull-based collectors on the monitoring worker include it since they have bounded cardinality (one series per connector).
### Pull-Based Collectors (`onyx.server.metrics.indexing_pipeline`)
Registered only in the **Monitoring** worker. Collectors query Redis/Postgres at scrape time with a 30-second TTL cache.
| Metric | Type | Labels | Description |
| ------------------------------------ | ----- | ------- | ----------------------------------- |
| `onyx_queue_depth` | Gauge | `queue` | Celery queue length |
| `onyx_queue_unacked` | Gauge | `queue` | Unacknowledged messages per queue |
| `onyx_queue_oldest_task_age_seconds` | Gauge | `queue` | Age of the oldest task in the queue |
Plus additional connector health, index attempt, and worker heartbeat metrics — see `indexing_pipeline.py` for the full list.
### Adding Metrics to a Worker
Currently only the docfetching and docprocessing workers have push-based task metrics wired up. To add metrics to another worker (e.g. heavy, light, primary):
**1. Import and call the generic handlers from the worker's signal handlers:**
```python
from onyx.server.metrics.celery_task_metrics import (
on_celery_task_prerun,
on_celery_task_postrun,
on_celery_task_retry,
on_celery_task_revoked,
on_celery_task_rejected,
)
@signals.task_prerun.connect
def on_task_prerun(sender, task_id, task, args, kwargs, **kwds):
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
on_celery_task_prerun(task_id, task)
```
Do the same for `task_postrun`, `task_retry`, `task_revoked`, and `task_rejected` — see `apps/docfetching.py` for the complete example.
**2. Start the metrics server on `worker_ready`:**
```python
from onyx.server.metrics.metrics_server import start_metrics_server
@worker_ready.connect
def on_worker_ready(sender, **kwargs):
start_metrics_server("your_worker_type")
app_base.on_worker_ready(sender, **kwargs)
```
Add a default port for your worker type in `metrics_server.py`'s `_DEFAULT_PORTS` dict, or set `PROMETHEUS_METRICS_PORT` in the environment.
**3. (Optional) Add domain-specific enrichment:**
If your tasks need richer labels beyond `task_name`/`queue`, create a new module in `server/metrics/` following `indexing_task_metrics.py`:
- Define Counters/Histograms with your domain labels
- Write `on_<domain>_task_prerun` / `on_<domain>_task_postrun` handlers that filter by task name and no-op for others
- Call them from the worker's signal handlers alongside the generic ones
**Cardinality warning:** Never use user-defined free-form strings as metric labels — they create unbounded cardinality. Use IDs or enum values. If you need free-form labels, use pull-based collectors (monitoring worker) where cardinality is naturally bounded.
### Current Worker Integration Status
| Worker | Generic Task Metrics | Domain Metrics | Metrics Server |
| -------------------- | -------------------- | -------------- | ------------------------------------ |
| Docfetching | ✓ | ✓ (indexing) | ✓ (port 9092) |
| Docprocessing | ✓ | ✓ (indexing) | ✓ (port 9093) |
| Monitoring | — | — | ✓ (port 9096, pull-based collectors) |
| Primary | — | — | — |
| Light | — | — | — |
| Heavy | — | — | — |
| User File Processing | — | — | — |
| KG Processing | — | — | — |
### Example PromQL Queries (Celery)
```promql
# Task completion rate by worker queue
sum by (queue) (rate(onyx_celery_task_completed_total[5m]))
# P95 task duration for pruning tasks
histogram_quantile(0.95,
sum by (le) (rate(onyx_celery_task_duration_seconds_bucket{task_name=~".*pruning.*"}[5m])))
# Task failure rate
sum by (task_name) (rate(onyx_celery_task_completed_total{outcome="failure"}[5m]))
/ sum by (task_name) (rate(onyx_celery_task_completed_total[5m]))
# Active tasks per queue
sum by (queue) (onyx_celery_tasks_active)
# Indexing throughput by source type
sum by (source) (rate(onyx_indexing_task_completed_total{outcome="success"}[5m]))
# Queue depth — are tasks backing up?
onyx_queue_depth > 100
```
## OpenSearch Search Metrics
These metrics track OpenSearch search latency and throughput. Collected via `onyx.server.metrics.opensearch_search`.
| Metric | Type | Labels | Description |
| ------------------------------------------------ | --------- | ------------- | --------------------------------------------------------------------------- |
| Metric | Type | Labels | Description |
|--------|------|--------|-------------|
| `onyx_opensearch_search_client_duration_seconds` | Histogram | `search_type` | Client-side end-to-end latency (network + serialization + server execution) |
| `onyx_opensearch_search_server_duration_seconds` | Histogram | `search_type` | Server-side execution time from OpenSearch `took` field |
| `onyx_opensearch_search_total` | Counter | `search_type` | Total search requests sent to OpenSearch |
| `onyx_opensearch_searches_in_progress` | Gauge | `search_type` | Currently in-flight OpenSearch searches |
| `onyx_opensearch_search_server_duration_seconds` | Histogram | `search_type` | Server-side execution time from OpenSearch `took` field |
| `onyx_opensearch_search_total` | Counter | `search_type` | Total search requests sent to OpenSearch |
| `onyx_opensearch_searches_in_progress` | Gauge | `search_type` | Currently in-flight OpenSearch searches |
Search type label values: See `OpenSearchSearchType`.

View File

@@ -70,10 +70,6 @@ backend = [
"lazy_imports==1.0.1",
"lxml==5.3.0",
"Mako==1.2.4",
# NOTE: Do not update without understanding the patching behavior in
# get_markitdown_converter in
# backend/onyx/file_processing/extract_file_text.py and what impacts
# updating might have on this behavior.
"markitdown[pdf, docx, pptx, xlsx, xls]==0.1.2",
"mcp[cli]==1.26.0",
"msal==1.34.0",

View File

@@ -1,22 +1,18 @@
import { Card } from "@opal/components/cards/card/components";
import { Content, SizePreset } from "@opal/layouts";
import { Content } from "@opal/layouts";
import { SvgEmpty } from "@opal/icons";
import type {
IconFunctionComponent,
PaddingVariants,
RichStr,
} from "@opal/types";
import type { IconFunctionComponent, PaddingVariants } from "@opal/types";
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
type EmptyMessageCardBaseProps = {
type EmptyMessageCardProps = {
/** Icon displayed alongside the title. */
icon?: IconFunctionComponent;
/** Primary message text. */
title: string | RichStr;
title: string;
/** Padding preset for the card. @default "md" */
padding?: PaddingVariants;
@@ -25,30 +21,16 @@ type EmptyMessageCardBaseProps = {
ref?: React.Ref<HTMLDivElement>;
};
type EmptyMessageCardProps =
| (EmptyMessageCardBaseProps & {
/** @default "secondary" */
sizePreset?: "secondary";
})
| (EmptyMessageCardBaseProps & {
sizePreset: "main-ui";
/** Description text. Only supported when `sizePreset` is `"main-ui"`. */
description?: string | RichStr;
});
// ---------------------------------------------------------------------------
// EmptyMessageCard
// ---------------------------------------------------------------------------
function EmptyMessageCard(props: EmptyMessageCardProps) {
const {
sizePreset = "secondary",
icon = SvgEmpty,
title,
padding = "md",
ref,
} = props;
function EmptyMessageCard({
icon = SvgEmpty,
title,
padding = "md",
ref,
}: EmptyMessageCardProps) {
return (
<Card
ref={ref}
@@ -57,23 +39,13 @@ function EmptyMessageCard(props: EmptyMessageCardProps) {
padding={padding}
rounding="md"
>
{sizePreset === "secondary" ? (
<Content
icon={icon}
title={title}
sizePreset="secondary"
variant="body"
prominence="muted"
/>
) : (
<Content
icon={icon}
title={title}
description={"description" in props ? props.description : undefined}
sizePreset={sizePreset}
variant="section"
/>
)}
<Content
icon={icon}
title={title}
sizePreset="secondary"
variant="body"
prominence="muted"
/>
</Card>
);
}

View File

@@ -1,16 +1,41 @@
"use client";
import "@opal/core/animations/styles.css";
import React from "react";
import React, { createContext, useContext, useState, useCallback } from "react";
import { cn } from "@opal/utils";
import type { WithoutStyles, ExtremaSizeVariants } from "@opal/types";
import { widthVariants } from "@opal/shared";
// ---------------------------------------------------------------------------
// Types
// Context-per-group registry
// ---------------------------------------------------------------------------
type HoverableInteraction = "rest" | "hover";
/**
* Lazily-created map of group names to React contexts.
*
* Each group gets its own `React.Context<boolean | null>` so that a
* `Hoverable.Item` only re-renders when its *own* group's hover state
* changes — not when any unrelated group changes.
*
* The default value is `null` (no provider found), which lets
* `Hoverable.Item` distinguish "no Root ancestor" from "Root says
* not hovered" and throw when `group` was explicitly specified.
*/
const contextMap = new Map<string, React.Context<boolean | null>>();
function getOrCreateContext(group: string): React.Context<boolean | null> {
let ctx = contextMap.get(group);
if (!ctx) {
ctx = createContext<boolean | null>(null);
ctx.displayName = `HoverableContext(${group})`;
contextMap.set(group, ctx);
}
return ctx;
}
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
interface HoverableRootProps
extends WithoutStyles<React.HTMLAttributes<HTMLDivElement>> {
@@ -18,17 +43,6 @@ interface HoverableRootProps
group: string;
/** Width preset. @default "auto" */
widthVariant?: ExtremaSizeVariants;
/**
* JS-controllable interaction state override.
*
* - `"rest"` (default): items are shown/hidden by CSS `:hover`.
* - `"hover"`: forces items visible regardless of hover state. Useful when
* a hoverable action opens a modal — set `interaction="hover"` while the
* modal is open so the user can see which element they're interacting with.
*
* @default "rest"
*/
interaction?: HoverableInteraction;
/** Ref forwarded to the root `<div>`. */
ref?: React.Ref<HTMLDivElement>;
}
@@ -51,10 +65,12 @@ interface HoverableItemProps
/**
* Hover-tracking container for a named group.
*
* Uses a `data-hover-group` attribute and CSS `:hover` to control
* descendant `Hoverable.Item` visibility. No React state or context
* the browser natively removes `:hover` when modals/portals steal
* pointer events, preventing stale hover state.
* Wraps children in a `<div>` that tracks mouse-enter / mouse-leave and
* provides the hover state via a per-group React context.
*
* Nesting works because each `Hoverable.Root` creates a **new** context
* provider that shadows the parent — so an inner `Hoverable.Item group="b"`
* reads from the inner provider, not the outer `group="a"` provider.
*
* @example
* ```tsx
@@ -71,20 +87,70 @@ function HoverableRoot({
group,
children,
widthVariant = "full",
interaction = "rest",
ref,
onMouseEnter: consumerMouseEnter,
onMouseLeave: consumerMouseLeave,
onFocusCapture: consumerFocusCapture,
onBlurCapture: consumerBlurCapture,
...props
}: HoverableRootProps) {
const [hovered, setHovered] = useState(false);
const [focused, setFocused] = useState(false);
const onMouseEnter = useCallback(
(e: React.MouseEvent<HTMLDivElement>) => {
setHovered(true);
consumerMouseEnter?.(e);
},
[consumerMouseEnter]
);
const onMouseLeave = useCallback(
(e: React.MouseEvent<HTMLDivElement>) => {
setHovered(false);
consumerMouseLeave?.(e);
},
[consumerMouseLeave]
);
const onFocusCapture = useCallback(
(e: React.FocusEvent<HTMLDivElement>) => {
setFocused(true);
consumerFocusCapture?.(e);
},
[consumerFocusCapture]
);
const onBlurCapture = useCallback(
(e: React.FocusEvent<HTMLDivElement>) => {
if (
!(e.relatedTarget instanceof Node) ||
!e.currentTarget.contains(e.relatedTarget)
) {
setFocused(false);
}
consumerBlurCapture?.(e);
},
[consumerBlurCapture]
);
const active = hovered || focused;
const GroupContext = getOrCreateContext(group);
return (
<div
{...props}
ref={ref}
className={cn(widthVariants[widthVariant])}
data-hover-group={group}
data-interaction={interaction !== "rest" ? interaction : undefined}
>
{children}
</div>
<GroupContext.Provider value={active}>
<div
{...props}
ref={ref}
className={cn(widthVariants[widthVariant])}
onMouseEnter={onMouseEnter}
onMouseLeave={onMouseLeave}
onFocusCapture={onFocusCapture}
onBlurCapture={onBlurCapture}
>
{children}
</div>
</GroupContext.Provider>
);
}
@@ -96,10 +162,13 @@ function HoverableRoot({
* An element whose visibility is controlled by hover state.
*
* **Local mode** (`group` omitted): the item handles hover on its own
* element via CSS `:hover`.
* element via CSS `:hover`. This is the core abstraction.
*
* **Group mode** (`group` provided): visibility is driven by CSS `:hover`
* on the nearest `Hoverable.Root` ancestor via `[data-hover-group]:hover`.
* **Group mode** (`group` provided): visibility is driven by a matching
* `Hoverable.Root` ancestor's hover state via React context. If no
* matching Root is found, an error is thrown.
*
* Uses data-attributes for variant styling (see `styles.css`).
*
* @example
* ```tsx
@@ -115,6 +184,8 @@ function HoverableRoot({
* </Hoverable.Item>
* </Hoverable.Root>
* ```
*
* @throws If `group` is specified but no matching `Hoverable.Root` ancestor exists.
*/
function HoverableItem({
group,
@@ -123,6 +194,17 @@ function HoverableItem({
ref,
...props
}: HoverableItemProps) {
const contextValue = useContext(
group ? getOrCreateContext(group) : NOOP_CONTEXT
);
if (group && contextValue === null) {
throw new Error(
`Hoverable.Item group="${group}" has no matching Hoverable.Root ancestor. ` +
`Either wrap it in <Hoverable.Root group="${group}"> or remove the group prop for local hover.`
);
}
const isLocal = group === undefined;
return (
@@ -131,6 +213,9 @@ function HoverableItem({
ref={ref}
className={cn("hoverable-item")}
data-hoverable-variant={variant}
data-hoverable-active={
isLocal ? undefined : contextValue ? "true" : undefined
}
data-hoverable-local={isLocal ? "true" : undefined}
>
{children}
@@ -138,6 +223,9 @@ function HoverableItem({
);
}
/** Stable context used when no group is specified (local mode). */
const NOOP_CONTEXT = createContext<boolean | null>(null);
// ---------------------------------------------------------------------------
// Compound export
// ---------------------------------------------------------------------------
@@ -145,16 +233,18 @@ function HoverableItem({
/**
* Hoverable compound component for hover-to-reveal patterns.
*
* Entirely CSS-driven — no React state or context. The browser's native
* `:hover` pseudo-class handles all state, which means hover is
* automatically cleared when modals/portals steal pointer events.
* Provides two sub-components:
*
* - `Hoverable.Root` — Container with `data-hover-group`. CSS `:hover`
* on this element reveals descendant `Hoverable.Item` elements.
* - `Hoverable.Root` — A container that tracks hover state for a named group
* and provides it via React context.
*
* - `Hoverable.Item` — Hidden by default. In group mode, revealed when
* the ancestor Root is hovered. In local mode (no `group`), revealed
* when the item itself is hovered.
* - `Hoverable.Item` — The core abstraction. On its own (no `group`), it
* applies local CSS `:hover` for the variant effect. When `group` is
* specified, it reads hover state from the nearest matching
* `Hoverable.Root` — and throws if no matching Root is found.
*
* Supports nesting: a child `Hoverable.Root` shadows the parent's context,
* so each group's items only respond to their own root's hover.
*
* @example
* ```tsx
@@ -186,5 +276,4 @@ export {
type HoverableRootProps,
type HoverableItemProps,
type HoverableItemVariant,
type HoverableInteraction,
};

View File

@@ -7,20 +7,8 @@
opacity: 0;
}
/* Group mode — Root :hover controls descendant item visibility via CSS.
Exclude local-mode items so they aren't revealed by an ancestor root. */
[data-hover-group]:hover
.hoverable-item[data-hoverable-variant="opacity-on-hover"]:not(
[data-hoverable-local]
) {
opacity: 1;
}
/* Interaction override — force items visible via JS */
[data-hover-group][data-interaction="hover"]
.hoverable-item[data-hoverable-variant="opacity-on-hover"]:not(
[data-hoverable-local]
) {
/* Group mode — Root controls visibility via React context */
.hoverable-item[data-hoverable-variant="opacity-on-hover"][data-hoverable-active="true"] {
opacity: 1;
}
@@ -29,16 +17,7 @@
opacity: 1;
}
/* Group focus — any focusable descendant of the Root receives keyboard focus,
revealing all group items (same behavior as hover). */
[data-hover-group]:focus-within
.hoverable-item[data-hoverable-variant="opacity-on-hover"]:not(
[data-hoverable-local]
) {
opacity: 1;
}
/* Local focus — item (or a focusable descendant) receives keyboard focus */
/* Focus — item (or a focusable descendant) receives keyboard focus */
.hoverable-item[data-hoverable-variant="opacity-on-hover"]:has(:focus-visible) {
opacity: 1;
}

6
web/package-lock.json generated
View File

@@ -18122,9 +18122,9 @@
}
},
"node_modules/vite": {
"version": "6.4.2",
"resolved": "https://registry.npmjs.org/vite/-/vite-6.4.2.tgz",
"integrity": "sha512-2N/55r4JDJ4gdrCvGgINMy+HH3iRpNIz8K6SFwVsA+JbQScLiC+clmAxBgwiSPgcG9U15QmvqCGWzMbqda5zGQ==",
"version": "6.4.1",
"resolved": "https://registry.npmjs.org/vite/-/vite-6.4.1.tgz",
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"dev": true,
"license": "MIT",
"peer": true,

View File

@@ -1 +1 @@
export { default } from "@/refresh-pages/admin/LLMProviderConfigurationPage";
export { default } from "@/refresh-pages/admin/LLMConfigurationPage";

View File

@@ -32,10 +32,8 @@ import {
OpenRouterFetchParams,
LiteLLMProxyFetchParams,
BifrostFetchParams,
OpenAICompatibleFetchParams,
OpenAICompatibleModelResponse,
} from "@/interfaces/llm";
import { SvgAws, SvgBifrost, SvgOpenrouter, SvgPlug } from "@opal/icons";
import { SvgAws, SvgBifrost, SvgOpenrouter } from "@opal/icons";
// Aggregator providers that host models from multiple vendors
export const AGGREGATOR_PROVIDERS = new Set([
@@ -46,7 +44,6 @@ export const AGGREGATOR_PROVIDERS = new Set([
"lm_studio",
"litellm_proxy",
"bifrost",
"openai_compatible",
"vertex_ai",
]);
@@ -85,7 +82,6 @@ export const getProviderIcon = (
openrouter: SvgOpenrouter,
litellm_proxy: LiteLLMIcon,
bifrost: SvgBifrost,
openai_compatible: SvgPlug,
vertex_ai: GeminiIcon,
};
@@ -415,64 +411,6 @@ export const fetchBifrostModels = async (
}
};
/**
* Fetches models from a generic OpenAI-compatible server.
* Uses snake_case params to match API structure.
*/
export const fetchOpenAICompatibleModels = async (
params: OpenAICompatibleFetchParams
): Promise<{ models: ModelConfiguration[]; error?: string }> => {
const apiBase = params.api_base;
if (!apiBase) {
return { models: [], error: "API Base is required" };
}
try {
const response = await fetch(
"/api/admin/llm/openai-compatible/available-models",
{
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
api_base: apiBase,
api_key: params.api_key,
provider_name: params.provider_name,
}),
signal: params.signal,
}
);
if (!response.ok) {
let errorMessage = "Failed to fetch models";
try {
const errorData = await response.json();
errorMessage = errorData.detail || errorData.message || errorMessage;
} catch {
// ignore JSON parsing errors
}
return { models: [], error: errorMessage };
}
const data: OpenAICompatibleModelResponse[] = await response.json();
const models: ModelConfiguration[] = data.map((modelData) => ({
name: modelData.name,
display_name: modelData.display_name,
is_visible: true,
max_input_tokens: modelData.max_input_tokens,
supports_image_input: modelData.supports_image_input,
supports_reasoning: modelData.supports_reasoning,
}));
return { models };
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : "Unknown error";
return { models: [], error: errorMessage };
}
};
/**
* Fetches LiteLLM Proxy models directly without any form state dependencies.
* Uses snake_case params to match API structure.
@@ -593,13 +531,6 @@ export const fetchModels = async (
provider_name: formValues.name,
signal,
});
case LLMProviderName.OPENAI_COMPATIBLE:
return fetchOpenAICompatibleModels({
api_base: formValues.api_base,
api_key: formValues.api_key,
provider_name: formValues.name,
signal,
});
default:
return { models: [], error: `Unknown provider: ${providerName}` };
}
@@ -614,7 +545,6 @@ export function canProviderFetchModels(providerName?: string) {
case LLMProviderName.OPENROUTER:
case LLMProviderName.LITELLM_PROXY:
case LLMProviderName.BIFROST:
case LLMProviderName.OPENAI_COMPATIBLE:
return true;
default:
return false;

View File

@@ -64,50 +64,50 @@ export default function CreateRateLimitModal({
title="Create a Token Rate Limit"
onClose={() => setIsOpen(false)}
/>
<Formik
initialValues={{
enabled: true,
period_hours: "",
token_budget: "",
target_scope: forSpecificScope || Scope.GLOBAL,
user_group_id: forSpecificUserGroup,
}}
validationSchema={Yup.object().shape({
period_hours: Yup.number()
.required("Time Window is a required field")
.min(1, "Time Window must be at least 1 hour"),
token_budget: Yup.number()
.required("Token Budget is a required field")
.min(1, "Token Budget must be at least 1"),
target_scope: Yup.string().required(
"Target Scope is a required field"
),
user_group_id: Yup.string().test(
"user_group_id",
"User Group is a required field",
(value, context) => {
return (
context.parent.target_scope !== "user_group" ||
(context.parent.target_scope === "user_group" &&
value !== undefined)
);
}
),
})}
onSubmit={async (values, formikHelpers) => {
formikHelpers.setSubmitting(true);
onSubmit(
values.target_scope,
Number(values.period_hours),
Number(values.token_budget),
Number(values.user_group_id)
);
return formikHelpers.setSubmitting(false);
}}
>
{({ isSubmitting, values, setFieldValue }) => (
<Form className="flex flex-col h-full min-h-0 overflow-visible">
<Modal.Body>
<Modal.Body>
<Formik
initialValues={{
enabled: true,
period_hours: "",
token_budget: "",
target_scope: forSpecificScope || Scope.GLOBAL,
user_group_id: forSpecificUserGroup,
}}
validationSchema={Yup.object().shape({
period_hours: Yup.number()
.required("Time Window is a required field")
.min(1, "Time Window must be at least 1 hour"),
token_budget: Yup.number()
.required("Token Budget is a required field")
.min(1, "Token Budget must be at least 1"),
target_scope: Yup.string().required(
"Target Scope is a required field"
),
user_group_id: Yup.string().test(
"user_group_id",
"User Group is a required field",
(value, context) => {
return (
context.parent.target_scope !== "user_group" ||
(context.parent.target_scope === "user_group" &&
value !== undefined)
);
}
),
})}
onSubmit={async (values, formikHelpers) => {
formikHelpers.setSubmitting(true);
onSubmit(
values.target_scope,
Number(values.period_hours),
Number(values.token_budget),
Number(values.user_group_id)
);
return formikHelpers.setSubmitting(false);
}}
>
{({ isSubmitting, values, setFieldValue }) => (
<Form className="overflow-visible px-2">
{!forSpecificScope && (
<SelectorFormField
name="target_scope"
@@ -147,15 +147,13 @@ export default function CreateRateLimitModal({
type="number"
placeholder=""
/>
</Modal.Body>
<Modal.Footer>
<Button disabled={isSubmitting} type="submit">
Create
</Button>
</Modal.Footer>
</Form>
)}
</Formik>
</Form>
)}
</Formik>
</Modal.Body>
</Modal.Content>
</Modal>
);

View File

@@ -1,126 +0,0 @@
"use client";
import { useCallback } from "react";
import { Button } from "@opal/components";
import { Text } from "@opal/components";
import { ContentAction } from "@opal/layouts";
import { SvgEyeOff, SvgX } from "@opal/icons";
import { getProviderIcon } from "@/app/admin/configuration/llm/utils";
import AgentMessage, {
AgentMessageProps,
} from "@/app/app/message/messageComponents/AgentMessage";
import { cn } from "@/lib/utils";
import { markdown } from "@opal/utils";
export interface MultiModelPanelProps {
/** Provider name for icon lookup */
provider: string;
/** Model name for icon lookup and display */
modelName: string;
/** Display-friendly model name */
displayName: string;
/** Whether this panel is the preferred/selected response */
isPreferred: boolean;
/** Whether this panel is currently hidden */
isHidden: boolean;
/** Whether this is a non-preferred panel in selection mode (pushed off-screen) */
isNonPreferredInSelection: boolean;
/** Callback when user clicks this panel to select as preferred */
onSelect: () => void;
/** Callback to hide/show this panel */
onToggleVisibility: () => void;
/** Props to pass through to AgentMessage */
agentMessageProps: AgentMessageProps;
}
/**
* A single model's response panel within the multi-model view.
*
* Renders in two states:
* - **Hidden** — compact header strip only (provider icon + strikethrough name + show button).
* - **Visible** — full header plus `AgentMessage` body. Clicking anywhere on a
* visible non-preferred panel marks it as preferred.
*
* The `isNonPreferredInSelection` flag disables pointer events on the body and
* hides the footer so the panel acts as a passive comparison surface.
*/
export default function MultiModelPanel({
provider,
modelName,
displayName,
isPreferred,
isHidden,
isNonPreferredInSelection,
onSelect,
onToggleVisibility,
agentMessageProps,
}: MultiModelPanelProps) {
const ProviderIcon = getProviderIcon(provider, modelName);
const handlePanelClick = useCallback(() => {
if (!isHidden && !isPreferred) onSelect();
}, [isHidden, isPreferred, onSelect]);
const header = (
<div
className={cn(
"rounded-12",
isPreferred ? "bg-background-tint-02" : "bg-background-tint-00"
)}
>
<ContentAction
sizePreset="main-ui"
variant="body"
paddingVariant="lg"
icon={ProviderIcon}
title={isHidden ? markdown(`~~${displayName}~~`) : displayName}
rightChildren={
<div className="flex items-center gap-1 px-2">
{isPreferred && (
<span className="text-action-link-05 shrink-0">
<Text font="secondary-body" color="inherit" nowrap>
Preferred Response
</Text>
</span>
)}
{!isPreferred && (
<Button
prominence="tertiary"
icon={isHidden ? SvgEyeOff : SvgX}
size="md"
onClick={(e) => {
e.stopPropagation();
onToggleVisibility();
}}
tooltip={isHidden ? "Show response" : "Hide response"}
/>
)}
</div>
}
/>
</div>
);
// Hidden/collapsed panel — just the header row
if (isHidden) {
return header;
}
return (
<div
className={cn(
"flex flex-col gap-3 min-w-0 rounded-16 transition-colors",
!isPreferred && "cursor-pointer hover:bg-background-tint-02"
)}
onClick={handlePanelClick}
>
{header}
<div className={cn(isNonPreferredInSelection && "pointer-events-none")}>
<AgentMessage
{...agentMessageProps}
hideFooter={isNonPreferredInSelection}
/>
</div>
</div>
);
}

View File

@@ -1,372 +0,0 @@
"use client";
import { useState, useCallback, useMemo, useEffect, useRef } from "react";
import { FullChatState } from "@/app/app/message/messageComponents/interfaces";
import { Message } from "@/app/app/interfaces";
import { LlmManager } from "@/lib/hooks";
import { RegenerationFactory } from "@/app/app/message/messageComponents/AgentMessage";
import MultiModelPanel from "@/app/app/message/MultiModelPanel";
import { MultiModelResponse } from "@/app/app/message/interfaces";
import { cn } from "@/lib/utils";
export interface MultiModelResponseViewProps {
responses: MultiModelResponse[];
chatState: FullChatState;
llmManager: LlmManager | null;
onRegenerate?: RegenerationFactory;
parentMessage?: Message | null;
otherMessagesCanSwitchTo?: number[];
onMessageSelection?: (nodeId: number) => void;
/** Called whenever the set of hidden panel indices changes */
onHiddenPanelsChange?: (hidden: Set<number>) => void;
}
// How many pixels of a non-preferred panel are visible at the viewport edge
const PEEK_W = 64;
// Uniform panel width used in the selection-mode carousel
const SELECTION_PANEL_W = 400;
// Compact width for hidden panels in the carousel track
const HIDDEN_PANEL_W = 220;
// Generation-mode panel widths (from Figma)
const GEN_PANEL_W_2 = 640; // 2 panels side-by-side
const GEN_PANEL_W_3 = 436; // 3 panels side-by-side
// Gap between panels — matches CSS gap-6 (24px)
const PANEL_GAP = 24;
// Minimum panel width before horizontal scroll kicks in
const MIN_PANEL_W = 300;
/**
* Renders N model responses side-by-side with two layout modes:
*
* **Generation mode** — equal-width panels in a horizontally-scrollable row.
* Panel width is determined by the number of visible (non-hidden) panels.
*
* **Selection mode** — activated when the user clicks a panel to mark it as
* preferred. All panels (including hidden ones) sit in a fixed-width carousel
* track. A CSS `translateX` transform slides the track so the preferred panel
* is centered in the viewport; the other panels peek in from the edges through
* a mask gradient. Non-preferred visible panels are height-capped to the
* preferred panel's measured height, dimmed at 50% opacity, and receive a
* bottom fade-out overlay.
*
* Hidden panels render as a compact header-only strip at `HIDDEN_PANEL_W` in
* both modes and are excluded from layout width calculations.
*/
export default function MultiModelResponseView({
responses,
chatState,
llmManager,
onRegenerate,
parentMessage,
otherMessagesCanSwitchTo,
onMessageSelection,
onHiddenPanelsChange,
}: MultiModelResponseViewProps) {
const [preferredIndex, setPreferredIndex] = useState<number | null>(null);
const [hiddenPanels, setHiddenPanels] = useState<Set<number>>(new Set());
// Controls animation: false = panels at start position, true = panels at peek position
const [selectionEntered, setSelectionEntered] = useState(false);
// Measures the overflow-hidden carousel container for responsive preferred-panel sizing.
const [trackContainerW, setTrackContainerW] = useState(0);
const roRef = useRef<ResizeObserver | null>(null);
const trackContainerRef = useCallback((el: HTMLDivElement | null) => {
if (roRef.current) {
roRef.current.disconnect();
roRef.current = null;
}
if (!el) return;
const ro = new ResizeObserver(([entry]) => {
setTrackContainerW(entry?.contentRect.width ?? 0);
});
ro.observe(el);
setTrackContainerW(el.offsetWidth);
roRef.current = ro;
}, []);
// Measures the preferred panel's height to cap non-preferred panels in selection mode.
const [preferredPanelHeight, setPreferredPanelHeight] = useState<
number | null
>(null);
const preferredRoRef = useRef<ResizeObserver | null>(null);
// Tracks which non-preferred panels overflow the preferred height cap
const [overflowingPanels, setOverflowingPanels] = useState<Set<number>>(
new Set()
);
const preferredPanelRef = useCallback((el: HTMLDivElement | null) => {
if (preferredRoRef.current) {
preferredRoRef.current.disconnect();
preferredRoRef.current = null;
}
if (!el) {
setPreferredPanelHeight(null);
return;
}
const ro = new ResizeObserver(([entry]) => {
setPreferredPanelHeight(entry?.contentRect.height ?? 0);
});
ro.observe(el);
setPreferredPanelHeight(el.offsetHeight);
preferredRoRef.current = ro;
}, []);
const isGenerating = useMemo(
() => responses.some((r) => r.isGenerating),
[responses]
);
// Non-hidden responses — used for layout width decisions and selection-mode gating
const visibleResponses = useMemo(
() => responses.filter((r) => !hiddenPanels.has(r.modelIndex)),
[responses, hiddenPanels]
);
const toggleVisibility = useCallback(
(modelIndex: number) => {
setHiddenPanels((prev) => {
const next = new Set(prev);
if (next.has(modelIndex)) {
next.delete(modelIndex);
} else {
// Don't hide the last visible panel
const visibleCount = responses.length - next.size;
if (visibleCount <= 1) return prev;
next.add(modelIndex);
}
onHiddenPanelsChange?.(next);
return next;
});
},
[responses.length, onHiddenPanelsChange]
);
const handleSelectPreferred = useCallback(
(modelIndex: number) => {
if (isGenerating) return;
setPreferredIndex(modelIndex);
const response = responses.find((r) => r.modelIndex === modelIndex);
if (!response) return;
if (onMessageSelection) {
onMessageSelection(response.nodeId);
}
},
[isGenerating, responses, onMessageSelection]
);
// Clear preferred selection when generation starts
useEffect(() => {
if (isGenerating) {
setPreferredIndex(null);
}
}, [isGenerating]);
// Find preferred panel position — used for both the selection guard and carousel layout
const preferredIdx = responses.findIndex(
(r) => r.modelIndex === preferredIndex
);
// Selection mode when preferred is set, found in responses, not generating, and at least 2 visible panels
const showSelectionMode =
preferredIndex !== null &&
preferredIdx !== -1 &&
!isGenerating &&
visibleResponses.length > 1;
// Trigger the slide-out animation one frame after entering selection mode
useEffect(() => {
if (!showSelectionMode) {
setSelectionEntered(false);
return;
}
const raf = requestAnimationFrame(() => setSelectionEntered(true));
return () => cancelAnimationFrame(raf);
}, [showSelectionMode]);
// Build panel props — isHidden reflects actual hidden state
const buildPanelProps = useCallback(
(response: MultiModelResponse, isNonPreferred: boolean) => ({
provider: response.provider,
modelName: response.modelName,
displayName: response.displayName,
isPreferred: preferredIndex === response.modelIndex,
isHidden: hiddenPanels.has(response.modelIndex),
isNonPreferredInSelection: isNonPreferred,
onSelect: () => handleSelectPreferred(response.modelIndex),
onToggleVisibility: () => toggleVisibility(response.modelIndex),
agentMessageProps: {
rawPackets: response.packets,
packetCount: response.packetCount,
chatState,
nodeId: response.nodeId,
messageId: response.messageId,
currentFeedback: response.currentFeedback,
llmManager,
otherMessagesCanSwitchTo,
onMessageSelection,
onRegenerate,
parentMessage,
},
}),
[
preferredIndex,
hiddenPanels,
handleSelectPreferred,
toggleVisibility,
chatState,
llmManager,
otherMessagesCanSwitchTo,
onMessageSelection,
onRegenerate,
parentMessage,
]
);
if (showSelectionMode) {
// ── Selection Layout (transform-based carousel) ──
//
// All panels (including hidden) sit in the track at their original A/B/C positions.
// Hidden panels use HIDDEN_PANEL_W; non-preferred use SELECTION_PANEL_W;
// preferred uses dynamicPrefW (up to GEN_PANEL_W_2).
const n = responses.length;
const dynamicPrefW =
trackContainerW > 0
? Math.min(trackContainerW - 2 * (PEEK_W + PANEL_GAP), GEN_PANEL_W_2)
: GEN_PANEL_W_2;
const selectionWidths = responses.map((r, i) => {
if (hiddenPanels.has(r.modelIndex)) return HIDDEN_PANEL_W;
if (i === preferredIdx) return dynamicPrefW;
return SELECTION_PANEL_W;
});
const panelLeftEdges = selectionWidths.reduce<number[]>((acc, w, i) => {
acc.push(i === 0 ? 0 : acc[i - 1]! + selectionWidths[i - 1]! + PANEL_GAP);
return acc;
}, []);
const preferredCenterInTrack =
panelLeftEdges[preferredIdx]! + selectionWidths[preferredIdx]! / 2;
// Start position: hidden panels at HIDDEN_PANEL_W, visible at SELECTION_PANEL_W
const uniformTrackW =
responses.reduce(
(sum, r) =>
sum +
(hiddenPanels.has(r.modelIndex) ? HIDDEN_PANEL_W : SELECTION_PANEL_W),
0
) +
(n - 1) * PANEL_GAP;
const trackTransform = selectionEntered
? `translateX(${trackContainerW / 2 - preferredCenterInTrack}px)`
: `translateX(${(trackContainerW - uniformTrackW) / 2}px)`;
return (
<div
ref={trackContainerRef}
className="w-full overflow-hidden"
style={{
maskImage: `linear-gradient(to right, transparent 0px, black ${PEEK_W}px, black calc(100% - ${PEEK_W}px), transparent 100%)`,
WebkitMaskImage: `linear-gradient(to right, transparent 0px, black ${PEEK_W}px, black calc(100% - ${PEEK_W}px), transparent 100%)`,
}}
>
<div
className="flex items-start"
style={{
gap: `${PANEL_GAP}px`,
transition: selectionEntered
? "transform 0.45s cubic-bezier(0.2, 0, 0, 1)"
: "none",
transform: trackTransform,
}}
>
{responses.map((r, i) => {
const isHidden = hiddenPanels.has(r.modelIndex);
const isPref = r.modelIndex === preferredIndex;
const isNonPref = !isHidden && !isPref;
const finalW = selectionWidths[i]!;
const startW = isHidden ? HIDDEN_PANEL_W : SELECTION_PANEL_W;
const capped = isNonPref && preferredPanelHeight != null;
const overflows = capped && overflowingPanels.has(r.modelIndex);
return (
<div
key={r.modelIndex}
ref={(el) => {
if (isPref) preferredPanelRef(el);
if (capped && el) {
const doesOverflow = el.scrollHeight > el.clientHeight;
setOverflowingPanels((prev) => {
const had = prev.has(r.modelIndex);
if (doesOverflow === had) return prev;
const next = new Set(prev);
if (doesOverflow) next.add(r.modelIndex);
else next.delete(r.modelIndex);
return next;
});
}
}}
style={{
width: `${selectionEntered ? finalW : startW}px`,
flexShrink: 0,
transition: selectionEntered
? "width 0.45s cubic-bezier(0.2, 0, 0, 1)"
: "none",
maxHeight: capped ? preferredPanelHeight : undefined,
overflow: capped ? "hidden" : undefined,
position: capped ? "relative" : undefined,
}}
>
<div className={cn(isNonPref && "opacity-50")}>
<MultiModelPanel {...buildPanelProps(r, isNonPref)} />
</div>
{overflows && (
<div
className="absolute inset-x-0 bottom-0 h-24 pointer-events-none"
style={{
background:
"linear-gradient(to top, var(--background-tint-01) 0%, transparent 100%)",
}}
/>
)}
</div>
);
})}
</div>
</div>
);
}
// ── Generation Layout (equal panels side-by-side) ──
// Panel width based on number of visible (non-hidden) panels.
const panelWidth =
visibleResponses.length <= 2 ? GEN_PANEL_W_2 : GEN_PANEL_W_3;
return (
<div className="overflow-x-auto">
<div className="flex gap-6 items-start w-fit mx-auto">
{responses.map((r) => {
const isHidden = hiddenPanels.has(r.modelIndex);
return (
<div
key={r.modelIndex}
style={
isHidden
? {
width: HIDDEN_PANEL_W,
minWidth: HIDDEN_PANEL_W,
maxWidth: HIDDEN_PANEL_W,
flexShrink: 0,
overflow: "hidden" as const,
}
: { width: panelWidth, minWidth: MIN_PANEL_W }
}
>
<MultiModelPanel {...buildPanelProps(r, false)} />
</div>
);
})}
</div>
</div>
);
}

View File

@@ -1,16 +0,0 @@
import { Packet } from "@/app/app/services/streamingModels";
import { FeedbackType } from "@/app/app/interfaces";
export interface MultiModelResponse {
modelIndex: number;
provider: string;
modelName: string;
displayName: string;
packets: Packet[];
packetCount: number;
nodeId: number;
messageId?: number;
currentFeedback?: FeedbackType | null;
isGenerating?: boolean;
}

View File

@@ -49,8 +49,6 @@ export interface AgentMessageProps {
parentMessage?: Message | null;
// Duration in seconds for processing this message (agent messages only)
processingDurationSeconds?: number;
/** Hide the feedback/toolbar footer (used in multi-model non-preferred panels) */
hideFooter?: boolean;
}
// TODO: Consider more robust comparisons:
@@ -78,8 +76,7 @@ function arePropsEqual(
prev.parentMessage?.messageId === next.parentMessage?.messageId &&
prev.llmManager?.isLoadingProviders ===
next.llmManager?.isLoadingProviders &&
prev.processingDurationSeconds === next.processingDurationSeconds &&
prev.hideFooter === next.hideFooter
prev.processingDurationSeconds === next.processingDurationSeconds
// Skip: chatState.regenerate, chatState.setPresentingDocument,
// most of llmManager, onMessageSelection (function/object props)
);
@@ -98,7 +95,6 @@ const AgentMessage = React.memo(function AgentMessage({
onRegenerate,
parentMessage,
processingDurationSeconds,
hideFooter,
}: AgentMessageProps) {
const markdownRef = useRef<HTMLDivElement>(null);
const finalAnswerRef = useRef<HTMLDivElement>(null);
@@ -330,7 +326,7 @@ const AgentMessage = React.memo(function AgentMessage({
</div>
{/* Feedback buttons - only show when streaming and rendering complete */}
{isComplete && !hideFooter && (
{isComplete && (
<MessageToolbar
nodeId={nodeId}
messageId={messageId}

View File

@@ -24,6 +24,7 @@ import {
} from "@/app/craft/onboarding/constants";
import { LLMProviderDescriptor } from "@/interfaces/llm";
import { LLM_PROVIDERS_ADMIN_URL } from "@/lib/llmConfig/constants";
import { buildOnboardingInitialValues as buildInitialValues } from "@/sections/modals/llmConfig/utils";
import { testApiKeyHelper } from "@/sections/modals/llmConfig/svc";
import OnboardingInfoPages from "@/app/craft/onboarding/components/OnboardingInfoPages";
import OnboardingUserInfo from "@/app/craft/onboarding/components/OnboardingUserInfo";
@@ -220,8 +221,10 @@ export default function BuildOnboardingModal({
setConnectionStatus("testing");
setErrorMessage("");
const baseValues = buildInitialValues();
const providerName = `build-mode-${currentProviderConfig.providerName}`;
const payload = {
...baseValues,
name: providerName,
provider: currentProviderConfig.providerName,
api_key: apiKey,

View File

@@ -133,7 +133,7 @@ async function createFederatedConnector(
async function updateFederatedConnector(
id: number,
credentials: CredentialForm | null,
credentials: CredentialForm,
config?: ConfigForm
): Promise<{ success: boolean; message: string }> {
try {
@@ -143,7 +143,7 @@ async function updateFederatedConnector(
"Content-Type": "application/json",
},
body: JSON.stringify({
credentials: credentials ?? undefined,
credentials,
config: config || {},
}),
});
@@ -201,9 +201,7 @@ export function FederatedConnectorForm({
const isEditMode = connectorId !== undefined;
const [formState, setFormState] = useState<FormState>({
// In edit mode, don't populate credentials with masked values from the API.
// Masked values (e.g. "••••••••••••") would be saved back and corrupt the real credentials.
credentials: isEditMode ? {} : preloadedConnectorData?.credentials || {},
credentials: preloadedConnectorData?.credentials || {},
config: preloadedConnectorData?.config || {},
schema: preloadedCredentialSchema?.credentials || null,
configurationSchema: null,
@@ -211,7 +209,6 @@ export function FederatedConnectorForm({
configurationSchemaError: null,
connectorError: null,
});
const [credentialsModified, setCredentialsModified] = useState(false);
const [isSubmitting, setIsSubmitting] = useState(false);
const [submitMessage, setSubmitMessage] = useState<string | null>(null);
const [submitSuccess, setSubmitSuccess] = useState<boolean | null>(null);
@@ -336,7 +333,6 @@ export function FederatedConnectorForm({
}
const handleCredentialChange = (key: string, value: string) => {
setCredentialsModified(true);
setFormState((prev) => ({
...prev,
credentials: {
@@ -358,11 +354,6 @@ export function FederatedConnectorForm({
const handleValidateCredentials = async () => {
if (!formState.schema) return;
if (isEditMode && !credentialsModified) {
setSubmitMessage("Enter new credential values before validating.");
setSubmitSuccess(false);
return;
}
setIsValidating(true);
setSubmitMessage(null);
@@ -420,10 +411,8 @@ export function FederatedConnectorForm({
setSubmitSuccess(null);
try {
const shouldValidateCredentials = !isEditMode || credentialsModified;
// Validate required fields (skip for credentials in edit mode when unchanged)
if (formState.schema && shouldValidateCredentials) {
// Validate required fields
if (formState.schema) {
const missingRequired = Object.entries(formState.schema)
.filter(
([key, field]) => field.required && !formState.credentials[key]
@@ -453,20 +442,16 @@ export function FederatedConnectorForm({
}
setConfigValidationErrors({});
// Validate credentials before creating/updating (skip in edit mode when unchanged)
if (shouldValidateCredentials) {
const validation = await validateCredentials(
connector,
formState.credentials
);
if (!validation.success) {
setSubmitMessage(
`Credential validation failed: ${validation.message}`
);
setSubmitSuccess(false);
setIsSubmitting(false);
return;
}
// Validate credentials before creating/updating
const validation = await validateCredentials(
connector,
formState.credentials
);
if (!validation.success) {
setSubmitMessage(`Credential validation failed: ${validation.message}`);
setSubmitSuccess(false);
setIsSubmitting(false);
return;
}
// Create or update the connector
@@ -474,7 +459,7 @@ export function FederatedConnectorForm({
isEditMode && connectorId
? await updateFederatedConnector(
connectorId,
credentialsModified ? formState.credentials : null,
formState.credentials,
formState.config
)
: await createFederatedConnector(
@@ -553,16 +538,14 @@ export function FederatedConnectorForm({
id={fieldKey}
type={fieldSpec.secret ? "password" : "text"}
placeholder={
isEditMode && !credentialsModified
? "•••••••• (leave blank to keep current value)"
: fieldSpec.example
? String(fieldSpec.example)
: fieldSpec.description
fieldSpec.example
? String(fieldSpec.example)
: fieldSpec.description
}
value={formState.credentials[fieldKey] || ""}
onChange={(e) => handleCredentialChange(fieldKey, e.target.value)}
className="w-96"
required={!isEditMode && fieldSpec.required}
required={fieldSpec.required}
/>
</div>
))}

View File

@@ -1,10 +1,25 @@
"use client";
import { LLMProviderDescriptor } from "@/interfaces/llm";
import React, { createContext, useContext, useCallback } from "react";
import {
WellKnownLLMProviderDescriptor,
LLMProviderDescriptor,
} from "@/interfaces/llm";
import React, {
createContext,
useContext,
useState,
useEffect,
useCallback,
} from "react";
import { useUser } from "@/providers/UserProvider";
import { useLLMProviders } from "@/hooks/useLLMProviders";
import { useLLMProviderOptions } from "@/lib/hooks/useLLMProviderOptions";
import { testDefaultProvider as testDefaultProviderSvc } from "@/lib/llmConfig/svc";
interface ProviderContextType {
shouldShowConfigurationNeeded: boolean;
providerOptions: WellKnownLLMProviderDescriptor[];
refreshProviderInfo: () => Promise<void>;
// Expose configured provider instances for components that need it (e.g., onboarding)
llmProviders: LLMProviderDescriptor[] | undefined;
isLoadingProviders: boolean;
hasProviders: boolean;
@@ -14,26 +29,79 @@ const ProviderContext = createContext<ProviderContextType | undefined>(
undefined
);
const DEFAULT_LLM_PROVIDER_TEST_COMPLETE_KEY = "defaultLlmProviderTestComplete";
function checkDefaultLLMProviderTestComplete() {
if (typeof window === "undefined") return true;
return (
localStorage.getItem(DEFAULT_LLM_PROVIDER_TEST_COMPLETE_KEY) === "true"
);
}
function setDefaultLLMProviderTestComplete() {
if (typeof window === "undefined") return;
localStorage.setItem(DEFAULT_LLM_PROVIDER_TEST_COMPLETE_KEY, "true");
}
export function ProviderContextProvider({
children,
}: {
children: React.ReactNode;
}) {
const { user } = useUser();
// Use SWR hooks instead of raw fetch
const {
llmProviders,
isLoading: isLoadingProviders,
refetch: refetchProviders,
} = useLLMProviders();
const { llmProviderOptions: providerOptions, refetch: refetchOptions } =
useLLMProviderOptions();
const [defaultCheckSuccessful, setDefaultCheckSuccessful] =
useState<boolean>(true);
// Test the default provider - only runs if test hasn't passed yet
const testDefaultProvider = useCallback(async () => {
const shouldCheck =
!checkDefaultLLMProviderTestComplete() &&
(!user || user.role === "admin");
if (shouldCheck) {
const success = await testDefaultProviderSvc();
setDefaultCheckSuccessful(success);
if (success) {
setDefaultLLMProviderTestComplete();
}
}
}, [user]);
// Test default provider on mount
useEffect(() => {
testDefaultProvider();
}, [testDefaultProvider]);
const hasProviders = (llmProviders?.length ?? 0) > 0;
const validProviderExists = hasProviders && defaultCheckSuccessful;
const shouldShowConfigurationNeeded =
!validProviderExists && (providerOptions?.length ?? 0) > 0;
const refreshProviderInfo = useCallback(async () => {
await refetchProviders();
}, [refetchProviders]);
// Refetch provider lists and re-test default provider if needed
await Promise.all([
refetchProviders(),
refetchOptions(),
testDefaultProvider(),
]);
}, [refetchProviders, refetchOptions, testDefaultProvider]);
return (
<ProviderContext.Provider
value={{
shouldShowConfigurationNeeded,
providerOptions: providerOptions ?? [],
refreshProviderInfo,
llmProviders,
isLoadingProviders,

View File

@@ -17,6 +17,7 @@ const mockProviderStatus = {
llmProviders: [] as unknown[],
isLoadingProviders: false,
hasProviders: false,
providerOptions: [],
refreshProviderInfo: jest.fn(),
};
@@ -70,6 +71,7 @@ describe("useShowOnboarding", () => {
mockProviderStatus.llmProviders = [];
mockProviderStatus.isLoadingProviders = false;
mockProviderStatus.hasProviders = false;
mockProviderStatus.providerOptions = [];
});
it("returns showOnboarding=false while providers are loading", () => {
@@ -196,6 +198,7 @@ describe("useShowOnboarding", () => {
OnboardingStep.Welcome
);
expect(result.current.onboardingActions).toBeDefined();
expect(result.current.llmDescriptors).toEqual([]);
});
describe("localStorage persistence", () => {

View File

@@ -5,7 +5,6 @@ import { errorHandlingFetcher } from "@/lib/fetcher";
import { SWR_KEYS } from "@/lib/swr-keys";
import {
LLMProviderDescriptor,
LLMProviderName,
LLMProviderResponse,
LLMProviderView,
WellKnownLLMProviderDescriptor,
@@ -139,12 +138,12 @@ export function useAdminLLMProviders() {
* Used inside individual provider modals to pre-populate model lists
* before the user has entered credentials.
*
* @param providerName - The provider's API endpoint name (e.g. "openai", "anthropic").
* @param providerEndpoint - The provider's API endpoint name (e.g. "openai", "anthropic").
* Pass `null` to suppress the request.
*/
export function useWellKnownLLMProvider(providerName: LLMProviderName) {
export function useWellKnownLLMProvider(providerEndpoint: string | null) {
const { data, error, isLoading } = useSWR<WellKnownLLMProviderDescriptor>(
providerName ? SWR_KEYS.wellKnownLlmProvider(providerName) : null,
providerEndpoint ? SWR_KEYS.wellKnownLlmProvider(providerEndpoint) : null,
errorHandlingFetcher,
{
revalidateOnFocus: false,

View File

@@ -1,192 +0,0 @@
"use client";
import { useState, useCallback, useEffect, useMemo } from "react";
import {
MAX_MODELS,
SelectedModel,
} from "@/refresh-components/popovers/ModelSelector";
import { LLMOverride } from "@/app/app/services/lib";
import { LlmManager } from "@/lib/hooks";
import { buildLlmOptions } from "@/refresh-components/popovers/LLMPopover";
export interface UseMultiModelChatReturn {
/** Currently selected models for multi-model comparison. */
selectedModels: SelectedModel[];
/** Whether multi-model mode is active (>1 model selected). */
isMultiModelActive: boolean;
/** Add a model to the selection. */
addModel: (model: SelectedModel) => void;
/** Remove a model by index. */
removeModel: (index: number) => void;
/** Replace a model at a specific index with a new one. */
replaceModel: (index: number, model: SelectedModel) => void;
/** Clear all selected models. */
clearModels: () => void;
/** Build the LLMOverride[] array from selectedModels. */
buildLlmOverrides: () => LLMOverride[];
/**
* Restore multi-model selection from model version strings (e.g. from chat history).
* Matches against available llmOptions to reconstruct full SelectedModel objects.
*/
restoreFromModelNames: (modelNames: string[]) => void;
/**
* Switch to a single model by name (after user picks a preferred response).
* Matches against llmOptions to find the full SelectedModel.
*/
selectSingleModel: (modelName: string) => void;
}
export default function useMultiModelChat(
llmManager: LlmManager
): UseMultiModelChatReturn {
const [selectedModels, setSelectedModels] = useState<SelectedModel[]>([]);
const [defaultInitialized, setDefaultInitialized] = useState(false);
// Initialize with the default model from llmManager once providers load
const llmOptions = useMemo(
() =>
llmManager.llmProviders ? buildLlmOptions(llmManager.llmProviders) : [],
[llmManager.llmProviders]
);
useEffect(() => {
if (defaultInitialized) return;
if (llmOptions.length === 0) return;
const { currentLlm } = llmManager;
// Don't initialize if currentLlm hasn't loaded yet
if (!currentLlm.modelName) return;
const match = llmOptions.find(
(opt) =>
opt.provider === currentLlm.provider &&
opt.modelName === currentLlm.modelName
);
if (match) {
setSelectedModels([
{
name: match.name,
provider: match.provider,
modelName: match.modelName,
displayName: match.displayName,
},
]);
setDefaultInitialized(true);
}
}, [llmOptions, llmManager.currentLlm, defaultInitialized]);
const isMultiModelActive = selectedModels.length > 1;
const addModel = useCallback((model: SelectedModel) => {
setSelectedModels((prev) => {
if (prev.length >= MAX_MODELS) return prev;
if (
prev.some(
(m) =>
m.provider === model.provider && m.modelName === model.modelName
)
) {
return prev;
}
return [...prev, model];
});
}, []);
const removeModel = useCallback((index: number) => {
setSelectedModels((prev) => prev.filter((_, i) => i !== index));
}, []);
const replaceModel = useCallback((index: number, model: SelectedModel) => {
setSelectedModels((prev) => {
// Don't replace with a model that's already selected elsewhere
if (
prev.some(
(m, i) =>
i !== index &&
m.provider === model.provider &&
m.modelName === model.modelName
)
) {
return prev;
}
const next = [...prev];
next[index] = model;
return next;
});
}, []);
const clearModels = useCallback(() => {
setSelectedModels([]);
}, []);
const restoreFromModelNames = useCallback(
(modelNames: string[]) => {
if (modelNames.length < 2 || llmOptions.length === 0) return;
const restored: SelectedModel[] = [];
for (const name of modelNames) {
// Try matching by modelName (raw version string like "claude-opus-4-6")
// or by displayName (friendly name like "Claude Opus 4.6")
const match = llmOptions.find(
(opt) =>
opt.modelName === name ||
opt.displayName === name ||
opt.name === name
);
if (match) {
restored.push({
name: match.name,
provider: match.provider,
modelName: match.modelName,
displayName: match.displayName,
});
}
}
if (restored.length >= 2) {
setSelectedModels(restored.slice(0, MAX_MODELS));
setDefaultInitialized(true);
}
},
[llmOptions]
);
const selectSingleModel = useCallback(
(modelName: string) => {
if (llmOptions.length === 0) return;
const match = llmOptions.find(
(opt) =>
opt.modelName === modelName ||
opt.displayName === modelName ||
opt.name === modelName
);
if (match) {
setSelectedModels([
{
name: match.name,
provider: match.provider,
modelName: match.modelName,
displayName: match.displayName,
},
]);
}
},
[llmOptions]
);
const buildLlmOverrides = useCallback((): LLMOverride[] => {
return selectedModels.map((m) => ({
model_provider: m.provider,
model_version: m.modelName,
display_name: m.displayName,
}));
}, [selectedModels]);
return {
selectedModels,
isMultiModelActive,
addModel,
removeModel,
replaceModel,
clearModels,
buildLlmOverrides,
restoreFromModelNames,
selectSingleModel,
};
}

View File

@@ -9,6 +9,7 @@ import {
OnboardingState,
OnboardingStep,
} from "@/interfaces/onboarding";
import { WellKnownLLMProviderDescriptor } from "@/interfaces/llm";
import { updateUserPersonalization } from "@/lib/userSettings";
import { useUser } from "@/providers/UserProvider";
import { MinimalPersonaSnapshot } from "@/app/admin/agents/interfaces";
@@ -21,6 +22,7 @@ function getOnboardingCompletedKey(userId: string): string {
function useOnboardingState(liveAgent?: MinimalPersonaSnapshot): {
state: OnboardingState;
llmDescriptors: WellKnownLLMProviderDescriptor[];
actions: OnboardingActions;
isLoading: boolean;
hasProviders: boolean;
@@ -33,6 +35,7 @@ function useOnboardingState(liveAgent?: MinimalPersonaSnapshot): {
llmProviders,
isLoadingProviders,
hasProviders: hasLlmProviders,
providerOptions,
refreshProviderInfo,
} = useProviderStatus();
@@ -40,6 +43,7 @@ function useOnboardingState(liveAgent?: MinimalPersonaSnapshot): {
const { refetch: refreshPersonaProviders } = useLLMProviders(liveAgent?.id);
const userName = user?.personalization?.name;
const llmDescriptors = providerOptions;
const nameUpdateTimeoutRef = useRef<ReturnType<typeof setTimeout> | null>(
null
@@ -231,6 +235,7 @@ function useOnboardingState(liveAgent?: MinimalPersonaSnapshot): {
return {
state,
llmDescriptors,
actions: {
nextStep,
prevStep,
@@ -275,6 +280,7 @@ export function useShowOnboarding({
const {
state: onboardingState,
actions: onboardingActions,
llmDescriptors,
isLoading: isLoadingOnboarding,
hasProviders: hasAnyProvider,
} = useOnboardingState(liveAgent);
@@ -344,6 +350,7 @@ export function useShowOnboarding({
onboardingDismissed,
onboardingState,
onboardingActions,
llmDescriptors,
isLoadingOnboarding,
hideOnboarding,
finishOnboarding,

View File

@@ -15,7 +15,6 @@ export enum LLMProviderName {
LITELLM = "litellm",
LITELLM_PROXY = "litellm_proxy",
BIFROST = "bifrost",
OPENAI_COMPATIBLE = "openai_compatible",
CUSTOM = "custom",
}
@@ -123,12 +122,16 @@ export interface LLMProviderFormProps {
variant?: LLMModalVariant;
existingLlmProvider?: LLMProviderView;
shouldMarkAsDefault?: boolean;
open?: boolean;
onOpenChange?: (open: boolean) => void;
/** Called after successful provider creation/update. */
onSuccess?: () => void | Promise<void>;
/** The current default model name for this provider (from the global default). */
defaultModelName?: string;
// Onboarding-specific (only when variant === "onboarding")
onboardingState?: OnboardingState;
onboardingActions?: OnboardingActions;
llmDescriptor?: WellKnownLLMProviderDescriptor;
}
// Param types for model fetching functions - use snake_case to match API structure
@@ -179,21 +182,6 @@ export interface BifrostModelResponse {
supports_reasoning: boolean;
}
export interface OpenAICompatibleFetchParams {
api_base?: string;
api_key?: string;
provider_name?: string;
signal?: AbortSignal;
}
export interface OpenAICompatibleModelResponse {
name: string;
display_name: string;
max_input_tokens: number | null;
supports_image_input: boolean;
supports_reasoning: boolean;
}
export interface VertexAIFetchParams {
model_configurations?: ModelConfiguration[];
}
@@ -212,6 +200,5 @@ export type FetchModelsParams =
| OpenRouterFetchParams
| LiteLLMProxyFetchParams
| BifrostFetchParams
| OpenAICompatibleFetchParams
| VertexAIFetchParams
| LMStudioFetchParams;

View File

@@ -1,9 +1,8 @@
"use client";
import type { RichStr, WithoutStyles } from "@opal/types";
import type { RichStr } from "@opal/types";
import { resolveStr } from "@opal/components/text/InlineMarkdown";
import Text from "@/refresh-components/texts/Text";
import Separator from "@/refresh-components/Separator";
import { SvgXOctagon, SvgAlertCircle } from "@opal/icons";
import { useField, useFormikContext } from "formik";
import { Section } from "@/layouts/general-layouts";
@@ -230,27 +229,9 @@ function ErrorTextLayout({ children, type = "error" }: ErrorTextLayoutProps) {
);
}
/**
* FieldSeparator - A horizontal rule with inline padding, used to visually separate field groups.
*/
function FieldSeparator() {
return <Separator noPadding className="p-2" />;
}
/**
* FieldPadder - Wraps a field in standard horizontal + vertical padding (`p-2 w-full`).
*/
type FieldPadderProps = WithoutStyles<React.HTMLAttributes<HTMLDivElement>>;
function FieldPadder(props: FieldPadderProps) {
return <div {...props} className="p-2 w-full" />;
}
export {
VerticalInputLayout as Vertical,
HorizontalInputLayout as Horizontal,
ErrorLayout as Error,
ErrorTextLayout,
FieldSeparator,
FieldPadder,
type FieldPadderProps,
};

View File

@@ -7,7 +7,6 @@ import {
SvgOllama,
SvgAws,
SvgOpenrouter,
SvgPlug,
SvgServer,
SvgAzure,
SvgGemini,
@@ -28,7 +27,6 @@ const PROVIDER_ICONS: Record<string, IconFunctionComponent> = {
[LLMProviderName.OPENROUTER]: SvgOpenrouter,
[LLMProviderName.LM_STUDIO]: SvgLmStudio,
[LLMProviderName.BIFROST]: SvgBifrost,
[LLMProviderName.OPENAI_COMPATIBLE]: SvgPlug,
// fallback
[LLMProviderName.CUSTOM]: SvgServer,
@@ -46,7 +44,6 @@ const PROVIDER_PRODUCT_NAMES: Record<string, string> = {
[LLMProviderName.OPENROUTER]: "OpenRouter",
[LLMProviderName.LM_STUDIO]: "LM Studio",
[LLMProviderName.BIFROST]: "Bifrost",
[LLMProviderName.OPENAI_COMPATIBLE]: "OpenAI Compatible",
// fallback
[LLMProviderName.CUSTOM]: "Custom Models",
@@ -64,7 +61,6 @@ const PROVIDER_DISPLAY_NAMES: Record<string, string> = {
[LLMProviderName.OPENROUTER]: "OpenRouter",
[LLMProviderName.LM_STUDIO]: "LM Studio",
[LLMProviderName.BIFROST]: "Bifrost",
[LLMProviderName.OPENAI_COMPATIBLE]: "OpenAI Compatible",
// fallback
[LLMProviderName.CUSTOM]: "Other providers or self-hosted",

View File

@@ -80,11 +80,6 @@ export const SWR_KEYS = {
// ── Users ─────────────────────────────────────────────────────────────────
acceptedUsers: "/api/manage/users/accepted/all",
invitedUsers: "/api/manage/users/invited",
// Curator-accessible listing of all users (and optionally service-account
// entries when `?include_api_keys=true`). Used by group create/edit pages so
// global curators — who cannot hit the admin-only `/accepted/all` and
// `/invited` endpoints — can still load the member picker.
groupMemberCandidates: "/api/manage/users?include_api_keys=true",
pendingTenantUsers: "/api/tenants/users/pending",
userCounts: "/api/manage/users/counts",

View File

@@ -89,6 +89,18 @@ export const KeyWideLayout: Story = {
},
};
export const Disabled: Story = {
render: () => (
<KeyValueInput
keyTitle="Key"
valueTitle="Value"
items={[{ key: "LOCKED", value: "cannot-edit" }]}
onChange={() => {}}
disabled
/>
),
};
export const EmptyLineMode: Story = {
render: function EmptyStory() {
const [items, setItems] = React.useState<KeyValue[]>([]);

View File

@@ -68,13 +68,21 @@
* ```
*/
import React, { useCallback, useEffect, useMemo, useRef } from "react";
import React, {
useCallback,
useContext,
useEffect,
useMemo,
useId,
useRef,
} from "react";
import { cn } from "@/lib/utils";
import InputTypeIn from "./InputTypeIn";
import { Button, EmptyMessageCard } from "@opal/components";
import type { WithoutStyles } from "@opal/types";
import Text from "@/refresh-components/texts/Text";
import { ErrorTextLayout } from "@/layouts/input-layouts";
import { FieldContext } from "../form/FieldContext";
import { FieldMessage } from "../messages/FieldMessage";
import { SvgMinusCircle, SvgPlusCircle } from "@opal/icons";
export type KeyValue = { key: string; value: string };
@@ -99,50 +107,82 @@ const GRID_COLS = {
interface KeyValueInputItemProps {
item: KeyValue;
onChange: (next: KeyValue) => void;
disabled?: boolean;
onRemove: () => void;
keyPlaceholder?: string;
valuePlaceholder?: string;
error?: KeyValueError;
canRemove: boolean;
index: number;
fieldId: string;
}
function KeyValueInputItem({
item,
onChange,
disabled,
onRemove,
keyPlaceholder,
valuePlaceholder,
error,
canRemove,
index,
fieldId,
}: KeyValueInputItemProps) {
return (
<>
<div className="flex flex-col gap-y-0.5">
<InputTypeIn
placeholder={keyPlaceholder}
placeholder={keyPlaceholder || "Key"}
value={item.key}
onChange={(e) => onChange({ ...item, key: e.target.value })}
aria-label={`${keyPlaceholder || "Key"} ${index + 1}`}
aria-invalid={!!error?.key}
aria-describedby={
error?.key ? `${fieldId}-key-error-${index}` : undefined
}
variant={disabled ? "disabled" : undefined}
showClearButton={false}
/>
{error?.key && <ErrorTextLayout>{error.key}</ErrorTextLayout>}
{error?.key && (
<FieldMessage variant="error" className="ml-0.5">
<FieldMessage.Content
id={`${fieldId}-key-error-${index}`}
role="alert"
className="ml-0.5"
>
{error.key}
</FieldMessage.Content>
</FieldMessage>
)}
</div>
<div className="flex flex-col gap-y-0.5">
<InputTypeIn
placeholder={valuePlaceholder}
placeholder={valuePlaceholder || "Value"}
value={item.value}
onChange={(e) => onChange({ ...item, value: e.target.value })}
aria-label={`${valuePlaceholder || "Value"} ${index + 1}`}
aria-invalid={!!error?.value}
aria-describedby={
error?.value ? `${fieldId}-value-error-${index}` : undefined
}
variant={disabled ? "disabled" : undefined}
showClearButton={false}
/>
{error?.value && <ErrorTextLayout>{error.value}</ErrorTextLayout>}
{error?.value && (
<FieldMessage variant="error" className="ml-0.5">
<FieldMessage.Content
id={`${fieldId}-value-error-${index}`}
role="alert"
className="ml-0.5"
>
{error.value}
</FieldMessage.Content>
</FieldMessage>
)}
</div>
<Button
disabled={!canRemove}
disabled={disabled || !canRemove}
prominence="tertiary"
icon={SvgMinusCircle}
onClick={onRemove}
@@ -158,31 +198,46 @@ export interface KeyValueInputProps
> {
/** Title for the key column */
keyTitle?: string;
/** Title for the value column */
valueTitle?: string;
/** Placeholder for the key input */
keyPlaceholder?: string;
/** Placeholder for the value input */
valuePlaceholder?: string;
/** Array of key-value pairs */
items: KeyValue[];
/** Callback when items change */
onChange: (nextItems: KeyValue[]) => void;
/** Custom add handler */
onAdd?: () => void;
/** Custom remove handler */
onRemove?: (index: number) => void;
/** Disabled state */
disabled?: boolean;
/** Mode: 'line' allows removing all items, 'fixed-line' requires at least one item */
mode?: "line" | "fixed-line";
/** Layout: 'equal' - both inputs same width, 'key-wide' - key input is wider (60/40 split) */
layout?: "equal" | "key-wide";
/** Callback when validation state changes */
onValidationChange?: (isValid: boolean, errors: KeyValueError[]) => void;
/** Callback to handle validation errors - integrates with Formik or custom error handling. Called with error message when invalid, null when valid */
onValidationError?: (errorMessage: string | null) => void;
/** Optional custom validator for the key field. Return { isValid, message } */
onKeyValidate?: (
key: string,
index: number,
item: KeyValue,
items: KeyValue[]
) => { isValid: boolean; message?: string };
/** Optional custom validator for the value field. Return { isValid, message } */
onValueValidate?: (
value: string,
index: number,
item: KeyValue,
items: KeyValue[]
) => { isValid: boolean; message?: string };
/** Whether to validate for duplicate keys */
validateDuplicateKeys?: boolean;
/** Whether to validate for empty keys */
validateEmptyKeys?: boolean;
/** Optional name for the field (for accessibility) */
name?: string;
/** Custom label for the add button (defaults to "Add Line") */
addButtonLabel?: string;
}
@@ -190,16 +245,26 @@ export interface KeyValueInputProps
export default function KeyValueInput({
keyTitle = "Key",
valueTitle = "Value",
keyPlaceholder,
valuePlaceholder,
items = [],
onChange,
onAdd,
onRemove,
disabled = false,
mode = "line",
layout = "equal",
onValidationChange,
onValidationError,
onKeyValidate,
onValueValidate,
validateDuplicateKeys = true,
validateEmptyKeys = true,
name,
addButtonLabel = "Add Line",
...rest
}: KeyValueInputProps) {
// Try to get field context if used within FormField (safe access)
const fieldContext = useContext(FieldContext);
// Validation logic
const errors = useMemo((): KeyValueError[] => {
if (!items || items.length === 0) return [];
@@ -208,8 +273,12 @@ export default function KeyValueInput({
const keyCount = new Map<string, number[]>();
items.forEach((item, index) => {
// Validate empty keys
if (item.key.trim() === "" && item.value.trim() !== "") {
// Validate empty keys - only if value is filled (user is actively working on this row)
if (
validateEmptyKeys &&
item.key.trim() === "" &&
item.value.trim() !== ""
) {
const error = errorsList[index];
if (error) {
error.key = "Key cannot be empty";
@@ -222,22 +291,56 @@ export default function KeyValueInput({
existing.push(index);
keyCount.set(item.key, existing);
}
});
// Validate duplicate keys
keyCount.forEach((indices, key) => {
if (indices.length > 1) {
indices.forEach((index) => {
// Custom key validation
if (onKeyValidate) {
const result = onKeyValidate(item.key, index, item, items);
if (result && result.isValid === false) {
const error = errorsList[index];
if (error) {
error.key = "Duplicate key";
error.key = result.message || "Invalid key";
}
});
}
}
// Custom value validation
if (onValueValidate) {
const result = onValueValidate(item.value, index, item, items);
if (result && result.isValid === false) {
const error = errorsList[index];
if (error) {
error.value = result.message || "Invalid value";
}
}
}
});
// Validate duplicate keys
if (validateDuplicateKeys) {
keyCount.forEach((indices, key) => {
if (indices.length > 1) {
indices.forEach((index) => {
const error = errorsList[index];
if (error) {
error.key = "Duplicate key";
}
});
}
});
}
return errorsList;
}, [items]);
}, [
items,
validateDuplicateKeys,
validateEmptyKeys,
onKeyValidate,
onValueValidate,
]);
const isValid = useMemo(() => {
return errors.every((error) => !error.key && !error.value);
}, [errors]);
const hasAnyError = useMemo(() => {
return errors.some((error) => error.key || error.value);
@@ -268,12 +371,21 @@ export default function KeyValueInput({
}, [hasAnyError, errors]);
// Notify parent of validation changes
const onValidationChangeRef = useRef(onValidationChange);
const onValidationErrorRef = useRef(onValidationError);
useEffect(() => {
onValidationChangeRef.current = onValidationChange;
}, [onValidationChange]);
useEffect(() => {
onValidationErrorRef.current = onValidationError;
}, [onValidationError]);
useEffect(() => {
onValidationChangeRef.current?.(isValid, errors);
}, [isValid, errors]);
// Notify parent of error state for form library integration
useEffect(() => {
onValidationErrorRef.current?.(errorMessage);
@@ -282,17 +394,25 @@ export default function KeyValueInput({
const canRemoveItems = mode === "line" || items.length > 1;
const handleAdd = useCallback(() => {
if (onAdd) {
onAdd();
return;
}
onChange([...(items || []), { key: "", value: "" }]);
}, [onChange, items]);
}, [onAdd, onChange, items]);
const handleRemove = useCallback(
(index: number) => {
if (!canRemoveItems && items.length === 1) return;
if (onRemove) {
onRemove(index);
return;
}
const next = (items || []).filter((_, i) => i !== index);
onChange(next);
},
[canRemoveItems, items, onChange]
[canRemoveItems, items, onRemove, onChange]
);
const handleItemChange = useCallback(
@@ -311,6 +431,8 @@ export default function KeyValueInput({
}
}, [mode]); // Only run on mode change
const autoId = useId();
const fieldId = fieldContext?.baseId || name || `key-value-input-${autoId}`;
const gridCols = GRID_COLS[layout];
return (
@@ -338,24 +460,23 @@ export default function KeyValueInput({
key={index}
item={item}
onChange={(next) => handleItemChange(index, next)}
disabled={disabled}
onRemove={() => handleRemove(index)}
keyPlaceholder={keyPlaceholder}
valuePlaceholder={valuePlaceholder}
keyPlaceholder={keyTitle}
valuePlaceholder={valueTitle}
error={errors[index]}
canRemove={canRemoveItems}
index={index}
fieldId={fieldId}
/>
))}
</div>
) : (
<EmptyMessageCard
title="No items added yet."
padding="sm"
sizePreset="secondary"
/>
<EmptyMessageCard title="No items added yet." />
)}
<Button
disabled={disabled}
prominence="secondary"
onClick={handleAdd}
icon={SvgPlusCircle}

View File

@@ -1,7 +1,7 @@
"use client";
import { useState, useEffect, useCallback, useMemo, useRef } from "react";
import Popover from "@/refresh-components/Popover";
import Popover, { PopoverMenu } from "@/refresh-components/Popover";
import { LlmDescriptor, LlmManager } from "@/lib/hooks";
import { structureValue } from "@/lib/llmConfig/utils";
import {
@@ -11,11 +11,25 @@ import {
import { LLMProviderDescriptor } from "@/interfaces/llm";
import { Slider } from "@/components/ui/slider";
import { useUser } from "@/providers/UserProvider";
import LineItem from "@/refresh-components/buttons/LineItem";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import Text from "@/refresh-components/texts/Text";
import { SvgRefreshCw } from "@opal/icons";
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
import {
Accordion,
AccordionContent,
AccordionItem,
AccordionTrigger,
} from "@/components/ui/accordion";
import {
SvgCheck,
SvgChevronDown,
SvgChevronRight,
SvgRefreshCw,
} from "@opal/icons";
import { Section } from "@/layouts/general-layouts";
import { OpenButton } from "@opal/components";
import { LLMOption, LLMOptionGroup } from "./interfaces";
import ModelListContent from "./ModelListContent";
export interface LLMPopoverProps {
llmManager: LlmManager;
@@ -136,6 +150,7 @@ export default function LLMPopover({
const isLoadingProviders = llmManager.isLoadingProviders;
const [open, setOpen] = useState(false);
const [searchQuery, setSearchQuery] = useState("");
const { user } = useUser();
const [localTemperature, setLocalTemperature] = useState(
@@ -146,7 +161,9 @@ export default function LLMPopover({
setLocalTemperature(llmManager.temperature ?? 0.5);
}, [llmManager.temperature]);
const searchInputRef = useRef<HTMLInputElement>(null);
const scrollContainerRef = useRef<HTMLDivElement>(null);
const selectedItemRef = useRef<HTMLDivElement>(null);
const handleGlobalTemperatureChange = useCallback((value: number[]) => {
const value_0 = value[0];
@@ -165,28 +182,39 @@ export default function LLMPopover({
[llmManager]
);
const isSelected = useCallback(
(option: LLMOption) =>
option.modelName === llmManager.currentLlm.modelName &&
option.provider === llmManager.currentLlm.provider,
[llmManager.currentLlm.modelName, llmManager.currentLlm.provider]
const llmOptions = useMemo(
() => buildLlmOptions(llmProviders, currentModelName),
[llmProviders, currentModelName]
);
const handleSelectModel = useCallback(
(option: LLMOption) => {
llmManager.updateCurrentLlm({
modelName: option.modelName,
provider: option.provider,
name: option.name,
} as LlmDescriptor);
onSelect?.(
structureValue(option.name, option.provider, option.modelName)
// Filter options by vision capability (when images are uploaded) and search query
const filteredOptions = useMemo(() => {
let result = llmOptions;
if (requiresImageInput) {
result = result.filter((opt) => opt.supportsImageInput);
}
if (searchQuery.trim()) {
const query = searchQuery.toLowerCase();
result = result.filter(
(opt) =>
opt.displayName.toLowerCase().includes(query) ||
opt.modelName.toLowerCase().includes(query) ||
(opt.vendor && opt.vendor.toLowerCase().includes(query))
);
setOpen(false);
},
[llmManager, onSelect]
}
return result;
}, [llmOptions, searchQuery, requiresImageInput]);
// Group options by provider using backend-provided display names and ordering
// For aggregator providers (bedrock, openrouter, vertex_ai), flatten to "Provider/Vendor" format
const groupedOptions = useMemo(
() => groupLlmOptions(filteredOptions),
[filteredOptions]
);
// Get display name for the model to show in the button
// Use currentModelName prop if provided (e.g., for regenerate showing the model used),
// otherwise fall back to the globally selected model
const currentLlmDisplayName = useMemo(() => {
// Only use currentModelName if it's a non-empty string
const currentModel =
@@ -206,30 +234,122 @@ export default function LLMPopover({
return currentModel;
}, [llmProviders, currentModelName, llmManager.currentLlm.modelName]);
const temperatureFooter = user?.preferences?.temperature_override_enabled ? (
<>
<div className="border-t border-border-02 mx-2" />
<div className="flex flex-col w-full py-2 gap-2">
<Slider
value={[localTemperature]}
max={llmManager.maxTemperature}
min={0}
step={0.01}
onValueChange={handleGlobalTemperatureChange}
onValueCommit={handleGlobalTemperatureCommit}
className="w-full"
/>
<div className="flex flex-row items-center justify-between">
<Text secondaryBody text03>
Temperature (creativity)
</Text>
<Text secondaryBody text03>
{localTemperature.toFixed(1)}
</Text>
</div>
// Determine which group the current model belongs to (for auto-expand)
const currentGroupKey = useMemo(() => {
const currentModel = llmManager.currentLlm.modelName;
const currentProvider = llmManager.currentLlm.provider;
// Match by both modelName AND provider to handle same model name across providers
const option = llmOptions.find(
(o) => o.modelName === currentModel && o.provider === currentProvider
);
if (!option) return "openai";
const provider = option.provider.toLowerCase();
const isAggregator = AGGREGATOR_PROVIDERS.has(provider);
if (isAggregator && option.vendor) {
return `${provider}/${option.vendor.toLowerCase()}`;
}
return provider;
}, [
llmOptions,
llmManager.currentLlm.modelName,
llmManager.currentLlm.provider,
]);
// Track expanded groups - initialize with current model's group
const [expandedGroups, setExpandedGroups] = useState<string[]>([
currentGroupKey,
]);
// Reset state when popover closes/opens
useEffect(() => {
if (!open) {
setSearchQuery("");
} else {
// Reset expanded groups to only show the selected model's group
setExpandedGroups([currentGroupKey]);
}
}, [open, currentGroupKey]);
// Auto-scroll to selected model when popover opens
useEffect(() => {
if (open) {
// Small delay to let accordion content render
const timer = setTimeout(() => {
selectedItemRef.current?.scrollIntoView({
behavior: "instant",
block: "center",
});
}, 50);
return () => clearTimeout(timer);
}
}, [open]);
const isSearching = searchQuery.trim().length > 0;
// Compute final expanded groups
const effectiveExpandedGroups = useMemo(() => {
if (isSearching) {
// Force expand all when searching
return groupedOptions.map((g) => g.key);
}
return expandedGroups;
}, [isSearching, groupedOptions, expandedGroups]);
// Handler for accordion changes
const handleAccordionChange = (value: string[]) => {
// Only update state when not searching (force-expanding)
if (!isSearching) {
setExpandedGroups(value);
}
};
const handleSelectModel = (option: LLMOption) => {
llmManager.updateCurrentLlm({
modelName: option.modelName,
provider: option.provider,
name: option.name,
} as LlmDescriptor);
onSelect?.(structureValue(option.name, option.provider, option.modelName));
setOpen(false);
};
const renderModelItem = (option: LLMOption) => {
const isSelected =
option.modelName === llmManager.currentLlm.modelName &&
option.provider === llmManager.currentLlm.provider;
const capabilities: string[] = [];
if (option.supportsReasoning) {
capabilities.push("Reasoning");
}
if (option.supportsImageInput) {
capabilities.push("Vision");
}
const description =
capabilities.length > 0 ? capabilities.join(", ") : undefined;
return (
<div
key={`${option.name}-${option.modelName}`}
ref={isSelected ? selectedItemRef : undefined}
>
<LineItem
selected={isSelected}
description={description}
onClick={() => handleSelectModel(option)}
rightChildren={
isSelected ? (
<SvgCheck className="h-4 w-4 stroke-action-link-05 shrink-0" />
) : null
}
>
{option.displayName}
</LineItem>
</div>
</>
) : undefined;
);
};
return (
<Popover open={open} onOpenChange={setOpen}>
@@ -253,16 +373,129 @@ export default function LLMPopover({
</div>
<Popover.Content side="top" align="end" width="xl">
<ModelListContent
llmProviders={llmProviders}
currentModelName={currentModelName}
requiresImageInput={requiresImageInput}
isLoading={isLoadingProviders}
onSelect={handleSelectModel}
isSelected={isSelected}
scrollContainerRef={scrollContainerRef}
footer={temperatureFooter}
/>
<Section gap={0.5}>
{/* Search Input */}
<InputTypeIn
ref={searchInputRef}
leftSearchIcon
variant="internal"
value={searchQuery}
onChange={(e) => setSearchQuery(e.target.value)}
placeholder="Search models..."
/>
{/* Model List with Vendor Groups */}
<PopoverMenu scrollContainerRef={scrollContainerRef}>
{isLoadingProviders
? [
<div key="loading" className="flex items-center gap-2 py-3">
<SimpleLoader />
<Text secondaryBody text03>
Loading models...
</Text>
</div>,
]
: groupedOptions.length === 0
? [
<div key="empty" className="py-3">
<Text secondaryBody text03>
No models found
</Text>
</div>,
]
: groupedOptions.length === 1
? // Single provider - show models directly without accordion
[
<div
key="single-provider"
className="flex flex-col gap-1"
>
{groupedOptions[0]!.options.map(renderModelItem)}
</div>,
]
: // Multiple providers - show accordion with groups
[
<Accordion
key="accordion"
type="multiple"
value={effectiveExpandedGroups}
onValueChange={handleAccordionChange}
className="w-full flex flex-col"
>
{groupedOptions.map((group) => {
const isExpanded = effectiveExpandedGroups.includes(
group.key
);
return (
<AccordionItem
key={group.key}
value={group.key}
className="border-none pt-1"
>
{/* Group Header */}
<AccordionTrigger className="flex items-center rounded-08 hover:no-underline hover:bg-background-tint-02 group [&>svg]:hidden w-full py-1">
<div className="flex items-center gap-1 shrink-0">
<div className="flex items-center justify-center size-5 shrink-0">
<group.Icon size={16} />
</div>
<Text
secondaryBody
text03
nowrap
className="px-0.5"
>
{group.displayName}
</Text>
</div>
<div className="flex-1" />
<div className="flex items-center justify-center size-6 shrink-0">
{isExpanded ? (
<SvgChevronDown className="h-4 w-4 stroke-text-04 shrink-0" />
) : (
<SvgChevronRight className="h-4 w-4 stroke-text-04 shrink-0" />
)}
</div>
</AccordionTrigger>
{/* Model Items - full width highlight */}
<AccordionContent className="pb-0 pt-0">
<div className="flex flex-col gap-1">
{group.options.map(renderModelItem)}
</div>
</AccordionContent>
</AccordionItem>
);
})}
</Accordion>,
]}
</PopoverMenu>
{/* Global Temperature Slider (shown if enabled in user prefs) */}
{user?.preferences?.temperature_override_enabled && (
<>
<div className="border-t border-border-02 mx-2" />
<div className="flex flex-col w-full py-2 gap-2">
<Slider
value={[localTemperature]}
max={llmManager.maxTemperature}
min={0}
step={0.01}
onValueChange={handleGlobalTemperatureChange}
onValueCommit={handleGlobalTemperatureCommit}
className="w-full"
/>
<div className="flex flex-row items-center justify-between">
<Text secondaryBody text03>
Temperature (creativity)
</Text>
<Text secondaryBody text03>
{localTemperature.toFixed(1)}
</Text>
</div>
</div>
</>
)}
</Section>
</Popover.Content>
</Popover>
);

View File

@@ -1,200 +0,0 @@
"use client";
import { useState, useMemo, useRef, useEffect } from "react";
import { PopoverMenu } from "@/refresh-components/Popover";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import { Text } from "@opal/components";
import { SvgCheck, SvgChevronDown, SvgChevronRight } from "@opal/icons";
import { Section } from "@/layouts/general-layouts";
import { LLMOption } from "./interfaces";
import { buildLlmOptions, groupLlmOptions } from "./LLMPopover";
import LineItem from "@/refresh-components/buttons/LineItem";
import { LLMProviderDescriptor } from "@/interfaces/llm";
import {
Collapsible,
CollapsibleContent,
CollapsibleTrigger,
} from "@/refresh-components/Collapsible";
export interface ModelListContentProps {
llmProviders: LLMProviderDescriptor[] | undefined;
currentModelName?: string;
requiresImageInput?: boolean;
onSelect: (option: LLMOption) => void;
isSelected: (option: LLMOption) => boolean;
isDisabled?: (option: LLMOption) => boolean;
scrollContainerRef?: React.RefObject<HTMLDivElement | null>;
isLoading?: boolean;
footer?: React.ReactNode;
}
export default function ModelListContent({
llmProviders,
currentModelName,
requiresImageInput,
onSelect,
isSelected,
isDisabled,
scrollContainerRef: externalScrollRef,
isLoading,
footer,
}: ModelListContentProps) {
const [searchQuery, setSearchQuery] = useState("");
const internalScrollRef = useRef<HTMLDivElement>(null);
const scrollContainerRef = externalScrollRef ?? internalScrollRef;
const llmOptions = useMemo(
() => buildLlmOptions(llmProviders, currentModelName),
[llmProviders, currentModelName]
);
const filteredOptions = useMemo(() => {
let result = llmOptions;
if (requiresImageInput) {
result = result.filter((opt) => opt.supportsImageInput);
}
if (searchQuery.trim()) {
const query = searchQuery.toLowerCase();
result = result.filter(
(opt) =>
opt.displayName.toLowerCase().includes(query) ||
opt.modelName.toLowerCase().includes(query) ||
(opt.vendor && opt.vendor.toLowerCase().includes(query))
);
}
return result;
}, [llmOptions, searchQuery, requiresImageInput]);
const groupedOptions = useMemo(
() => groupLlmOptions(filteredOptions),
[filteredOptions]
);
// Find which group contains a currently-selected model (for auto-expand)
const defaultGroupKey = useMemo(() => {
for (const group of groupedOptions) {
if (group.options.some((opt) => isSelected(opt))) {
return group.key;
}
}
return groupedOptions[0]?.key ?? "";
}, [groupedOptions, isSelected]);
const [expandedGroups, setExpandedGroups] = useState<Set<string>>(
new Set([defaultGroupKey])
);
// Reset expanded groups when default changes (e.g. popover re-opens)
useEffect(() => {
setExpandedGroups(new Set([defaultGroupKey]));
}, [defaultGroupKey]);
const isSearching = searchQuery.trim().length > 0;
const toggleGroup = (key: string) => {
if (isSearching) return;
setExpandedGroups((prev) => {
const next = new Set(prev);
if (next.has(key)) next.delete(key);
else next.add(key);
return next;
});
};
const isGroupOpen = (key: string) => isSearching || expandedGroups.has(key);
const renderModelItem = (option: LLMOption) => {
const selected = isSelected(option);
const disabled = isDisabled?.(option) ?? false;
const capabilities: string[] = [];
if (option.supportsReasoning) capabilities.push("Reasoning");
if (option.supportsImageInput) capabilities.push("Vision");
const description =
capabilities.length > 0 ? capabilities.join(", ") : undefined;
return (
<LineItem
key={`${option.provider}:${option.modelName}`}
selected={selected}
disabled={disabled}
description={description}
onClick={() => onSelect(option)}
rightChildren={
selected ? (
<SvgCheck className="h-4 w-4 stroke-action-link-05 shrink-0" />
) : null
}
>
{option.displayName}
</LineItem>
);
};
return (
<Section gap={0.5}>
<InputTypeIn
leftSearchIcon
variant="internal"
value={searchQuery}
onChange={(e) => setSearchQuery(e.target.value)}
placeholder="Search models..."
/>
<PopoverMenu scrollContainerRef={scrollContainerRef}>
{isLoading
? [
<Text key="loading" font="secondary-body" color="text-03">
Loading models...
</Text>,
]
: groupedOptions.length === 0
? [
<Text key="empty" font="secondary-body" color="text-03">
No models found
</Text>,
]
: groupedOptions.length === 1
? [
<Section key="single-provider" gap={0.25}>
{groupedOptions[0]!.options.map(renderModelItem)}
</Section>,
]
: groupedOptions.map((group) => {
const open = isGroupOpen(group.key);
return (
<Collapsible
key={group.key}
open={open}
onOpenChange={() => toggleGroup(group.key)}
>
<CollapsibleTrigger asChild>
<LineItem
muted
icon={group.Icon}
rightChildren={
open ? (
<SvgChevronDown className="h-4 w-4 stroke-text-04 shrink-0" />
) : (
<SvgChevronRight className="h-4 w-4 stroke-text-04 shrink-0" />
)
}
>
{group.displayName}
</LineItem>
</CollapsibleTrigger>
<CollapsibleContent>
<Section gap={0.25}>
{group.options.map(renderModelItem)}
</Section>
</CollapsibleContent>
</Collapsible>
);
})}
</PopoverMenu>
{footer}
</Section>
);
}

View File

@@ -1,230 +0,0 @@
"use client";
import { useState, useMemo, useRef } from "react";
import Popover from "@/refresh-components/Popover";
import { LlmManager } from "@/lib/hooks";
import { getProviderIcon } from "@/app/admin/configuration/llm/utils";
import { Button, SelectButton, OpenButton } from "@opal/components";
import { SvgPlusCircle, SvgX } from "@opal/icons";
import { LLMOption } from "@/refresh-components/popovers/interfaces";
import ModelListContent from "@/refresh-components/popovers/ModelListContent";
import Separator from "@/refresh-components/Separator";
export const MAX_MODELS = 3;
export interface SelectedModel {
name: string;
provider: string;
modelName: string;
displayName: string;
}
export interface ModelSelectorProps {
llmManager: LlmManager;
selectedModels: SelectedModel[];
onAdd: (model: SelectedModel) => void;
onRemove: (index: number) => void;
onReplace: (index: number, model: SelectedModel) => void;
}
function modelKey(provider: string, modelName: string): string {
return `${provider}:${modelName}`;
}
export default function ModelSelector({
llmManager,
selectedModels,
onAdd,
onRemove,
onReplace,
}: ModelSelectorProps) {
const [open, setOpen] = useState(false);
// null = add mode (via + button), number = replace mode (via pill click)
const [replacingIndex, setReplacingIndex] = useState<number | null>(null);
// Virtual anchor ref — points to the clicked pill so the popover positions above it
const anchorRef = useRef<HTMLElement | null>(null);
const isMultiModel = selectedModels.length > 1;
const atMax = selectedModels.length >= MAX_MODELS;
const selectedKeys = useMemo(
() => new Set(selectedModels.map((m) => modelKey(m.provider, m.modelName))),
[selectedModels]
);
const otherSelectedKeys = useMemo(() => {
if (replacingIndex === null) return new Set<string>();
return new Set(
selectedModels
.filter((_, i) => i !== replacingIndex)
.map((m) => modelKey(m.provider, m.modelName))
);
}, [selectedModels, replacingIndex]);
const replacingKey =
replacingIndex !== null
? (() => {
const m = selectedModels[replacingIndex];
return m ? modelKey(m.provider, m.modelName) : null;
})()
: null;
const isSelected = (option: LLMOption) => {
const key = modelKey(option.provider, option.modelName);
if (replacingIndex !== null) return key === replacingKey;
return selectedKeys.has(key);
};
const isDisabled = (option: LLMOption) => {
const key = modelKey(option.provider, option.modelName);
if (replacingIndex !== null) return otherSelectedKeys.has(key);
return !selectedKeys.has(key) && atMax;
};
const handleSelect = (option: LLMOption) => {
const model: SelectedModel = {
name: option.name,
provider: option.provider,
modelName: option.modelName,
displayName: option.displayName,
};
if (replacingIndex !== null) {
onReplace(replacingIndex, model);
setOpen(false);
setReplacingIndex(null);
return;
}
const key = modelKey(option.provider, option.modelName);
const existingIndex = selectedModels.findIndex(
(m) => modelKey(m.provider, m.modelName) === key
);
if (existingIndex >= 0) {
onRemove(existingIndex);
} else if (!atMax) {
onAdd(model);
}
};
const handleOpenChange = (nextOpen: boolean) => {
setOpen(nextOpen);
if (!nextOpen) setReplacingIndex(null);
};
const handlePillClick = (index: number, element: HTMLElement) => {
anchorRef.current = element;
setReplacingIndex(index);
setOpen(true);
};
return (
<Popover open={open} onOpenChange={handleOpenChange}>
<div className="flex items-center justify-end gap-1 p-1">
{!atMax && (
<Button
prominence="tertiary"
icon={SvgPlusCircle}
size="sm"
tooltip="Add Model"
onClick={(e: React.MouseEvent) => {
anchorRef.current = e.currentTarget as HTMLElement;
setReplacingIndex(null);
setOpen(true);
}}
/>
)}
<Popover.Anchor
virtualRef={anchorRef as React.RefObject<HTMLElement>}
/>
{selectedModels.length > 0 && (
<>
{!atMax && (
<Separator
orientation="vertical"
paddingXRem={0.5}
paddingYRem={0.5}
/>
)}
<div className="flex items-center">
{selectedModels.map((model, index) => {
const ProviderIcon = getProviderIcon(
model.provider,
model.modelName
);
if (!isMultiModel) {
return (
<OpenButton
key={modelKey(model.provider, model.modelName)}
icon={ProviderIcon}
onClick={(e: React.MouseEvent) =>
handlePillClick(index, e.currentTarget as HTMLElement)
}
>
{model.displayName}
</OpenButton>
);
}
return (
<div
key={modelKey(model.provider, model.modelName)}
className="flex items-center"
>
{index > 0 && (
<Separator
orientation="vertical"
paddingXRem={0.5}
className="h-5"
/>
)}
<SelectButton
icon={ProviderIcon}
rightIcon={SvgX}
state="empty"
variant="select-tinted"
interaction="hover"
size="lg"
onClick={(e: React.MouseEvent) => {
const target = e.target as HTMLElement;
const btn = e.currentTarget as HTMLElement;
const icons = btn.querySelectorAll(
".interactive-foreground-icon"
);
const lastIcon = icons[icons.length - 1];
if (lastIcon && lastIcon.contains(target)) {
onRemove(index);
} else {
handlePillClick(index, btn);
}
}}
>
{model.displayName}
</SelectButton>
</div>
);
})}
</div>
</>
)}
</div>
<Popover.Content
side="top"
align="start"
width="lg"
avoidCollisions={false}
>
<ModelListContent
llmProviders={llmManager.llmProviders}
isLoading={llmManager.isLoadingProviders}
onSelect={handleSelect}
isSelected={isSelected}
isDisabled={isDisabled}
/>
</Popover.Content>
</Popover>
);
}

View File

@@ -232,6 +232,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
onboardingDismissed,
onboardingState,
onboardingActions,
llmDescriptors,
isLoadingOnboarding,
finishOnboarding,
hideOnboarding,
@@ -811,6 +812,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
handleFinishOnboarding={finishOnboarding}
state={onboardingState}
actions={onboardingActions}
llmDescriptors={llmDescriptors}
/>
)}

View File

@@ -1,7 +1,8 @@
"use client";
import { useState } from "react";
import { useMemo, useState } from "react";
import { useRouter } from "next/navigation";
import useSWR from "swr";
import { Table, Button } from "@opal/components";
import { IllustrationContent } from "@opal/layouts";
import { SvgUsers } from "@opal/icons";
@@ -13,14 +14,17 @@ import Text from "@/refresh-components/texts/Text";
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
import Separator from "@/refresh-components/Separator";
import { toast } from "@/hooks/useToast";
import useGroupMemberCandidates from "./useGroupMemberCandidates";
import { errorHandlingFetcher } from "@/lib/fetcher";
import useAdminUsers from "@/hooks/useAdminUsers";
import { SWR_KEYS } from "@/lib/swr-keys";
import type { ApiKeyDescriptor, MemberRow } from "./interfaces";
import {
createGroup,
updateAgentGroupSharing,
updateDocSetGroupSharing,
saveTokenLimits,
} from "./svc";
import { memberTableColumns, PAGE_SIZE } from "./shared";
import { apiKeyToMemberRow, memberTableColumns, PAGE_SIZE } from "./shared";
import SharedGroupResources from "@/refresh-pages/admin/GroupsPage/SharedGroupResources";
import TokenLimitSection from "./TokenLimitSection";
import type { TokenLimit } from "./TokenLimitSection";
@@ -38,7 +42,22 @@ function CreateGroupPage() {
{ tokenBudget: null, periodHours: null },
]);
const { rows: allRows, isLoading, error } = useGroupMemberCandidates();
const { users, isLoading: usersLoading, error: usersError } = useAdminUsers();
const {
data: apiKeys,
isLoading: apiKeysLoading,
error: apiKeysError,
} = useSWR<ApiKeyDescriptor[]>(SWR_KEYS.adminApiKeys, errorHandlingFetcher);
const isLoading = usersLoading || apiKeysLoading;
const error = usersError ?? apiKeysError;
const allRows: MemberRow[] = useMemo(() => {
const activeUsers = users.filter((u) => u.is_active);
const serviceAccountRows = (apiKeys ?? []).map(apiKeyToMemberRow);
return [...activeUsers, ...serviceAccountRows];
}, [users, apiKeys]);
async function handleCreate() {
const trimmed = groupName.trim();
@@ -115,11 +134,11 @@ function CreateGroupPage() {
{/* Members table */}
{isLoading && <SimpleLoader />}
{error ? (
{error && (
<Text as="p" secondaryBody text03>
Failed to load users.
</Text>
) : null}
)}
{!isLoading && !error && (
<Section

View File

@@ -3,7 +3,6 @@
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { useRouter } from "next/navigation";
import useSWR, { useSWRConfig } from "swr";
import useGroupMemberCandidates from "./useGroupMemberCandidates";
import { Table, Button } from "@opal/components";
import { IllustrationContent } from "@opal/layouts";
import { SvgUsers, SvgTrash, SvgMinusCircle, SvgPlusCircle } from "@opal/icons";
@@ -20,9 +19,20 @@ import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationMo
import Separator from "@/refresh-components/Separator";
import { toast } from "@/hooks/useToast";
import { errorHandlingFetcher } from "@/lib/fetcher";
import useAdminUsers from "@/hooks/useAdminUsers";
import type { UserGroup } from "@/lib/types";
import type { MemberRow, TokenRateLimitDisplay } from "./interfaces";
import { baseColumns, memberTableColumns, tc, PAGE_SIZE } from "./shared";
import type {
ApiKeyDescriptor,
MemberRow,
TokenRateLimitDisplay,
} from "./interfaces";
import {
apiKeyToMemberRow,
baseColumns,
memberTableColumns,
tc,
PAGE_SIZE,
} from "./shared";
import {
renameGroup,
updateGroup,
@@ -94,15 +104,18 @@ function EditGroupPage({ groupId }: EditGroupPageProps) {
const initialAgentIdsRef = useRef<number[]>([]);
const initialDocSetIdsRef = useRef<number[]>([]);
// Users + service accounts (curator-accessible — see hook docs).
const {
rows: allRows,
isLoading: candidatesLoading,
error: candidatesError,
} = useGroupMemberCandidates();
// Users and API keys
const { users, isLoading: usersLoading, error: usersError } = useAdminUsers();
const isLoading = groupLoading || candidatesLoading || tokenLimitsLoading;
const error = groupError ?? candidatesError;
const {
data: apiKeys,
isLoading: apiKeysLoading,
error: apiKeysError,
} = useSWR<ApiKeyDescriptor[]>(SWR_KEYS.adminApiKeys, errorHandlingFetcher);
const isLoading =
groupLoading || usersLoading || apiKeysLoading || tokenLimitsLoading;
const error = groupError ?? usersError ?? apiKeysError;
// Pre-populate form when group data loads
useEffect(() => {
@@ -132,6 +145,12 @@ function EditGroupPage({ groupId }: EditGroupPageProps) {
}
}, [tokenRateLimits]);
const allRows = useMemo(() => {
const activeUsers = users.filter((u) => u.is_active);
const serviceAccountRows = (apiKeys ?? []).map(apiKeyToMemberRow);
return [...activeUsers, ...serviceAccountRows];
}, [users, apiKeys]);
const memberRows = useMemo(() => {
const selected = new Set(selectedUserIds);
return allRows.filter((r) => selected.has(r.id ?? r.email));

View File

@@ -1,147 +0,0 @@
"use client";
import { useMemo } from "react";
import useSWR from "swr";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { SWR_KEYS } from "@/lib/swr-keys";
import { useUser } from "@/providers/UserProvider";
import { AccountType, UserStatus, type UserRole } from "@/lib/types";
import type {
UserGroupInfo,
UserRow,
} from "@/refresh-pages/admin/UsersPage/interfaces";
import type { ApiKeyDescriptor, MemberRow } from "./interfaces";
// Backend response shape for `/api/manage/users?include_api_keys=true`. The
// existing `AllUsersResponse` in `lib/types.ts` types `accepted` as `User[]`,
// which is missing fields the table needs (`personal_name`, `account_type`,
// `groups`, etc.), so we declare an accurate local type here.
interface FullUserSnapshot {
id: string;
email: string;
role: UserRole;
account_type: AccountType;
is_active: boolean;
password_configured: boolean;
personal_name: string | null;
created_at: string;
updated_at: string;
groups: UserGroupInfo[];
is_scim_synced: boolean;
}
interface ManageUsersResponse {
accepted: FullUserSnapshot[];
invited: { email: string }[];
slack_users: FullUserSnapshot[];
accepted_pages: number;
invited_pages: number;
slack_users_pages: number;
}
function snapshotToMemberRow(snapshot: FullUserSnapshot): MemberRow {
return {
id: snapshot.id,
email: snapshot.email,
role: snapshot.role,
status: snapshot.is_active ? UserStatus.ACTIVE : UserStatus.INACTIVE,
is_active: snapshot.is_active,
is_scim_synced: snapshot.is_scim_synced,
personal_name: snapshot.personal_name,
created_at: snapshot.created_at,
updated_at: snapshot.updated_at,
groups: snapshot.groups,
};
}
function serviceAccountToMemberRow(
snapshot: FullUserSnapshot,
apiKey: ApiKeyDescriptor | undefined
): MemberRow {
return {
id: snapshot.id,
email: "Service Account",
role: apiKey?.api_key_role ?? snapshot.role,
status: UserStatus.ACTIVE,
is_active: true,
is_scim_synced: false,
personal_name:
apiKey?.api_key_name ?? snapshot.personal_name ?? "Unnamed Key",
created_at: null,
updated_at: null,
groups: [],
api_key_display: apiKey?.api_key_display,
};
}
interface UseGroupMemberCandidatesResult {
/** Active users + service-account rows, in the order the table expects. */
rows: MemberRow[];
/** Subset of `rows` representing real (non-service-account) users. */
userRows: MemberRow[];
isLoading: boolean;
error: unknown;
}
/**
* Returns the candidate list for the group create/edit member pickers.
*
* Hits `/api/manage/users?include_api_keys=true`, which is gated by
* `current_curator_or_admin_user` on the backend, so this works for both
* admins and global curators (the admin-only `/accepted/all` and `/invited`
* endpoints used to be called here, which 403'd for global curators and broke
* the Edit Group page entirely).
*
* For admins, we additionally fetch `/admin/api-key` to enrich service-account
* rows with the masked api-key display string. That call is admin-only and is
* skipped for curators; its failure is non-fatal.
*/
export default function useGroupMemberCandidates(): UseGroupMemberCandidatesResult {
const { isAdmin } = useUser();
const {
data: usersData,
isLoading: usersLoading,
error: usersError,
} = useSWR<ManageUsersResponse>(
SWR_KEYS.groupMemberCandidates,
errorHandlingFetcher
);
const { data: apiKeys, isLoading: apiKeysLoading } = useSWR<
ApiKeyDescriptor[]
>(isAdmin ? SWR_KEYS.adminApiKeys : null, errorHandlingFetcher);
const apiKeysByUserId = useMemo(() => {
const map = new Map<string, ApiKeyDescriptor>();
for (const key of apiKeys ?? []) map.set(key.user_id, key);
return map;
}, [apiKeys]);
const { rows, userRows } = useMemo(() => {
const accepted = usersData?.accepted ?? [];
const userRowsLocal: MemberRow[] = [];
const serviceAccountRows: MemberRow[] = [];
for (const snapshot of accepted) {
if (!snapshot.is_active) continue;
if (snapshot.account_type === AccountType.SERVICE_ACCOUNT) {
serviceAccountRows.push(
serviceAccountToMemberRow(snapshot, apiKeysByUserId.get(snapshot.id))
);
} else {
userRowsLocal.push(snapshotToMemberRow(snapshot));
}
}
return {
rows: [...userRowsLocal, ...serviceAccountRows],
userRows: userRowsLocal,
};
}, [usersData, apiKeysByUserId]);
return {
rows,
userRows,
isLoading: usersLoading || (isAdmin && apiKeysLoading),
error: usersError,
};
}

View File

@@ -31,7 +31,6 @@ import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationMo
import { useCreateModal } from "@/refresh-components/contexts/ModalContext";
import Separator from "@/refresh-components/Separator";
import {
LLMProviderName,
LLMProviderView,
WellKnownLLMProviderDescriptor,
} from "@/interfaces/llm";
@@ -44,10 +43,9 @@ import BedrockModal from "@/sections/modals/llmConfig/BedrockModal";
import VertexAIModal from "@/sections/modals/llmConfig/VertexAIModal";
import OpenRouterModal from "@/sections/modals/llmConfig/OpenRouterModal";
import CustomModal from "@/sections/modals/llmConfig/CustomModal";
import LMStudioModal from "@/sections/modals/llmConfig/LMStudioModal";
import LMStudioForm from "@/sections/modals/llmConfig/LMStudioForm";
import LiteLLMProxyModal from "@/sections/modals/llmConfig/LiteLLMProxyModal";
import BifrostModal from "@/sections/modals/llmConfig/BifrostModal";
import OpenAICompatibleModal from "@/sections/modals/llmConfig/OpenAICompatibleModal";
import { Section } from "@/layouts/general-layouts";
const route = ADMIN_ROUTES.LLM_MODELS;
@@ -59,60 +57,93 @@ const route = ADMIN_ROUTES.LLM_MODELS;
// Client-side ordering for the "Add Provider" cards. The backend may return
// wellKnownLLMProviders in an arbitrary order, so we sort explicitly here.
const PROVIDER_DISPLAY_ORDER: string[] = [
LLMProviderName.OPENAI,
LLMProviderName.ANTHROPIC,
LLMProviderName.VERTEX_AI,
LLMProviderName.BEDROCK,
LLMProviderName.AZURE,
LLMProviderName.LITELLM,
LLMProviderName.LITELLM_PROXY,
LLMProviderName.OLLAMA_CHAT,
LLMProviderName.OPENROUTER,
LLMProviderName.LM_STUDIO,
LLMProviderName.BIFROST,
LLMProviderName.OPENAI_COMPATIBLE,
"openai",
"anthropic",
"vertex_ai",
"bedrock",
"azure",
"litellm_proxy",
"ollama_chat",
"openrouter",
"lm_studio",
"bifrost",
];
const PROVIDER_MODAL_MAP: Record<
string,
(
shouldMarkAsDefault: boolean,
open: boolean,
onOpenChange: (open: boolean) => void
) => React.ReactNode
> = {
openai: (d, onOpenChange) => (
<OpenAIModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
),
anthropic: (d, onOpenChange) => (
<AnthropicModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
),
ollama_chat: (d, onOpenChange) => (
<OllamaModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
),
azure: (d, onOpenChange) => (
<AzureModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
),
bedrock: (d, onOpenChange) => (
<BedrockModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
),
vertex_ai: (d, onOpenChange) => (
<VertexAIModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
),
openrouter: (d, onOpenChange) => (
<OpenRouterModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
),
lm_studio: (d, onOpenChange) => (
<LMStudioModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
),
litellm_proxy: (d, onOpenChange) => (
<LiteLLMProxyModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
),
bifrost: (d, onOpenChange) => (
<BifrostModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
),
openai_compatible: (d, onOpenChange) => (
<OpenAICompatibleModal
openai: (d, open, onOpenChange) => (
<OpenAIModal
shouldMarkAsDefault={d}
open={open}
onOpenChange={onOpenChange}
/>
),
anthropic: (d, open, onOpenChange) => (
<AnthropicModal
shouldMarkAsDefault={d}
open={open}
onOpenChange={onOpenChange}
/>
),
ollama_chat: (d, open, onOpenChange) => (
<OllamaModal
shouldMarkAsDefault={d}
open={open}
onOpenChange={onOpenChange}
/>
),
azure: (d, open, onOpenChange) => (
<AzureModal
shouldMarkAsDefault={d}
open={open}
onOpenChange={onOpenChange}
/>
),
bedrock: (d, open, onOpenChange) => (
<BedrockModal
shouldMarkAsDefault={d}
open={open}
onOpenChange={onOpenChange}
/>
),
vertex_ai: (d, open, onOpenChange) => (
<VertexAIModal
shouldMarkAsDefault={d}
open={open}
onOpenChange={onOpenChange}
/>
),
openrouter: (d, open, onOpenChange) => (
<OpenRouterModal
shouldMarkAsDefault={d}
open={open}
onOpenChange={onOpenChange}
/>
),
lm_studio: (d, open, onOpenChange) => (
<LMStudioForm
shouldMarkAsDefault={d}
open={open}
onOpenChange={onOpenChange}
/>
),
litellm_proxy: (d, open, onOpenChange) => (
<LiteLLMProxyModal
shouldMarkAsDefault={d}
open={open}
onOpenChange={onOpenChange}
/>
),
bifrost: (d, open, onOpenChange) => (
<BifrostModal
shouldMarkAsDefault={d}
open={open}
onOpenChange={onOpenChange}
/>
),
@@ -179,10 +210,7 @@ function ExistingProviderCard({
</ConfirmationModalLayout>
)}
<Hoverable.Root
group="ExistingProviderCard"
interaction={deleteModal.isOpen ? "hover" : "rest"}
>
<Hoverable.Root group="ExistingProviderCard">
<SelectCard
state="filled"
padding="sm"
@@ -224,8 +252,12 @@ function ExistingProviderCard({
</div>
}
/>
{isOpen &&
getModalForExistingProvider(provider, setIsOpen, defaultModelName)}
{getModalForExistingProvider(
provider,
isOpen,
setIsOpen,
defaultModelName
)}
</SelectCard>
</Hoverable.Root>
</>
@@ -241,6 +273,7 @@ interface NewProviderCardProps {
isFirstProvider: boolean;
formFn: (
shouldMarkAsDefault: boolean,
open: boolean,
onOpenChange: (open: boolean) => void
) => React.ReactNode;
}
@@ -278,7 +311,7 @@ function NewProviderCard({
</Button>
}
/>
{isOpen && formFn(isFirstProvider, setIsOpen)}
{formFn(isFirstProvider, isOpen, setIsOpen)}
</SelectCard>
);
}
@@ -322,12 +355,11 @@ function NewCustomProviderCard({
</Button>
}
/>
{isOpen && (
<CustomModal
shouldMarkAsDefault={isFirstProvider}
onOpenChange={setIsOpen}
/>
)}
<CustomModal
shouldMarkAsDefault={isFirstProvider}
open={isOpen}
onOpenChange={setIsOpen}
/>
</SelectCard>
);
}
@@ -336,7 +368,7 @@ function NewCustomProviderCard({
// LLMConfigurationPage — main page component
// ============================================================================
export default function LLMProviderConfigurationPage() {
export default function LLMConfigurationPage() {
const { mutate } = useSWRConfig();
const { llmProviders: existingLlmProviders, defaultText } =
useAdminLLMProviders();

View File

@@ -1,99 +1,177 @@
"use client";
import { useState } from "react";
import { useSWRConfig } from "swr";
import { LLMProviderFormProps, LLMProviderName } from "@/interfaces/llm";
import { Formik } from "formik";
import { LLMProviderFormProps } from "@/interfaces/llm";
import * as Yup from "yup";
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
import {
useInitialValues,
buildValidationSchema,
buildDefaultInitialValues,
buildDefaultValidationSchema,
buildAvailableModelConfigurations,
buildOnboardingInitialValues,
} from "@/sections/modals/llmConfig/utils";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
import {
submitLLMProvider,
submitOnboardingProvider,
} from "@/sections/modals/llmConfig/svc";
import {
APIKeyField,
ModelSelectionField,
ModelsField,
DisplayNameField,
ModelAccessField,
ModalWrapper,
ModelsAccessField,
FieldSeparator,
SingleDefaultModelField,
LLMConfigurationModalWrapper,
} from "@/sections/modals/llmConfig/shared";
import * as InputLayouts from "@/layouts/input-layouts";
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
import { toast } from "@/hooks/useToast";
const ANTHROPIC_PROVIDER_NAME = "anthropic";
const DEFAULT_DEFAULT_MODEL_NAME = "claude-sonnet-4-5";
export default function AnthropicModal({
variant = "llm-configuration",
existingLlmProvider,
shouldMarkAsDefault,
open,
onOpenChange,
onSuccess,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
}: LLMProviderFormProps) {
const isOnboarding = variant === "onboarding";
const [isTesting, setIsTesting] = useState(false);
const { mutate } = useSWRConfig();
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
ANTHROPIC_PROVIDER_NAME
);
if (open === false) return null;
const onClose = () => onOpenChange?.(false);
const initialValues = useInitialValues(
isOnboarding,
LLMProviderName.ANTHROPIC,
existingLlmProvider
const modelConfigurations = buildAvailableModelConfigurations(
existingLlmProvider,
wellKnownLLMProvider ?? llmDescriptor
);
const validationSchema = buildValidationSchema(isOnboarding, {
apiKey: true,
});
const initialValues = isOnboarding
? {
...buildOnboardingInitialValues(),
name: ANTHROPIC_PROVIDER_NAME,
provider: ANTHROPIC_PROVIDER_NAME,
api_key: "",
default_model_name: DEFAULT_DEFAULT_MODEL_NAME,
}
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
api_key: existingLlmProvider?.api_key ?? "",
api_base: existingLlmProvider?.api_base ?? undefined,
default_model_name:
(defaultModelName &&
modelConfigurations.some((m) => m.name === defaultModelName)
? defaultModelName
: undefined) ??
wellKnownLLMProvider?.recommended_default_model?.name ??
DEFAULT_DEFAULT_MODEL_NAME,
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,
};
const validationSchema = isOnboarding
? Yup.object().shape({
api_key: Yup.string().required("API Key is required"),
default_model_name: Yup.string().required("Model name is required"),
})
: buildDefaultValidationSchema().shape({
api_key: Yup.string().required("API Key is required"),
});
return (
<ModalWrapper
providerName={LLMProviderName.ANTHROPIC}
llmProvider={existingLlmProvider}
onClose={onClose}
<Formik
initialValues={initialValues}
validationSchema={validationSchema}
onSubmit={async (values, { setSubmitting, setStatus }) => {
await submitProvider({
analyticsSource: isOnboarding
? LLMProviderConfiguredSource.CHAT_ONBOARDING
: LLMProviderConfiguredSource.ADMIN_PAGE,
providerName: LLMProviderName.ANTHROPIC,
values,
initialValues,
existingLlmProvider,
shouldMarkAsDefault,
setStatus,
setSubmitting,
onClose,
onSuccess: async () => {
if (onSuccess) {
await onSuccess();
} else {
await refreshLlmProviderCaches(mutate);
toast.success(
existingLlmProvider
? "Provider updated successfully!"
: "Provider enabled successfully!"
);
}
},
});
validateOnMount={true}
onSubmit={async (values, { setSubmitting }) => {
if (isOnboarding && onboardingState && onboardingActions) {
const modelConfigsToUse =
(wellKnownLLMProvider ?? llmDescriptor)?.known_models ?? [];
await submitOnboardingProvider({
providerName: ANTHROPIC_PROVIDER_NAME,
payload: {
...values,
model_configurations: modelConfigsToUse,
is_auto_mode:
values.default_model_name === DEFAULT_DEFAULT_MODEL_NAME,
},
onboardingState,
onboardingActions,
isCustomProvider: false,
onClose,
setIsSubmitting: setSubmitting,
});
} else {
await submitLLMProvider({
providerName: ANTHROPIC_PROVIDER_NAME,
values,
initialValues,
modelConfigurations,
existingLlmProvider,
shouldMarkAsDefault,
setIsTesting,
mutate,
onClose,
setSubmitting,
});
}
}}
>
<APIKeyField providerName="Anthropic" />
{(formikProps) => (
<LLMConfigurationModalWrapper
providerEndpoint={ANTHROPIC_PROVIDER_NAME}
existingProviderName={existingLlmProvider?.name}
onClose={onClose}
isFormValid={formikProps.isValid}
isDirty={formikProps.dirty}
isTesting={isTesting}
isSubmitting={formikProps.isSubmitting}
>
<APIKeyField providerName="Anthropic" />
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
{!isOnboarding && (
<>
<FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
)}
<FieldSeparator />
{isOnboarding ? (
<SingleDefaultModelField placeholder="E.g. claude-sonnet-4-5" />
) : (
<ModelsField
modelConfigurations={modelConfigurations}
formikProps={formikProps}
recommendedDefaultModel={
wellKnownLLMProvider?.recommended_default_model ?? null
}
shouldShowAutoUpdateToggle={true}
/>
)}
{!isOnboarding && (
<>
<FieldSeparator />
<ModelsAccessField formikProps={formikProps} />
</>
)}
</LLMConfigurationModalWrapper>
)}
<InputLayouts.FieldSeparator />
<ModelSelectionField shouldShowAutoUpdateToggle={true} />
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<ModelAccessField />
</>
)}
</ModalWrapper>
</Formik>
);
}

View File

@@ -1,35 +1,45 @@
"use client";
import { useState } from "react";
import { useSWRConfig } from "swr";
import { useFormikContext } from "formik";
import { Formik } from "formik";
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
import * as InputLayouts from "@/layouts/input-layouts";
import {
LLMProviderFormProps,
LLMProviderName,
LLMProviderView,
ModelConfiguration,
} from "@/interfaces/llm";
import * as Yup from "yup";
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
import {
useInitialValues,
buildValidationSchema,
buildDefaultInitialValues,
buildDefaultValidationSchema,
buildAvailableModelConfigurations,
buildOnboardingInitialValues,
BaseLLMFormValues,
} from "@/sections/modals/llmConfig/utils";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
import {
submitLLMProvider,
submitOnboardingProvider,
} from "@/sections/modals/llmConfig/svc";
import {
APIKeyField,
DisplayNameField,
ModelAccessField,
ModelSelectionField,
ModalWrapper,
FieldSeparator,
FieldWrapper,
ModelsAccessField,
ModelsField,
SingleDefaultModelField,
LLMConfigurationModalWrapper,
} from "@/sections/modals/llmConfig/shared";
import {
isValidAzureTargetUri,
parseAzureTargetUri,
} from "@/lib/azureTargetUri";
import { toast } from "@/hooks/useToast";
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
const AZURE_PROVIDER_NAME = "azure";
interface AzureModalValues extends BaseLLMFormValues {
api_key: string;
@@ -39,33 +49,6 @@ interface AzureModalValues extends BaseLLMFormValues {
deployment_name?: string;
}
function AzureModelSelection() {
const formikProps = useFormikContext<AzureModalValues>();
return (
<ModelSelectionField
shouldShowAutoUpdateToggle={false}
onAddModel={(modelName) => {
const current = formikProps.values.model_configurations;
if (current.some((m) => m.name === modelName)) return;
const updated = [
...current,
{
name: modelName,
is_visible: true,
max_input_tokens: null,
supports_image_input: false,
supports_reasoning: false,
},
];
formikProps.setFieldValue("model_configurations", updated);
if (!formikProps.values.test_model_name) {
formikProps.setFieldValue("test_model_name", modelName);
}
}}
/>
);
}
function buildTargetUri(existingLlmProvider?: LLMProviderView): string {
if (!existingLlmProvider?.api_base || !existingLlmProvider?.api_version) {
return "";
@@ -100,104 +83,193 @@ export default function AzureModal({
variant = "llm-configuration",
existingLlmProvider,
shouldMarkAsDefault,
open,
onOpenChange,
onSuccess,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
}: LLMProviderFormProps) {
const isOnboarding = variant === "onboarding";
const [isTesting, setIsTesting] = useState(false);
const { mutate } = useSWRConfig();
const { wellKnownLLMProvider } = useWellKnownLLMProvider(AZURE_PROVIDER_NAME);
const onClose = () => onOpenChange?.(false);
const [addedModels, setAddedModels] = useState<ModelConfiguration[]>([]);
const initialValues: AzureModalValues = {
...useInitialValues(
isOnboarding,
LLMProviderName.AZURE,
existingLlmProvider
),
target_uri: buildTargetUri(existingLlmProvider),
} as AzureModalValues;
if (open === false) return null;
const validationSchema = buildValidationSchema(isOnboarding, {
apiKey: true,
extra: {
target_uri: Yup.string()
.required("Target URI is required")
.test(
"valid-target-uri",
"Target URI must be a valid URL with api-version query parameter and either a deployment name in the path or /openai/responses",
(value) => (value ? isValidAzureTargetUri(value) : false)
const onClose = () => {
setAddedModels([]);
onOpenChange?.(false);
};
const baseModelConfigurations = buildAvailableModelConfigurations(
existingLlmProvider,
wellKnownLLMProvider ?? llmDescriptor
);
// Merge base models with any user-added models (dedup by name)
const existingNames = new Set(baseModelConfigurations.map((m) => m.name));
const modelConfigurations = [
...baseModelConfigurations,
...addedModels.filter((m) => !existingNames.has(m.name)),
];
const initialValues: AzureModalValues = isOnboarding
? ({
...buildOnboardingInitialValues(),
name: AZURE_PROVIDER_NAME,
provider: AZURE_PROVIDER_NAME,
api_key: "",
target_uri: "",
default_model_name: "",
} as AzureModalValues)
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
},
});
api_key: existingLlmProvider?.api_key ?? "",
target_uri: buildTargetUri(existingLlmProvider),
};
const validationSchema = isOnboarding
? Yup.object().shape({
api_key: Yup.string().required("API Key is required"),
target_uri: Yup.string()
.required("Target URI is required")
.test(
"valid-target-uri",
"Target URI must be a valid URL with api-version query parameter and either a deployment name in the path or /openai/responses",
(value) => (value ? isValidAzureTargetUri(value) : false)
),
default_model_name: Yup.string().required("Model name is required"),
})
: buildDefaultValidationSchema().shape({
api_key: Yup.string().required("API Key is required"),
target_uri: Yup.string()
.required("Target URI is required")
.test(
"valid-target-uri",
"Target URI must be a valid URL with api-version query parameter and either a deployment name in the path or /openai/responses",
(value) => (value ? isValidAzureTargetUri(value) : false)
),
});
return (
<ModalWrapper
providerName={LLMProviderName.AZURE}
llmProvider={existingLlmProvider}
onClose={onClose}
<Formik
initialValues={initialValues}
validationSchema={validationSchema}
onSubmit={async (values, { setSubmitting, setStatus }) => {
validateOnMount={true}
onSubmit={async (values, { setSubmitting }) => {
const processedValues = processValues(values);
await submitProvider({
analyticsSource: isOnboarding
? LLMProviderConfiguredSource.CHAT_ONBOARDING
: LLMProviderConfiguredSource.ADMIN_PAGE,
providerName: LLMProviderName.AZURE,
values: processedValues,
initialValues,
existingLlmProvider,
shouldMarkAsDefault,
setStatus,
setSubmitting,
onClose,
onSuccess: async () => {
if (onSuccess) {
await onSuccess();
} else {
await refreshLlmProviderCaches(mutate);
toast.success(
existingLlmProvider
? "Provider updated successfully!"
: "Provider enabled successfully!"
);
}
},
});
if (isOnboarding && onboardingState && onboardingActions) {
const modelConfigsToUse =
(wellKnownLLMProvider ?? llmDescriptor)?.known_models ?? [];
await submitOnboardingProvider({
providerName: AZURE_PROVIDER_NAME,
payload: {
...processedValues,
model_configurations: modelConfigsToUse,
},
onboardingState,
onboardingActions,
isCustomProvider: false,
onClose,
setIsSubmitting: setSubmitting,
});
} else {
await submitLLMProvider({
providerName: AZURE_PROVIDER_NAME,
values: processedValues,
initialValues,
modelConfigurations,
existingLlmProvider,
shouldMarkAsDefault,
setIsTesting,
mutate,
onClose,
setSubmitting,
});
}
}}
>
<InputLayouts.FieldPadder>
<InputLayouts.Vertical
name="target_uri"
title="Target URI"
subDescription="Paste your endpoint target URI from Azure OpenAI (including API endpoint base, deployment name, and API version)."
{(formikProps) => (
<LLMConfigurationModalWrapper
providerEndpoint={AZURE_PROVIDER_NAME}
existingProviderName={existingLlmProvider?.name}
onClose={onClose}
isFormValid={formikProps.isValid}
isDirty={formikProps.dirty}
isTesting={isTesting}
isSubmitting={formikProps.isSubmitting}
>
<InputTypeInField
name="target_uri"
placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview"
/>
</InputLayouts.Vertical>
</InputLayouts.FieldPadder>
<FieldWrapper>
<InputLayouts.Vertical
name="target_uri"
title="Target URI"
subDescription="Paste your endpoint target URI from Azure OpenAI (including API endpoint base, deployment name, and API version)."
>
<InputTypeInField
name="target_uri"
placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview"
/>
</InputLayouts.Vertical>
</FieldWrapper>
<APIKeyField providerName="Azure" />
<APIKeyField providerName="Azure" />
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
{!isOnboarding && (
<>
<FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
)}
<FieldSeparator />
{isOnboarding ? (
<SingleDefaultModelField placeholder="E.g. gpt-4o" />
) : (
<ModelsField
modelConfigurations={modelConfigurations}
formikProps={formikProps}
recommendedDefaultModel={null}
shouldShowAutoUpdateToggle={false}
onAddModel={(modelName) => {
const newModel: ModelConfiguration = {
name: modelName,
is_visible: true,
max_input_tokens: null,
supports_image_input: false,
supports_reasoning: false,
};
setAddedModels((prev) => [...prev, newModel]);
const currentSelected =
formikProps.values.selected_model_names ?? [];
formikProps.setFieldValue("selected_model_names", [
...currentSelected,
modelName,
]);
if (!formikProps.values.default_model_name) {
formikProps.setFieldValue("default_model_name", modelName);
}
}}
/>
)}
{!isOnboarding && (
<>
<FieldSeparator />
<ModelsAccessField formikProps={formikProps} />
</>
)}
</LLMConfigurationModalWrapper>
)}
<InputLayouts.FieldSeparator />
<AzureModelSelection />
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<ModelAccessField />
</>
)}
</ModalWrapper>
</Formik>
);
}

View File

@@ -1,8 +1,8 @@
"use client";
import { useEffect } from "react";
import { useState, useEffect } from "react";
import { useSWRConfig } from "swr";
import { useFormikContext } from "formik";
import { Formik, FormikProps } from "formik";
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
import InputSelectField from "@/refresh-components/form/InputSelectField";
import InputSelect from "@/refresh-components/inputs/InputSelect";
@@ -10,22 +10,30 @@ import * as InputLayouts from "@/layouts/input-layouts";
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
import {
LLMProviderFormProps,
LLMProviderName,
LLMProviderView,
ModelConfiguration,
} from "@/interfaces/llm";
import * as Yup from "yup";
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
import {
useInitialValues,
buildValidationSchema,
buildDefaultInitialValues,
buildDefaultValidationSchema,
buildAvailableModelConfigurations,
buildOnboardingInitialValues,
BaseLLMFormValues,
} from "@/sections/modals/llmConfig/utils";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
import {
ModelSelectionField,
submitLLMProvider,
submitOnboardingProvider,
} from "@/sections/modals/llmConfig/svc";
import {
ModelsField,
DisplayNameField,
ModelAccessField,
ModalWrapper,
FieldSeparator,
FieldWrapper,
ModelsAccessField,
SingleDefaultModelField,
LLMConfigurationModalWrapper,
} from "@/sections/modals/llmConfig/shared";
import { fetchBedrockModels } from "@/app/admin/configuration/llm/utils";
import { Card } from "@opal/components";
@@ -33,9 +41,9 @@ import { Section } from "@/layouts/general-layouts";
import { SvgAlertCircle } from "@opal/icons";
import { Content } from "@opal/layouts";
import { toast } from "@/hooks/useToast";
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
import useOnMount from "@/hooks/useOnMount";
const BEDROCK_PROVIDER_NAME = "bedrock";
const AWS_REGION_OPTIONS = [
{ name: "us-east-1", value: "us-east-1" },
{ name: "us-east-2", value: "us-east-2" },
@@ -71,15 +79,26 @@ interface BedrockModalValues extends BaseLLMFormValues {
}
interface BedrockModalInternalsProps {
formikProps: FormikProps<BedrockModalValues>;
existingLlmProvider: LLMProviderView | undefined;
fetchedModels: ModelConfiguration[];
setFetchedModels: (models: ModelConfiguration[]) => void;
modelConfigurations: ModelConfiguration[];
isTesting: boolean;
onClose: () => void;
isOnboarding: boolean;
}
function BedrockModalInternals({
formikProps,
existingLlmProvider,
fetchedModels,
setFetchedModels,
modelConfigurations,
isTesting,
onClose,
isOnboarding,
}: BedrockModalInternalsProps) {
const formikProps = useFormikContext<BedrockModalValues>();
const authMethod = formikProps.values.custom_config?.BEDROCK_AUTH_METHOD;
useEffect(() => {
@@ -96,6 +115,11 @@ function BedrockModalInternals({
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [authMethod]);
const currentModels =
fetchedModels.length > 0
? fetchedModels
: existingLlmProvider?.model_configurations || modelConfigurations;
const isAuthComplete =
authMethod === AUTH_METHOD_IAM ||
(authMethod === AUTH_METHOD_ACCESS_KEY &&
@@ -115,12 +139,12 @@ function BedrockModalInternals({
formikProps.values.custom_config?.AWS_SECRET_ACCESS_KEY,
aws_bearer_token_bedrock:
formikProps.values.custom_config?.AWS_BEARER_TOKEN_BEDROCK,
provider_name: LLMProviderName.BEDROCK,
provider_name: existingLlmProvider?.name,
});
if (error) {
throw new Error(error);
}
formikProps.setFieldValue("model_configurations", models);
setFetchedModels(models);
};
// Auto-fetch models on initial load when editing an existing provider
@@ -135,8 +159,16 @@ function BedrockModalInternals({
});
return (
<>
<InputLayouts.FieldPadder>
<LLMConfigurationModalWrapper
providerEndpoint={BEDROCK_PROVIDER_NAME}
existingProviderName={existingLlmProvider?.name}
onClose={onClose}
isFormValid={formikProps.isValid}
isDirty={formikProps.dirty}
isTesting={isTesting}
isSubmitting={formikProps.isSubmitting}
>
<FieldWrapper>
<Section gap={1}>
<InputLayouts.Vertical
name={FIELD_AWS_REGION_NAME}
@@ -190,7 +222,7 @@ function BedrockModalInternals({
</InputSelect>
</InputLayouts.Vertical>
</Section>
</InputLayouts.FieldPadder>
</FieldWrapper>
{authMethod === AUTH_METHOD_ACCESS_KEY && (
<Card background="light" border="none" padding="sm">
@@ -218,7 +250,7 @@ function BedrockModalInternals({
)}
{authMethod === AUTH_METHOD_IAM && (
<InputLayouts.FieldPadder>
<FieldWrapper>
<Card background="none" border="solid" padding="sm">
<Content
icon={SvgAlertCircle}
@@ -227,7 +259,7 @@ function BedrockModalInternals({
sizePreset="main-ui"
/>
</Card>
</InputLayouts.FieldPadder>
</FieldWrapper>
)}
{authMethod === AUTH_METHOD_LONG_TERM_API_KEY && (
@@ -248,24 +280,32 @@ function BedrockModalInternals({
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
)}
<InputLayouts.FieldSeparator />
<ModelSelectionField
shouldShowAutoUpdateToggle={false}
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
/>
<FieldSeparator />
{isOnboarding ? (
<SingleDefaultModelField placeholder="E.g. us.anthropic.claude-sonnet-4-5-v1" />
) : (
<ModelsField
modelConfigurations={currentModels}
formikProps={formikProps}
recommendedDefaultModel={null}
shouldShowAutoUpdateToggle={false}
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
/>
)}
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<ModelAccessField />
<FieldSeparator />
<ModelsAccessField formikProps={formikProps} />
</>
)}
</>
</LLMConfigurationModalWrapper>
);
}
@@ -273,53 +313,88 @@ export default function BedrockModal({
variant = "llm-configuration",
existingLlmProvider,
shouldMarkAsDefault,
open,
onOpenChange,
onSuccess,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
}: LLMProviderFormProps) {
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
const [isTesting, setIsTesting] = useState(false);
const isOnboarding = variant === "onboarding";
const { mutate } = useSWRConfig();
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
BEDROCK_PROVIDER_NAME
);
if (open === false) return null;
const onClose = () => onOpenChange?.(false);
const initialValues: BedrockModalValues = {
...useInitialValues(
isOnboarding,
LLMProviderName.BEDROCK,
existingLlmProvider
),
custom_config: {
AWS_REGION_NAME:
(existingLlmProvider?.custom_config?.AWS_REGION_NAME as string) ?? "",
BEDROCK_AUTH_METHOD:
(existingLlmProvider?.custom_config?.BEDROCK_AUTH_METHOD as string) ??
"access_key",
AWS_ACCESS_KEY_ID:
(existingLlmProvider?.custom_config?.AWS_ACCESS_KEY_ID as string) ?? "",
AWS_SECRET_ACCESS_KEY:
(existingLlmProvider?.custom_config?.AWS_SECRET_ACCESS_KEY as string) ??
"",
AWS_BEARER_TOKEN_BEDROCK:
(existingLlmProvider?.custom_config
?.AWS_BEARER_TOKEN_BEDROCK as string) ?? "",
},
} as BedrockModalValues;
const modelConfigurations = buildAvailableModelConfigurations(
existingLlmProvider,
wellKnownLLMProvider ?? llmDescriptor
);
const validationSchema = buildValidationSchema(isOnboarding, {
extra: {
custom_config: Yup.object({
AWS_REGION_NAME: Yup.string().required("AWS Region is required"),
}),
},
});
const initialValues: BedrockModalValues = isOnboarding
? ({
...buildOnboardingInitialValues(),
name: BEDROCK_PROVIDER_NAME,
provider: BEDROCK_PROVIDER_NAME,
default_model_name: "",
custom_config: {
AWS_REGION_NAME: "",
BEDROCK_AUTH_METHOD: "access_key",
AWS_ACCESS_KEY_ID: "",
AWS_SECRET_ACCESS_KEY: "",
AWS_BEARER_TOKEN_BEDROCK: "",
},
} as BedrockModalValues)
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
custom_config: {
AWS_REGION_NAME:
(existingLlmProvider?.custom_config?.AWS_REGION_NAME as string) ??
"",
BEDROCK_AUTH_METHOD:
(existingLlmProvider?.custom_config
?.BEDROCK_AUTH_METHOD as string) ?? "access_key",
AWS_ACCESS_KEY_ID:
(existingLlmProvider?.custom_config?.AWS_ACCESS_KEY_ID as string) ??
"",
AWS_SECRET_ACCESS_KEY:
(existingLlmProvider?.custom_config
?.AWS_SECRET_ACCESS_KEY as string) ?? "",
AWS_BEARER_TOKEN_BEDROCK:
(existingLlmProvider?.custom_config
?.AWS_BEARER_TOKEN_BEDROCK as string) ?? "",
},
};
const validationSchema = isOnboarding
? Yup.object().shape({
default_model_name: Yup.string().required("Model name is required"),
custom_config: Yup.object({
AWS_REGION_NAME: Yup.string().required("AWS Region is required"),
}),
})
: buildDefaultValidationSchema().shape({
custom_config: Yup.object({
AWS_REGION_NAME: Yup.string().required("AWS Region is required"),
}),
});
return (
<ModalWrapper
providerName={LLMProviderName.BEDROCK}
llmProvider={existingLlmProvider}
onClose={onClose}
<Formik
initialValues={initialValues}
validationSchema={validationSchema}
onSubmit={async (values, { setSubmitting, setStatus }) => {
validateOnMount={true}
onSubmit={async (values, { setSubmitting }) => {
const filteredCustomConfig = Object.fromEntries(
Object.entries(values.custom_config || {}).filter(([, v]) => v !== "")
);
@@ -332,37 +407,51 @@ export default function BedrockModal({
: undefined,
};
await submitProvider({
analyticsSource: isOnboarding
? LLMProviderConfiguredSource.CHAT_ONBOARDING
: LLMProviderConfiguredSource.ADMIN_PAGE,
providerName: LLMProviderName.BEDROCK,
values: submitValues,
initialValues,
existingLlmProvider,
shouldMarkAsDefault,
setStatus,
setSubmitting,
onClose,
onSuccess: async () => {
if (onSuccess) {
await onSuccess();
} else {
await refreshLlmProviderCaches(mutate);
toast.success(
existingLlmProvider
? "Provider updated successfully!"
: "Provider enabled successfully!"
);
}
},
});
if (isOnboarding && onboardingState && onboardingActions) {
const modelConfigsToUse =
fetchedModels.length > 0 ? fetchedModels : [];
await submitOnboardingProvider({
providerName: BEDROCK_PROVIDER_NAME,
payload: {
...submitValues,
model_configurations: modelConfigsToUse,
},
onboardingState,
onboardingActions,
isCustomProvider: false,
onClose,
setIsSubmitting: setSubmitting,
});
} else {
await submitLLMProvider({
providerName: BEDROCK_PROVIDER_NAME,
values: submitValues,
initialValues,
modelConfigurations:
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
existingLlmProvider,
shouldMarkAsDefault,
setIsTesting,
mutate,
onClose,
setSubmitting,
});
}
}}
>
<BedrockModalInternals
existingLlmProvider={existingLlmProvider}
isOnboarding={isOnboarding}
/>
</ModalWrapper>
{(formikProps) => (
<BedrockModalInternals
formikProps={formikProps}
existingLlmProvider={existingLlmProvider}
fetchedModels={fetchedModels}
setFetchedModels={setFetchedModels}
modelConfigurations={modelConfigurations}
isTesting={isTesting}
onClose={onClose}
isOnboarding={isOnboarding}
/>
)}
</Formik>
);
}

View File

@@ -1,33 +1,45 @@
"use client";
import { useEffect } from "react";
import { useState, useEffect } from "react";
import { markdown } from "@opal/utils";
import { useSWRConfig } from "swr";
import { useFormikContext } from "formik";
import { Formik, FormikProps } from "formik";
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
import * as InputLayouts from "@/layouts/input-layouts";
import {
LLMProviderFormProps,
LLMProviderName,
LLMProviderView,
ModelConfiguration,
} from "@/interfaces/llm";
import { fetchBifrostModels } from "@/app/admin/configuration/llm/utils";
import * as Yup from "yup";
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
import {
useInitialValues,
buildValidationSchema,
buildDefaultInitialValues,
buildDefaultValidationSchema,
buildAvailableModelConfigurations,
buildOnboardingInitialValues,
BaseLLMFormValues,
} from "@/sections/modals/llmConfig/utils";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
import {
APIBaseField,
APIKeyField,
ModelSelectionField,
submitLLMProvider,
submitOnboardingProvider,
} from "@/sections/modals/llmConfig/svc";
import {
ModelsField,
DisplayNameField,
ModelAccessField,
ModalWrapper,
ModelsAccessField,
FieldSeparator,
FieldWrapper,
SingleDefaultModelField,
LLMConfigurationModalWrapper,
} from "@/sections/modals/llmConfig/shared";
import { toast } from "@/hooks/useToast";
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
const BIFROST_PROVIDER_NAME = LLMProviderName.BIFROST;
const DEFAULT_API_BASE = "";
interface BifrostModalValues extends BaseLLMFormValues {
api_key: string;
@@ -35,15 +47,30 @@ interface BifrostModalValues extends BaseLLMFormValues {
}
interface BifrostModalInternalsProps {
formikProps: FormikProps<BifrostModalValues>;
existingLlmProvider: LLMProviderView | undefined;
fetchedModels: ModelConfiguration[];
setFetchedModels: (models: ModelConfiguration[]) => void;
modelConfigurations: ModelConfiguration[];
isTesting: boolean;
onClose: () => void;
isOnboarding: boolean;
}
function BifrostModalInternals({
formikProps,
existingLlmProvider,
fetchedModels,
setFetchedModels,
modelConfigurations,
isTesting,
onClose,
isOnboarding,
}: BifrostModalInternalsProps) {
const formikProps = useFormikContext<BifrostModalValues>();
const currentModels =
fetchedModels.length > 0
? fetchedModels
: existingLlmProvider?.model_configurations || modelConfigurations;
const isFetchDisabled = !formikProps.values.api_base;
@@ -51,12 +78,12 @@ function BifrostModalInternals({
const { models, error } = await fetchBifrostModels({
api_base: formikProps.values.api_base,
api_key: formikProps.values.api_key || undefined,
provider_name: LLMProviderName.BIFROST,
provider_name: existingLlmProvider?.name,
});
if (error) {
throw new Error(error);
}
formikProps.setFieldValue("model_configurations", models);
setFetchedModels(models);
};
// Auto-fetch models on initial load when editing an existing provider
@@ -73,39 +100,69 @@ function BifrostModalInternals({
}, []);
return (
<>
<APIBaseField
subDescription="Paste your Bifrost gateway endpoint URL (including API version)."
placeholder="https://your-bifrost-gateway.com/v1"
/>
<LLMConfigurationModalWrapper
providerEndpoint={LLMProviderName.BIFROST}
existingProviderName={existingLlmProvider?.name}
onClose={onClose}
isFormValid={formikProps.isValid}
isDirty={formikProps.dirty}
isTesting={isTesting}
isSubmitting={formikProps.isSubmitting}
>
<FieldWrapper>
<InputLayouts.Vertical
name="api_base"
title="API Base URL"
subDescription="Paste your Bifrost gateway endpoint URL (including API version)."
>
<InputTypeInField
name="api_base"
placeholder="https://your-bifrost-gateway.com/v1"
/>
</InputLayouts.Vertical>
</FieldWrapper>
<APIKeyField
optional
subDescription={markdown(
"Paste your API key from [Bifrost](https://docs.getbifrost.ai/overview) to access your models."
)}
/>
<FieldWrapper>
<InputLayouts.Vertical
name="api_key"
title="API Key"
suffix="optional"
subDescription={markdown(
"Paste your API key from [Bifrost](https://docs.getbifrost.ai/overview) to access your models."
)}
>
<PasswordInputTypeInField name="api_key" placeholder="API Key" />
</InputLayouts.Vertical>
</FieldWrapper>
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
)}
<InputLayouts.FieldSeparator />
<ModelSelectionField
shouldShowAutoUpdateToggle={false}
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
/>
<FieldSeparator />
{isOnboarding ? (
<SingleDefaultModelField placeholder="E.g. anthropic/claude-sonnet-4-6" />
) : (
<ModelsField
modelConfigurations={currentModels}
formikProps={formikProps}
recommendedDefaultModel={null}
shouldShowAutoUpdateToggle={false}
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
/>
)}
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<ModelAccessField />
<FieldSeparator />
<ModelsAccessField formikProps={formikProps} />
</>
)}
</>
</LLMConfigurationModalWrapper>
);
}
@@ -113,63 +170,109 @@ export default function BifrostModal({
variant = "llm-configuration",
existingLlmProvider,
shouldMarkAsDefault,
open,
onOpenChange,
onSuccess,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
}: LLMProviderFormProps) {
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
const [isTesting, setIsTesting] = useState(false);
const isOnboarding = variant === "onboarding";
const { mutate } = useSWRConfig();
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
BIFROST_PROVIDER_NAME
);
if (open === false) return null;
const onClose = () => onOpenChange?.(false);
const initialValues: BifrostModalValues = useInitialValues(
isOnboarding,
LLMProviderName.BIFROST,
existingLlmProvider
) as BifrostModalValues;
const modelConfigurations = buildAvailableModelConfigurations(
existingLlmProvider,
wellKnownLLMProvider ?? llmDescriptor
);
const validationSchema = buildValidationSchema(isOnboarding, {
apiBase: true,
});
const initialValues: BifrostModalValues = isOnboarding
? ({
...buildOnboardingInitialValues(),
name: BIFROST_PROVIDER_NAME,
provider: BIFROST_PROVIDER_NAME,
api_key: "",
api_base: DEFAULT_API_BASE,
default_model_name: "",
} as BifrostModalValues)
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
api_key: existingLlmProvider?.api_key ?? "",
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
};
const validationSchema = isOnboarding
? Yup.object().shape({
api_base: Yup.string().required("API Base URL is required"),
default_model_name: Yup.string().required("Model name is required"),
})
: buildDefaultValidationSchema().shape({
api_base: Yup.string().required("API Base URL is required"),
});
return (
<ModalWrapper
providerName={LLMProviderName.BIFROST}
llmProvider={existingLlmProvider}
onClose={onClose}
<Formik
initialValues={initialValues}
validationSchema={validationSchema}
onSubmit={async (values, { setSubmitting, setStatus }) => {
await submitProvider({
analyticsSource: isOnboarding
? LLMProviderConfiguredSource.CHAT_ONBOARDING
: LLMProviderConfiguredSource.ADMIN_PAGE,
providerName: LLMProviderName.BIFROST,
values,
initialValues,
existingLlmProvider,
shouldMarkAsDefault,
setStatus,
setSubmitting,
onClose,
onSuccess: async () => {
if (onSuccess) {
await onSuccess();
} else {
await refreshLlmProviderCaches(mutate);
toast.success(
existingLlmProvider
? "Provider updated successfully!"
: "Provider enabled successfully!"
);
}
},
});
validateOnMount={true}
onSubmit={async (values, { setSubmitting }) => {
if (isOnboarding && onboardingState && onboardingActions) {
const modelConfigsToUse =
fetchedModels.length > 0 ? fetchedModels : [];
await submitOnboardingProvider({
providerName: BIFROST_PROVIDER_NAME,
payload: {
...values,
model_configurations: modelConfigsToUse,
},
onboardingState,
onboardingActions,
isCustomProvider: false,
onClose,
setIsSubmitting: setSubmitting,
});
} else {
await submitLLMProvider({
providerName: BIFROST_PROVIDER_NAME,
values,
initialValues,
modelConfigurations:
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
existingLlmProvider,
shouldMarkAsDefault,
setIsTesting,
mutate,
onClose,
setSubmitting,
});
}
}}
>
<BifrostModalInternals
existingLlmProvider={existingLlmProvider}
isOnboarding={isOnboarding}
/>
</ModalWrapper>
{(formikProps) => (
<BifrostModalInternals
formikProps={formikProps}
existingLlmProvider={existingLlmProvider}
fetchedModels={fetchedModels}
setFetchedModels={setFetchedModels}
modelConfigurations={modelConfigurations}
isTesting={isTesting}
onClose={onClose}
isOnboarding={isOnboarding}
/>
)}
</Formik>
);
}

View File

@@ -70,9 +70,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
}
) {
const nameInput = screen.getByPlaceholderText("Display Name");
const providerInput = screen.getByPlaceholderText(
"Provider Name as shown on LiteLLM"
);
const providerInput = screen.getByPlaceholderText("Provider Name");
await user.type(nameInput, options.name);
await user.type(providerInput, options.provider);
@@ -101,7 +99,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
}),
} as Response);
render(<CustomModal onOpenChange={() => {}} />);
render(<CustomModal open={true} onOpenChange={() => {}} />);
await fillBasicFields(user, {
name: "My Custom Provider",
@@ -168,7 +166,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
json: async () => ({ detail: "Invalid API key" }),
} as Response);
render(<CustomModal onOpenChange={() => {}} />);
render(<CustomModal open={true} onOpenChange={() => {}} />);
await fillBasicFields(user, {
name: "Bad Provider",
@@ -246,6 +244,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
render(
<CustomModal
existingLlmProvider={existingProvider}
open={true}
onOpenChange={() => {}}
/>
);
@@ -340,6 +339,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
render(
<CustomModal
existingLlmProvider={existingProvider}
open={true}
onOpenChange={() => {}}
/>
);
@@ -406,7 +406,13 @@ describe("Custom LLM Provider Configuration Workflow", () => {
json: async () => ({}),
} as Response);
render(<CustomModal shouldMarkAsDefault={true} onOpenChange={() => {}} />);
render(
<CustomModal
shouldMarkAsDefault={true}
open={true}
onOpenChange={() => {}}
/>
);
await fillBasicFields(user, {
name: "New Default Provider",
@@ -451,7 +457,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
json: async () => ({ detail: "Database error" }),
} as Response);
render(<CustomModal onOpenChange={() => {}} />);
render(<CustomModal open={true} onOpenChange={() => {}} />);
await fillBasicFields(user, {
name: "Test Provider",
@@ -486,15 +492,13 @@ describe("Custom LLM Provider Configuration Workflow", () => {
json: async () => ({ id: 1, name: "Provider with Custom Config" }),
} as Response);
render(<CustomModal onOpenChange={() => {}} />);
render(<CustomModal open={true} onOpenChange={() => {}} />);
// Fill basic fields
const nameInput = screen.getByPlaceholderText("Display Name");
await user.type(nameInput, "Cloudflare Provider");
const providerInput = screen.getByPlaceholderText(
"Provider Name as shown on LiteLLM"
);
const providerInput = screen.getByPlaceholderText("Provider Name");
await user.type(providerInput, "cloudflare");
// Click "Add Line" button for custom config (aria-label from KeyValueInput)
@@ -504,8 +508,8 @@ describe("Custom LLM Provider Configuration Workflow", () => {
await user.click(addLineButton);
// Fill in custom config key-value pair
const keyInputs = screen.getAllByRole("textbox", { name: /Key \d+/ });
const valueInputs = screen.getAllByRole("textbox", { name: /Value \d+/ });
const keyInputs = screen.getAllByPlaceholderText("Key");
const valueInputs = screen.getAllByPlaceholderText("Value");
await user.type(keyInputs[0]!, "CLOUDFLARE_ACCOUNT_ID");
await user.type(valueInputs[0]!, "my-account-id-123");

View File

@@ -1,22 +1,24 @@
"use client";
import { useState } from "react";
import { useSWRConfig } from "swr";
import { useFormikContext } from "formik";
import {
LLMProviderFormProps,
LLMProviderName,
ModelConfiguration,
} from "@/interfaces/llm";
import { Formik, FormikProps } from "formik";
import { LLMProviderFormProps, ModelConfiguration } from "@/interfaces/llm";
import * as Yup from "yup";
import { useInitialValues } from "@/sections/modals/llmConfig/utils";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
import {
APIKeyField,
APIBaseField,
buildDefaultInitialValues,
buildOnboardingInitialValues,
} from "@/sections/modals/llmConfig/utils";
import {
submitLLMProvider,
submitOnboardingProvider,
} from "@/sections/modals/llmConfig/svc";
import {
DisplayNameField,
ModelAccessField,
ModalWrapper,
FieldSeparator,
ModelsAccessField,
LLMConfigurationModalWrapper,
FieldWrapper,
} from "@/sections/modals/llmConfig/shared";
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
import * as InputLayouts from "@/layouts/input-layouts";
@@ -28,9 +30,7 @@ import InputSelect from "@/refresh-components/inputs/InputSelect";
import Text from "@/refresh-components/texts/Text";
import { Button, Card, EmptyMessageCard } from "@opal/components";
import { SvgMinusCircle, SvgPlusCircle } from "@opal/icons";
import { markdown } from "@opal/utils";
import { toast } from "@/hooks/useToast";
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
import { Content } from "@opal/layouts";
import { Section } from "@/layouts/general-layouts";
@@ -107,10 +107,13 @@ function ModelConfigurationItem({
);
}
function ModelConfigurationList() {
const formikProps = useFormikContext<{
interface ModelConfigurationListProps {
formikProps: FormikProps<{
model_configurations: CustomModelConfiguration[];
}>();
}>;
}
function ModelConfigurationList({ formikProps }: ModelConfigurationListProps) {
const models = formikProps.values.model_configurations;
function handleChange(index: number, next: CustomModelConfiguration) {
@@ -176,68 +179,55 @@ function ModelConfigurationList() {
);
}
function CustomConfigKeyValue() {
const formikProps = useFormikContext<{ custom_config_list: KeyValue[] }>();
return (
<KeyValueInput
items={formikProps.values.custom_config_list}
onChange={(items) =>
formikProps.setFieldValue("custom_config_list", items)
}
addButtonLabel="Add Line"
/>
);
}
// ─── Custom Config Processing ─────────────────────────────────────────────────
function keyValueListToDict(items: KeyValue[]): Record<string, string> {
const result: Record<string, string> = {};
for (const { key, value } of items) {
if (key.trim() !== "") {
result[key] = value;
}
}
return result;
function customConfigProcessing(items: KeyValue[]) {
const customConfig: { [key: string]: string } = {};
items.forEach(({ key, value }) => {
customConfig[key] = value;
});
return customConfig;
}
export default function CustomModal({
variant = "llm-configuration",
existingLlmProvider,
shouldMarkAsDefault,
open,
onOpenChange,
onSuccess,
defaultModelName,
onboardingState,
onboardingActions,
}: LLMProviderFormProps) {
const isOnboarding = variant === "onboarding";
const [isTesting, setIsTesting] = useState(false);
const { mutate } = useSWRConfig();
if (open === false) return null;
const onClose = () => onOpenChange?.(false);
const initialValues = {
...useInitialValues(
isOnboarding,
LLMProviderName.CUSTOM,
existingLlmProvider
...buildDefaultInitialValues(
existingLlmProvider,
undefined,
defaultModelName
),
...(isOnboarding ? buildOnboardingInitialValues() : {}),
provider: existingLlmProvider?.provider ?? "",
api_version: existingLlmProvider?.api_version ?? "",
model_configurations: existingLlmProvider?.model_configurations.map(
(mc) => ({
name: mc.name,
display_name: mc.display_name ?? "",
is_visible: mc.is_visible,
max_input_tokens: mc.max_input_tokens ?? null,
supports_image_input: mc.supports_image_input,
supports_reasoning: mc.supports_reasoning,
})
) ?? [
{
name: "",
display_name: "",
is_visible: true,
max_input_tokens: null,
supports_image_input: false,
supports_reasoning: false,
},
],
custom_config_list: existingLlmProvider?.custom_config
@@ -269,13 +259,11 @@ export default function CustomModal({
});
return (
<ModalWrapper
providerName={LLMProviderName.CUSTOM}
llmProvider={existingLlmProvider}
onClose={onClose}
<Formik
initialValues={initialValues}
validationSchema={validationSchema}
onSubmit={async (values, { setSubmitting, setStatus }) => {
validateOnMount={true}
onSubmit={async (values, { setSubmitting }) => {
setSubmitting(true);
const modelConfigurations = values.model_configurations
@@ -295,127 +283,127 @@ export default function CustomModal({
return;
}
// Always send custom_config as a dict (even empty) so the backend
// preserves it as non-null — this is the signal that the provider was
// created via CustomModal.
const customConfig = keyValueListToDict(values.custom_config_list);
if (isOnboarding && onboardingState && onboardingActions) {
await submitOnboardingProvider({
providerName: values.provider,
payload: {
...values,
model_configurations: modelConfigurations,
custom_config: customConfigProcessing(values.custom_config_list),
},
onboardingState,
onboardingActions,
isCustomProvider: true,
onClose,
setIsSubmitting: setSubmitting,
});
} else {
const selectedModelNames = modelConfigurations.map(
(config) => config.name
);
await submitProvider({
analyticsSource: isOnboarding
? LLMProviderConfiguredSource.CHAT_ONBOARDING
: LLMProviderConfiguredSource.ADMIN_PAGE,
providerName: (values as Record<string, unknown>).provider as string,
values: {
...values,
model_configurations: modelConfigurations,
custom_config: customConfig,
},
initialValues: {
...initialValues,
custom_config: keyValueListToDict(initialValues.custom_config_list),
},
existingLlmProvider,
shouldMarkAsDefault,
isCustomProvider: true,
setStatus,
setSubmitting,
onClose,
onSuccess: async () => {
if (onSuccess) {
await onSuccess();
} else {
await refreshLlmProviderCaches(mutate);
toast.success(
existingLlmProvider
? "Provider updated successfully!"
: "Provider enabled successfully!"
);
}
},
});
await submitLLMProvider({
providerName: values.provider,
values: {
...values,
selected_model_names: selectedModelNames,
custom_config: customConfigProcessing(values.custom_config_list),
},
initialValues: {
...initialValues,
custom_config: customConfigProcessing(
initialValues.custom_config_list
),
},
modelConfigurations,
existingLlmProvider,
shouldMarkAsDefault,
setIsTesting,
mutate,
onClose,
setSubmitting,
});
}
}}
>
{!isOnboarding && (
<InputLayouts.FieldPadder>
<InputLayouts.Vertical
name="provider"
title="Provider Name"
subDescription={markdown(
"Should be one of the providers listed at [LiteLLM](https://docs.litellm.ai/docs/providers)."
)}
>
<InputTypeInField
name="provider"
placeholder="Provider Name as shown on LiteLLM"
variant={existingLlmProvider ? "disabled" : undefined}
/>
</InputLayouts.Vertical>
</InputLayouts.FieldPadder>
)}
<APIBaseField optional />
<InputLayouts.FieldPadder>
<InputLayouts.Vertical
name="api_version"
title="API Version"
suffix="optional"
{(formikProps) => (
<LLMConfigurationModalWrapper
providerEndpoint="custom"
existingProviderName={existingLlmProvider?.name}
onClose={onClose}
isFormValid={formikProps.isValid}
isDirty={formikProps.dirty}
isTesting={isTesting}
isSubmitting={formikProps.isSubmitting}
>
<InputTypeInField name="api_version" />
</InputLayouts.Vertical>
</InputLayouts.FieldPadder>
{!isOnboarding && (
<Section gap={0}>
<DisplayNameField disabled={!!existingLlmProvider} />
<APIKeyField
optional
subDescription="Paste your API key if your model provider requires authentication."
/>
<FieldWrapper>
<InputLayouts.Vertical
name="provider"
title="Provider Name"
subDescription="Should be one of the providers listed at https://docs.litellm.ai/docs/providers."
>
<InputTypeInField
name="provider"
placeholder="Provider Name"
variant={existingLlmProvider ? "disabled" : undefined}
/>
</InputLayouts.Vertical>
</FieldWrapper>
</Section>
)}
<InputLayouts.FieldPadder>
<Section gap={0.75}>
<Content
title="Additional Configs"
description={markdown(
"Add extra properties as needed by the model provider. These are passed to LiteLLM's `completion()` call as [environment variables](https://docs.litellm.ai/docs/set_keys#environment-variables). See [documentation](https://docs.onyx.app/admins/ai_models/custom_inference_provider) for more instructions."
)}
widthVariant="full"
variant="section"
sizePreset="main-content"
/>
<FieldSeparator />
<CustomConfigKeyValue />
</Section>
</InputLayouts.FieldPadder>
<FieldWrapper>
<Section gap={0.75}>
<Content
title="Provider Configs"
description="Add properties as needed by the model provider. This is passed to LiteLLM completion() call as arguments in the environment variable. See LiteLLM documentation for more instructions."
widthVariant="full"
variant="section"
sizePreset="main-content"
/>
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
<KeyValueInput
items={formikProps.values.custom_config_list}
onChange={(items) =>
formikProps.setFieldValue("custom_config_list", items)
}
addButtonLabel="Add Line"
/>
</Section>
</FieldWrapper>
<FieldSeparator />
<Section gap={0.5}>
<FieldWrapper>
<Content
title="Models"
description="List LLM models you wish to use and their configurations for this provider. See full list of models at LiteLLM."
variant="section"
sizePreset="main-content"
widthVariant="full"
/>
</FieldWrapper>
<Card padding="sm">
<ModelConfigurationList formikProps={formikProps as any} />
</Card>
</Section>
{!isOnboarding && (
<>
<FieldSeparator />
<ModelsAccessField formikProps={formikProps} />
</>
)}
</LLMConfigurationModalWrapper>
)}
<InputLayouts.FieldSeparator />
<Section gap={0.5}>
<InputLayouts.FieldPadder>
<Content
title="Models"
description="List LLM models you wish to use and their configurations for this provider. See full list of models at LiteLLM."
variant="section"
sizePreset="main-content"
widthVariant="full"
/>
</InputLayouts.FieldPadder>
<Card padding="sm">
<ModelConfigurationList />
</Card>
</Section>
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<ModelAccessField />
</>
)}
</ModalWrapper>
</Formik>
);
}

View File

@@ -0,0 +1,315 @@
"use client";
import { useCallback, useEffect, useMemo, useState } from "react";
import { useSWRConfig } from "swr";
import { Formik, FormikProps } from "formik";
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
import * as InputLayouts from "@/layouts/input-layouts";
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
import {
LLMProviderFormProps,
LLMProviderName,
LLMProviderView,
ModelConfiguration,
} from "@/interfaces/llm";
import * as Yup from "yup";
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
import {
buildDefaultInitialValues,
buildDefaultValidationSchema,
buildAvailableModelConfigurations,
buildOnboardingInitialValues,
BaseLLMFormValues,
} from "@/sections/modals/llmConfig/utils";
import {
submitLLMProvider,
submitOnboardingProvider,
} from "@/sections/modals/llmConfig/svc";
import {
ModelsField,
DisplayNameField,
ModelsAccessField,
FieldSeparator,
FieldWrapper,
SingleDefaultModelField,
LLMConfigurationModalWrapper,
} from "@/sections/modals/llmConfig/shared";
import { fetchModels } from "@/app/admin/configuration/llm/utils";
import debounce from "lodash/debounce";
import { toast } from "@/hooks/useToast";
const DEFAULT_API_BASE = "http://localhost:1234";
interface LMStudioFormValues extends BaseLLMFormValues {
api_base: string;
custom_config: {
LM_STUDIO_API_KEY?: string;
};
}
interface LMStudioFormInternalsProps {
formikProps: FormikProps<LMStudioFormValues>;
existingLlmProvider: LLMProviderView | undefined;
fetchedModels: ModelConfiguration[];
setFetchedModels: (models: ModelConfiguration[]) => void;
isTesting: boolean;
onClose: () => void;
isOnboarding: boolean;
}
function LMStudioFormInternals({
formikProps,
existingLlmProvider,
fetchedModels,
setFetchedModels,
isTesting,
onClose,
isOnboarding,
}: LMStudioFormInternalsProps) {
const initialApiKey =
(existingLlmProvider?.custom_config?.LM_STUDIO_API_KEY as string) ?? "";
const doFetchModels = useCallback(
(apiBase: string, apiKey: string | undefined, signal: AbortSignal) => {
fetchModels(
LLMProviderName.LM_STUDIO,
{
api_base: apiBase,
custom_config: apiKey ? { LM_STUDIO_API_KEY: apiKey } : {},
api_key_changed: apiKey !== initialApiKey,
name: existingLlmProvider?.name,
},
signal
).then((data) => {
if (signal.aborted) return;
if (data.error) {
toast.error(data.error);
setFetchedModels([]);
return;
}
setFetchedModels(data.models);
});
},
[existingLlmProvider?.name, initialApiKey, setFetchedModels]
);
const debouncedFetchModels = useMemo(
() => debounce(doFetchModels, 500),
[doFetchModels]
);
const apiBase = formikProps.values.api_base;
const apiKey = formikProps.values.custom_config?.LM_STUDIO_API_KEY;
useEffect(() => {
if (apiBase) {
const controller = new AbortController();
debouncedFetchModels(apiBase, apiKey, controller.signal);
return () => {
debouncedFetchModels.cancel();
controller.abort();
};
} else {
setFetchedModels([]);
}
}, [apiBase, apiKey, debouncedFetchModels, setFetchedModels]);
const currentModels =
fetchedModels.length > 0
? fetchedModels
: existingLlmProvider?.model_configurations || [];
return (
<LLMConfigurationModalWrapper
providerEndpoint={LLMProviderName.LM_STUDIO}
existingProviderName={existingLlmProvider?.name}
onClose={onClose}
isFormValid={formikProps.isValid}
isDirty={formikProps.dirty}
isTesting={isTesting}
isSubmitting={formikProps.isSubmitting}
>
<FieldWrapper>
<InputLayouts.Vertical
name="api_base"
title="API Base URL"
subDescription="The base URL for your LM Studio server."
>
<InputTypeInField
name="api_base"
placeholder="Your LM Studio API base URL"
/>
</InputLayouts.Vertical>
</FieldWrapper>
<FieldWrapper>
<InputLayouts.Vertical
name="custom_config.LM_STUDIO_API_KEY"
title="API Key"
subDescription="Optional API key if your LM Studio server requires authentication."
suffix="optional"
>
<PasswordInputTypeInField
name="custom_config.LM_STUDIO_API_KEY"
placeholder="API Key"
/>
</InputLayouts.Vertical>
</FieldWrapper>
{!isOnboarding && (
<>
<FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
)}
<FieldSeparator />
{isOnboarding ? (
<SingleDefaultModelField placeholder="E.g. llama3.1" />
) : (
<ModelsField
modelConfigurations={currentModels}
formikProps={formikProps}
recommendedDefaultModel={null}
shouldShowAutoUpdateToggle={false}
/>
)}
{!isOnboarding && (
<>
<FieldSeparator />
<ModelsAccessField formikProps={formikProps} />
</>
)}
</LLMConfigurationModalWrapper>
);
}
export default function LMStudioForm({
variant = "llm-configuration",
existingLlmProvider,
shouldMarkAsDefault,
open,
onOpenChange,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
}: LLMProviderFormProps) {
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
const [isTesting, setIsTesting] = useState(false);
const isOnboarding = variant === "onboarding";
const { mutate } = useSWRConfig();
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
LLMProviderName.LM_STUDIO
);
if (open === false) return null;
const onClose = () => onOpenChange?.(false);
const modelConfigurations = buildAvailableModelConfigurations(
existingLlmProvider,
wellKnownLLMProvider ?? llmDescriptor
);
const initialValues: LMStudioFormValues = isOnboarding
? ({
...buildOnboardingInitialValues(),
name: LLMProviderName.LM_STUDIO,
provider: LLMProviderName.LM_STUDIO,
api_base: DEFAULT_API_BASE,
default_model_name: "",
custom_config: {
LM_STUDIO_API_KEY: "",
},
} as LMStudioFormValues)
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
custom_config: {
LM_STUDIO_API_KEY:
(existingLlmProvider?.custom_config?.LM_STUDIO_API_KEY as string) ??
"",
},
};
const validationSchema = isOnboarding
? Yup.object().shape({
api_base: Yup.string().required("API Base URL is required"),
default_model_name: Yup.string().required("Model name is required"),
})
: buildDefaultValidationSchema().shape({
api_base: Yup.string().required("API Base URL is required"),
});
return (
<Formik
initialValues={initialValues}
validationSchema={validationSchema}
validateOnMount={true}
onSubmit={async (values, { setSubmitting }) => {
const filteredCustomConfig = Object.fromEntries(
Object.entries(values.custom_config || {}).filter(([, v]) => v !== "")
);
const submitValues = {
...values,
custom_config:
Object.keys(filteredCustomConfig).length > 0
? filteredCustomConfig
: undefined,
};
if (isOnboarding && onboardingState && onboardingActions) {
const modelConfigsToUse =
fetchedModels.length > 0 ? fetchedModels : [];
await submitOnboardingProvider({
providerName: LLMProviderName.LM_STUDIO,
payload: {
...submitValues,
model_configurations: modelConfigsToUse,
},
onboardingState,
onboardingActions,
isCustomProvider: false,
onClose,
setIsSubmitting: setSubmitting,
});
} else {
await submitLLMProvider({
providerName: LLMProviderName.LM_STUDIO,
values: submitValues,
initialValues,
modelConfigurations:
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
existingLlmProvider,
shouldMarkAsDefault,
setIsTesting,
mutate,
onClose,
setSubmitting,
});
}
}}
>
{(formikProps) => (
<LMStudioFormInternals
formikProps={formikProps}
existingLlmProvider={existingLlmProvider}
fetchedModels={fetchedModels}
setFetchedModels={setFetchedModels}
isTesting={isTesting}
onClose={onClose}
isOnboarding={isOnboarding}
/>
)}
</Formik>
);
}

View File

@@ -1,213 +0,0 @@
"use client";
import { useCallback, useEffect, useMemo } from "react";
import { useSWRConfig } from "swr";
import { useFormikContext } from "formik";
import * as InputLayouts from "@/layouts/input-layouts";
import {
LLMProviderFormProps,
LLMProviderName,
LLMProviderView,
} from "@/interfaces/llm";
import {
useInitialValues,
buildValidationSchema,
BaseLLMFormValues as BaseLLMModalValues,
} from "@/sections/modals/llmConfig/utils";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
import {
APIKeyField,
APIBaseField,
ModelSelectionField,
DisplayNameField,
ModelAccessField,
ModalWrapper,
} from "@/sections/modals/llmConfig/shared";
import { fetchModels } from "@/app/admin/configuration/llm/utils";
import debounce from "lodash/debounce";
import { toast } from "@/hooks/useToast";
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
const DEFAULT_API_BASE = "http://localhost:1234";
interface LMStudioModalValues extends BaseLLMModalValues {
api_base: string;
custom_config: {
LM_STUDIO_API_KEY?: string;
};
}
interface LMStudioModalInternalsProps {
existingLlmProvider: LLMProviderView | undefined;
isOnboarding: boolean;
}
function LMStudioModalInternals({
existingLlmProvider,
isOnboarding,
}: LMStudioModalInternalsProps) {
const formikProps = useFormikContext<LMStudioModalValues>();
const initialApiKey = existingLlmProvider?.custom_config?.LM_STUDIO_API_KEY;
const doFetchModels = useCallback(
(apiBase: string, apiKey: string | undefined, signal: AbortSignal) => {
fetchModels(
LLMProviderName.LM_STUDIO,
{
api_base: apiBase,
custom_config: apiKey ? { LM_STUDIO_API_KEY: apiKey } : {},
api_key_changed: apiKey !== initialApiKey,
name: existingLlmProvider?.name,
},
signal
).then((data) => {
if (signal.aborted) return;
if (data.error) {
toast.error(data.error);
formikProps.setFieldValue("model_configurations", []);
return;
}
formikProps.setFieldValue("model_configurations", data.models);
});
},
// eslint-disable-next-line react-hooks/exhaustive-deps
[existingLlmProvider?.name, initialApiKey]
);
const debouncedFetchModels = useMemo(
() => debounce(doFetchModels, 500),
[doFetchModels]
);
const apiBase = formikProps.values.api_base;
const apiKey = formikProps.values.custom_config?.LM_STUDIO_API_KEY;
useEffect(() => {
if (apiBase) {
const controller = new AbortController();
debouncedFetchModels(apiBase, apiKey, controller.signal);
return () => {
debouncedFetchModels.cancel();
controller.abort();
};
} else {
formikProps.setFieldValue("model_configurations", []);
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [apiBase, apiKey, debouncedFetchModels]);
return (
<>
<APIBaseField
subDescription="The base URL for your LM Studio server."
placeholder="Your LM Studio API base URL"
/>
<APIKeyField
optional
subDescription="Optional API key if your LM Studio server requires authentication."
/>
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
)}
<InputLayouts.FieldSeparator />
<ModelSelectionField shouldShowAutoUpdateToggle={false} />
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<ModelAccessField />
</>
)}
</>
);
}
export default function LMStudioModal({
variant = "llm-configuration",
existingLlmProvider,
shouldMarkAsDefault,
onOpenChange,
onSuccess,
}: LLMProviderFormProps) {
const isOnboarding = variant === "onboarding";
const { mutate } = useSWRConfig();
const onClose = () => onOpenChange?.(false);
const initialValues: LMStudioModalValues = {
...useInitialValues(
isOnboarding,
LLMProviderName.LM_STUDIO,
existingLlmProvider
),
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
custom_config: {
LM_STUDIO_API_KEY: existingLlmProvider?.custom_config?.LM_STUDIO_API_KEY,
},
} as LMStudioModalValues;
const validationSchema = buildValidationSchema(isOnboarding, {
apiBase: true,
});
return (
<ModalWrapper
providerName={LLMProviderName.LM_STUDIO}
llmProvider={existingLlmProvider}
onClose={onClose}
initialValues={initialValues}
validationSchema={validationSchema}
onSubmit={async (values, { setSubmitting, setStatus }) => {
const filteredCustomConfig = Object.fromEntries(
Object.entries(values.custom_config || {}).filter(([, v]) => v !== "")
);
const submitValues = {
...values,
custom_config:
Object.keys(filteredCustomConfig).length > 0
? filteredCustomConfig
: undefined,
};
await submitProvider({
analyticsSource: isOnboarding
? LLMProviderConfiguredSource.CHAT_ONBOARDING
: LLMProviderConfiguredSource.ADMIN_PAGE,
providerName: LLMProviderName.LM_STUDIO,
values: submitValues,
initialValues,
existingLlmProvider,
shouldMarkAsDefault,
setStatus,
setSubmitting,
onClose,
onSuccess: async () => {
if (onSuccess) {
await onSuccess();
} else {
await refreshLlmProviderCaches(mutate);
toast.success(
existingLlmProvider
? "Provider updated successfully!"
: "Provider enabled successfully!"
);
}
},
});
}}
>
<LMStudioModalInternals
existingLlmProvider={existingLlmProvider}
isOnboarding={isOnboarding}
/>
</ModalWrapper>
);
}

View File

@@ -1,32 +1,41 @@
"use client";
import { useEffect } from "react";
import { useState, useEffect } from "react";
import { useSWRConfig } from "swr";
import { useFormikContext } from "formik";
import { Formik, FormikProps } from "formik";
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
import * as InputLayouts from "@/layouts/input-layouts";
import {
LLMProviderFormProps,
LLMProviderName,
LLMProviderView,
ModelConfiguration,
} from "@/interfaces/llm";
import { fetchLiteLLMProxyModels } from "@/app/admin/configuration/llm/utils";
import * as Yup from "yup";
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
import {
useInitialValues,
buildValidationSchema,
buildDefaultInitialValues,
buildDefaultValidationSchema,
buildAvailableModelConfigurations,
buildOnboardingInitialValues,
BaseLLMFormValues,
} from "@/sections/modals/llmConfig/utils";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
import {
submitLLMProvider,
submitOnboardingProvider,
} from "@/sections/modals/llmConfig/svc";
import {
APIKeyField,
APIBaseField,
ModelSelectionField,
ModelsField,
DisplayNameField,
ModelAccessField,
ModalWrapper,
ModelsAccessField,
FieldSeparator,
FieldWrapper,
SingleDefaultModelField,
LLMConfigurationModalWrapper,
} from "@/sections/modals/llmConfig/shared";
import { toast } from "@/hooks/useToast";
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
const DEFAULT_API_BASE = "http://localhost:4000";
@@ -36,15 +45,30 @@ interface LiteLLMProxyModalValues extends BaseLLMFormValues {
}
interface LiteLLMProxyModalInternalsProps {
formikProps: FormikProps<LiteLLMProxyModalValues>;
existingLlmProvider: LLMProviderView | undefined;
fetchedModels: ModelConfiguration[];
setFetchedModels: (models: ModelConfiguration[]) => void;
modelConfigurations: ModelConfiguration[];
isTesting: boolean;
onClose: () => void;
isOnboarding: boolean;
}
function LiteLLMProxyModalInternals({
formikProps,
existingLlmProvider,
fetchedModels,
setFetchedModels,
modelConfigurations,
isTesting,
onClose,
isOnboarding,
}: LiteLLMProxyModalInternalsProps) {
const formikProps = useFormikContext<LiteLLMProxyModalValues>();
const currentModels =
fetchedModels.length > 0
? fetchedModels
: existingLlmProvider?.model_configurations || modelConfigurations;
const isFetchDisabled =
!formikProps.values.api_base || !formikProps.values.api_key;
@@ -53,12 +77,12 @@ function LiteLLMProxyModalInternals({
const { models, error } = await fetchLiteLLMProxyModels({
api_base: formikProps.values.api_base,
api_key: formikProps.values.api_key,
provider_name: LLMProviderName.LITELLM_PROXY,
provider_name: existingLlmProvider?.name,
});
if (error) {
throw new Error(error);
}
formikProps.setFieldValue("model_configurations", models);
setFetchedModels(models);
};
// Auto-fetch models on initial load when editing an existing provider
@@ -74,34 +98,58 @@ function LiteLLMProxyModalInternals({
}, []);
return (
<>
<APIBaseField
subDescription="The base URL for your LiteLLM Proxy server."
placeholder="https://your-litellm-proxy.com"
/>
<LLMConfigurationModalWrapper
providerEndpoint={LLMProviderName.LITELLM_PROXY}
existingProviderName={existingLlmProvider?.name}
onClose={onClose}
isFormValid={formikProps.isValid}
isDirty={formikProps.dirty}
isTesting={isTesting}
isSubmitting={formikProps.isSubmitting}
>
<FieldWrapper>
<InputLayouts.Vertical
name="api_base"
title="API Base URL"
subDescription="The base URL for your LiteLLM Proxy server."
>
<InputTypeInField
name="api_base"
placeholder="https://your-litellm-proxy.com"
/>
</InputLayouts.Vertical>
</FieldWrapper>
<APIKeyField providerName="LiteLLM Proxy" />
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
)}
<InputLayouts.FieldSeparator />
<ModelSelectionField
shouldShowAutoUpdateToggle={false}
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
/>
<FieldSeparator />
{isOnboarding ? (
<SingleDefaultModelField placeholder="E.g. gpt-4o" />
) : (
<ModelsField
modelConfigurations={currentModels}
formikProps={formikProps}
recommendedDefaultModel={null}
shouldShowAutoUpdateToggle={false}
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
/>
)}
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<ModelAccessField />
<FieldSeparator />
<ModelsAccessField formikProps={formikProps} />
</>
)}
</>
</LLMConfigurationModalWrapper>
);
}
@@ -109,67 +157,111 @@ export default function LiteLLMProxyModal({
variant = "llm-configuration",
existingLlmProvider,
shouldMarkAsDefault,
open,
onOpenChange,
onSuccess,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
}: LLMProviderFormProps) {
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
const [isTesting, setIsTesting] = useState(false);
const isOnboarding = variant === "onboarding";
const { mutate } = useSWRConfig();
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
LLMProviderName.LITELLM_PROXY
);
if (open === false) return null;
const onClose = () => onOpenChange?.(false);
const initialValues: LiteLLMProxyModalValues = {
...useInitialValues(
isOnboarding,
LLMProviderName.LITELLM_PROXY,
existingLlmProvider
),
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
} as LiteLLMProxyModalValues;
const modelConfigurations = buildAvailableModelConfigurations(
existingLlmProvider,
wellKnownLLMProvider ?? llmDescriptor
);
const validationSchema = buildValidationSchema(isOnboarding, {
apiKey: true,
apiBase: true,
});
const initialValues: LiteLLMProxyModalValues = isOnboarding
? ({
...buildOnboardingInitialValues(),
name: LLMProviderName.LITELLM_PROXY,
provider: LLMProviderName.LITELLM_PROXY,
api_key: "",
api_base: DEFAULT_API_BASE,
default_model_name: "",
} as LiteLLMProxyModalValues)
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
api_key: existingLlmProvider?.api_key ?? "",
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
};
const validationSchema = isOnboarding
? Yup.object().shape({
api_key: Yup.string().required("API Key is required"),
api_base: Yup.string().required("API Base URL is required"),
default_model_name: Yup.string().required("Model name is required"),
})
: buildDefaultValidationSchema().shape({
api_key: Yup.string().required("API Key is required"),
api_base: Yup.string().required("API Base URL is required"),
});
return (
<ModalWrapper
providerName={LLMProviderName.LITELLM_PROXY}
llmProvider={existingLlmProvider}
onClose={onClose}
<Formik
initialValues={initialValues}
validationSchema={validationSchema}
onSubmit={async (values, { setSubmitting, setStatus }) => {
await submitProvider({
analyticsSource: isOnboarding
? LLMProviderConfiguredSource.CHAT_ONBOARDING
: LLMProviderConfiguredSource.ADMIN_PAGE,
providerName: LLMProviderName.LITELLM_PROXY,
values,
initialValues,
existingLlmProvider,
shouldMarkAsDefault,
setStatus,
setSubmitting,
onClose,
onSuccess: async () => {
if (onSuccess) {
await onSuccess();
} else {
await refreshLlmProviderCaches(mutate);
toast.success(
existingLlmProvider
? "Provider updated successfully!"
: "Provider enabled successfully!"
);
}
},
});
validateOnMount={true}
onSubmit={async (values, { setSubmitting }) => {
if (isOnboarding && onboardingState && onboardingActions) {
const modelConfigsToUse =
fetchedModels.length > 0 ? fetchedModels : [];
await submitOnboardingProvider({
providerName: LLMProviderName.LITELLM_PROXY,
payload: {
...values,
model_configurations: modelConfigsToUse,
},
onboardingState,
onboardingActions,
isCustomProvider: false,
onClose,
setIsSubmitting: setSubmitting,
});
} else {
await submitLLMProvider({
providerName: LLMProviderName.LITELLM_PROXY,
values,
initialValues,
modelConfigurations:
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
existingLlmProvider,
shouldMarkAsDefault,
setIsTesting,
mutate,
onClose,
setSubmitting,
});
}
}}
>
<LiteLLMProxyModalInternals
existingLlmProvider={existingLlmProvider}
isOnboarding={isOnboarding}
/>
</ModalWrapper>
{(formikProps) => (
<LiteLLMProxyModalInternals
formikProps={formikProps}
existingLlmProvider={existingLlmProvider}
fetchedModels={fetchedModels}
setFetchedModels={setFetchedModels}
modelConfigurations={modelConfigurations}
isTesting={isTesting}
onClose={onClose}
isOnboarding={isOnboarding}
/>
)}
</Formik>
);
}

View File

@@ -1,44 +1,47 @@
"use client";
import * as Yup from "yup";
import { Dispatch, SetStateAction, useMemo, useState } from "react";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { useSWRConfig } from "swr";
import { useFormikContext } from "formik";
import { Formik, FormikProps } from "formik";
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
import * as InputLayouts from "@/layouts/input-layouts";
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
import {
LLMProviderFormProps,
LLMProviderName,
LLMProviderView,
ModelConfiguration,
} from "@/interfaces/llm";
import * as Yup from "yup";
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
import {
useInitialValues,
buildValidationSchema,
buildDefaultInitialValues,
buildDefaultValidationSchema,
buildAvailableModelConfigurations,
buildOnboardingInitialValues,
BaseLLMFormValues,
} from "@/sections/modals/llmConfig/utils";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
import {
ModelSelectionField,
submitLLMProvider,
submitOnboardingProvider,
} from "@/sections/modals/llmConfig/svc";
import {
ModelsField,
DisplayNameField,
ModelAccessField,
ModalWrapper,
ModelsAccessField,
FieldSeparator,
SingleDefaultModelField,
LLMConfigurationModalWrapper,
} from "@/sections/modals/llmConfig/shared";
import { fetchOllamaModels } from "@/app/admin/configuration/llm/utils";
import debounce from "lodash/debounce";
import Tabs from "@/refresh-components/Tabs";
import { Card } from "@opal/components";
import { toast } from "@/hooks/useToast";
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
import useOnMount from "@/hooks/useOnMount";
const OLLAMA_PROVIDER_NAME = "ollama_chat";
const DEFAULT_API_BASE = "http://127.0.0.1:11434";
const CLOUD_API_BASE = "https://ollama.com";
enum Tab {
TAB_SELF_HOSTED = "self-hosted",
TAB_CLOUD = "cloud",
}
const TAB_SELF_HOSTED = "self-hosted";
const TAB_CLOUD = "cloud";
interface OllamaModalValues extends BaseLLMFormValues {
api_base: string;
@@ -48,67 +51,104 @@ interface OllamaModalValues extends BaseLLMFormValues {
}
interface OllamaModalInternalsProps {
formikProps: FormikProps<OllamaModalValues>;
existingLlmProvider: LLMProviderView | undefined;
fetchedModels: ModelConfiguration[];
setFetchedModels: (models: ModelConfiguration[]) => void;
isTesting: boolean;
onClose: () => void;
isOnboarding: boolean;
tab: Tab;
setTab: Dispatch<SetStateAction<Tab>>;
}
function OllamaModalInternals({
formikProps,
existingLlmProvider,
fetchedModels,
setFetchedModels,
isTesting,
onClose,
isOnboarding,
tab,
setTab,
}: OllamaModalInternalsProps) {
const formikProps = useFormikContext<OllamaModalValues>();
const isInitialMount = useRef(true);
const isFetchDisabled = useMemo(
() =>
tab === Tab.TAB_SELF_HOSTED
? !formikProps.values.api_base
: !formikProps.values.custom_config.OLLAMA_API_KEY,
[tab, formikProps]
const doFetchModels = useCallback(
(apiBase: string, signal: AbortSignal) => {
fetchOllamaModels({
api_base: apiBase,
provider_name: existingLlmProvider?.name,
signal,
}).then((data) => {
if (signal.aborted) return;
if (data.error) {
toast.error(data.error);
setFetchedModels([]);
return;
}
setFetchedModels(data.models);
});
},
[existingLlmProvider?.name, setFetchedModels]
);
const handleFetchModels = async (signal?: AbortSignal) => {
// Only Ollama cloud accepts API key
const apiBase = formikProps.values.custom_config?.OLLAMA_API_KEY
? CLOUD_API_BASE
: formikProps.values.api_base;
const { models, error } = await fetchOllamaModels({
api_base: apiBase,
provider_name: existingLlmProvider?.name,
signal,
});
if (signal?.aborted) return;
if (error) {
throw new Error(error);
}
formikProps.setFieldValue("model_configurations", models);
};
const debouncedFetchModels = useMemo(
() => debounce(doFetchModels, 500),
[doFetchModels]
);
// Auto-fetch models on initial load when editing an existing provider
useOnMount(() => {
if (existingLlmProvider) {
handleFetchModels().catch((err) => {
toast.error(
err instanceof Error ? err.message : "Failed to fetch models"
);
});
// Skip the initial fetch for new providers — api_base starts with a default
// value, which would otherwise trigger a fetch before the user has done
// anything. Existing providers should still auto-fetch on mount.
useEffect(() => {
if (isInitialMount.current) {
isInitialMount.current = false;
if (!existingLlmProvider) return;
}
});
if (formikProps.values.api_base) {
const controller = new AbortController();
debouncedFetchModels(formikProps.values.api_base, controller.signal);
return () => {
debouncedFetchModels.cancel();
controller.abort();
};
} else {
setFetchedModels([]);
}
}, [
formikProps.values.api_base,
debouncedFetchModels,
setFetchedModels,
existingLlmProvider,
]);
const currentModels =
fetchedModels.length > 0
? fetchedModels
: existingLlmProvider?.model_configurations || [];
const hasApiKey = !!formikProps.values.custom_config?.OLLAMA_API_KEY;
const defaultTab =
existingLlmProvider && hasApiKey ? TAB_CLOUD : TAB_SELF_HOSTED;
return (
<>
<LLMConfigurationModalWrapper
providerEndpoint={OLLAMA_PROVIDER_NAME}
existingProviderName={existingLlmProvider?.name}
onClose={onClose}
isFormValid={formikProps.isValid}
isDirty={formikProps.dirty}
isTesting={isTesting}
isSubmitting={formikProps.isSubmitting}
>
<Card background="light" border="none" padding="sm">
<Tabs value={tab} onValueChange={(value) => setTab(value as Tab)}>
<Tabs defaultValue={defaultTab}>
<Tabs.List>
<Tabs.Trigger value={Tab.TAB_SELF_HOSTED}>
<Tabs.Trigger value={TAB_SELF_HOSTED}>
Self-hosted Ollama
</Tabs.Trigger>
<Tabs.Trigger value={Tab.TAB_CLOUD}>Ollama Cloud</Tabs.Trigger>
<Tabs.Trigger value={TAB_CLOUD}>Ollama Cloud</Tabs.Trigger>
</Tabs.List>
<Tabs.Content value={Tab.TAB_SELF_HOSTED} padding={0}>
<Tabs.Content value={TAB_SELF_HOSTED}>
<InputLayouts.Vertical
name="api_base"
title="API Base URL"
@@ -121,7 +161,7 @@ function OllamaModalInternals({
</InputLayouts.Vertical>
</Tabs.Content>
<Tabs.Content value={Tab.TAB_CLOUD}>
<Tabs.Content value={TAB_CLOUD}>
<InputLayouts.Vertical
name="custom_config.OLLAMA_API_KEY"
title="API Key"
@@ -138,24 +178,31 @@ function OllamaModalInternals({
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
)}
<InputLayouts.FieldSeparator />
<ModelSelectionField
shouldShowAutoUpdateToggle={false}
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
/>
<FieldSeparator />
{isOnboarding ? (
<SingleDefaultModelField placeholder="E.g. llama3.1" />
) : (
<ModelsField
modelConfigurations={currentModels}
formikProps={formikProps}
recommendedDefaultModel={null}
shouldShowAutoUpdateToggle={false}
/>
)}
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<ModelAccessField />
<FieldSeparator />
<ModelsAccessField formikProps={formikProps} />
</>
)}
</>
</LLMConfigurationModalWrapper>
);
}
@@ -163,102 +210,125 @@ export default function OllamaModal({
variant = "llm-configuration",
existingLlmProvider,
shouldMarkAsDefault,
open,
onOpenChange,
onSuccess,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
}: LLMProviderFormProps) {
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
const [isTesting, setIsTesting] = useState(false);
const isOnboarding = variant === "onboarding";
const { mutate } = useSWRConfig();
const apiKey = existingLlmProvider?.custom_config?.OLLAMA_API_KEY;
const defaultTab =
existingLlmProvider && !!apiKey ? Tab.TAB_CLOUD : Tab.TAB_SELF_HOSTED;
const [tab, setTab] = useState<Tab>(defaultTab);
const { wellKnownLLMProvider } =
useWellKnownLLMProvider(OLLAMA_PROVIDER_NAME);
if (open === false) return null;
const onClose = () => onOpenChange?.(false);
const initialValues: OllamaModalValues = {
...useInitialValues(
isOnboarding,
LLMProviderName.OLLAMA_CHAT,
existingLlmProvider
),
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
custom_config: {
OLLAMA_API_KEY: apiKey,
},
} as OllamaModalValues;
const validationSchema = useMemo(
() =>
buildValidationSchema(isOnboarding, {
apiBase: tab === Tab.TAB_SELF_HOSTED,
extra:
tab === Tab.TAB_CLOUD
? {
custom_config: Yup.object({
OLLAMA_API_KEY: Yup.string().required("API Key is required"),
}),
}
: undefined,
}),
[tab, isOnboarding]
const modelConfigurations = buildAvailableModelConfigurations(
existingLlmProvider,
wellKnownLLMProvider ?? llmDescriptor
);
const initialValues: OllamaModalValues = isOnboarding
? ({
...buildOnboardingInitialValues(),
name: OLLAMA_PROVIDER_NAME,
provider: OLLAMA_PROVIDER_NAME,
api_base: DEFAULT_API_BASE,
default_model_name: "",
custom_config: {
OLLAMA_API_KEY: "",
},
} as OllamaModalValues)
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
custom_config: {
OLLAMA_API_KEY:
(existingLlmProvider?.custom_config?.OLLAMA_API_KEY as string) ??
"",
},
};
const validationSchema = isOnboarding
? Yup.object().shape({
api_base: Yup.string().required("API Base URL is required"),
default_model_name: Yup.string().required("Model name is required"),
})
: buildDefaultValidationSchema().shape({
api_base: Yup.string().required("API Base URL is required"),
});
return (
<ModalWrapper
providerName={LLMProviderName.OLLAMA_CHAT}
llmProvider={existingLlmProvider}
onClose={onClose}
<Formik
initialValues={initialValues}
validationSchema={validationSchema}
onSubmit={async (values, { setSubmitting, setStatus }) => {
validateOnMount={true}
onSubmit={async (values, { setSubmitting }) => {
const filteredCustomConfig = Object.fromEntries(
Object.entries(values.custom_config || {}).filter(([, v]) => v !== "")
);
const submitValues = {
...values,
api_base: filteredCustomConfig.OLLAMA_API_KEY
? CLOUD_API_BASE
: values.api_base,
custom_config:
Object.keys(filteredCustomConfig).length > 0
? filteredCustomConfig
: undefined,
};
await submitProvider({
analyticsSource: isOnboarding
? LLMProviderConfiguredSource.CHAT_ONBOARDING
: LLMProviderConfiguredSource.ADMIN_PAGE,
providerName: LLMProviderName.OLLAMA_CHAT,
values: submitValues,
initialValues,
existingLlmProvider,
shouldMarkAsDefault,
setStatus,
setSubmitting,
onClose,
onSuccess: async () => {
if (onSuccess) {
await onSuccess();
} else {
await refreshLlmProviderCaches(mutate);
toast.success(
existingLlmProvider
? "Provider updated successfully!"
: "Provider enabled successfully!"
);
}
},
});
if (isOnboarding && onboardingState && onboardingActions) {
const modelConfigsToUse =
fetchedModels.length > 0 ? fetchedModels : [];
await submitOnboardingProvider({
providerName: OLLAMA_PROVIDER_NAME,
payload: {
...submitValues,
model_configurations: modelConfigsToUse,
},
onboardingState,
onboardingActions,
isCustomProvider: false,
onClose,
setIsSubmitting: setSubmitting,
});
} else {
await submitLLMProvider({
providerName: OLLAMA_PROVIDER_NAME,
values: submitValues,
initialValues,
modelConfigurations:
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
existingLlmProvider,
shouldMarkAsDefault,
setIsTesting,
mutate,
onClose,
setSubmitting,
});
}
}}
>
<OllamaModalInternals
existingLlmProvider={existingLlmProvider}
isOnboarding={isOnboarding}
tab={tab}
setTab={setTab}
/>
</ModalWrapper>
{(formikProps) => (
<OllamaModalInternals
formikProps={formikProps}
existingLlmProvider={existingLlmProvider}
fetchedModels={fetchedModels}
setFetchedModels={setFetchedModels}
isTesting={isTesting}
onClose={onClose}
isOnboarding={isOnboarding}
/>
)}
</Formik>
);
}

View File

@@ -1,174 +0,0 @@
"use client";
import { useEffect } from "react";
import { markdown } from "@opal/utils";
import { useSWRConfig } from "swr";
import { useFormikContext } from "formik";
import * as InputLayouts from "@/layouts/input-layouts";
import {
LLMProviderFormProps,
LLMProviderName,
LLMProviderView,
} from "@/interfaces/llm";
import { fetchOpenAICompatibleModels } from "@/app/admin/configuration/llm/utils";
import {
useInitialValues,
buildValidationSchema,
BaseLLMFormValues,
} from "@/sections/modals/llmConfig/utils";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
import {
APIBaseField,
APIKeyField,
ModelSelectionField,
DisplayNameField,
ModelAccessField,
ModalWrapper,
} from "@/sections/modals/llmConfig/shared";
import { toast } from "@/hooks/useToast";
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
interface OpenAICompatibleModalValues extends BaseLLMFormValues {
api_key: string;
api_base: string;
}
interface OpenAICompatibleModalInternalsProps {
existingLlmProvider: LLMProviderView | undefined;
isOnboarding: boolean;
}
function OpenAICompatibleModalInternals({
existingLlmProvider,
isOnboarding,
}: OpenAICompatibleModalInternalsProps) {
const formikProps = useFormikContext<OpenAICompatibleModalValues>();
const isFetchDisabled = !formikProps.values.api_base;
const handleFetchModels = async () => {
const { models, error } = await fetchOpenAICompatibleModels({
api_base: formikProps.values.api_base,
api_key: formikProps.values.api_key || undefined,
provider_name: existingLlmProvider?.name,
});
if (error) {
throw new Error(error);
}
formikProps.setFieldValue("model_configurations", models);
};
// Auto-fetch models on initial load when editing an existing provider
useEffect(() => {
if (existingLlmProvider && !isFetchDisabled) {
handleFetchModels().catch((err) => {
toast.error(
err instanceof Error ? err.message : "Failed to fetch models"
);
});
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
return (
<>
<APIBaseField
subDescription="The base URL of your OpenAI-compatible server."
placeholder="http://localhost:8000/v1"
/>
<APIKeyField
optional
subDescription={markdown(
"Provide an API key if your server requires authentication."
)}
/>
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
)}
<InputLayouts.FieldSeparator />
<ModelSelectionField
shouldShowAutoUpdateToggle={false}
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
/>
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<ModelAccessField />
</>
)}
</>
);
}
export default function OpenAICompatibleModal({
variant = "llm-configuration",
existingLlmProvider,
shouldMarkAsDefault,
onOpenChange,
onSuccess,
}: LLMProviderFormProps) {
const isOnboarding = variant === "onboarding";
const { mutate } = useSWRConfig();
const onClose = () => onOpenChange?.(false);
const initialValues = useInitialValues(
isOnboarding,
LLMProviderName.OPENAI_COMPATIBLE,
existingLlmProvider
) as OpenAICompatibleModalValues;
const validationSchema = buildValidationSchema(isOnboarding, {
apiBase: true,
});
return (
<ModalWrapper
providerName={LLMProviderName.OPENAI_COMPATIBLE}
llmProvider={existingLlmProvider}
onClose={onClose}
initialValues={initialValues}
validationSchema={validationSchema}
onSubmit={async (values, { setSubmitting, setStatus }) => {
await submitProvider({
analyticsSource: isOnboarding
? LLMProviderConfiguredSource.CHAT_ONBOARDING
: LLMProviderConfiguredSource.ADMIN_PAGE,
providerName: LLMProviderName.OPENAI_COMPATIBLE,
values,
initialValues,
existingLlmProvider,
shouldMarkAsDefault,
setStatus,
setSubmitting,
onClose,
onSuccess: async () => {
if (onSuccess) {
await onSuccess();
} else {
await refreshLlmProviderCaches(mutate);
toast.success(
existingLlmProvider
? "Provider updated successfully!"
: "Provider enabled successfully!"
);
}
},
});
}}
>
<OpenAICompatibleModalInternals
existingLlmProvider={existingLlmProvider}
isOnboarding={isOnboarding}
/>
</ModalWrapper>
);
}

View File

@@ -1,99 +1,175 @@
"use client";
import { useState } from "react";
import { useSWRConfig } from "swr";
import { LLMProviderFormProps, LLMProviderName } from "@/interfaces/llm";
import { Formik } from "formik";
import { LLMProviderFormProps } from "@/interfaces/llm";
import * as Yup from "yup";
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
import {
useInitialValues,
buildValidationSchema,
buildDefaultInitialValues,
buildDefaultValidationSchema,
buildAvailableModelConfigurations,
buildOnboardingInitialValues,
} from "@/sections/modals/llmConfig/utils";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
import {
submitLLMProvider,
submitOnboardingProvider,
} from "@/sections/modals/llmConfig/svc";
import {
APIKeyField,
ModelSelectionField,
ModelsField,
DisplayNameField,
ModelAccessField,
ModalWrapper,
FieldSeparator,
ModelsAccessField,
SingleDefaultModelField,
LLMConfigurationModalWrapper,
} from "@/sections/modals/llmConfig/shared";
import * as InputLayouts from "@/layouts/input-layouts";
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
import { toast } from "@/hooks/useToast";
const OPENAI_PROVIDER_NAME = "openai";
const DEFAULT_DEFAULT_MODEL_NAME = "gpt-5.2";
export default function OpenAIModal({
variant = "llm-configuration",
existingLlmProvider,
shouldMarkAsDefault,
open,
onOpenChange,
onSuccess,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
}: LLMProviderFormProps) {
const isOnboarding = variant === "onboarding";
const [isTesting, setIsTesting] = useState(false);
const { mutate } = useSWRConfig();
const { wellKnownLLMProvider } =
useWellKnownLLMProvider(OPENAI_PROVIDER_NAME);
if (open === false) return null;
const onClose = () => onOpenChange?.(false);
const initialValues = useInitialValues(
isOnboarding,
LLMProviderName.OPENAI,
existingLlmProvider
const modelConfigurations = buildAvailableModelConfigurations(
existingLlmProvider,
wellKnownLLMProvider ?? llmDescriptor
);
const validationSchema = buildValidationSchema(isOnboarding, {
apiKey: true,
});
const initialValues = isOnboarding
? {
...buildOnboardingInitialValues(),
name: OPENAI_PROVIDER_NAME,
provider: OPENAI_PROVIDER_NAME,
api_key: "",
default_model_name: DEFAULT_DEFAULT_MODEL_NAME,
}
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
api_key: existingLlmProvider?.api_key ?? "",
default_model_name:
(defaultModelName &&
modelConfigurations.some((m) => m.name === defaultModelName)
? defaultModelName
: undefined) ??
wellKnownLLMProvider?.recommended_default_model?.name ??
DEFAULT_DEFAULT_MODEL_NAME,
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,
};
const validationSchema = isOnboarding
? Yup.object().shape({
api_key: Yup.string().required("API Key is required"),
default_model_name: Yup.string().required("Model name is required"),
})
: buildDefaultValidationSchema().shape({
api_key: Yup.string().required("API Key is required"),
});
return (
<ModalWrapper
providerName={LLMProviderName.OPENAI}
llmProvider={existingLlmProvider}
onClose={onClose}
<Formik
initialValues={initialValues}
validationSchema={validationSchema}
onSubmit={async (values, { setSubmitting, setStatus }) => {
await submitProvider({
analyticsSource: isOnboarding
? LLMProviderConfiguredSource.CHAT_ONBOARDING
: LLMProviderConfiguredSource.ADMIN_PAGE,
providerName: LLMProviderName.OPENAI,
values,
initialValues,
existingLlmProvider,
shouldMarkAsDefault,
setStatus,
setSubmitting,
onClose,
onSuccess: async () => {
if (onSuccess) {
await onSuccess();
} else {
await refreshLlmProviderCaches(mutate);
toast.success(
existingLlmProvider
? "Provider updated successfully!"
: "Provider enabled successfully!"
);
}
},
});
validateOnMount={true}
onSubmit={async (values, { setSubmitting }) => {
if (isOnboarding && onboardingState && onboardingActions) {
const modelConfigsToUse =
(wellKnownLLMProvider ?? llmDescriptor)?.known_models ?? [];
await submitOnboardingProvider({
providerName: OPENAI_PROVIDER_NAME,
payload: {
...values,
model_configurations: modelConfigsToUse,
is_auto_mode:
values.default_model_name === DEFAULT_DEFAULT_MODEL_NAME,
},
onboardingState,
onboardingActions,
isCustomProvider: false,
onClose,
setIsSubmitting: setSubmitting,
});
} else {
await submitLLMProvider({
providerName: OPENAI_PROVIDER_NAME,
values,
initialValues,
modelConfigurations,
existingLlmProvider,
shouldMarkAsDefault,
setIsTesting,
mutate,
onClose,
setSubmitting,
});
}
}}
>
<APIKeyField providerName="OpenAI" />
{(formikProps) => (
<LLMConfigurationModalWrapper
providerEndpoint={OPENAI_PROVIDER_NAME}
existingProviderName={existingLlmProvider?.name}
onClose={onClose}
isFormValid={formikProps.isValid}
isDirty={formikProps.dirty}
isTesting={isTesting}
isSubmitting={formikProps.isSubmitting}
>
<APIKeyField providerName="OpenAI" />
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
{!isOnboarding && (
<>
<FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
)}
<FieldSeparator />
{isOnboarding ? (
<SingleDefaultModelField placeholder="E.g. gpt-5.2" />
) : (
<ModelsField
modelConfigurations={modelConfigurations}
formikProps={formikProps}
recommendedDefaultModel={
wellKnownLLMProvider?.recommended_default_model ?? null
}
shouldShowAutoUpdateToggle={true}
/>
)}
{!isOnboarding && (
<>
<FieldSeparator />
<ModelsAccessField formikProps={formikProps} />
</>
)}
</LLMConfigurationModalWrapper>
)}
<InputLayouts.FieldSeparator />
<ModelSelectionField shouldShowAutoUpdateToggle={true} />
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<ModelAccessField />
</>
)}
</ModalWrapper>
</Formik>
);
}

View File

@@ -1,50 +1,73 @@
"use client";
import { useEffect } from "react";
import { useState, useEffect } from "react";
import { useSWRConfig } from "swr";
import { useFormikContext } from "formik";
import { Formik, FormikProps } from "formik";
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
import * as InputLayouts from "@/layouts/input-layouts";
import {
LLMProviderFormProps,
LLMProviderName,
LLMProviderView,
ModelConfiguration,
} from "@/interfaces/llm";
import { fetchOpenRouterModels } from "@/app/admin/configuration/llm/utils";
import * as Yup from "yup";
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
import {
useInitialValues,
buildValidationSchema,
buildDefaultInitialValues,
buildDefaultValidationSchema,
buildAvailableModelConfigurations,
buildOnboardingInitialValues,
BaseLLMFormValues,
} from "@/sections/modals/llmConfig/utils";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
import {
submitLLMProvider,
submitOnboardingProvider,
} from "@/sections/modals/llmConfig/svc";
import {
APIKeyField,
APIBaseField,
ModelSelectionField,
ModelsField,
DisplayNameField,
ModelAccessField,
ModalWrapper,
ModelsAccessField,
FieldSeparator,
FieldWrapper,
SingleDefaultModelField,
LLMConfigurationModalWrapper,
} from "@/sections/modals/llmConfig/shared";
import { toast } from "@/hooks/useToast";
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
const OPENROUTER_PROVIDER_NAME = "openrouter";
const DEFAULT_API_BASE = "https://openrouter.ai/api/v1";
interface OpenRouterModalValues extends BaseLLMFormValues {
api_key: string;
api_base: string;
}
interface OpenRouterModalInternalsProps {
formikProps: FormikProps<OpenRouterModalValues>;
existingLlmProvider: LLMProviderView | undefined;
fetchedModels: ModelConfiguration[];
setFetchedModels: (models: ModelConfiguration[]) => void;
modelConfigurations: ModelConfiguration[];
isTesting: boolean;
onClose: () => void;
isOnboarding: boolean;
}
function OpenRouterModalInternals({
formikProps,
existingLlmProvider,
fetchedModels,
setFetchedModels,
modelConfigurations,
isTesting,
onClose,
isOnboarding,
}: OpenRouterModalInternalsProps) {
const formikProps = useFormikContext<OpenRouterModalValues>();
const currentModels =
fetchedModels.length > 0
? fetchedModels
: existingLlmProvider?.model_configurations || modelConfigurations;
const isFetchDisabled =
!formikProps.values.api_base || !formikProps.values.api_key;
@@ -53,12 +76,12 @@ function OpenRouterModalInternals({
const { models, error } = await fetchOpenRouterModels({
api_base: formikProps.values.api_base,
api_key: formikProps.values.api_key,
provider_name: LLMProviderName.OPENROUTER,
provider_name: existingLlmProvider?.name,
});
if (error) {
throw new Error(error);
}
formikProps.setFieldValue("model_configurations", models);
setFetchedModels(models);
};
// Auto-fetch models on initial load when editing an existing provider
@@ -74,34 +97,58 @@ function OpenRouterModalInternals({
}, []);
return (
<>
<APIBaseField
subDescription="Paste your OpenRouter-compatible endpoint URL or use OpenRouter API directly."
placeholder="Your OpenRouter base URL"
/>
<LLMConfigurationModalWrapper
providerEndpoint={OPENROUTER_PROVIDER_NAME}
existingProviderName={existingLlmProvider?.name}
onClose={onClose}
isFormValid={formikProps.isValid}
isDirty={formikProps.dirty}
isTesting={isTesting}
isSubmitting={formikProps.isSubmitting}
>
<FieldWrapper>
<InputLayouts.Vertical
name="api_base"
title="API Base URL"
subDescription="Paste your OpenRouter-compatible endpoint URL or use OpenRouter API directly."
>
<InputTypeInField
name="api_base"
placeholder="Your OpenRouter base URL"
/>
</InputLayouts.Vertical>
</FieldWrapper>
<APIKeyField providerName="OpenRouter" />
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
)}
<InputLayouts.FieldSeparator />
<ModelSelectionField
shouldShowAutoUpdateToggle={false}
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
/>
<FieldSeparator />
{isOnboarding ? (
<SingleDefaultModelField placeholder="E.g. openai/gpt-4o" />
) : (
<ModelsField
modelConfigurations={currentModels}
formikProps={formikProps}
recommendedDefaultModel={null}
shouldShowAutoUpdateToggle={false}
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
/>
)}
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<ModelAccessField />
<FieldSeparator />
<ModelsAccessField formikProps={formikProps} />
</>
)}
</>
</LLMConfigurationModalWrapper>
);
}
@@ -109,67 +156,111 @@ export default function OpenRouterModal({
variant = "llm-configuration",
existingLlmProvider,
shouldMarkAsDefault,
open,
onOpenChange,
onSuccess,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
}: LLMProviderFormProps) {
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
const [isTesting, setIsTesting] = useState(false);
const isOnboarding = variant === "onboarding";
const { mutate } = useSWRConfig();
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
OPENROUTER_PROVIDER_NAME
);
if (open === false) return null;
const onClose = () => onOpenChange?.(false);
const initialValues: OpenRouterModalValues = {
...useInitialValues(
isOnboarding,
LLMProviderName.OPENROUTER,
existingLlmProvider
),
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
} as OpenRouterModalValues;
const modelConfigurations = buildAvailableModelConfigurations(
existingLlmProvider,
wellKnownLLMProvider ?? llmDescriptor
);
const validationSchema = buildValidationSchema(isOnboarding, {
apiKey: true,
apiBase: true,
});
const initialValues: OpenRouterModalValues = isOnboarding
? ({
...buildOnboardingInitialValues(),
name: OPENROUTER_PROVIDER_NAME,
provider: OPENROUTER_PROVIDER_NAME,
api_key: "",
api_base: DEFAULT_API_BASE,
default_model_name: "",
} as OpenRouterModalValues)
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
api_key: existingLlmProvider?.api_key ?? "",
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
};
const validationSchema = isOnboarding
? Yup.object().shape({
api_key: Yup.string().required("API Key is required"),
api_base: Yup.string().required("API Base URL is required"),
default_model_name: Yup.string().required("Model name is required"),
})
: buildDefaultValidationSchema().shape({
api_key: Yup.string().required("API Key is required"),
api_base: Yup.string().required("API Base URL is required"),
});
return (
<ModalWrapper
providerName={LLMProviderName.OPENROUTER}
llmProvider={existingLlmProvider}
onClose={onClose}
<Formik
initialValues={initialValues}
validationSchema={validationSchema}
onSubmit={async (values, { setSubmitting, setStatus }) => {
await submitProvider({
analyticsSource: isOnboarding
? LLMProviderConfiguredSource.CHAT_ONBOARDING
: LLMProviderConfiguredSource.ADMIN_PAGE,
providerName: LLMProviderName.OPENROUTER,
values,
initialValues,
existingLlmProvider,
shouldMarkAsDefault,
setStatus,
setSubmitting,
onClose,
onSuccess: async () => {
if (onSuccess) {
await onSuccess();
} else {
await refreshLlmProviderCaches(mutate);
toast.success(
existingLlmProvider
? "Provider updated successfully!"
: "Provider enabled successfully!"
);
}
},
});
validateOnMount={true}
onSubmit={async (values, { setSubmitting }) => {
if (isOnboarding && onboardingState && onboardingActions) {
const modelConfigsToUse =
fetchedModels.length > 0 ? fetchedModels : [];
await submitOnboardingProvider({
providerName: OPENROUTER_PROVIDER_NAME,
payload: {
...values,
model_configurations: modelConfigsToUse,
},
onboardingState,
onboardingActions,
isCustomProvider: false,
onClose,
setIsSubmitting: setSubmitting,
});
} else {
await submitLLMProvider({
providerName: OPENROUTER_PROVIDER_NAME,
values,
initialValues,
modelConfigurations:
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
existingLlmProvider,
shouldMarkAsDefault,
setIsTesting,
mutate,
onClose,
setSubmitting,
});
}
}}
>
<OpenRouterModalInternals
existingLlmProvider={existingLlmProvider}
isOnboarding={isOnboarding}
/>
</ModalWrapper>
{(formikProps) => (
<OpenRouterModalInternals
formikProps={formikProps}
existingLlmProvider={existingLlmProvider}
fetchedModels={fetchedModels}
setFetchedModels={setFetchedModels}
modelConfigurations={modelConfigurations}
isTesting={isTesting}
onClose={onClose}
isOnboarding={isOnboarding}
/>
)}
</Formik>
);
}

View File

@@ -1,27 +1,38 @@
"use client";
import { useState } from "react";
import { useSWRConfig } from "swr";
import { Formik } from "formik";
import { FileUploadFormField } from "@/components/Field";
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
import * as InputLayouts from "@/layouts/input-layouts";
import { LLMProviderFormProps, LLMProviderName } from "@/interfaces/llm";
import { LLMProviderFormProps } from "@/interfaces/llm";
import * as Yup from "yup";
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
import {
useInitialValues,
buildValidationSchema,
buildDefaultInitialValues,
buildDefaultValidationSchema,
buildAvailableModelConfigurations,
buildOnboardingInitialValues,
BaseLLMFormValues,
} from "@/sections/modals/llmConfig/utils";
import { submitProvider } from "@/sections/modals/llmConfig/svc";
import { LLMProviderConfiguredSource } from "@/lib/analytics";
import {
ModelSelectionField,
submitLLMProvider,
submitOnboardingProvider,
} from "@/sections/modals/llmConfig/svc";
import {
ModelsField,
DisplayNameField,
ModelAccessField,
ModalWrapper,
FieldSeparator,
FieldWrapper,
ModelsAccessField,
SingleDefaultModelField,
LLMConfigurationModalWrapper,
} from "@/sections/modals/llmConfig/shared";
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
import { toast } from "@/hooks/useToast";
const VERTEXAI_PROVIDER_NAME = "vertex_ai";
const VERTEXAI_DISPLAY_NAME = "Google Cloud Vertex AI";
const VERTEXAI_DEFAULT_MODEL = "gemini-2.5-pro";
const VERTEXAI_DEFAULT_LOCATION = "global";
interface VertexAIModalValues extends BaseLLMFormValues {
@@ -35,49 +46,89 @@ export default function VertexAIModal({
variant = "llm-configuration",
existingLlmProvider,
shouldMarkAsDefault,
open,
onOpenChange,
onSuccess,
defaultModelName,
onboardingState,
onboardingActions,
llmDescriptor,
}: LLMProviderFormProps) {
const isOnboarding = variant === "onboarding";
const [isTesting, setIsTesting] = useState(false);
const { mutate } = useSWRConfig();
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
VERTEXAI_PROVIDER_NAME
);
if (open === false) return null;
const onClose = () => onOpenChange?.(false);
const initialValues: VertexAIModalValues = {
...useInitialValues(
isOnboarding,
LLMProviderName.VERTEX_AI,
existingLlmProvider
),
custom_config: {
vertex_credentials:
(existingLlmProvider?.custom_config?.vertex_credentials as string) ??
"",
vertex_location:
(existingLlmProvider?.custom_config?.vertex_location as string) ??
VERTEXAI_DEFAULT_LOCATION,
},
} as VertexAIModalValues;
const modelConfigurations = buildAvailableModelConfigurations(
existingLlmProvider,
wellKnownLLMProvider ?? llmDescriptor
);
const validationSchema = buildValidationSchema(isOnboarding, {
extra: {
custom_config: Yup.object({
vertex_credentials: Yup.string().required(
"Credentials file is required"
const initialValues: VertexAIModalValues = isOnboarding
? ({
...buildOnboardingInitialValues(),
name: VERTEXAI_PROVIDER_NAME,
provider: VERTEXAI_PROVIDER_NAME,
default_model_name: VERTEXAI_DEFAULT_MODEL,
custom_config: {
vertex_credentials: "",
vertex_location: VERTEXAI_DEFAULT_LOCATION,
},
} as VertexAIModalValues)
: {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations,
defaultModelName
),
vertex_location: Yup.string(),
}),
},
});
default_model_name:
(defaultModelName &&
modelConfigurations.some((m) => m.name === defaultModelName)
? defaultModelName
: undefined) ??
wellKnownLLMProvider?.recommended_default_model?.name ??
VERTEXAI_DEFAULT_MODEL,
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,
custom_config: {
vertex_credentials:
(existingLlmProvider?.custom_config
?.vertex_credentials as string) ?? "",
vertex_location:
(existingLlmProvider?.custom_config?.vertex_location as string) ??
VERTEXAI_DEFAULT_LOCATION,
},
};
const validationSchema = isOnboarding
? Yup.object().shape({
default_model_name: Yup.string().required("Model name is required"),
custom_config: Yup.object({
vertex_credentials: Yup.string().required(
"Credentials file is required"
),
vertex_location: Yup.string(),
}),
})
: buildDefaultValidationSchema().shape({
custom_config: Yup.object({
vertex_credentials: Yup.string().required(
"Credentials file is required"
),
vertex_location: Yup.string(),
}),
});
return (
<ModalWrapper
providerName={LLMProviderName.VERTEX_AI}
llmProvider={existingLlmProvider}
onClose={onClose}
<Formik
initialValues={initialValues}
validationSchema={validationSchema}
onSubmit={async (values, { setSubmitting, setStatus }) => {
validateOnMount={true}
onSubmit={async (values, { setSubmitting }) => {
const filteredCustomConfig = Object.fromEntries(
Object.entries(values.custom_config || {}).filter(
([key, v]) => key === "vertex_credentials" || v !== ""
@@ -92,75 +143,101 @@ export default function VertexAIModal({
: undefined,
};
await submitProvider({
analyticsSource: isOnboarding
? LLMProviderConfiguredSource.CHAT_ONBOARDING
: LLMProviderConfiguredSource.ADMIN_PAGE,
providerName: LLMProviderName.VERTEX_AI,
values: submitValues,
initialValues,
existingLlmProvider,
shouldMarkAsDefault,
setStatus,
setSubmitting,
onClose,
onSuccess: async () => {
if (onSuccess) {
await onSuccess();
} else {
await refreshLlmProviderCaches(mutate);
toast.success(
existingLlmProvider
? "Provider updated successfully!"
: "Provider enabled successfully!"
);
}
},
});
if (isOnboarding && onboardingState && onboardingActions) {
const modelConfigsToUse =
(wellKnownLLMProvider ?? llmDescriptor)?.known_models ?? [];
await submitOnboardingProvider({
providerName: VERTEXAI_PROVIDER_NAME,
payload: {
...submitValues,
model_configurations: modelConfigsToUse,
is_auto_mode:
values.default_model_name === VERTEXAI_DEFAULT_MODEL,
},
onboardingState,
onboardingActions,
isCustomProvider: false,
onClose,
setIsSubmitting: setSubmitting,
});
} else {
await submitLLMProvider({
providerName: VERTEXAI_PROVIDER_NAME,
values: submitValues,
initialValues,
modelConfigurations,
existingLlmProvider,
shouldMarkAsDefault,
setIsTesting,
mutate,
onClose,
setSubmitting,
});
}
}}
>
<InputLayouts.FieldPadder>
<InputLayouts.Vertical
name="custom_config.vertex_location"
title="Google Cloud Region Name"
subDescription="Region where your Google Vertex AI models are hosted. See full list of regions supported at Google Cloud."
{(formikProps) => (
<LLMConfigurationModalWrapper
providerEndpoint={VERTEXAI_PROVIDER_NAME}
providerName={VERTEXAI_DISPLAY_NAME}
existingProviderName={existingLlmProvider?.name}
onClose={onClose}
isFormValid={formikProps.isValid}
isDirty={formikProps.dirty}
isTesting={isTesting}
isSubmitting={formikProps.isSubmitting}
>
<InputTypeInField
name="custom_config.vertex_location"
placeholder={VERTEXAI_DEFAULT_LOCATION}
/>
</InputLayouts.Vertical>
</InputLayouts.FieldPadder>
<FieldWrapper>
<InputLayouts.Vertical
name="custom_config.vertex_location"
title="Google Cloud Region Name"
subDescription="Region where your Google Vertex AI models are hosted. See full list of regions supported at Google Cloud."
>
<InputTypeInField
name="custom_config.vertex_location"
placeholder={VERTEXAI_DEFAULT_LOCATION}
/>
</InputLayouts.Vertical>
</FieldWrapper>
<InputLayouts.FieldPadder>
<InputLayouts.Vertical
name="custom_config.vertex_credentials"
title="API Key"
subDescription="Attach your API key JSON from Google Cloud to access your models."
>
<FileUploadFormField
name="custom_config.vertex_credentials"
label=""
/>
</InputLayouts.Vertical>
</InputLayouts.FieldPadder>
<FieldWrapper>
<InputLayouts.Vertical
name="custom_config.vertex_credentials"
title="API Key"
subDescription="Attach your API key JSON from Google Cloud to access your models."
>
<FileUploadFormField
name="custom_config.vertex_credentials"
label=""
/>
</InputLayouts.Vertical>
</FieldWrapper>
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<DisplayNameField disabled={!!existingLlmProvider} />
</>
<FieldSeparator />
{!isOnboarding && (
<DisplayNameField disabled={!!existingLlmProvider} />
)}
<FieldSeparator />
{isOnboarding ? (
<SingleDefaultModelField placeholder="E.g. gemini-2.5-pro" />
) : (
<ModelsField
modelConfigurations={modelConfigurations}
formikProps={formikProps}
recommendedDefaultModel={
wellKnownLLMProvider?.recommended_default_model ?? null
}
shouldShowAutoUpdateToggle={true}
/>
)}
{!isOnboarding && <ModelsAccessField formikProps={formikProps} />}
</LLMConfigurationModalWrapper>
)}
<InputLayouts.FieldSeparator />
<ModelSelectionField shouldShowAutoUpdateToggle={true} />
{!isOnboarding && (
<>
<InputLayouts.FieldSeparator />
<ModelAccessField />
</>
)}
</ModalWrapper>
</Formik>
);
}

View File

@@ -7,68 +7,58 @@ import VertexAIModal from "@/sections/modals/llmConfig/VertexAIModal";
import OpenRouterModal from "@/sections/modals/llmConfig/OpenRouterModal";
import CustomModal from "@/sections/modals/llmConfig/CustomModal";
import BedrockModal from "@/sections/modals/llmConfig/BedrockModal";
import LMStudioModal from "@/sections/modals/llmConfig/LMStudioModal";
import LMStudioForm from "@/sections/modals/llmConfig/LMStudioForm";
import LiteLLMProxyModal from "@/sections/modals/llmConfig/LiteLLMProxyModal";
import BifrostModal from "@/sections/modals/llmConfig/BifrostModal";
import OpenAICompatibleModal from "@/sections/modals/llmConfig/OpenAICompatibleModal";
function detectIfRealOpenAIProvider(provider: LLMProviderView) {
return (
provider.provider === LLMProviderName.OPENAI &&
provider.api_key &&
!provider.api_base &&
Object.keys(provider.custom_config || {}).length === 0
);
}
export function getModalForExistingProvider(
provider: LLMProviderView,
open?: boolean,
onOpenChange?: (open: boolean) => void,
defaultModelName?: string
) {
const props = {
existingLlmProvider: provider,
open,
onOpenChange,
defaultModelName,
};
const hasCustomConfig = provider.custom_config != null;
switch (provider.provider) {
// These providers don't use custom_config themselves, so a non-null
// custom_config means the provider was created via CustomModal.
case LLMProviderName.OPENAI:
return hasCustomConfig ? (
<CustomModal {...props} />
) : (
<OpenAIModal {...props} />
);
// "openai" as a provider name can be used for litellm proxy / any OpenAI-compatible provider
if (detectIfRealOpenAIProvider(provider)) {
return <OpenAIModal {...props} />;
} else {
return <CustomModal {...props} />;
}
case LLMProviderName.ANTHROPIC:
return hasCustomConfig ? (
<CustomModal {...props} />
) : (
<AnthropicModal {...props} />
);
case LLMProviderName.AZURE:
return hasCustomConfig ? (
<CustomModal {...props} />
) : (
<AzureModal {...props} />
);
case LLMProviderName.OPENROUTER:
return hasCustomConfig ? (
<CustomModal {...props} />
) : (
<OpenRouterModal {...props} />
);
// These providers legitimately store settings in custom_config,
// so always use their dedicated modals.
return <AnthropicModal {...props} />;
case LLMProviderName.OLLAMA_CHAT:
return <OllamaModal {...props} />;
case LLMProviderName.AZURE:
return <AzureModal {...props} />;
case LLMProviderName.VERTEX_AI:
return <VertexAIModal {...props} />;
case LLMProviderName.BEDROCK:
return <BedrockModal {...props} />;
case LLMProviderName.OPENROUTER:
return <OpenRouterModal {...props} />;
case LLMProviderName.LM_STUDIO:
return <LMStudioModal {...props} />;
return <LMStudioForm {...props} />;
case LLMProviderName.LITELLM_PROXY:
return <LiteLLMProxyModal {...props} />;
case LLMProviderName.BIFROST:
return <BifrostModal {...props} />;
case LLMProviderName.OPENAI_COMPATIBLE:
return <OpenAICompatibleModal {...props} />;
default:
return <CustomModal {...props} />;
}

View File

@@ -1,12 +1,11 @@
"use client";
import React, { useEffect, useRef, useState } from "react";
import { Formik, Form, useFormikContext } from "formik";
import type { FormikConfig } from "formik";
import { ReactNode, useState } from "react";
import { Form, FormikProps } from "formik";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import { useAgents } from "@/hooks/useAgents";
import { useUserGroups } from "@/lib/hooks";
import { LLMProviderView, ModelConfiguration } from "@/interfaces/llm";
import { ModelConfiguration, SimpleKnownModel } from "@/interfaces/llm";
import * as InputLayouts from "@/layouts/input-layouts";
import Checkbox from "@/refresh-components/inputs/Checkbox";
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
@@ -16,10 +15,12 @@ import InputSelect from "@/refresh-components/inputs/InputSelect";
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
import Switch from "@/refresh-components/inputs/Switch";
import Text from "@/refresh-components/texts/Text";
import { Button, LineItemButton } from "@opal/components";
import { Button, LineItemButton, Tag } from "@opal/components";
import { BaseLLMFormValues } from "@/sections/modals/llmConfig/utils";
import type { RichStr } from "@opal/types";
import { WithoutStyles } from "@opal/types";
import Separator from "@/refresh-components/Separator";
import { Section } from "@/layouts/general-layouts";
import { Hoverable } from "@opal/core";
import { Content } from "@opal/layouts";
import {
SvgArrowExchange,
@@ -47,14 +48,27 @@ import {
getProviderProductName,
} from "@/lib/llmConfig/providers";
export function FieldSeparator() {
return <Separator noPadding className="px-2" />;
}
export type FieldWrapperProps = WithoutStyles<
React.HTMLAttributes<HTMLDivElement>
>;
export function FieldWrapper(props: FieldWrapperProps) {
return <div {...props} className="p-2 w-full" />;
}
// ─── DisplayNameField ────────────────────────────────────────────────────────
export interface DisplayNameFieldProps {
disabled?: boolean;
}
export function DisplayNameField({ disabled = false }: DisplayNameFieldProps) {
return (
<InputLayouts.FieldPadder>
<FieldWrapper>
<InputLayouts.Vertical
name="name"
title="Display Name"
@@ -66,7 +80,7 @@ export function DisplayNameField({ disabled = false }: DisplayNameFieldProps) {
variant={disabled ? "disabled" : undefined}
/>
</InputLayouts.Vertical>
</InputLayouts.FieldPadder>
</FieldWrapper>
);
}
@@ -75,56 +89,47 @@ export function DisplayNameField({ disabled = false }: DisplayNameFieldProps) {
export interface APIKeyFieldProps {
optional?: boolean;
providerName?: string;
subDescription?: string | RichStr;
}
export function APIKeyField({
optional = false,
providerName,
subDescription,
}: APIKeyFieldProps) {
return (
<InputLayouts.FieldPadder>
<FieldWrapper>
<InputLayouts.Vertical
name="api_key"
title="API Key"
subDescription={
subDescription
? subDescription
: providerName
? `Paste your API key from ${providerName} to access your models.`
: "Paste your API key to access your models."
providerName
? `Paste your API key from ${providerName} to access your models.`
: "Paste your API key to access your models."
}
suffix={optional ? "optional" : undefined}
>
<PasswordInputTypeInField name="api_key" />
<PasswordInputTypeInField name="api_key" placeholder="API Key" />
</InputLayouts.Vertical>
</InputLayouts.FieldPadder>
</FieldWrapper>
);
}
// ─── APIBaseField ───────────────────────────────────────────────────────────
// ─── SingleDefaultModelField ─────────────────────────────────────────────────
export interface APIBaseFieldProps {
optional?: boolean;
subDescription?: string | RichStr;
export interface SingleDefaultModelFieldProps {
placeholder?: string;
}
export function APIBaseField({
optional = false,
subDescription,
placeholder = "https://",
}: APIBaseFieldProps) {
export function SingleDefaultModelField({
placeholder = "E.g. gpt-4o",
}: SingleDefaultModelFieldProps) {
return (
<InputLayouts.FieldPadder>
<InputLayouts.Vertical
name="api_base"
title="API Base URL"
subDescription={subDescription}
suffix={optional ? "optional" : undefined}
>
<InputTypeInField name="api_base" placeholder={placeholder} />
</InputLayouts.Vertical>
</InputLayouts.FieldPadder>
<InputLayouts.Vertical
name="default_model_name"
title="Default Model"
description="The model to use by default for this provider unless otherwise specified."
>
<InputTypeInField name="default_model_name" placeholder={placeholder} />
</InputLayouts.Vertical>
);
}
@@ -134,8 +139,13 @@ export function APIBaseField({
const GROUP_PREFIX = "group:";
const AGENT_PREFIX = "agent:";
export function ModelAccessField() {
const formikProps = useFormikContext<BaseLLMFormValues>();
interface ModelsAccessFieldProps<T> {
formikProps: FormikProps<T>;
}
export function ModelsAccessField<T extends BaseLLMFormValues>({
formikProps,
}: ModelsAccessFieldProps<T>) {
const { agents } = useAgents();
const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups();
const { data: usersData } = useUsers({ includeApiKeys: false });
@@ -218,7 +228,7 @@ export function ModelAccessField() {
return (
<div className="flex flex-col w-full">
<InputLayouts.FieldPadder>
<FieldWrapper>
<InputLayouts.Horizontal
name="is_public"
title="Models Access"
@@ -239,7 +249,7 @@ export function ModelAccessField() {
</InputSelect.Content>
</InputSelect>
</InputLayouts.Horizontal>
</InputLayouts.FieldPadder>
</FieldWrapper>
{!isPublic && (
<Card background="light" border="none" padding="sm">
@@ -305,7 +315,7 @@ export function ModelAccessField() {
</div>
)}
<InputLayouts.FieldSeparator />
<FieldSeparator />
{selectedAgentIds.length > 0 ? (
<div className="grid grid-cols-2 gap-1 w-full">
@@ -358,115 +368,90 @@ export function ModelAccessField() {
);
}
// ─── RefetchButton ──────────────────────────────────────────────────
/**
* Manages an AbortController so that clicking the button cancels any
* in-flight fetch before starting a new one. Also aborts on unmount.
*/
interface RefetchButtonProps {
onRefetch: (signal: AbortSignal) => Promise<void> | void;
}
function RefetchButton({ onRefetch }: RefetchButtonProps) {
const abortRef = useRef<AbortController | null>(null);
const [isFetching, setIsFetching] = useState(false);
useEffect(() => {
return () => abortRef.current?.abort();
}, []);
return (
<Button
prominence="tertiary"
icon={isFetching ? SimpleLoader : SvgRefreshCw}
onClick={async () => {
abortRef.current?.abort();
const controller = new AbortController();
abortRef.current = controller;
setIsFetching(true);
try {
await onRefetch(controller.signal);
} catch (err) {
if (err instanceof DOMException && err.name === "AbortError") return;
toast.error(
err instanceof Error ? err.message : "Failed to fetch models"
);
} finally {
if (!controller.signal.aborted) {
setIsFetching(false);
}
}
}}
disabled={isFetching}
/>
);
}
// ─── ModelsField ─────────────────────────────────────────────────────
export interface ModelSelectionFieldProps {
export interface ModelsFieldProps<T> {
formikProps: FormikProps<T>;
modelConfigurations: ModelConfiguration[];
recommendedDefaultModel: SimpleKnownModel | null;
shouldShowAutoUpdateToggle: boolean;
onRefetch?: (signal: AbortSignal) => Promise<void> | void;
/** Called when the user clicks the refresh button to re-fetch models. */
onRefetch?: () => Promise<void> | void;
/** Called when the user adds a custom model by name. Enables the "Add Model" input. */
onAddModel?: (modelName: string) => void;
}
export function ModelSelectionField({
export function ModelsField<T extends BaseLLMFormValues>({
formikProps,
modelConfigurations,
recommendedDefaultModel,
shouldShowAutoUpdateToggle,
onRefetch,
onAddModel,
}: ModelSelectionFieldProps) {
const formikProps = useFormikContext<BaseLLMFormValues>();
}: ModelsFieldProps<T>) {
const [newModelName, setNewModelName] = useState("");
const isAutoMode = formikProps.values.is_auto_mode;
const models = formikProps.values.model_configurations;
const selectedModels = formikProps.values.selected_model_names ?? [];
const defaultModel = formikProps.values.default_model_name;
// Snapshot the original model visibility so we can restore it when
// toggling auto mode back on.
const originalModelsRef = useRef(models);
useEffect(() => {
if (originalModelsRef.current.length === 0 && models.length > 0) {
originalModelsRef.current = models;
function handleCheckboxChange(modelName: string, checked: boolean) {
// Read current values inside the handler to avoid stale closure issues
const currentSelected = formikProps.values.selected_model_names ?? [];
const currentDefault = formikProps.values.default_model_name;
if (checked) {
const newSelected = [...currentSelected, modelName];
formikProps.setFieldValue("selected_model_names", newSelected);
// If this is the first model, set it as default
if (currentSelected.length === 0) {
formikProps.setFieldValue("default_model_name", modelName);
}
} else {
const newSelected = currentSelected.filter((name) => name !== modelName);
formikProps.setFieldValue("selected_model_names", newSelected);
// If removing the default, set the first remaining model as default
if (currentDefault === modelName && newSelected.length > 0) {
formikProps.setFieldValue("default_model_name", newSelected[0]);
} else if (newSelected.length === 0) {
formikProps.setFieldValue("default_model_name", undefined);
}
}
}, [models]);
}
// Automatically derive test_model_name from model_configurations.
// Any change to visibility or the model list syncs this automatically.
useEffect(() => {
const firstVisible = models.find((m) => m.is_visible)?.name;
if (firstVisible !== formikProps.values.test_model_name) {
formikProps.setFieldValue("test_model_name", firstVisible);
}
}, [models]); // eslint-disable-line react-hooks/exhaustive-deps
function setVisibility(modelName: string, visible: boolean) {
const updated = models.map((m) =>
m.name === modelName ? { ...m, is_visible: visible } : m
);
formikProps.setFieldValue("model_configurations", updated);
function handleSetDefault(modelName: string) {
formikProps.setFieldValue("default_model_name", modelName);
}
function handleToggleAutoMode(nextIsAutoMode: boolean) {
formikProps.setFieldValue("is_auto_mode", nextIsAutoMode);
if (nextIsAutoMode) {
formikProps.setFieldValue(
"model_configurations",
originalModelsRef.current
);
formikProps.setFieldValue(
"selected_model_names",
modelConfigurations.filter((m) => m.is_visible).map((m) => m.name)
);
formikProps.setFieldValue(
"default_model_name",
recommendedDefaultModel?.name ?? undefined
);
}
const allSelected =
modelConfigurations.length > 0 &&
modelConfigurations.every((m) => selectedModels.includes(m.name));
function handleToggleSelectAll() {
if (allSelected) {
formikProps.setFieldValue("selected_model_names", []);
formikProps.setFieldValue("default_model_name", undefined);
} else {
const allNames = modelConfigurations.map((m) => m.name);
formikProps.setFieldValue("selected_model_names", allNames);
if (!formikProps.values.default_model_name && allNames.length > 0) {
formikProps.setFieldValue("default_model_name", allNames[0]);
}
}
}
const allSelected = models.length > 0 && models.every((m) => m.is_visible);
function handleToggleSelectAll() {
const nextVisible = !allSelected;
const updated = models.map((m) => ({
...m,
is_visible: nextVisible,
}));
formikProps.setFieldValue("model_configurations", updated);
}
const visibleModels = models.filter((m) => m.is_visible);
const visibleModels = modelConfigurations.filter((m) => m.is_visible);
return (
<Card background="light" border="none" padding="sm">
@@ -479,45 +464,118 @@ export function ModelSelectionField({
>
<Section flexDirection="row" gap={0}>
<Button
disabled={isAutoMode || models.length === 0}
disabled={isAutoMode || modelConfigurations.length === 0}
prominence="tertiary"
size="md"
onClick={handleToggleSelectAll}
>
{allSelected ? "Unselect All" : "Select All"}
</Button>
{onRefetch && <RefetchButton onRefetch={onRefetch} />}
{onRefetch && (
<Button
prominence="tertiary"
icon={SvgRefreshCw}
onClick={async () => {
try {
await onRefetch();
} catch (err) {
toast.error(
err instanceof Error
? err.message
: "Failed to fetch models"
);
}
}}
/>
)}
</Section>
</InputLayouts.Horizontal>
{models.length === 0 ? (
{modelConfigurations.length === 0 ? (
<EmptyMessageCard title="No models available." padding="sm" />
) : (
<Section gap={0.25}>
{isAutoMode
? visibleModels.map((model) => (
<LineItemButton
? // Auto mode: read-only display
visibleModels.map((model) => (
<Hoverable.Root
key={model.name}
variant="section"
sizePreset="main-ui"
selectVariant="select-heavy"
state="selected"
icon={() => <Checkbox checked />}
title={model.display_name || model.name}
/>
group="LLMConfigurationButton"
widthVariant="full"
>
<LineItemButton
variant="section"
sizePreset="main-ui"
selectVariant="select-heavy"
state="selected"
icon={() => <Checkbox checked />}
title={model.display_name || model.name}
rightChildren={
model.name === defaultModel ? (
<Section>
<Tag title="Default Model" color="blue" />
</Section>
) : undefined
}
/>
</Hoverable.Root>
))
: models.map((model) => (
<LineItemButton
key={model.name}
variant="section"
sizePreset="main-ui"
selectVariant="select-heavy"
state={model.is_visible ? "selected" : "empty"}
icon={() => <Checkbox checked={model.is_visible} />}
title={model.name}
onClick={() => setVisibility(model.name, !model.is_visible)}
/>
))}
: // Manual mode: checkbox selection
modelConfigurations.map((modelConfiguration) => {
const isSelected = selectedModels.includes(
modelConfiguration.name
);
const isDefault = defaultModel === modelConfiguration.name;
return (
<Hoverable.Root
key={modelConfiguration.name}
group="LLMConfigurationButton"
widthVariant="full"
>
<LineItemButton
variant="section"
sizePreset="main-ui"
selectVariant="select-heavy"
state={isSelected ? "selected" : "empty"}
icon={() => <Checkbox checked={isSelected} />}
title={modelConfiguration.name}
onClick={() =>
handleCheckboxChange(
modelConfiguration.name,
!isSelected
)
}
rightChildren={
isSelected ? (
isDefault ? (
<Section>
<Tag color="blue" title="Default Model" />
</Section>
) : (
<Hoverable.Item
group="LLMConfigurationButton"
variant="opacity-on-hover"
>
<Button
size="sm"
prominence="internal"
onClick={(e) => {
e.stopPropagation();
handleSetDefault(modelConfiguration.name);
}}
type="button"
>
Set as default
</Button>
</Hoverable.Item>
)
) : undefined
}
/>
</Hoverable.Root>
);
})}
</Section>
)}
@@ -532,7 +590,7 @@ export function ModelSelectionField({
if (e.key === "Enter" && newModelName.trim()) {
e.preventDefault();
const trimmed = newModelName.trim();
if (!models.some((m) => m.name === trimmed)) {
if (!modelConfigurations.some((m) => m.name === trimmed)) {
onAddModel(trimmed);
setNewModelName("");
}
@@ -547,11 +605,14 @@ export function ModelSelectionField({
type="button"
disabled={
!newModelName.trim() ||
models.some((m) => m.name === newModelName.trim())
modelConfigurations.some((m) => m.name === newModelName.trim())
}
onClick={() => {
const trimmed = newModelName.trim();
if (trimmed && !models.some((m) => m.name === trimmed)) {
if (
trimmed &&
!modelConfigurations.some((m) => m.name === trimmed)
) {
onAddModel(trimmed);
setNewModelName("");
}
@@ -578,87 +639,41 @@ export function ModelSelectionField({
);
}
// ─── ModalWrapper ─────────────────────────────────────────────────────
// ============================================================================
// LLMConfigurationModalWrapper
// ============================================================================
export interface ModalWrapperProps<
T extends BaseLLMFormValues = BaseLLMFormValues,
> {
providerName: string;
llmProvider?: LLMProviderView;
interface LLMConfigurationModalWrapperProps {
providerEndpoint: string;
providerName?: string;
existingProviderName?: string;
onClose: () => void;
initialValues: T;
validationSchema: FormikConfig<T>["validationSchema"];
onSubmit: FormikConfig<T>["onSubmit"];
children: React.ReactNode;
isFormValid: boolean;
isDirty?: boolean;
isTesting?: boolean;
isSubmitting?: boolean;
children: ReactNode;
}
export function ModalWrapper<T extends BaseLLMFormValues = BaseLLMFormValues>({
export function LLMConfigurationModalWrapper({
providerEndpoint,
providerName,
llmProvider,
existingProviderName,
onClose,
initialValues,
validationSchema,
onSubmit,
isFormValid,
isDirty,
isTesting,
isSubmitting,
children,
}: ModalWrapperProps<T>) {
return (
<Formik
initialValues={initialValues}
validationSchema={validationSchema}
validateOnMount
onSubmit={onSubmit}
>
{() => (
<ModalWrapperInner
providerName={providerName}
llmProvider={llmProvider}
onClose={onClose}
modelConfigurations={initialValues.model_configurations}
>
{children}
</ModalWrapperInner>
)}
</Formik>
);
}
interface ModalWrapperInnerProps {
providerName: string;
llmProvider?: LLMProviderView;
onClose: () => void;
modelConfigurations?: ModelConfiguration[];
children: React.ReactNode;
}
function ModalWrapperInner({
providerName,
llmProvider,
onClose,
modelConfigurations,
children,
}: ModalWrapperInnerProps) {
const { isValid, dirty, isSubmitting, status, setFieldValue, values } =
useFormikContext<BaseLLMFormValues>();
// When SWR resolves after mount, populate model_configurations if still
// empty. test_model_name is then derived automatically by
// ModelSelectionField's useEffect.
useEffect(() => {
if (
modelConfigurations &&
modelConfigurations.length > 0 &&
values.model_configurations.length === 0
) {
setFieldValue("model_configurations", modelConfigurations);
}
}, [modelConfigurations]); // eslint-disable-line react-hooks/exhaustive-deps
const isTesting = status?.isTesting === true;
}: LLMConfigurationModalWrapperProps) {
const busy = isTesting || isSubmitting;
const providerIcon = getProviderIcon(providerName);
const providerDisplayName = getProviderDisplayName(providerName);
const providerProductName = getProviderProductName(providerName);
const providerIcon = getProviderIcon(providerEndpoint);
const providerDisplayName =
providerName ?? getProviderDisplayName(providerEndpoint);
const providerProductName = getProviderProductName(providerEndpoint);
const title = llmProvider
? `Configure "${llmProvider.name}"`
const title = existingProviderName
? `Configure "${existingProviderName}"`
: `Set up ${providerProductName}`;
const description = `Connect to ${providerDisplayName} and set up your ${providerProductName} models.`;
@@ -674,7 +689,7 @@ function ModalWrapperInner({
description={description}
onClose={onClose}
/>
<Modal.Body padding={0.5} gap={0}>
<Modal.Body padding={0.5} gap={0.5}>
{children}
</Modal.Body>
<Modal.Footer>
@@ -682,11 +697,13 @@ function ModalWrapperInner({
Cancel
</Button>
<Button
disabled={!isValid || !dirty || busy}
disabled={
!isFormValid || busy || (!!existingProviderName && !isDirty)
}
type="submit"
icon={busy ? SimpleLoader : undefined}
>
{llmProvider?.name
{existingProviderName
? busy
? "Updating"
: "Update"

View File

@@ -1,8 +1,13 @@
import { LLMProviderName, LLMProviderView } from "@/interfaces/llm";
import {
LLMProviderName,
LLMProviderView,
ModelConfiguration,
} from "@/interfaces/llm";
import {
LLM_ADMIN_URL,
LLM_PROVIDERS_ADMIN_URL,
} from "@/lib/llmConfig/constants";
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
import { toast } from "@/hooks/useToast";
import isEqual from "lodash/isEqual";
import { parseAzureTargetUri } from "@/lib/azureTargetUri";
@@ -13,11 +18,13 @@ import {
} from "@/lib/analytics";
import {
BaseLLMFormValues,
SubmitLLMProviderParams,
SubmitOnboardingProviderParams,
TestApiKeyResult,
filterModelConfigurations,
getAutoModeModelConfigurations,
} from "@/sections/modals/llmConfig/utils";
// ─── Test helpers ─────────────────────────────────────────────────────────
const submitLlmTestRequest = async (
payload: Record<string, unknown>,
fallbackErrorMessage: string
@@ -43,6 +50,161 @@ const submitLlmTestRequest = async (
}
};
export const submitLLMProvider = async <T extends BaseLLMFormValues>({
providerName,
values,
initialValues,
modelConfigurations,
existingLlmProvider,
shouldMarkAsDefault,
hideSuccess,
setIsTesting,
mutate,
onClose,
setSubmitting,
}: SubmitLLMProviderParams<T>): Promise<void> => {
setSubmitting(true);
const { selected_model_names: visibleModels, api_key, ...rest } = values;
// In auto mode, use recommended models from descriptor
// In manual mode, use user's selection
let filteredModelConfigurations: ModelConfiguration[];
let finalDefaultModelName = rest.default_model_name;
if (values.is_auto_mode) {
filteredModelConfigurations =
getAutoModeModelConfigurations(modelConfigurations);
// In auto mode, use the first recommended model as default if current default isn't in the list
const visibleModelNames = new Set(
filteredModelConfigurations.map((m) => m.name)
);
if (
finalDefaultModelName &&
!visibleModelNames.has(finalDefaultModelName)
) {
finalDefaultModelName = filteredModelConfigurations[0]?.name ?? "";
}
} else {
filteredModelConfigurations = filterModelConfigurations(
modelConfigurations,
visibleModels,
rest.default_model_name as string | undefined
);
}
const customConfigChanged = !isEqual(
values.custom_config,
initialValues.custom_config
);
const normalizedApiBase =
typeof rest.api_base === "string" && rest.api_base.trim() === ""
? undefined
: rest.api_base;
const finalValues = {
...rest,
api_base: normalizedApiBase,
default_model_name: finalDefaultModelName,
api_key,
api_key_changed: api_key !== (initialValues.api_key as string | undefined),
custom_config_changed: customConfigChanged,
model_configurations: filteredModelConfigurations,
};
// Test the configuration
if (!isEqual(finalValues, initialValues)) {
setIsTesting(true);
const response = await fetch("/api/admin/llm/test", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
provider: providerName,
...finalValues,
model: finalDefaultModelName,
id: existingLlmProvider?.id,
}),
});
setIsTesting(false);
if (!response.ok) {
const errorMsg = (await response.json()).detail;
toast.error(errorMsg);
setSubmitting(false);
return;
}
}
const response = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}${
existingLlmProvider ? "" : "?is_creation=true"
}`,
{
method: "PUT",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
provider: providerName,
...finalValues,
id: existingLlmProvider?.id,
}),
}
);
if (!response.ok) {
const errorMsg = (await response.json()).detail;
const fullErrorMsg = existingLlmProvider
? `Failed to update provider: ${errorMsg}`
: `Failed to enable provider: ${errorMsg}`;
toast.error(fullErrorMsg);
return;
}
if (shouldMarkAsDefault) {
const newLlmProvider = (await response.json()) as LLMProviderView;
const setDefaultResponse = await fetch(`${LLM_ADMIN_URL}/default`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
provider_id: newLlmProvider.id,
model_name: finalDefaultModelName,
}),
});
if (!setDefaultResponse.ok) {
const errorMsg = (await setDefaultResponse.json()).detail;
toast.error(`Failed to set provider as default: ${errorMsg}`);
return;
}
}
await refreshLlmProviderCaches(mutate);
onClose();
if (!hideSuccess) {
const successMsg = existingLlmProvider
? "Provider updated successfully!"
: "Provider enabled successfully!";
toast.success(successMsg);
}
const knownProviders = new Set<string>(Object.values(LLMProviderName));
track(AnalyticsEvent.CONFIGURED_LLM_PROVIDER, {
provider: knownProviders.has(providerName) ? providerName : "custom",
is_creation: !existingLlmProvider,
source: LLMProviderConfiguredSource.ADMIN_PAGE,
});
setSubmitting(false);
};
export const testApiKeyHelper = async (
providerName: string,
formValues: Record<string, unknown>,
@@ -79,7 +241,7 @@ export const testApiKeyHelper = async (
...((formValues?.custom_config as Record<string, unknown>) ?? {}),
...(customConfigOverride ?? {}),
},
model: modelName ?? (formValues?.test_model_name as string) ?? "",
model: modelName ?? (formValues?.default_model_name as string) ?? "",
};
return await submitLlmTestRequest(
@@ -97,148 +259,96 @@ export const testCustomProvider = async (
);
};
// ─── Submit provider ──────────────────────────────────────────────────────
export interface SubmitProviderParams<
T extends BaseLLMFormValues = BaseLLMFormValues,
> {
providerName: string;
values: T;
initialValues: T;
existingLlmProvider?: LLMProviderView;
shouldMarkAsDefault?: boolean;
isCustomProvider?: boolean;
setStatus: (status: Record<string, unknown>) => void;
setSubmitting: (submitting: boolean) => void;
onClose: () => void;
/** Called after successful create/update + set-default. Use for cache refresh, state updates, toasts, etc. */
onSuccess?: () => void | Promise<void>;
/** Analytics source for tracking. @default LLMProviderConfiguredSource.ADMIN_PAGE */
analyticsSource?: LLMProviderConfiguredSource;
}
export async function submitProvider<T extends BaseLLMFormValues>({
export const submitOnboardingProvider = async ({
providerName,
values,
initialValues,
existingLlmProvider,
shouldMarkAsDefault,
payload,
onboardingState,
onboardingActions,
isCustomProvider,
setStatus,
setSubmitting,
onClose,
onSuccess,
analyticsSource = LLMProviderConfiguredSource.ADMIN_PAGE,
}: SubmitProviderParams<T>): Promise<void> {
setSubmitting(true);
setIsSubmitting,
}: SubmitOnboardingProviderParams): Promise<void> => {
setIsSubmitting(true);
const { test_model_name, api_key, ...rest } = values;
const testModelName =
test_model_name ||
values.model_configurations.find((m) => m.is_visible)?.name ||
"";
// ── Test credentials ────────────────────────────────────────────────
const customConfigChanged = !isEqual(
values.custom_config,
initialValues.custom_config
);
const normalizedApiBase =
typeof rest.api_base === "string" && rest.api_base.trim() === ""
? undefined
: rest.api_base;
const finalValues = {
...rest,
api_base: normalizedApiBase,
api_key,
api_key_changed: api_key !== (initialValues.api_key as string | undefined),
custom_config_changed: customConfigChanged,
};
if (!isEqual(finalValues, initialValues)) {
setStatus({ isTesting: true });
const testResult = await submitLlmTestRequest(
{
provider: providerName,
...finalValues,
model: testModelName,
id: existingLlmProvider?.id,
},
"An error occurred while testing the provider."
);
setStatus({ isTesting: false });
if (!testResult.ok) {
toast.error(testResult.errorMessage);
setSubmitting(false);
return;
}
// Test credentials
let result: TestApiKeyResult;
if (isCustomProvider) {
result = await testCustomProvider(payload);
} else {
result = await testApiKeyHelper(providerName, payload);
}
// ── Create/update provider ──────────────────────────────────────────
const response = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}${
existingLlmProvider ? "" : "?is_creation=true"
}`,
{
method: "PUT",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
provider: providerName,
...finalValues,
id: existingLlmProvider?.id,
}),
}
);
if (!response.ok) {
const errorMsg = (await response.json()).detail;
const fullErrorMsg = existingLlmProvider
? `Failed to update provider: ${errorMsg}`
: `Failed to enable provider: ${errorMsg}`;
toast.error(fullErrorMsg);
setSubmitting(false);
if (!result.ok) {
toast.error(result.errorMessage);
setIsSubmitting(false);
return;
}
// ── Set as default ──────────────────────────────────────────────────
if (shouldMarkAsDefault && testModelName) {
// Create provider
const response = await fetch(`${LLM_PROVIDERS_ADMIN_URL}?is_creation=true`, {
method: "PUT",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(payload),
});
if (!response.ok) {
const errorMsg = (await response.json()).detail;
toast.error(errorMsg);
setIsSubmitting(false);
return;
}
// Set as default if first provider
if (
onboardingState?.data?.llmProviders == null ||
onboardingState.data.llmProviders.length === 0
) {
try {
const newLlmProvider = await response.json();
if (newLlmProvider?.id != null) {
const setDefaultResponse = await fetch(`${LLM_ADMIN_URL}/default`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
provider_id: newLlmProvider.id,
model_name: testModelName,
}),
});
if (!setDefaultResponse.ok) {
const err = await setDefaultResponse.json().catch(() => ({}));
toast.error(err?.detail ?? "Failed to set provider as default");
setSubmitting(false);
return;
const defaultModelName =
(payload as Record<string, string>).default_model_name ??
(payload as Record<string, ModelConfiguration[]>)
.model_configurations?.[0]?.name ??
"";
if (defaultModelName) {
const setDefaultResponse = await fetch(`${LLM_ADMIN_URL}/default`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
provider_id: newLlmProvider.id,
model_name: defaultModelName,
}),
});
if (!setDefaultResponse.ok) {
const err = await setDefaultResponse.json().catch(() => ({}));
toast.error(err?.detail ?? "Failed to set provider as default");
setIsSubmitting(false);
return;
}
}
}
} catch {
} catch (_e) {
toast.error("Failed to set new provider as default");
}
}
// ── Post-success ────────────────────────────────────────────────────
const knownProviders = new Set<string>(Object.values(LLMProviderName));
track(AnalyticsEvent.CONFIGURED_LLM_PROVIDER, {
provider: knownProviders.has(providerName) ? providerName : "custom",
is_creation: !existingLlmProvider,
source: analyticsSource,
provider: isCustomProvider ? "custom" : providerName,
is_creation: true,
source: LLMProviderConfiguredSource.CHAT_ONBOARDING,
});
if (onSuccess) await onSuccess();
// Update onboarding state
onboardingActions.updateData({
llmProviders: [
...(onboardingState?.data.llmProviders ?? []),
isCustomProvider ? "custom" : providerName,
],
});
onboardingActions.setButtonActive(true);
setSubmitting(false);
setIsSubmitting(false);
onClose();
}
};

View File

@@ -1,130 +1,197 @@
import {
LLMProviderName,
LLMProviderView,
ModelConfiguration,
WellKnownLLMProviderDescriptor,
} from "@/interfaces/llm";
import * as Yup from "yup";
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
import { ScopedMutator } from "swr";
import { OnboardingActions, OnboardingState } from "@/interfaces/onboarding";
// ─── useInitialValues ─────────────────────────────────────────────────────
// Common class names for the Form component across all LLM provider forms
export const LLM_FORM_CLASS_NAME = "flex flex-col gap-y-4 items-stretch mt-6";
/** Builds the merged model list from existing + well-known, deduped by name. */
function buildModelConfigurations(
export const buildDefaultInitialValues = (
existingLlmProvider?: LLMProviderView,
wellKnownLLMProvider?: WellKnownLLMProviderDescriptor
): ModelConfiguration[] {
const existingModels = existingLlmProvider?.model_configurations ?? [];
const wellKnownModels = wellKnownLLMProvider?.known_models ?? [];
modelConfigurations?: ModelConfiguration[],
currentDefaultModelName?: string
) => {
const defaultModelName =
(currentDefaultModelName &&
existingLlmProvider?.model_configurations?.some(
(m) => m.name === currentDefaultModelName
)
? currentDefaultModelName
: undefined) ??
existingLlmProvider?.model_configurations?.[0]?.name ??
modelConfigurations?.[0]?.name ??
"";
const modelMap = new Map<string, ModelConfiguration>();
wellKnownModels.forEach((m) => modelMap.set(m.name, m));
existingModels.forEach((m) => modelMap.set(m.name, m));
return Array.from(modelMap.values());
}
/** Shared initial values for all LLM provider forms (both onboarding and admin). */
export function useInitialValues(
isOnboarding: boolean,
providerName: LLMProviderName,
existingLlmProvider?: LLMProviderView
) {
const { wellKnownLLMProvider } = useWellKnownLLMProvider(providerName);
const modelConfigurations = buildModelConfigurations(
existingLlmProvider,
wellKnownLLMProvider ?? undefined
);
const testModelName =
modelConfigurations.find((m) => m.is_visible)?.name ??
wellKnownLLMProvider?.recommended_default_model?.name;
// Auto mode must be explicitly enabled by the user
// Default to false for new providers, preserve existing value when editing
const isAutoMode = existingLlmProvider?.is_auto_mode ?? false;
return {
provider: existingLlmProvider?.provider ?? providerName,
name: isOnboarding ? providerName : existingLlmProvider?.name ?? "",
api_key: existingLlmProvider?.api_key ?? undefined,
api_base: existingLlmProvider?.api_base ?? undefined,
name: existingLlmProvider?.name || "",
default_model_name: defaultModelName,
is_public: existingLlmProvider?.is_public ?? true,
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,
is_auto_mode: isAutoMode,
groups: existingLlmProvider?.groups ?? [],
personas: existingLlmProvider?.personas ?? [],
model_configurations: modelConfigurations,
test_model_name: testModelName,
selected_model_names: existingLlmProvider
? existingLlmProvider.model_configurations
.filter((modelConfiguration) => modelConfiguration.is_visible)
.map((modelConfiguration) => modelConfiguration.name)
: modelConfigurations
?.filter((modelConfiguration) => modelConfiguration.is_visible)
.map((modelConfiguration) => modelConfiguration.name) ?? [],
};
}
// ─── buildValidationSchema ────────────────────────────────────────────────
interface ValidationSchemaOptions {
apiKey?: boolean;
apiBase?: boolean;
extra?: Yup.ObjectShape;
}
/**
* Builds the validation schema for a modal.
*
* @param isOnboarding — controls the base schema:
* - `true`: minimal (only `test_model_name`).
* - `false`: full admin schema (display name, access, models, etc.).
* @param options.apiKey — require `api_key`.
* @param options.apiBase — require `api_base`.
* @param options.extra — arbitrary Yup fields for provider-specific validation.
*/
export function buildValidationSchema(
isOnboarding: boolean,
{ apiKey, apiBase, extra }: ValidationSchemaOptions = {}
) {
const providerFields: Yup.ObjectShape = {
...(apiKey && {
api_key: Yup.string().required("API Key is required"),
}),
...(apiBase && {
api_base: Yup.string().required("API Base URL is required"),
}),
...extra,
};
if (isOnboarding) {
return Yup.object().shape({
test_model_name: Yup.string().required("Model name is required"),
...providerFields,
});
}
};
export const buildDefaultValidationSchema = () => {
return Yup.object({
name: Yup.string().required("Display Name is required"),
default_model_name: Yup.string().required("Model name is required"),
is_public: Yup.boolean().required(),
is_auto_mode: Yup.boolean().required(),
groups: Yup.array().of(Yup.number()),
personas: Yup.array().of(Yup.number()),
test_model_name: Yup.string().required("Model name is required"),
...providerFields,
selected_model_names: Yup.array().of(Yup.string()),
});
}
};
// ─── Form value types ─────────────────────────────────────────────────────
export const buildAvailableModelConfigurations = (
existingLlmProvider?: LLMProviderView,
wellKnownLLMProvider?: WellKnownLLMProviderDescriptor
): ModelConfiguration[] => {
const existingModels = existingLlmProvider?.model_configurations ?? [];
const wellKnownModels = wellKnownLLMProvider?.known_models ?? [];
/** Base form values that all provider forms share. */
// Create a map to deduplicate by model name, preferring existing models
const modelMap = new Map<string, ModelConfiguration>();
// Add well-known models first
wellKnownModels.forEach((model) => {
modelMap.set(model.name, model);
});
// Override with existing models (they take precedence)
existingModels.forEach((model) => {
modelMap.set(model.name, model);
});
return Array.from(modelMap.values());
};
// Base form values that all provider forms share
export interface BaseLLMFormValues {
name: string;
api_key?: string;
api_base?: string;
/** Model name used for the test request — automatically derived. */
test_model_name?: string;
default_model_name?: string;
is_public: boolean;
is_auto_mode: boolean;
groups: number[];
personas: number[];
/** The full model list with is_visible set directly by user interaction. */
model_configurations: ModelConfiguration[];
selected_model_names: string[];
custom_config?: Record<string, string>;
}
// ─── Misc ─────────────────────────────────────────────────────────────────
export interface SubmitLLMProviderParams<
T extends BaseLLMFormValues = BaseLLMFormValues,
> {
providerName: string;
values: T;
initialValues: T;
modelConfigurations: ModelConfiguration[];
existingLlmProvider?: LLMProviderView;
shouldMarkAsDefault?: boolean;
hideSuccess?: boolean;
setIsTesting: (testing: boolean) => void;
mutate: ScopedMutator;
onClose: () => void;
setSubmitting: (submitting: boolean) => void;
}
export const filterModelConfigurations = (
currentModelConfigurations: ModelConfiguration[],
visibleModels: string[],
defaultModelName?: string
): ModelConfiguration[] => {
return currentModelConfigurations
.map(
(modelConfiguration): ModelConfiguration => ({
name: modelConfiguration.name,
is_visible: visibleModels.includes(modelConfiguration.name),
max_input_tokens: modelConfiguration.max_input_tokens ?? null,
supports_image_input: modelConfiguration.supports_image_input,
supports_reasoning: modelConfiguration.supports_reasoning,
display_name: modelConfiguration.display_name,
})
)
.filter(
(modelConfiguration) =>
modelConfiguration.name === defaultModelName ||
modelConfiguration.is_visible
);
};
// Helper to get model configurations for auto mode
// In auto mode, we include ALL models but preserve their visibility status
// Models in the auto config are visible, others are created but not visible
export const getAutoModeModelConfigurations = (
modelConfigurations: ModelConfiguration[]
): ModelConfiguration[] => {
return modelConfigurations.map(
(modelConfiguration): ModelConfiguration => ({
name: modelConfiguration.name,
is_visible: modelConfiguration.is_visible,
max_input_tokens: modelConfiguration.max_input_tokens ?? null,
supports_image_input: modelConfiguration.supports_image_input,
supports_reasoning: modelConfiguration.supports_reasoning,
display_name: modelConfiguration.display_name,
})
);
};
export type TestApiKeyResult =
| { ok: true }
| { ok: false; errorMessage: string };
export const getModelOptions = (
fetchedModelConfigurations: Array<{ name: string }>
) => {
return fetchedModelConfigurations.map((model) => ({
label: model.name,
value: model.name,
}));
};
/** Initial values used by onboarding forms (flat shape, always creating new). */
export const buildOnboardingInitialValues = () => ({
name: "",
provider: "",
api_key: "",
api_base: "",
api_version: "",
default_model_name: "",
model_configurations: [] as ModelConfiguration[],
custom_config: {} as Record<string, string>,
api_key_changed: true,
groups: [] as number[],
is_public: true,
is_auto_mode: false,
personas: [] as number[],
selected_model_names: [] as string[],
deployment_name: "",
target_uri: "",
});
export interface SubmitOnboardingProviderParams {
providerName: string;
payload: Record<string, unknown>;
onboardingState: OnboardingState;
onboardingActions: OnboardingActions;
isCustomProvider: boolean;
onClose: () => void;
setIsSubmitting: (submitting: boolean) => void;
}

Some files were not shown because too many files have changed in this diff Show More