mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-08 08:22:42 +00:00
Compare commits
3 Commits
fix/onboar
...
jamison/cl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
77ce667b21 | ||
|
|
1e0a8afc72 | ||
|
|
85302a1cde |
2
.github/workflows/deployment.yml
vendored
2
.github/workflows/deployment.yml
vendored
@@ -228,7 +228,7 @@ jobs:
|
||||
|
||||
- name: Create GitHub Release
|
||||
id: create-release
|
||||
uses: softprops/action-gh-release@153bb8e04406b158c6c84fc1615b65b24149a1fe # ratchet:softprops/action-gh-release@v2
|
||||
uses: softprops/action-gh-release@da05d552573ad5aba039eaac05058a918a7bf631 # ratchet:softprops/action-gh-release@v2
|
||||
with:
|
||||
tag_name: ${{ steps.release-tag.outputs.tag }}
|
||||
name: ${{ steps.release-tag.outputs.tag }}
|
||||
|
||||
2
.github/workflows/helm-chart-releases.yml
vendored
2
.github/workflows/helm-chart-releases.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Helm CLI
|
||||
uses: azure/setup-helm@dda3372f752e03dde6b3237bc9431cdc2f7a02a2 # ratchet:azure/setup-helm@v5.0.0
|
||||
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4
|
||||
with:
|
||||
version: v3.12.1
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # ratchet:actions/stale@v10
|
||||
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # ratchet:actions/stale@v10
|
||||
with:
|
||||
stale-issue-message: 'This issue is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
|
||||
stale-pr-message: 'This PR is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
|
||||
|
||||
2
.github/workflows/pr-helm-chart-testing.yml
vendored
2
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -36,7 +36,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@dda3372f752e03dde6b3237bc9431cdc2f7a02a2 # ratchet:azure/setup-helm@v5.0.0
|
||||
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4.3.1
|
||||
with:
|
||||
version: v3.19.0
|
||||
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -59,6 +59,3 @@ node_modules
|
||||
|
||||
# plans
|
||||
plans/
|
||||
|
||||
# Added context for LLMs
|
||||
onyx-llm-context/
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
from onyx.db.engine.iam_auth import get_iam_auth_token
|
||||
from onyx.configs.app_configs import USE_IAM_AUTH
|
||||
from onyx.configs.app_configs import POSTGRES_HOST
|
||||
@@ -19,6 +19,7 @@ from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.sql.schema import SchemaItem
|
||||
from onyx.configs.constants import SSL_CERT_FILE
|
||||
from shared_configs.configs import (
|
||||
MULTI_TENANT,
|
||||
@@ -44,6 +45,8 @@ if config.config_file_name is not None and config.attributes.get(
|
||||
|
||||
target_metadata = [Base.metadata, ResultModelBase.metadata]
|
||||
|
||||
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ssl_context: ssl.SSLContext | None = None
|
||||
@@ -53,6 +56,25 @@ if USE_IAM_AUTH:
|
||||
ssl_context = ssl.create_default_context(cafile=SSL_CERT_FILE)
|
||||
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem, # noqa: ARG001
|
||||
name: str | None,
|
||||
type_: Literal[
|
||||
"schema",
|
||||
"table",
|
||||
"column",
|
||||
"index",
|
||||
"unique_constraint",
|
||||
"foreign_key_constraint",
|
||||
],
|
||||
reflected: bool, # noqa: ARG001
|
||||
compare_to: SchemaItem | None, # noqa: ARG001
|
||||
) -> bool:
|
||||
if type_ == "table" and name in EXCLUDE_TABLES:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def filter_tenants_by_range(
|
||||
tenant_ids: list[str], start_range: int | None = None, end_range: int | None = None
|
||||
) -> list[str]:
|
||||
@@ -209,6 +231,7 @@ def do_run_migrations(
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
include_object=include_object,
|
||||
version_table_schema=schema_name,
|
||||
include_schemas=True,
|
||||
compare_type=True,
|
||||
@@ -382,6 +405,7 @@ def run_migrations_offline() -> None:
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
include_object=include_object,
|
||||
version_table_schema=schema,
|
||||
include_schemas=True,
|
||||
script_location=config.get_main_option("script_location"),
|
||||
@@ -423,6 +447,7 @@ def run_migrations_offline() -> None:
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
include_object=include_object,
|
||||
version_table_schema=schema,
|
||||
include_schemas=True,
|
||||
script_location=config.get_main_option("script_location"),
|
||||
@@ -465,6 +490,7 @@ def run_migrations_online() -> None:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
include_object=include_object,
|
||||
version_table_schema=schema_name,
|
||||
include_schemas=True,
|
||||
compare_type=True,
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.schema import SchemaItem
|
||||
|
||||
from alembic import context
|
||||
from onyx.db.engine.sql_engine import build_connection_string
|
||||
@@ -33,6 +35,27 @@ target_metadata = [PublicBase.metadata]
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem, # noqa: ARG001
|
||||
name: str | None,
|
||||
type_: Literal[
|
||||
"schema",
|
||||
"table",
|
||||
"column",
|
||||
"index",
|
||||
"unique_constraint",
|
||||
"foreign_key_constraint",
|
||||
],
|
||||
reflected: bool, # noqa: ARG001
|
||||
compare_to: SchemaItem | None, # noqa: ARG001
|
||||
) -> bool:
|
||||
if type_ == "table" and name in EXCLUDE_TABLES:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
@@ -62,6 +85,7 @@ def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore[arg-type]
|
||||
include_object=include_object,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
|
||||
@@ -5,7 +5,6 @@ from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from ee.onyx.server.tenants.product_gating import get_gated_tenants
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
@@ -31,7 +30,6 @@ def cloud_beat_task_generator(
|
||||
queue: str = OnyxCeleryTask.DEFAULT,
|
||||
priority: int = OnyxCeleryPriority.MEDIUM,
|
||||
expires: int = BEAT_EXPIRES_DEFAULT,
|
||||
skip_gated: bool = True,
|
||||
) -> bool | None:
|
||||
"""a lightweight task used to kick off individual beat tasks per tenant."""
|
||||
time_start = time.monotonic()
|
||||
@@ -50,22 +48,20 @@ def cloud_beat_task_generator(
|
||||
last_lock_time = time.monotonic()
|
||||
tenant_ids: list[str] = []
|
||||
num_processed_tenants = 0
|
||||
num_skipped_gated = 0
|
||||
|
||||
try:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
# Per-task control over whether gated tenants are included. Most periodic tasks
|
||||
# do no useful work on gated tenants and just waste DB connections fanning out
|
||||
# to ~10k+ inactive tenants. A small number of cleanup tasks (connector deletion,
|
||||
# checkpoint/index attempt cleanup) need to run on gated tenants and pass
|
||||
# `skip_gated=False` from the beat schedule.
|
||||
gated_tenants: set[str] = get_gated_tenants() if skip_gated else set()
|
||||
# NOTE: for now, we are running tasks for gated tenants, since we want to allow
|
||||
# connector deletion to run successfully. The new plan is to continously prune
|
||||
# the gated tenants set, so we won't have a build up of old, unused gated tenants.
|
||||
# Keeping this around in case we want to revert to the previous behavior.
|
||||
# gated_tenants = get_gated_tenants()
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if tenant_id in gated_tenants:
|
||||
num_skipped_gated += 1
|
||||
continue
|
||||
# Same comment here as the above NOTE
|
||||
# if tenant_id in gated_tenants:
|
||||
# continue
|
||||
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
|
||||
@@ -108,7 +104,6 @@ def cloud_beat_task_generator(
|
||||
f"cloud_beat_task_generator finished: "
|
||||
f"task={task_name} "
|
||||
f"num_processed_tenants={num_processed_tenants} "
|
||||
f"num_skipped_gated={num_skipped_gated} "
|
||||
f"num_tenants={len(tenant_ids)} "
|
||||
f"elapsed={time_elapsed:.2f}"
|
||||
)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# Overview of Onyx Background Jobs
|
||||
|
||||
The background jobs take care of:
|
||||
|
||||
1. Pulling/Indexing documents (from connectors)
|
||||
2. Updating document metadata (from connectors)
|
||||
3. Cleaning up checkpoints and logic around indexing work (indexing indexing checkpoints and index attempt metadata)
|
||||
@@ -10,41 +9,37 @@ The background jobs take care of:
|
||||
|
||||
## Worker → Queue Mapping
|
||||
|
||||
| Worker | File | Queues |
|
||||
| ------------------------- | ------------------------------ | -------------------------------------------------------------------------------------------------------------------- |
|
||||
| Primary | `apps/primary.py` | `celery` |
|
||||
| Light | `apps/light.py` | `vespa_metadata_sync`, `connector_deletion`, `doc_permissions_upsert`, `checkpoint_cleanup`, `index_attempt_cleanup` |
|
||||
| Heavy | `apps/heavy.py` | `connector_pruning`, `connector_doc_permissions_sync`, `connector_external_group_sync`, `csv_generation`, `sandbox` |
|
||||
| Docprocessing | `apps/docprocessing.py` | `docprocessing` |
|
||||
| Docfetching | `apps/docfetching.py` | `connector_doc_fetching` |
|
||||
| User File Processing | `apps/user_file_processing.py` | `user_file_processing`, `user_file_project_sync`, `user_file_delete` |
|
||||
| Monitoring | `apps/monitoring.py` | `monitoring` |
|
||||
| Background (consolidated) | `apps/background.py` | All queues above except `celery` |
|
||||
| Worker | File | Queues |
|
||||
|--------|------|--------|
|
||||
| Primary | `apps/primary.py` | `celery` |
|
||||
| Light | `apps/light.py` | `vespa_metadata_sync`, `connector_deletion`, `doc_permissions_upsert`, `checkpoint_cleanup`, `index_attempt_cleanup` |
|
||||
| Heavy | `apps/heavy.py` | `connector_pruning`, `connector_doc_permissions_sync`, `connector_external_group_sync`, `csv_generation`, `sandbox` |
|
||||
| Docprocessing | `apps/docprocessing.py` | `docprocessing` |
|
||||
| Docfetching | `apps/docfetching.py` | `connector_doc_fetching` |
|
||||
| User File Processing | `apps/user_file_processing.py` | `user_file_processing`, `user_file_project_sync`, `user_file_delete` |
|
||||
| Monitoring | `apps/monitoring.py` | `monitoring` |
|
||||
| Background (consolidated) | `apps/background.py` | All queues above except `celery` |
|
||||
|
||||
## Non-Worker Apps
|
||||
|
||||
| App | File | Purpose |
|
||||
| ---------- | ----------- | ----------------------------------------------------------------------------------------------------- |
|
||||
| **Beat** | `beat.py` | Celery beat scheduler with `DynamicTenantScheduler` that generates per-tenant periodic task schedules |
|
||||
| **Client** | `client.py` | Minimal app for task submission from non-worker processes (e.g., API server) |
|
||||
| App | File | Purpose |
|
||||
|-----|------|---------|
|
||||
| **Beat** | `beat.py` | Celery beat scheduler with `DynamicTenantScheduler` that generates per-tenant periodic task schedules |
|
||||
| **Client** | `client.py` | Minimal app for task submission from non-worker processes (e.g., API server) |
|
||||
|
||||
### Shared Module
|
||||
|
||||
`app_base.py` provides:
|
||||
|
||||
- `TenantAwareTask` - Base task class that sets tenant context
|
||||
- Signal handlers for logging, cleanup, and lifecycle events
|
||||
- Readiness probes and health checks
|
||||
|
||||
|
||||
## Worker Details
|
||||
|
||||
### Primary (Coordinator and task dispatcher)
|
||||
|
||||
It is the single worker which handles tasks from the default celery queue. It is a singleton worker ensured by the `PRIMARY_WORKER` Redis lock
|
||||
which it touches every `CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8` seconds (using Celery Bootsteps)
|
||||
|
||||
On startup:
|
||||
|
||||
- waits for redis, postgres, document index to all be healthy
|
||||
- acquires the singleton lock
|
||||
- cleans all the redis states associated with background jobs
|
||||
@@ -52,34 +47,34 @@ On startup:
|
||||
|
||||
Then it cycles through its tasks as scheduled by Celery Beat:
|
||||
|
||||
| Task | Frequency | Description |
|
||||
| --------------------------------- | --------- | ------------------------------------------------------------------------------------------ |
|
||||
| `check_for_indexing` | 15s | Scans for connectors needing indexing → dispatches to `DOCFETCHING` queue |
|
||||
| `check_for_vespa_sync_task` | 20s | Finds stale documents/document sets → dispatches sync tasks to `VESPA_METADATA_SYNC` queue |
|
||||
| `check_for_pruning` | 20s | Finds connectors due for pruning → dispatches to `CONNECTOR_PRUNING` queue |
|
||||
| `check_for_connector_deletion` | 20s | Processes deletion requests → dispatches to `CONNECTOR_DELETION` queue |
|
||||
| `check_for_user_file_processing` | 20s | Checks for user uploads → dispatches to `USER_FILE_PROCESSING` queue |
|
||||
| `check_for_checkpoint_cleanup` | 1h | Cleans up old indexing checkpoints |
|
||||
| `check_for_index_attempt_cleanup` | 30m | Cleans up old index attempts |
|
||||
| `celery_beat_heartbeat` | 1m | Heartbeat for Beat watchdog |
|
||||
| Task | Frequency | Description |
|
||||
|------|-----------|-------------|
|
||||
| `check_for_indexing` | 15s | Scans for connectors needing indexing → dispatches to `DOCFETCHING` queue |
|
||||
| `check_for_vespa_sync_task` | 20s | Finds stale documents/document sets → dispatches sync tasks to `VESPA_METADATA_SYNC` queue |
|
||||
| `check_for_pruning` | 20s | Finds connectors due for pruning → dispatches to `CONNECTOR_PRUNING` queue |
|
||||
| `check_for_connector_deletion` | 20s | Processes deletion requests → dispatches to `CONNECTOR_DELETION` queue |
|
||||
| `check_for_user_file_processing` | 20s | Checks for user uploads → dispatches to `USER_FILE_PROCESSING` queue |
|
||||
| `check_for_checkpoint_cleanup` | 1h | Cleans up old indexing checkpoints |
|
||||
| `check_for_index_attempt_cleanup` | 30m | Cleans up old index attempts |
|
||||
| `kombu_message_cleanup_task` | periodic | Cleans orphaned Kombu messages from DB (Kombu being the messaging framework used by Celery) |
|
||||
| `celery_beat_heartbeat` | 1m | Heartbeat for Beat watchdog |
|
||||
|
||||
Watchdog is a separate Python process managed by supervisord which runs alongside celery workers. It checks the ONYX_CELERY_BEAT_HEARTBEAT_KEY in
|
||||
Redis to ensure Celery Beat is not dead. Beat schedules the celery_beat_heartbeat for Primary to touch the key and share that it's still alive.
|
||||
See supervisord.conf for watchdog config.
|
||||
|
||||
### Light
|
||||
|
||||
### Light
|
||||
Fast and short living tasks that are not resource intensive. High concurrency:
|
||||
Can have 24 concurrent workers, each with a prefetch of 8 for a total of 192 tasks in flight at once.
|
||||
|
||||
Tasks it handles:
|
||||
|
||||
- Syncs access/permissions, document sets, boosts, hidden state
|
||||
- Deletes documents that are marked for deletion in Postgres
|
||||
- Cleanup of checkpoints and index attempts
|
||||
|
||||
### Heavy
|
||||
|
||||
### Heavy
|
||||
Long running, resource intensive tasks, handles pruning and sandbox operations. Low concurrency - max concurrency of 4 with 1 prefetch.
|
||||
|
||||
Does not interact with the Document Index, it handles the syncs with external systems. Large volume API calls to handle pruning and fetching permissions, etc.
|
||||
@@ -88,24 +83,16 @@ Generates CSV exports which may take a long time with significant data in Postgr
|
||||
|
||||
Sandbox (new feature) for running Next.js, Python virtual env, OpenCode AI Agent, and access to knowledge files
|
||||
|
||||
|
||||
### Docprocessing, Docfetching, User File Processing
|
||||
|
||||
Docprocessing and Docfetching are for indexing documents:
|
||||
|
||||
- Docfetching runs connectors to pull documents from external APIs (Google Drive, Confluence, etc.), stores batches to file storage, and dispatches docprocessing tasks
|
||||
- Docprocessing retrieves batches, runs the indexing pipeline (chunking, embedding), and indexes into the Document Index
|
||||
- User Files come from uploads directly via the input bar
|
||||
- Docprocessing retrieves batches, runs the indexing pipeline (chunking, embedding), and indexes into the Document Index
|
||||
User Files come from uploads directly via the input bar
|
||||
|
||||
|
||||
### Monitoring
|
||||
|
||||
Observability and metrics collections:
|
||||
|
||||
- Queue lengths, connector success/failure, connector latencies
|
||||
- Queue lengths, connector success/failure, lconnector latencies
|
||||
- Memory of supervisor managed processes (workers, beat, slack)
|
||||
- Cloud and multitenant specific monitorings
|
||||
|
||||
## Prometheus Metrics
|
||||
|
||||
Workers can expose Prometheus metrics via a standalone HTTP server. Currently docfetching and docprocessing have push-based task lifecycle metrics; the monitoring worker runs pull-based collectors for queue depth and connector health.
|
||||
|
||||
For the full metric reference, integration guide, and PromQL examples, see [`docs/METRICS.md`](../../../docs/METRICS.md#celery-worker-metrics).
|
||||
|
||||
@@ -317,6 +317,7 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
"onyx.background.celery.tasks.evals",
|
||||
"onyx.background.celery.tasks.hierarchyfetching",
|
||||
"onyx.background.celery.tasks.periodic",
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
|
||||
@@ -75,8 +75,6 @@ beat_task_templates: list[dict] = [
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
# Run on gated tenants too — they may still have stale checkpoints to clean.
|
||||
"skip_gated": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -86,8 +84,6 @@ beat_task_templates: list[dict] = [
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
# Run on gated tenants too — they may still have stale index attempts.
|
||||
"skip_gated": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -97,8 +93,6 @@ beat_task_templates: list[dict] = [
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
# Gated tenants may still have connectors awaiting deletion.
|
||||
"skip_gated": False,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -272,7 +266,7 @@ def make_cloud_generator_task(task: dict[str, Any]) -> dict[str, Any]:
|
||||
cloud_task["kwargs"] = {}
|
||||
cloud_task["kwargs"]["task_name"] = task["task"]
|
||||
|
||||
optional_fields = ["queue", "priority", "expires", "skip_gated"]
|
||||
optional_fields = ["queue", "priority", "expires"]
|
||||
for field in optional_fields:
|
||||
if field in task["options"]:
|
||||
cloud_task["kwargs"][field] = task["options"][field]
|
||||
@@ -365,13 +359,7 @@ if not MULTI_TENANT:
|
||||
]
|
||||
)
|
||||
|
||||
# `skip_gated` is a cloud-only hint consumed by `cloud_beat_task_generator`. Strip
|
||||
# it before extending the self-hosted schedule so it doesn't leak into apply_async
|
||||
# as an unrecognised option on every fired task message.
|
||||
for _template in beat_task_templates:
|
||||
_self_hosted_template = copy.deepcopy(_template)
|
||||
_self_hosted_template["options"].pop("skip_gated", None)
|
||||
tasks_to_schedule.append(_self_hosted_template)
|
||||
tasks_to_schedule.extend(beat_task_templates)
|
||||
|
||||
|
||||
def generate_cloud_tasks(
|
||||
|
||||
138
backend/onyx/background/celery/tasks/periodic/tasks.py
Normal file
138
backend/onyx/background/celery/tasks/periodic/tasks.py
Normal file
@@ -0,0 +1,138 @@
|
||||
#####
|
||||
# Periodic Tasks
|
||||
#####
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
from celery.contrib.abortable import AbortableTask # type: ignore
|
||||
from celery.exceptions import TaskRevokedError
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import PostgresAdvisoryLocks
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
base=AbortableTask,
|
||||
)
|
||||
def kombu_message_cleanup_task(self: Any, tenant_id: str) -> int: # noqa: ARG001
|
||||
"""Runs periodically to clean up the kombu_message table"""
|
||||
|
||||
# we will select messages older than this amount to clean up
|
||||
KOMBU_MESSAGE_CLEANUP_AGE = 7 # days
|
||||
KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT = 1000
|
||||
|
||||
ctx = {}
|
||||
ctx["last_processed_id"] = 0
|
||||
ctx["deleted"] = 0
|
||||
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
|
||||
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Exit the task if we can't take the advisory lock
|
||||
result = db_session.execute(
|
||||
text("SELECT pg_try_advisory_lock(:id)"),
|
||||
{"id": PostgresAdvisoryLocks.KOMBU_MESSAGE_CLEANUP_LOCK_ID.value},
|
||||
).scalar()
|
||||
if not result:
|
||||
return 0
|
||||
|
||||
while True:
|
||||
if self.is_aborted():
|
||||
raise TaskRevokedError("kombu_message_cleanup_task was aborted.")
|
||||
|
||||
b = kombu_message_cleanup_task_helper(ctx, db_session)
|
||||
if not b:
|
||||
break
|
||||
|
||||
db_session.commit()
|
||||
|
||||
if ctx["deleted"] > 0:
|
||||
task_logger.info(
|
||||
f"Deleted {ctx['deleted']} orphaned messages from kombu_message."
|
||||
)
|
||||
|
||||
return ctx["deleted"]
|
||||
|
||||
|
||||
def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool:
|
||||
"""
|
||||
Helper function to clean up old messages from the `kombu_message` table that are no longer relevant.
|
||||
|
||||
This function retrieves messages from the `kombu_message` table that are no longer visible and
|
||||
older than a specified interval. It checks if the corresponding task_id exists in the
|
||||
`celery_taskmeta` table. If the task_id does not exist, the message is deleted.
|
||||
|
||||
Args:
|
||||
ctx (dict): A context dictionary containing configuration parameters such as:
|
||||
- 'cleanup_age' (int): The age in days after which messages are considered old.
|
||||
- 'page_limit' (int): The maximum number of messages to process in one batch.
|
||||
- 'last_processed_id' (int): The ID of the last processed message to handle pagination.
|
||||
- 'deleted' (int): A counter to track the number of deleted messages.
|
||||
db_session (Session): The SQLAlchemy database session for executing queries.
|
||||
|
||||
Returns:
|
||||
bool: Returns True if there are more rows to process, False if not.
|
||||
"""
|
||||
|
||||
inspector = inspect(db_session.bind)
|
||||
if not inspector:
|
||||
return False
|
||||
|
||||
# With the move to redis as celery's broker and backend, kombu tables may not even exist.
|
||||
# We can fail silently.
|
||||
if not inspector.has_table("kombu_message"):
|
||||
return False
|
||||
|
||||
query = text(
|
||||
"""
|
||||
SELECT id, timestamp, payload
|
||||
FROM kombu_message WHERE visible = 'false'
|
||||
AND timestamp < CURRENT_TIMESTAMP - INTERVAL :interval_days
|
||||
AND id > :last_processed_id
|
||||
ORDER BY id
|
||||
LIMIT :page_limit
|
||||
"""
|
||||
)
|
||||
kombu_messages = db_session.execute(
|
||||
query,
|
||||
{
|
||||
"interval_days": f"{ctx['cleanup_age']} days",
|
||||
"page_limit": ctx["page_limit"],
|
||||
"last_processed_id": ctx["last_processed_id"],
|
||||
},
|
||||
).fetchall()
|
||||
|
||||
if len(kombu_messages) == 0:
|
||||
return False
|
||||
|
||||
for msg in kombu_messages:
|
||||
payload = json.loads(msg[2])
|
||||
task_id = payload["headers"]["id"]
|
||||
|
||||
# Check if task_id exists in celery_taskmeta
|
||||
task_exists = db_session.execute(
|
||||
text("SELECT 1 FROM celery_taskmeta WHERE task_id = :task_id"),
|
||||
{"task_id": task_id},
|
||||
).fetchone()
|
||||
|
||||
# If task_id does not exist, delete the message
|
||||
if not task_exists:
|
||||
result = db_session.execute(
|
||||
text("DELETE FROM kombu_message WHERE id = :message_id"),
|
||||
{"message_id": msg[0]},
|
||||
)
|
||||
if result.rowcount > 0: # type: ignore
|
||||
ctx["deleted"] += 1
|
||||
|
||||
ctx["last_processed_id"] = msg[0]
|
||||
|
||||
return True
|
||||
@@ -379,14 +379,6 @@ POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "127.0.0.1"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||
AWS_REGION_NAME = os.environ.get("AWS_REGION_NAME") or "us-east-2"
|
||||
# Comma-separated replica / multi-host list. If unset, defaults to POSTGRES_HOST
|
||||
# only.
|
||||
_POSTGRES_HOSTS_STR = os.environ.get("POSTGRES_HOSTS", "").strip()
|
||||
POSTGRES_HOSTS: list[str] = (
|
||||
[h.strip() for h in _POSTGRES_HOSTS_STR.split(",") if h.strip()]
|
||||
if _POSTGRES_HOSTS_STR
|
||||
else [POSTGRES_HOST]
|
||||
)
|
||||
|
||||
POSTGRES_API_SERVER_POOL_SIZE = int(
|
||||
os.environ.get("POSTGRES_API_SERVER_POOL_SIZE") or 40
|
||||
|
||||
@@ -12,11 +12,6 @@ SLACK_USER_TOKEN_PREFIX = "xoxp-"
|
||||
SLACK_BOT_TOKEN_PREFIX = "xoxb-"
|
||||
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
|
||||
|
||||
# The mask_string() function in encryption.py uses "•" (U+2022 BULLET) to mask secrets.
|
||||
MASK_CREDENTIAL_CHAR = "\u2022"
|
||||
# Pattern produced by mask_string for strings >= 14 chars: "abcd...wxyz" (exactly 11 chars)
|
||||
MASK_CREDENTIAL_LONG_RE = re.compile(r"^.{4}\.{3}.{4}$")
|
||||
|
||||
SOURCE_TYPE = "source_type"
|
||||
# stored in the `metadata` of a chunk. Used to signify that this chunk should
|
||||
# not be used for QA. For example, Google Drive file types which can't be parsed
|
||||
@@ -396,6 +391,10 @@ class MilestoneRecordType(str, Enum):
|
||||
REQUESTED_CONNECTOR = "requested_connector"
|
||||
|
||||
|
||||
class PostgresAdvisoryLocks(Enum):
|
||||
KOMBU_MESSAGE_CLEANUP_LOCK_ID = auto()
|
||||
|
||||
|
||||
class OnyxCeleryQueues:
|
||||
# "celery" is the default queue defined by celery and also the queue
|
||||
# we are running in the primary worker to run system tasks
|
||||
@@ -578,6 +577,7 @@ class OnyxCeleryTask:
|
||||
MONITOR_PROCESS_MEMORY = "monitor_process_memory"
|
||||
CELERY_BEAT_HEARTBEAT = "celery_beat_heartbeat"
|
||||
|
||||
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
||||
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
||||
"connector_permission_sync_generator_task"
|
||||
)
|
||||
|
||||
@@ -8,8 +8,6 @@ from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import FederatedConnectorSource
|
||||
from onyx.configs.constants import MASK_CREDENTIAL_CHAR
|
||||
from onyx.configs.constants import MASK_CREDENTIAL_LONG_RE
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import FederatedConnector
|
||||
@@ -47,23 +45,6 @@ def fetch_all_federated_connectors_parallel() -> list[FederatedConnector]:
|
||||
return fetch_all_federated_connectors(db_session)
|
||||
|
||||
|
||||
def _reject_masked_credentials(credentials: dict[str, Any]) -> None:
|
||||
"""Raise if any credential string value contains mask placeholder characters.
|
||||
|
||||
mask_string() has two output formats:
|
||||
- Short strings (< 14 chars): "••••••••••••" (U+2022 BULLET)
|
||||
- Long strings (>= 14 chars): "abcd...wxyz" (first4 + "..." + last4)
|
||||
Both must be rejected.
|
||||
"""
|
||||
for key, val in credentials.items():
|
||||
if isinstance(val, str) and (
|
||||
MASK_CREDENTIAL_CHAR in val or MASK_CREDENTIAL_LONG_RE.match(val)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Credential field '{key}' contains masked placeholder characters. Please provide the actual credential value."
|
||||
)
|
||||
|
||||
|
||||
def validate_federated_connector_credentials(
|
||||
source: FederatedConnectorSource,
|
||||
credentials: dict[str, Any],
|
||||
@@ -85,8 +66,6 @@ def create_federated_connector(
|
||||
config: dict[str, Any] | None = None,
|
||||
) -> FederatedConnector:
|
||||
"""Create a new federated connector with credential and config validation."""
|
||||
_reject_masked_credentials(credentials)
|
||||
|
||||
# Validate credentials before creating
|
||||
if not validate_federated_connector_credentials(source, credentials):
|
||||
raise ValueError(
|
||||
@@ -298,8 +277,6 @@ def update_federated_connector(
|
||||
)
|
||||
|
||||
if credentials is not None:
|
||||
_reject_masked_credentials(credentials)
|
||||
|
||||
# Validate credentials before updating
|
||||
if not validate_federated_connector_credentials(
|
||||
federated_connector.source, credentials
|
||||
|
||||
@@ -236,15 +236,14 @@ def upsert_llm_provider(
|
||||
db_session.add(existing_llm_provider)
|
||||
|
||||
# Filter out empty strings and None values from custom_config to allow
|
||||
# providers like Bedrock to fall back to IAM roles when credentials are not provided.
|
||||
# NOTE: An empty dict ({}) is preserved as-is — it signals that the provider was
|
||||
# created via the custom modal and must be reopened with CustomModal, not a
|
||||
# provider-specific modal. Only None means "no custom config at all".
|
||||
# providers like Bedrock to fall back to IAM roles when credentials are not provided
|
||||
custom_config = llm_provider_upsert_request.custom_config
|
||||
if custom_config:
|
||||
custom_config = {
|
||||
k: v for k, v in custom_config.items() if v is not None and v.strip() != ""
|
||||
}
|
||||
# Set to None if the dict is empty after filtering
|
||||
custom_config = custom_config or None
|
||||
|
||||
api_base = llm_provider_upsert_request.api_base or None
|
||||
existing_llm_provider.provider = llm_provider_upsert_request.provider
|
||||
@@ -304,7 +303,16 @@ def upsert_llm_provider(
|
||||
).delete(synchronize_session="fetch")
|
||||
db_session.flush()
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from onyx.llm.utils import get_max_input_tokens
|
||||
|
||||
for model_config in llm_provider_upsert_request.model_configurations:
|
||||
max_input_tokens = model_config.max_input_tokens
|
||||
if max_input_tokens is None:
|
||||
max_input_tokens = get_max_input_tokens(
|
||||
model_name=model_config.name,
|
||||
model_provider=llm_provider_upsert_request.provider,
|
||||
)
|
||||
|
||||
supported_flows = [LLMModelFlowType.CHAT]
|
||||
if model_config.supports_image_input:
|
||||
@@ -317,7 +325,7 @@ def upsert_llm_provider(
|
||||
model_configuration_id=existing.id,
|
||||
supported_flows=supported_flows,
|
||||
is_visible=model_config.is_visible,
|
||||
max_input_tokens=model_config.max_input_tokens,
|
||||
max_input_tokens=max_input_tokens,
|
||||
display_name=model_config.display_name,
|
||||
)
|
||||
else:
|
||||
@@ -327,7 +335,7 @@ def upsert_llm_provider(
|
||||
model_name=model_config.name,
|
||||
supported_flows=supported_flows,
|
||||
is_visible=model_config.is_visible,
|
||||
max_input_tokens=model_config.max_input_tokens,
|
||||
max_input_tokens=max_input_tokens,
|
||||
display_name=model_config.display_name,
|
||||
)
|
||||
|
||||
|
||||
@@ -52,21 +52,9 @@ KNOWN_OPENPYXL_BUGS = [
|
||||
|
||||
def get_markitdown_converter() -> "MarkItDown":
|
||||
global _MARKITDOWN_CONVERTER
|
||||
from markitdown import MarkItDown
|
||||
|
||||
if _MARKITDOWN_CONVERTER is None:
|
||||
from markitdown import MarkItDown
|
||||
|
||||
# Patch this function to effectively no-op because we were seeing this
|
||||
# module take an inordinate amount of time to convert charts to markdown,
|
||||
# making some powerpoint files with many or complicated charts nearly
|
||||
# unindexable.
|
||||
from markitdown.converters._pptx_converter import PptxConverter
|
||||
|
||||
setattr(
|
||||
PptxConverter,
|
||||
"_convert_chart_to_markdown",
|
||||
lambda self, chart: "\n\n[chart omitted]\n\n", # noqa: ARG005
|
||||
)
|
||||
_MARKITDOWN_CONVERTER = MarkItDown(enable_plugins=False)
|
||||
return _MARKITDOWN_CONVERTER
|
||||
|
||||
@@ -217,26 +205,18 @@ def read_pdf_file(
|
||||
try:
|
||||
pdf_reader = PdfReader(file)
|
||||
|
||||
if pdf_reader.is_encrypted:
|
||||
# Try the explicit password first, then fall back to an empty
|
||||
# string. Owner-password-only PDFs (permission restrictions but
|
||||
# no open password) decrypt successfully with "".
|
||||
# See https://github.com/onyx-dot-app/onyx/issues/9754
|
||||
passwords = [p for p in [pdf_pass, ""] if p is not None]
|
||||
if pdf_reader.is_encrypted and pdf_pass is not None:
|
||||
decrypt_success = False
|
||||
for pw in passwords:
|
||||
try:
|
||||
if pdf_reader.decrypt(pw) != 0:
|
||||
decrypt_success = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
decrypt_success = pdf_reader.decrypt(pdf_pass) != 0
|
||||
except Exception:
|
||||
logger.error("Unable to decrypt pdf")
|
||||
|
||||
if not decrypt_success:
|
||||
logger.error(
|
||||
"Encrypted PDF could not be decrypted, returning empty text."
|
||||
)
|
||||
return "", metadata, []
|
||||
elif pdf_reader.is_encrypted:
|
||||
logger.warning("No Password for an encrypted PDF, returning empty text.")
|
||||
return "", metadata, []
|
||||
|
||||
# Basic PDF metadata
|
||||
if pdf_reader.metadata is not None:
|
||||
|
||||
@@ -33,20 +33,8 @@ def is_pdf_protected(file: IO[Any]) -> bool:
|
||||
|
||||
with preserve_position(file):
|
||||
reader = PdfReader(file)
|
||||
if not reader.is_encrypted:
|
||||
return False
|
||||
|
||||
# PDFs with only an owner password (permission restrictions like
|
||||
# print/copy disabled) use an empty user password — any viewer can open
|
||||
# them without prompting. decrypt("") returns 0 only when a real user
|
||||
# password is required. See https://github.com/onyx-dot-app/onyx/issues/9754
|
||||
try:
|
||||
return reader.decrypt("") == 0
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to evaluate PDF encryption; treating as password protected"
|
||||
)
|
||||
return True
|
||||
return bool(reader.is_encrypted)
|
||||
|
||||
|
||||
def is_docx_protected(file: IO[Any]) -> bool:
|
||||
|
||||
@@ -26,7 +26,6 @@ class LlmProviderNames(str, Enum):
|
||||
MISTRAL = "mistral"
|
||||
LITELLM_PROXY = "litellm_proxy"
|
||||
BIFROST = "bifrost"
|
||||
OPENAI_COMPATIBLE = "openai_compatible"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Needed so things like:
|
||||
@@ -47,7 +46,6 @@ WELL_KNOWN_PROVIDER_NAMES = [
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.LITELLM_PROXY,
|
||||
LlmProviderNames.BIFROST,
|
||||
LlmProviderNames.OPENAI_COMPATIBLE,
|
||||
]
|
||||
|
||||
|
||||
@@ -66,7 +64,6 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
|
||||
LlmProviderNames.LM_STUDIO: "LM Studio",
|
||||
LlmProviderNames.LITELLM_PROXY: "LiteLLM Proxy",
|
||||
LlmProviderNames.BIFROST: "Bifrost",
|
||||
LlmProviderNames.OPENAI_COMPATIBLE: "OpenAI Compatible",
|
||||
"groq": "Groq",
|
||||
"anyscale": "Anyscale",
|
||||
"deepseek": "DeepSeek",
|
||||
@@ -119,7 +116,6 @@ AGGREGATOR_PROVIDERS: set[str] = {
|
||||
LlmProviderNames.AZURE,
|
||||
LlmProviderNames.LITELLM_PROXY,
|
||||
LlmProviderNames.BIFROST,
|
||||
LlmProviderNames.OPENAI_COMPATIBLE,
|
||||
}
|
||||
|
||||
# Model family name mappings for display name generation
|
||||
|
||||
@@ -327,19 +327,12 @@ class LitellmLLM(LLM):
|
||||
):
|
||||
model_kwargs[VERTEX_LOCATION_KWARG] = "global"
|
||||
|
||||
# Bifrost and OpenAI-compatible: OpenAI-compatible proxies that send
|
||||
# model names directly to the endpoint. We route through LiteLLM's
|
||||
# openai provider with the server's base URL, and ensure /v1 is appended.
|
||||
if model_provider in (
|
||||
LlmProviderNames.BIFROST,
|
||||
LlmProviderNames.OPENAI_COMPATIBLE,
|
||||
):
|
||||
# Bifrost: OpenAI-compatible proxy that expects model names in
|
||||
# provider/model format (e.g. "anthropic/claude-sonnet-4-6").
|
||||
# We route through LiteLLM's openai provider with the Bifrost base URL,
|
||||
# and ensure /v1 is appended.
|
||||
if model_provider == LlmProviderNames.BIFROST:
|
||||
self._custom_llm_provider = "openai"
|
||||
# LiteLLM's OpenAI client requires an api_key to be set.
|
||||
# Many OpenAI-compatible servers don't need auth, so supply a
|
||||
# placeholder to prevent LiteLLM from raising AuthenticationError.
|
||||
if not self._api_key:
|
||||
model_kwargs.setdefault("api_key", "not-needed")
|
||||
if self._api_base is not None:
|
||||
base = self._api_base.rstrip("/")
|
||||
self._api_base = base if base.endswith("/v1") else f"{base}/v1"
|
||||
@@ -456,20 +449,17 @@ class LitellmLLM(LLM):
|
||||
optional_kwargs: dict[str, Any] = {}
|
||||
|
||||
# Model name
|
||||
is_openai_compatible_proxy = self._model_provider in (
|
||||
LlmProviderNames.BIFROST,
|
||||
LlmProviderNames.OPENAI_COMPATIBLE,
|
||||
)
|
||||
is_bifrost = self._model_provider == LlmProviderNames.BIFROST
|
||||
model_provider = (
|
||||
f"{self.config.model_provider}/responses"
|
||||
if is_openai_model # Uses litellm's completions -> responses bridge
|
||||
else self.config.model_provider
|
||||
)
|
||||
if is_openai_compatible_proxy:
|
||||
# OpenAI-compatible proxies (Bifrost, generic OpenAI-compatible
|
||||
# servers) expect model names sent directly to their endpoint.
|
||||
# We use custom_llm_provider="openai" so LiteLLM doesn't try
|
||||
# to route based on the provider prefix.
|
||||
if is_bifrost:
|
||||
# Bifrost expects model names in provider/model format
|
||||
# (e.g. "anthropic/claude-sonnet-4-6") sent directly to its
|
||||
# OpenAI-compatible endpoint. We use custom_llm_provider="openai"
|
||||
# so LiteLLM doesn't try to route based on the provider prefix.
|
||||
model = self.config.deployment_name or self.config.model_name
|
||||
else:
|
||||
model = f"{model_provider}/{self.config.deployment_name or self.config.model_name}"
|
||||
@@ -560,10 +550,7 @@ class LitellmLLM(LLM):
|
||||
if structured_response_format:
|
||||
optional_kwargs["response_format"] = structured_response_format
|
||||
|
||||
if (
|
||||
not (is_claude_model or is_ollama or is_mistral)
|
||||
or is_openai_compatible_proxy
|
||||
):
|
||||
if not (is_claude_model or is_ollama or is_mistral) or is_bifrost:
|
||||
# Litellm bug: tool_choice is dropped silently if not specified here for OpenAI
|
||||
# However, this param breaks Anthropic and Mistral models,
|
||||
# so it must be conditionally included unless the request is
|
||||
|
||||
@@ -15,8 +15,6 @@ LITELLM_PROXY_PROVIDER_NAME = "litellm_proxy"
|
||||
|
||||
BIFROST_PROVIDER_NAME = "bifrost"
|
||||
|
||||
OPENAI_COMPATIBLE_PROVIDER_NAME = "openai_compatible"
|
||||
|
||||
# Providers that use optional Bearer auth from custom_config
|
||||
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING: dict[str, str] = {
|
||||
LlmProviderNames.OLLAMA_CHAT: OLLAMA_API_KEY_CONFIG_KEY,
|
||||
|
||||
@@ -19,7 +19,6 @@ from onyx.llm.well_known_providers.constants import BIFROST_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import LITELLM_PROXY_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import LM_STUDIO_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OLLAMA_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENAI_COMPATIBLE_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENROUTER_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import VERTEXAI_PROVIDER_NAME
|
||||
@@ -52,7 +51,6 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
|
||||
OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API
|
||||
LITELLM_PROXY_PROVIDER_NAME: [], # Dynamic - fetched from LiteLLM proxy API
|
||||
BIFROST_PROVIDER_NAME: [], # Dynamic - fetched from Bifrost API
|
||||
OPENAI_COMPATIBLE_PROVIDER_NAME: [], # Dynamic - fetched from OpenAI-compatible API
|
||||
}
|
||||
|
||||
|
||||
@@ -338,7 +336,6 @@ def get_provider_display_name(provider_name: str) -> str:
|
||||
VERTEXAI_PROVIDER_NAME: "Google Vertex AI",
|
||||
OPENROUTER_PROVIDER_NAME: "OpenRouter",
|
||||
LITELLM_PROXY_PROVIDER_NAME: "LiteLLM Proxy",
|
||||
OPENAI_COMPATIBLE_PROVIDER_NAME: "OpenAI Compatible",
|
||||
}
|
||||
|
||||
if provider_name in _ONYX_PROVIDER_DISPLAY_NAMES:
|
||||
|
||||
@@ -6,7 +6,6 @@ from onyx.configs.app_configs import MCP_SERVER_ENABLED
|
||||
from onyx.configs.app_configs import MCP_SERVER_HOST
|
||||
from onyx.configs.app_configs import MCP_SERVER_PORT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -17,7 +16,6 @@ def main() -> None:
|
||||
logger.info("MCP server is disabled (MCP_SERVER_ENABLED=false)")
|
||||
return
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
logger.info(f"Starting MCP server on {MCP_SERVER_HOST}:{MCP_SERVER_PORT}")
|
||||
|
||||
from onyx.mcp_server.api import mcp_app
|
||||
|
||||
@@ -74,8 +74,6 @@ from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OllamaModelDetails
|
||||
from onyx.server.manage.llm.models import OllamaModelsRequest
|
||||
from onyx.server.manage.llm.models import OpenAICompatibleFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OpenAICompatibleModelsRequest
|
||||
from onyx.server.manage.llm.models import OpenRouterFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OpenRouterModelDetails
|
||||
from onyx.server.manage.llm.models import OpenRouterModelsRequest
|
||||
@@ -1577,95 +1575,3 @@ def _get_bifrost_models_response(api_base: str, api_key: str | None = None) -> d
|
||||
source_name="Bifrost",
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
|
||||
@admin_router.post("/openai-compatible/available-models")
|
||||
def get_openai_compatible_server_available_models(
|
||||
request: OpenAICompatibleModelsRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[OpenAICompatibleFinalModelResponse]:
|
||||
"""Fetch available models from a generic OpenAI-compatible /v1/models endpoint."""
|
||||
response_json = _get_openai_compatible_server_response(
|
||||
api_base=request.api_base, api_key=request.api_key
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
if not isinstance(models, list) or len(models) == 0:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No models found from your OpenAI-compatible endpoint",
|
||||
)
|
||||
|
||||
results: list[OpenAICompatibleFinalModelResponse] = []
|
||||
for model in models:
|
||||
try:
|
||||
model_id = model.get("id", "")
|
||||
model_name = model.get("name", model_id)
|
||||
|
||||
if not model_id:
|
||||
continue
|
||||
|
||||
# Skip embedding models
|
||||
if is_embedding_model(model_id):
|
||||
continue
|
||||
|
||||
results.append(
|
||||
OpenAICompatibleFinalModelResponse(
|
||||
name=model_id,
|
||||
display_name=model_name,
|
||||
max_input_tokens=model.get("context_length"),
|
||||
supports_image_input=infer_vision_support(model_id),
|
||||
supports_reasoning=is_reasoning_model(model_id, model_name),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to parse OpenAI-compatible model entry",
|
||||
extra={"error": str(e), "item": str(model)[:1000]},
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No compatible models found from OpenAI-compatible endpoint",
|
||||
)
|
||||
|
||||
sorted_results = sorted(results, key=lambda m: m.name.lower())
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="OpenAI Compatible",
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
|
||||
def _get_openai_compatible_server_response(
|
||||
api_base: str, api_key: str | None = None
|
||||
) -> dict:
|
||||
"""Perform GET to an OpenAI-compatible /v1/models and return parsed JSON."""
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
# Ensure we hit /v1/models
|
||||
if cleaned_api_base.endswith("/v1"):
|
||||
url = f"{cleaned_api_base}/models"
|
||||
else:
|
||||
url = f"{cleaned_api_base}/v1/models"
|
||||
|
||||
return _get_openai_compatible_models_response(
|
||||
url=url,
|
||||
source_name="OpenAI Compatible",
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
@@ -79,9 +79,7 @@ class LLMProviderDescriptor(BaseModel):
|
||||
provider=provider,
|
||||
provider_display_name=get_provider_display_name(provider),
|
||||
model_configurations=filter_model_configurations(
|
||||
llm_provider_model.model_configurations,
|
||||
provider,
|
||||
use_stored_display_name=llm_provider_model.custom_config is not None,
|
||||
llm_provider_model.model_configurations, provider
|
||||
),
|
||||
)
|
||||
|
||||
@@ -158,9 +156,7 @@ class LLMProviderView(LLMProvider):
|
||||
personas=personas,
|
||||
deployment_name=llm_provider_model.deployment_name,
|
||||
model_configurations=filter_model_configurations(
|
||||
llm_provider_model.model_configurations,
|
||||
provider,
|
||||
use_stored_display_name=llm_provider_model.custom_config is not None,
|
||||
llm_provider_model.model_configurations, provider
|
||||
),
|
||||
)
|
||||
|
||||
@@ -202,13 +198,13 @@ class ModelConfigurationView(BaseModel):
|
||||
cls,
|
||||
model_configuration_model: "ModelConfigurationModel",
|
||||
provider_name: str,
|
||||
use_stored_display_name: bool = False,
|
||||
) -> "ModelConfigurationView":
|
||||
# For dynamic providers (OpenRouter, Bedrock, Ollama) and custom-config
|
||||
# providers, use the display_name stored in DB. Skip LiteLLM parsing.
|
||||
# For dynamic providers (OpenRouter, Bedrock, Ollama), use the display_name
|
||||
# stored in DB from the source API. Skip LiteLLM parsing entirely.
|
||||
if (
|
||||
provider_name in DYNAMIC_LLM_PROVIDERS or use_stored_display_name
|
||||
) and model_configuration_model.display_name:
|
||||
provider_name in DYNAMIC_LLM_PROVIDERS
|
||||
and model_configuration_model.display_name
|
||||
):
|
||||
# Extract vendor from model name for grouping (e.g., "Anthropic", "OpenAI")
|
||||
vendor = extract_vendor_from_model_name(
|
||||
model_configuration_model.name, provider_name
|
||||
@@ -468,18 +464,3 @@ class BifrostFinalModelResponse(BaseModel):
|
||||
max_input_tokens: int | None
|
||||
supports_image_input: bool
|
||||
supports_reasoning: bool
|
||||
|
||||
|
||||
# OpenAI Compatible dynamic models fetch
|
||||
class OpenAICompatibleModelsRequest(BaseModel):
|
||||
api_base: str
|
||||
api_key: str | None = None
|
||||
provider_name: str | None = None # Optional: to save models to existing provider
|
||||
|
||||
|
||||
class OpenAICompatibleFinalModelResponse(BaseModel):
|
||||
name: str # Model ID (e.g. "meta-llama/Llama-3-8B-Instruct")
|
||||
display_name: str # Human-readable name from API
|
||||
max_input_tokens: int | None
|
||||
supports_image_input: bool
|
||||
supports_reasoning: bool
|
||||
|
||||
@@ -26,7 +26,6 @@ DYNAMIC_LLM_PROVIDERS = frozenset(
|
||||
LlmProviderNames.OLLAMA_CHAT,
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.BIFROST,
|
||||
LlmProviderNames.OPENAI_COMPATIBLE,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -309,15 +308,12 @@ def should_filter_as_dated_duplicate(
|
||||
def filter_model_configurations(
|
||||
model_configurations: list,
|
||||
provider: str,
|
||||
use_stored_display_name: bool = False,
|
||||
) -> list:
|
||||
"""Filter out obsolete and dated duplicate models from configurations.
|
||||
|
||||
Args:
|
||||
model_configurations: List of ModelConfiguration DB models
|
||||
provider: The provider name (e.g., "openai", "anthropic")
|
||||
use_stored_display_name: If True, prefer the display_name stored in the
|
||||
DB over LiteLLM enrichments. Set for custom-config providers.
|
||||
|
||||
Returns:
|
||||
List of ModelConfigurationView objects with obsolete/duplicate models removed
|
||||
@@ -337,9 +333,7 @@ def filter_model_configurations(
|
||||
if should_filter_as_dated_duplicate(model_configuration.name, all_model_names):
|
||||
continue
|
||||
filtered_configs.append(
|
||||
ModelConfigurationView.from_model(
|
||||
model_configuration, provider, use_stored_display_name
|
||||
)
|
||||
ModelConfigurationView.from_model(model_configuration, provider)
|
||||
)
|
||||
|
||||
return filtered_configs
|
||||
|
||||
@@ -186,7 +186,7 @@ class TestDocumentIndexNew:
|
||||
)
|
||||
document_index.index(chunks=[pre_chunk], indexing_metadata=pre_metadata)
|
||||
|
||||
time.sleep(2)
|
||||
time.sleep(1)
|
||||
|
||||
# Now index a batch with the existing doc and a new doc.
|
||||
chunks = [
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import MASK_CREDENTIAL_CHAR
|
||||
from onyx.db.federated import _reject_masked_credentials
|
||||
|
||||
|
||||
class TestRejectMaskedCredentials:
|
||||
"""Verify that masked credential values are never accepted for DB writes.
|
||||
|
||||
mask_string() has two output formats:
|
||||
- Short strings (< 14 chars): "••••••••••••" (U+2022 BULLET)
|
||||
- Long strings (>= 14 chars): "abcd...wxyz" (first4 + "..." + last4)
|
||||
_reject_masked_credentials must catch both.
|
||||
"""
|
||||
|
||||
def test_rejects_fully_masked_value(self) -> None:
|
||||
masked = MASK_CREDENTIAL_CHAR * 12 # "••••••••••••"
|
||||
with pytest.raises(ValueError, match="masked placeholder"):
|
||||
_reject_masked_credentials({"client_id": masked})
|
||||
|
||||
def test_rejects_long_string_masked_value(self) -> None:
|
||||
"""mask_string returns 'first4...last4' for long strings — the real
|
||||
format used for OAuth credentials like client_id and client_secret."""
|
||||
with pytest.raises(ValueError, match="masked placeholder"):
|
||||
_reject_masked_credentials({"client_id": "1234...7890"})
|
||||
|
||||
def test_rejects_when_any_field_is_masked(self) -> None:
|
||||
"""Even if client_id is real, a masked client_secret must be caught."""
|
||||
with pytest.raises(ValueError, match="client_secret"):
|
||||
_reject_masked_credentials(
|
||||
{
|
||||
"client_id": "1234567890.1234567890",
|
||||
"client_secret": MASK_CREDENTIAL_CHAR * 12,
|
||||
}
|
||||
)
|
||||
|
||||
def test_accepts_real_credentials(self) -> None:
|
||||
# Should not raise
|
||||
_reject_masked_credentials(
|
||||
{
|
||||
"client_id": "1234567890.1234567890",
|
||||
"client_secret": "test_client_secret_value",
|
||||
}
|
||||
)
|
||||
|
||||
def test_accepts_empty_dict(self) -> None:
|
||||
# Should not raise — empty credentials are handled elsewhere
|
||||
_reject_masked_credentials({})
|
||||
|
||||
def test_ignores_non_string_values(self) -> None:
|
||||
# Non-string values (None, bool, int) should pass through
|
||||
_reject_masked_credentials(
|
||||
{
|
||||
"client_id": "real_value",
|
||||
"redirect_uri": None,
|
||||
"some_flag": True,
|
||||
}
|
||||
)
|
||||
@@ -1,76 +0,0 @@
|
||||
%PDF-1.3
|
||||
%<25><><EFBFBD><EFBFBD>
|
||||
1 0 obj
|
||||
<<
|
||||
/Producer <1083d595b1>
|
||||
>>
|
||||
endobj
|
||||
2 0 obj
|
||||
<<
|
||||
/Type /Pages
|
||||
/Count 1
|
||||
/Kids [ 4 0 R ]
|
||||
>>
|
||||
endobj
|
||||
3 0 obj
|
||||
<<
|
||||
/Type /Catalog
|
||||
/Pages 2 0 R
|
||||
>>
|
||||
endobj
|
||||
4 0 obj
|
||||
<<
|
||||
/Type /Page
|
||||
/Resources <<
|
||||
/Font <<
|
||||
/F1 <<
|
||||
/Type /Font
|
||||
/Subtype /Type1
|
||||
/BaseFont /Helvetica
|
||||
>>
|
||||
>>
|
||||
>>
|
||||
/MediaBox [ 0.0 0.0 200 200 ]
|
||||
/Contents 5 0 R
|
||||
/Parent 2 0 R
|
||||
>>
|
||||
endobj
|
||||
5 0 obj
|
||||
<<
|
||||
/Length 42
|
||||
>>
|
||||
stream
|
||||
,N<><6~<7E>)<29><><EFBFBD><EFBFBD><EFBFBD>u<EFBFBD><0C><><EFBFBD>Zc'<27><>>8g<38><67><EFBFBD>n<EFBFBD><6E><EFBFBD><EFBFBD><EFBFBD>9"
|
||||
endstream
|
||||
endobj
|
||||
6 0 obj
|
||||
<<
|
||||
/V 2
|
||||
/R 3
|
||||
/Length 128
|
||||
/P 4294967292
|
||||
/Filter /Standard
|
||||
/O <6a340a292629053da84a6d8b19a5d505953b8b3fdac3d2d389fde0e354528d44>
|
||||
/U <d6f0dc91c7b9de264a8d708515468e6528bf4e5e4e758a4164004e56fffa0108>
|
||||
>>
|
||||
endobj
|
||||
xref
|
||||
0 7
|
||||
0000000000 65535 f
|
||||
0000000015 00000 n
|
||||
0000000059 00000 n
|
||||
0000000118 00000 n
|
||||
0000000167 00000 n
|
||||
0000000348 00000 n
|
||||
0000000440 00000 n
|
||||
trailer
|
||||
<<
|
||||
/Size 7
|
||||
/Root 3 0 R
|
||||
/Info 1 0 R
|
||||
/ID [ <6364336635356135633239323638353039306635656133623165313637366430> <6364336635356135633239323638353039306635656133623165313637366430> ]
|
||||
/Encrypt 6 0 R
|
||||
>>
|
||||
startxref
|
||||
655
|
||||
%%EOF
|
||||
@@ -54,12 +54,6 @@ class TestReadPdfFile:
|
||||
text, _, _ = read_pdf_file(_load("encrypted.pdf"), pdf_pass="wrong")
|
||||
assert text == ""
|
||||
|
||||
def test_owner_password_only_pdf_extracts_text(self) -> None:
|
||||
"""A PDF encrypted with only an owner password (no user password)
|
||||
should still yield its text content. Regression for #9754."""
|
||||
text, _, _ = read_pdf_file(_load("owner_protected.pdf"))
|
||||
assert "Hello World" in text
|
||||
|
||||
def test_empty_pdf(self) -> None:
|
||||
text, _, _ = read_pdf_file(_load("empty.pdf"))
|
||||
assert text.strip() == ""
|
||||
@@ -123,12 +117,6 @@ class TestIsPdfProtected:
|
||||
def test_protected_pdf(self) -> None:
|
||||
assert is_pdf_protected(_load("encrypted.pdf")) is True
|
||||
|
||||
def test_owner_password_only_is_not_protected(self) -> None:
|
||||
"""A PDF with only an owner password (permission restrictions) but no
|
||||
user password should NOT be considered protected — any viewer can open
|
||||
it without prompting for a password."""
|
||||
assert is_pdf_protected(_load("owner_protected.pdf")) is False
|
||||
|
||||
def test_preserves_file_position(self) -> None:
|
||||
pdf = _load("simple.pdf")
|
||||
pdf.seek(42)
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
import io
|
||||
|
||||
from pptx import Presentation # type: ignore[import-untyped]
|
||||
from pptx.chart.data import CategoryChartData # type: ignore[import-untyped]
|
||||
from pptx.enum.chart import XL_CHART_TYPE # type: ignore[import-untyped]
|
||||
from pptx.util import Inches # type: ignore[import-untyped]
|
||||
|
||||
from onyx.file_processing.extract_file_text import pptx_to_text
|
||||
|
||||
|
||||
def _make_pptx_with_chart() -> io.BytesIO:
|
||||
"""Create an in-memory pptx with one text slide and one chart slide."""
|
||||
prs = Presentation()
|
||||
|
||||
# Slide 1: text only
|
||||
slide1 = prs.slides.add_slide(prs.slide_layouts[1])
|
||||
slide1.shapes.title.text = "Introduction"
|
||||
slide1.placeholders[1].text = "This is the first slide."
|
||||
|
||||
# Slide 2: chart
|
||||
slide2 = prs.slides.add_slide(prs.slide_layouts[5]) # Blank layout
|
||||
chart_data = CategoryChartData()
|
||||
chart_data.categories = ["Q1", "Q2", "Q3"]
|
||||
chart_data.add_series("Revenue", (100, 200, 300))
|
||||
slide2.shapes.add_chart(
|
||||
XL_CHART_TYPE.COLUMN_CLUSTERED,
|
||||
Inches(1),
|
||||
Inches(1),
|
||||
Inches(6),
|
||||
Inches(4),
|
||||
chart_data,
|
||||
)
|
||||
|
||||
buf = io.BytesIO()
|
||||
prs.save(buf)
|
||||
buf.seek(0)
|
||||
return buf
|
||||
|
||||
|
||||
def _make_pptx_without_chart() -> io.BytesIO:
|
||||
"""Create an in-memory pptx with a single text-only slide."""
|
||||
prs = Presentation()
|
||||
slide = prs.slides.add_slide(prs.slide_layouts[1])
|
||||
slide.shapes.title.text = "Hello World"
|
||||
slide.placeholders[1].text = "Some content here."
|
||||
|
||||
buf = io.BytesIO()
|
||||
prs.save(buf)
|
||||
buf.seek(0)
|
||||
return buf
|
||||
|
||||
|
||||
class TestPptxToText:
|
||||
def test_chart_is_omitted(self) -> None:
|
||||
# Precondition
|
||||
pptx_file = _make_pptx_with_chart()
|
||||
|
||||
# Under test
|
||||
result = pptx_to_text(pptx_file)
|
||||
|
||||
# Postcondition
|
||||
assert "Introduction" in result
|
||||
assert "first slide" in result
|
||||
assert "[chart omitted]" in result
|
||||
# The actual chart data should NOT appear in the output.
|
||||
assert "Revenue" not in result
|
||||
assert "Q1" not in result
|
||||
|
||||
def test_text_only_pptx(self) -> None:
|
||||
# Precondition
|
||||
pptx_file = _make_pptx_without_chart()
|
||||
|
||||
# Under test
|
||||
result = pptx_to_text(pptx_file)
|
||||
|
||||
# Postcondition
|
||||
assert "Hello World" in result
|
||||
assert "Some content" in result
|
||||
assert "[chart omitted]" not in result
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"text/tabwriter"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/exitcodes"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
@@ -25,7 +24,7 @@ Use --json for machine-readable output.`,
|
||||
onyx-cli agents --json
|
||||
onyx-cli agents --json | jq '.[].name'`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg := config.Load()
|
||||
cfg := loadConfig(cmd)
|
||||
if !cfg.IsConfigured() {
|
||||
return exitcodes.New(exitcodes.NotConfigured, "onyx CLI is not configured\n Run: onyx-cli configure")
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"syscall"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/exitcodes"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/overflow"
|
||||
@@ -49,7 +48,7 @@ to a temp file. Set --max-output 0 to disable truncation.`,
|
||||
cat error.log | onyx-cli ask --prompt "Find the root cause"
|
||||
echo "what is onyx?" | onyx-cli ask`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg := config.Load()
|
||||
cfg := loadConfig(cmd)
|
||||
if !cfg.IsConfigured() {
|
||||
return exitcodes.New(exitcodes.NotConfigured, "onyx CLI is not configured\n Run: onyx-cli configure")
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ an interactive setup wizard will guide you through configuration.`,
|
||||
Example: ` onyx-cli chat
|
||||
onyx-cli`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg := config.Load()
|
||||
cfg := loadConfig(cmd)
|
||||
|
||||
// First-run: onboarding
|
||||
if !config.ConfigExists() || !cfg.IsConfigured() {
|
||||
|
||||
@@ -69,7 +69,7 @@ Use --dry-run to test the connection without saving the configuration.`,
|
||||
return exitcodes.New(exitcodes.BadRequest, "both --server-url and --api-key are required for non-interactive setup\n Run 'onyx-cli configure' without flags for interactive setup")
|
||||
}
|
||||
|
||||
cfg := config.Load()
|
||||
cfg := loadConfig(cmd)
|
||||
onboarding.Run(&cfg)
|
||||
return nil
|
||||
},
|
||||
|
||||
@@ -12,7 +12,7 @@ func newExperimentsCmd() *cobra.Command {
|
||||
Use: "experiments",
|
||||
Short: "List experimental features and their status",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg := config.Load()
|
||||
cfg := loadConfig(cmd)
|
||||
_, _ = fmt.Fprintln(cmd.OutOrStdout(), config.ExperimentsText(cfg.Features))
|
||||
return nil
|
||||
},
|
||||
|
||||
@@ -13,6 +13,20 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// loadConfig loads the CLI config, using the --config-file persistent flag if set.
|
||||
func loadConfig(cmd *cobra.Command) config.OnyxCliConfig {
|
||||
cf, _ := cmd.Flags().GetString("config-file")
|
||||
return config.Load(cf)
|
||||
}
|
||||
|
||||
// effectiveConfigPath returns the config file path, respecting --config-file.
|
||||
func effectiveConfigPath(cmd *cobra.Command) string {
|
||||
if cf, _ := cmd.Flags().GetString("config-file"); cf != "" {
|
||||
return cf
|
||||
}
|
||||
return config.ConfigFilePath()
|
||||
}
|
||||
|
||||
// Version and Commit are set via ldflags at build time.
|
||||
var (
|
||||
Version string
|
||||
@@ -29,7 +43,7 @@ func fullVersion() string {
|
||||
func printVersion(cmd *cobra.Command) {
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Client version: %s\n", fullVersion())
|
||||
|
||||
cfg := config.Load()
|
||||
cfg := loadConfig(cmd)
|
||||
if !cfg.IsConfigured() {
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Server version: unknown (not configured)\n")
|
||||
return
|
||||
@@ -84,6 +98,8 @@ func Execute() error {
|
||||
}
|
||||
|
||||
rootCmd.PersistentFlags().BoolVar(&opts.Debug, "debug", false, "run in debug mode")
|
||||
rootCmd.PersistentFlags().String("config-file", "",
|
||||
"Path to config file (default: "+config.ConfigFilePath()+")")
|
||||
|
||||
// Custom --version flag instead of Cobra's built-in (which only shows one version string)
|
||||
var showVersion bool
|
||||
|
||||
@@ -50,16 +50,14 @@ func sessionEnv(s ssh.Session, key string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func validateAPIKey(serverURL string, apiKey string) error {
|
||||
func validateAPIKey(serverCfg config.OnyxCliConfig, apiKey string) error {
|
||||
trimmedKey := strings.TrimSpace(apiKey)
|
||||
if len(trimmedKey) > maxAPIKeyLength {
|
||||
return fmt.Errorf("API key is too long (max %d characters)", maxAPIKeyLength)
|
||||
}
|
||||
|
||||
cfg := config.OnyxCliConfig{
|
||||
ServerURL: serverURL,
|
||||
APIKey: trimmedKey,
|
||||
}
|
||||
cfg := serverCfg
|
||||
cfg.APIKey = trimmedKey
|
||||
client := api.NewClient(cfg)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), apiKeyValidationTimeout)
|
||||
defer cancel()
|
||||
@@ -83,7 +81,7 @@ type authValidatedMsg struct {
|
||||
|
||||
type authModel struct {
|
||||
input textinput.Model
|
||||
serverURL string
|
||||
serverCfg config.OnyxCliConfig
|
||||
state authState
|
||||
apiKey string // set on successful validation
|
||||
errMsg string
|
||||
@@ -91,7 +89,7 @@ type authModel struct {
|
||||
aborted bool
|
||||
}
|
||||
|
||||
func newAuthModel(serverURL, initialErr string) authModel {
|
||||
func newAuthModel(serverCfg config.OnyxCliConfig, initialErr string) authModel {
|
||||
ti := textinput.New()
|
||||
ti.Prompt = " API Key: "
|
||||
ti.EchoMode = textinput.EchoPassword
|
||||
@@ -102,7 +100,7 @@ func newAuthModel(serverURL, initialErr string) authModel {
|
||||
|
||||
return authModel{
|
||||
input: ti,
|
||||
serverURL: serverURL,
|
||||
serverCfg: serverCfg,
|
||||
errMsg: initialErr,
|
||||
}
|
||||
}
|
||||
@@ -138,9 +136,9 @@ func (m authModel) Update(msg tea.Msg) (authModel, tea.Cmd) {
|
||||
}
|
||||
m.state = authValidating
|
||||
m.errMsg = ""
|
||||
serverURL := m.serverURL
|
||||
serverCfg := m.serverCfg
|
||||
return m, func() tea.Msg {
|
||||
return authValidatedMsg{key: key, err: validateAPIKey(serverURL, key)}
|
||||
return authValidatedMsg{key: key, err: validateAPIKey(serverCfg, key)}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -171,12 +169,13 @@ func (m authModel) Update(msg tea.Msg) (authModel, tea.Cmd) {
|
||||
}
|
||||
|
||||
func (m authModel) View() string {
|
||||
settingsURL := strings.TrimRight(m.serverURL, "/") + "/app/settings/accounts-access"
|
||||
serverURL := m.serverCfg.ServerURL
|
||||
settingsURL := strings.TrimRight(serverURL, "/") + "/app/settings/accounts-access"
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString("\n")
|
||||
b.WriteString(" \x1b[1;35mOnyx CLI\x1b[0m\n")
|
||||
b.WriteString(" \x1b[90m" + m.serverURL + "\x1b[0m\n")
|
||||
b.WriteString(" \x1b[90m" + serverURL + "\x1b[0m\n")
|
||||
b.WriteString("\n")
|
||||
b.WriteString(" Generate an API key at:\n")
|
||||
b.WriteString(" \x1b[4;34m" + settingsURL + "\x1b[0m\n")
|
||||
@@ -215,7 +214,7 @@ type serveModel struct {
|
||||
|
||||
func newServeModel(serverCfg config.OnyxCliConfig, initialErr string) serveModel {
|
||||
return serveModel{
|
||||
auth: newAuthModel(serverCfg.ServerURL, initialErr),
|
||||
auth: newAuthModel(serverCfg, initialErr),
|
||||
serverCfg: serverCfg,
|
||||
}
|
||||
}
|
||||
@@ -238,11 +237,8 @@ func (m serveModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return m, tea.Quit
|
||||
}
|
||||
if m.auth.apiKey != "" {
|
||||
cfg := config.OnyxCliConfig{
|
||||
ServerURL: m.serverCfg.ServerURL,
|
||||
APIKey: m.auth.apiKey,
|
||||
DefaultAgentID: m.serverCfg.DefaultAgentID,
|
||||
}
|
||||
cfg := m.serverCfg
|
||||
cfg.APIKey = m.auth.apiKey
|
||||
m.tui = tui.NewModel(cfg)
|
||||
m.authed = true
|
||||
w, h := m.width, m.height
|
||||
@@ -280,6 +276,8 @@ func newServeCmd() *cobra.Command {
|
||||
rateLimitPerMin int
|
||||
rateLimitBurst int
|
||||
rateLimitCache int
|
||||
serverURL string
|
||||
apiServerURL string
|
||||
)
|
||||
|
||||
cmd := &cobra.Command{
|
||||
@@ -300,9 +298,18 @@ environment variable (the --host-key flag takes precedence).`,
|
||||
Example: ` onyx-cli serve --port 2222
|
||||
ssh localhost -p 2222
|
||||
onyx-cli serve --host 0.0.0.0 --port 2222
|
||||
onyx-cli serve --idle-timeout 30m --max-session-timeout 2h`,
|
||||
onyx-cli serve --idle-timeout 30m --max-session-timeout 2h
|
||||
onyx-cli serve --server-url https://my-onyx.example.com
|
||||
onyx-cli serve --api-server-url http://api_server:8080 # bypass nginx
|
||||
onyx-cli serve --config-file /etc/onyx-cli/config.json # global flag`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
serverCfg := config.Load()
|
||||
serverCfg := loadConfig(cmd)
|
||||
if cmd.Flags().Changed("server-url") {
|
||||
serverCfg.ServerURL = serverURL
|
||||
}
|
||||
if cmd.Flags().Changed("api-server-url") {
|
||||
serverCfg.InternalURL = apiServerURL
|
||||
}
|
||||
if serverCfg.ServerURL == "" {
|
||||
return exitcodes.New(exitcodes.NotConfigured, "server URL is not configured\n Run: onyx-cli configure")
|
||||
}
|
||||
@@ -333,7 +340,7 @@ environment variable (the --host-key flag takes precedence).`,
|
||||
var envErr string
|
||||
|
||||
if apiKey != "" {
|
||||
if err := validateAPIKey(serverCfg.ServerURL, apiKey); err != nil {
|
||||
if err := validateAPIKey(serverCfg, apiKey); err != nil {
|
||||
envErr = fmt.Sprintf("ONYX_API_KEY from SSH environment is invalid: %s", err.Error())
|
||||
apiKey = ""
|
||||
}
|
||||
@@ -341,11 +348,8 @@ environment variable (the --host-key flag takes precedence).`,
|
||||
|
||||
if apiKey != "" {
|
||||
// Env key is valid — go straight to the TUI.
|
||||
cfg := config.OnyxCliConfig{
|
||||
ServerURL: serverCfg.ServerURL,
|
||||
APIKey: apiKey,
|
||||
DefaultAgentID: serverCfg.DefaultAgentID,
|
||||
}
|
||||
cfg := serverCfg
|
||||
cfg.APIKey = apiKey
|
||||
return tui.NewModel(cfg), []tea.ProgramOption{
|
||||
tea.WithAltScreen(),
|
||||
tea.WithMouseCellMotion(),
|
||||
@@ -446,6 +450,10 @@ environment variable (the --host-key flag takes precedence).`,
|
||||
defaultServeRateLimitCacheSize,
|
||||
"Maximum number of IP limiter entries tracked in memory",
|
||||
)
|
||||
cmd.Flags().StringVar(&serverURL, "server-url", "",
|
||||
"Onyx server URL (overrides config file and $"+config.EnvServerURL+")")
|
||||
cmd.Flags().StringVar(&apiServerURL, "api-server-url", "",
|
||||
"API server URL for direct access, bypassing nginx (overrides $"+config.EnvAPIServerURL+")")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
@@ -4,10 +4,10 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/exitcodes"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -23,19 +23,21 @@ is valid. Also reports the server version and warns if it is below the
|
||||
minimum required.`,
|
||||
Example: ` onyx-cli validate-config`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfgPath := effectiveConfigPath(cmd)
|
||||
|
||||
// Check config file
|
||||
if !config.ConfigExists() {
|
||||
return exitcodes.Newf(exitcodes.NotConfigured, "config file not found at %s\n Run: onyx-cli configure", config.ConfigFilePath())
|
||||
if _, err := os.Stat(cfgPath); err != nil {
|
||||
return exitcodes.Newf(exitcodes.NotConfigured, "config file not found at %s\n Run: onyx-cli configure", cfgPath)
|
||||
}
|
||||
|
||||
cfg := config.Load()
|
||||
cfg := loadConfig(cmd)
|
||||
|
||||
// Check API key
|
||||
if !cfg.IsConfigured() {
|
||||
return exitcodes.New(exitcodes.NotConfigured, "API key is missing\n Run: onyx-cli configure")
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Config: %s\n", config.ConfigFilePath())
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Config: %s\n", cfgPath)
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Server: %s\n", cfg.ServerURL)
|
||||
|
||||
// Test connection
|
||||
|
||||
@@ -16,18 +16,30 @@ import (
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Client is the Onyx API client.
|
||||
type Client struct {
|
||||
baseURL string
|
||||
serverURL string // root server URL (for reachability checks)
|
||||
baseURL string // API base URL (includes /api when going through nginx)
|
||||
apiKey string
|
||||
httpClient *http.Client // default 30s timeout for quick requests
|
||||
longHTTPClient *http.Client // 5min timeout for streaming/uploads
|
||||
}
|
||||
|
||||
// NewClient creates a new API client from config.
|
||||
//
|
||||
// When InternalURL is set, requests go directly to the API server (no /api
|
||||
// prefix needed — mirrors INTERNAL_URL in the web server). Otherwise,
|
||||
// requests go through the nginx proxy at ServerURL which strips /api.
|
||||
func NewClient(cfg config.OnyxCliConfig) *Client {
|
||||
baseURL := apiBaseURL(cfg)
|
||||
log.WithFields(log.Fields{
|
||||
"server_url": cfg.ServerURL,
|
||||
"internal_url": cfg.InternalURL,
|
||||
"base_url": baseURL,
|
||||
}).Debug("creating API client")
|
||||
var transport *http.Transport
|
||||
if t, ok := http.DefaultTransport.(*http.Transport); ok {
|
||||
transport = t.Clone()
|
||||
@@ -35,8 +47,9 @@ func NewClient(cfg config.OnyxCliConfig) *Client {
|
||||
transport = &http.Transport{}
|
||||
}
|
||||
return &Client{
|
||||
baseURL: strings.TrimRight(cfg.ServerURL, "/"),
|
||||
apiKey: cfg.APIKey,
|
||||
serverURL: strings.TrimRight(cfg.ServerURL, "/"),
|
||||
baseURL: baseURL,
|
||||
apiKey: cfg.APIKey,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: transport,
|
||||
@@ -48,14 +61,27 @@ func NewClient(cfg config.OnyxCliConfig) *Client {
|
||||
}
|
||||
}
|
||||
|
||||
// apiBaseURL returns the base URL for API requests. When InternalURL is set,
|
||||
// it points directly at the API server. Otherwise it goes through the nginx
|
||||
// proxy at ServerURL/api.
|
||||
func apiBaseURL(cfg config.OnyxCliConfig) string {
|
||||
if cfg.InternalURL != "" {
|
||||
return strings.TrimRight(cfg.InternalURL, "/")
|
||||
}
|
||||
return strings.TrimRight(cfg.ServerURL, "/") + "/api"
|
||||
}
|
||||
|
||||
// UpdateConfig replaces the client's config.
|
||||
func (c *Client) UpdateConfig(cfg config.OnyxCliConfig) {
|
||||
c.baseURL = strings.TrimRight(cfg.ServerURL, "/")
|
||||
c.serverURL = strings.TrimRight(cfg.ServerURL, "/")
|
||||
c.baseURL = apiBaseURL(cfg)
|
||||
c.apiKey = cfg.APIKey
|
||||
}
|
||||
|
||||
func (c *Client) newRequest(ctx context.Context, method, path string, body io.Reader) (*http.Request, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, method, c.baseURL+path, body)
|
||||
url := c.baseURL + path
|
||||
log.WithFields(log.Fields{"method": method, "url": url}).Debug("API request")
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -87,12 +113,16 @@ func (c *Client) doJSON(ctx context.Context, method, path string, reqBody any, r
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("url", req.URL.String()).Debug("API request failed")
|
||||
return err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
log.WithFields(log.Fields{"url": req.URL.String(), "status": resp.StatusCode}).Debug("API response")
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
log.WithFields(log.Fields{"status": resp.StatusCode, "body": string(respBody)}).Debug("API error response")
|
||||
return &OnyxAPIError{StatusCode: resp.StatusCode, Detail: string(respBody)}
|
||||
}
|
||||
|
||||
@@ -105,16 +135,26 @@ func (c *Client) doJSON(ctx context.Context, method, path string, reqBody any, r
|
||||
// TestConnection checks if the server is reachable and credentials are valid.
|
||||
// Returns nil on success, or an error with a descriptive message on failure.
|
||||
func (c *Client) TestConnection(ctx context.Context) error {
|
||||
// Step 1: Basic reachability
|
||||
req, err := c.newRequest(ctx, "GET", "/", nil)
|
||||
// Step 1: Basic reachability (hit the server root, not the API prefix)
|
||||
reachURL := c.serverURL
|
||||
if reachURL == "" {
|
||||
reachURL = c.baseURL
|
||||
}
|
||||
log.WithFields(log.Fields{
|
||||
"reach_url": reachURL,
|
||||
"base_url": c.baseURL,
|
||||
}).Debug("testing connection — step 1: reachability")
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", reachURL+"/", nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot connect to %s: %w", c.baseURL, err)
|
||||
return fmt.Errorf("cannot connect to %s: %w", reachURL, err)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot connect to %s — is the server running?", c.baseURL)
|
||||
log.WithError(err).Debug("reachability check failed")
|
||||
return fmt.Errorf("cannot connect to %s — is the server running?", reachURL)
|
||||
}
|
||||
log.WithField("status", resp.StatusCode).Debug("reachability check response")
|
||||
_ = resp.Body.Close()
|
||||
|
||||
serverHeader := strings.ToLower(resp.Header.Get("Server"))
|
||||
@@ -127,7 +167,8 @@ func (c *Client) TestConnection(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// Step 2: Authenticated check
|
||||
req2, err := c.newRequest(ctx, "GET", "/api/me", nil)
|
||||
log.WithField("url", c.baseURL+"/me").Debug("testing connection — step 2: auth check")
|
||||
req2, err := c.newRequest(ctx, "GET", "/me", nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("server reachable but API error: %w", err)
|
||||
}
|
||||
@@ -167,7 +208,7 @@ func (c *Client) TestConnection(ctx context.Context) error {
|
||||
// ListAgents returns visible agents.
|
||||
func (c *Client) ListAgents(ctx context.Context) ([]models.AgentSummary, error) {
|
||||
var raw []models.AgentSummary
|
||||
if err := c.doJSON(ctx, "GET", "/api/persona", nil, &raw); err != nil {
|
||||
if err := c.doJSON(ctx, "GET", "/persona", nil, &raw); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var result []models.AgentSummary
|
||||
@@ -184,7 +225,7 @@ func (c *Client) ListChatSessions(ctx context.Context) ([]models.ChatSessionDeta
|
||||
var resp struct {
|
||||
Sessions []models.ChatSessionDetails `json:"sessions"`
|
||||
}
|
||||
if err := c.doJSON(ctx, "GET", "/api/chat/get-user-chat-sessions", nil, &resp); err != nil {
|
||||
if err := c.doJSON(ctx, "GET", "/chat/get-user-chat-sessions", nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp.Sessions, nil
|
||||
@@ -193,7 +234,7 @@ func (c *Client) ListChatSessions(ctx context.Context) ([]models.ChatSessionDeta
|
||||
// GetChatSession returns full details for a session.
|
||||
func (c *Client) GetChatSession(ctx context.Context, sessionID string) (*models.ChatSessionDetailResponse, error) {
|
||||
var resp models.ChatSessionDetailResponse
|
||||
if err := c.doJSON(ctx, "GET", "/api/chat/get-chat-session/"+sessionID, nil, &resp); err != nil {
|
||||
if err := c.doJSON(ctx, "GET", "/chat/get-chat-session/"+sessionID, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
@@ -210,7 +251,7 @@ func (c *Client) RenameChatSession(ctx context.Context, sessionID string, name *
|
||||
var resp struct {
|
||||
NewName string `json:"new_name"`
|
||||
}
|
||||
if err := c.doJSON(ctx, "PUT", "/api/chat/rename-chat-session", payload, &resp); err != nil {
|
||||
if err := c.doJSON(ctx, "PUT", "/chat/rename-chat-session", payload, &resp); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return resp.NewName, nil
|
||||
@@ -236,7 +277,7 @@ func (c *Client) UploadFile(ctx context.Context, filePath string) (*models.FileD
|
||||
}
|
||||
_ = writer.Close()
|
||||
|
||||
req, err := c.newRequest(ctx, "POST", "/api/user/projects/file/upload", &buf)
|
||||
req, err := c.newRequest(ctx, "POST", "/user/projects/file/upload", &buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -275,7 +316,7 @@ func (c *Client) GetBackendVersion(ctx context.Context) (string, error) {
|
||||
var resp struct {
|
||||
BackendVersion string `json:"backend_version"`
|
||||
}
|
||||
if err := c.doJSON(ctx, "GET", "/api/version", nil, &resp); err != nil {
|
||||
if err := c.doJSON(ctx, "GET", "/version", nil, &resp); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return resp.BackendVersion, nil
|
||||
@@ -283,7 +324,7 @@ func (c *Client) GetBackendVersion(ctx context.Context) (string, error) {
|
||||
|
||||
// StopChatSession sends a stop signal for a streaming session (best-effort).
|
||||
func (c *Client) StopChatSession(ctx context.Context, sessionID string) {
|
||||
req, err := c.newRequest(ctx, "POST", "/api/chat/stop-chat-session/"+sessionID, nil)
|
||||
req, err := c.newRequest(ctx, "POST", "/chat/stop-chat-session/"+sessionID, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -64,7 +64,7 @@ func (c *Client) SendMessageStream(
|
||||
return
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+"/api/chat/send-chat-message", nil)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+"/chat/send-chat-message", nil)
|
||||
if err != nil {
|
||||
ch <- models.ErrorEvent{Error: fmt.Sprintf("request error: %v", err), IsRetryable: false}
|
||||
return
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
const (
|
||||
EnvServerURL = "ONYX_SERVER_URL"
|
||||
EnvAPIServerURL = "ONYX_API_SERVER_URL"
|
||||
EnvAPIKey = "ONYX_API_KEY"
|
||||
EnvAgentID = "ONYX_PERSONA_ID"
|
||||
EnvSSHHostKey = "ONYX_SSH_HOST_KEY"
|
||||
@@ -27,6 +28,7 @@ type Features struct {
|
||||
// OnyxCliConfig holds the CLI configuration.
|
||||
type OnyxCliConfig struct {
|
||||
ServerURL string `json:"server_url"`
|
||||
InternalURL string `json:"internal_url,omitempty"`
|
||||
APIKey string `json:"api_key"`
|
||||
DefaultAgentID int `json:"default_persona_id"`
|
||||
Features Features `json:"features,omitempty"`
|
||||
@@ -78,30 +80,47 @@ func ConfigExists() bool {
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// LoadFromDisk reads config from the file only, without applying environment
|
||||
// variable overrides. Use this when you need the persisted config values
|
||||
// (e.g., to preserve them during a save operation).
|
||||
func LoadFromDisk() OnyxCliConfig {
|
||||
// LoadFromDisk reads config from the given file path without applying
|
||||
// environment variable overrides. Use this when you need the persisted
|
||||
// config values (e.g., to preserve them during a save operation).
|
||||
// If no path is provided, the default config file path is used.
|
||||
func LoadFromDisk(path ...string) OnyxCliConfig {
|
||||
p := ConfigFilePath()
|
||||
if len(path) > 0 && path[0] != "" {
|
||||
p = path[0]
|
||||
}
|
||||
|
||||
cfg := DefaultConfig()
|
||||
|
||||
data, err := os.ReadFile(ConfigFilePath())
|
||||
data, err := os.ReadFile(p)
|
||||
if err == nil {
|
||||
if jsonErr := json.Unmarshal(data, &cfg); jsonErr != nil {
|
||||
fmt.Fprintf(os.Stderr, "warning: config file %s is malformed: %v (using defaults)\n", ConfigFilePath(), jsonErr)
|
||||
fmt.Fprintf(os.Stderr, "warning: config file %s is malformed: %v (using defaults)\n", p, jsonErr)
|
||||
}
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Load reads config from file and applies environment variable overrides.
|
||||
func Load() OnyxCliConfig {
|
||||
cfg := LoadFromDisk()
|
||||
// Load reads config from the given file path and applies environment variable
|
||||
// overrides. If no path is provided, the default config file path is used.
|
||||
func Load(path ...string) OnyxCliConfig {
|
||||
cfg := LoadFromDisk(path...)
|
||||
applyEnvOverrides(&cfg)
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Environment overrides
|
||||
func applyEnvOverrides(cfg *OnyxCliConfig) {
|
||||
if v := os.Getenv(EnvServerURL); v != "" {
|
||||
cfg.ServerURL = v
|
||||
}
|
||||
// ONYX_API_SERVER_URL takes precedence; fall back to INTERNAL_URL
|
||||
// (the env var used by the web server) for compatibility.
|
||||
if v := os.Getenv(EnvAPIServerURL); v != "" {
|
||||
cfg.InternalURL = v
|
||||
} else if v := os.Getenv("INTERNAL_URL"); v != "" {
|
||||
cfg.InternalURL = v
|
||||
}
|
||||
if v := os.Getenv(EnvAPIKey); v != "" {
|
||||
cfg.APIKey = v
|
||||
}
|
||||
@@ -117,8 +136,6 @@ func Load() OnyxCliConfig {
|
||||
fmt.Fprintf(os.Stderr, "warning: invalid value %q for %s (expected true/false), ignoring\n", v, EnvStreamMarkdown)
|
||||
}
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Save writes the config to disk, creating parent directories if needed.
|
||||
|
||||
@@ -19,6 +19,6 @@ dependencies:
|
||||
version: 5.4.0
|
||||
- name: code-interpreter
|
||||
repository: https://onyx-dot-app.github.io/python-sandbox/
|
||||
version: 0.3.2
|
||||
digest: sha256:74908ea45ace2b4be913ff762772e6d87e40bab64e92c6662aa51730eaeb9d87
|
||||
generated: "2026-04-06T15:34:02.597166-07:00"
|
||||
version: 0.3.1
|
||||
digest: sha256:4965b6ea3674c37163832a2192cd3bc8004f2228729fca170af0b9f457e8f987
|
||||
generated: "2026-03-02T15:29:39.632344-08:00"
|
||||
|
||||
@@ -5,7 +5,7 @@ home: https://www.onyx.app/
|
||||
sources:
|
||||
- "https://github.com/onyx-dot-app/onyx"
|
||||
type: application
|
||||
version: 0.4.40
|
||||
version: 0.4.39
|
||||
appVersion: latest
|
||||
annotations:
|
||||
category: Productivity
|
||||
@@ -45,6 +45,6 @@ dependencies:
|
||||
repository: https://charts.min.io/
|
||||
condition: minio.enabled
|
||||
- name: code-interpreter
|
||||
version: 0.3.2
|
||||
version: 0.3.1
|
||||
repository: https://onyx-dot-app.github.io/python-sandbox/
|
||||
condition: codeInterpreter.enabled
|
||||
|
||||
@@ -67,9 +67,6 @@ spec:
|
||||
- "/bin/sh"
|
||||
- "-c"
|
||||
- |
|
||||
{{- if .Values.api.runUpdateCaCertificates }}
|
||||
update-ca-certificates &&
|
||||
{{- end }}
|
||||
alembic upgrade head &&
|
||||
echo "Starting Onyx Api Server" &&
|
||||
uvicorn onyx.main:app --host {{ .Values.global.host }} --port {{ .Values.api.containerPorts.server }}
|
||||
|
||||
@@ -504,18 +504,6 @@ api:
|
||||
tolerations: []
|
||||
affinity: {}
|
||||
|
||||
# Run update-ca-certificates before starting the server.
|
||||
# Useful when mounting custom CA certificates via volumes/volumeMounts.
|
||||
# NOTE: Requires the container to run as root (runAsUser: 0).
|
||||
# CA certificate files must be mounted under /usr/local/share/ca-certificates/
|
||||
# with a .crt extension (e.g. /usr/local/share/ca-certificates/my-ca.crt).
|
||||
# NOTE: Python HTTP clients (requests, httpx) use certifi's bundle by default
|
||||
# and will not pick up the system CA store automatically. Set the following
|
||||
# environment variables via configMap values (loaded through envFrom) to make them use the updated system bundle:
|
||||
# REQUESTS_CA_BUNDLE: /etc/ssl/certs/ca-certificates.crt
|
||||
# SSL_CERT_FILE: /etc/ssl/certs/ca-certificates.crt
|
||||
runUpdateCaCertificates: false
|
||||
|
||||
|
||||
######################################################################
|
||||
#
|
||||
|
||||
@@ -30,10 +30,7 @@ target "backend" {
|
||||
context = "backend"
|
||||
dockerfile = "Dockerfile"
|
||||
|
||||
cache-from = [
|
||||
"type=registry,ref=${BACKEND_REPOSITORY}:latest",
|
||||
"type=registry,ref=${BACKEND_REPOSITORY}:edge",
|
||||
]
|
||||
cache-from = ["type=registry,ref=${BACKEND_REPOSITORY}:latest"]
|
||||
cache-to = ["type=inline"]
|
||||
|
||||
tags = ["${BACKEND_REPOSITORY}:${TAG}"]
|
||||
@@ -43,10 +40,7 @@ target "web" {
|
||||
context = "web"
|
||||
dockerfile = "Dockerfile"
|
||||
|
||||
cache-from = [
|
||||
"type=registry,ref=${WEB_SERVER_REPOSITORY}:latest",
|
||||
"type=registry,ref=${WEB_SERVER_REPOSITORY}:edge",
|
||||
]
|
||||
cache-from = ["type=registry,ref=${WEB_SERVER_REPOSITORY}:latest"]
|
||||
cache-to = ["type=inline"]
|
||||
|
||||
tags = ["${WEB_SERVER_REPOSITORY}:${TAG}"]
|
||||
@@ -57,10 +51,7 @@ target "model-server" {
|
||||
|
||||
dockerfile = "Dockerfile.model_server"
|
||||
|
||||
cache-from = [
|
||||
"type=registry,ref=${MODEL_SERVER_REPOSITORY}:latest",
|
||||
"type=registry,ref=${MODEL_SERVER_REPOSITORY}:edge",
|
||||
]
|
||||
cache-from = ["type=registry,ref=${MODEL_SERVER_REPOSITORY}:latest"]
|
||||
cache-to = ["type=inline"]
|
||||
|
||||
tags = ["${MODEL_SERVER_REPOSITORY}:${TAG}"]
|
||||
@@ -82,10 +73,7 @@ target "cli" {
|
||||
context = "cli"
|
||||
dockerfile = "Dockerfile"
|
||||
|
||||
cache-from = [
|
||||
"type=registry,ref=${CLI_REPOSITORY}:latest",
|
||||
"type=registry,ref=${CLI_REPOSITORY}:edge",
|
||||
]
|
||||
cache-from = ["type=registry,ref=${CLI_REPOSITORY}:latest"]
|
||||
cache-to = ["type=inline"]
|
||||
|
||||
tags = ["${CLI_REPOSITORY}:${TAG}"]
|
||||
|
||||
225
docs/METRICS.md
225
docs/METRICS.md
@@ -6,11 +6,11 @@ All Prometheus metrics live in the `backend/onyx/server/metrics/` package. Follo
|
||||
|
||||
### 1. Choose the right file (or create a new one)
|
||||
|
||||
| File | Purpose |
|
||||
| ------------------------------------- | -------------------------------------------- |
|
||||
| `metrics/slow_requests.py` | Slow request counter + callback |
|
||||
| `metrics/postgres_connection_pool.py` | SQLAlchemy connection pool metrics |
|
||||
| `metrics/prometheus_setup.py` | FastAPI instrumentator config (orchestrator) |
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `metrics/slow_requests.py` | Slow request counter + callback |
|
||||
| `metrics/postgres_connection_pool.py` | SQLAlchemy connection pool metrics |
|
||||
| `metrics/prometheus_setup.py` | FastAPI instrumentator config (orchestrator) |
|
||||
|
||||
If your metric is a standalone concern (e.g. cache hit rates, queue depths), create a new file under `metrics/` and keep one metric concept per file.
|
||||
|
||||
@@ -30,7 +30,6 @@ _my_counter = Counter(
|
||||
```
|
||||
|
||||
**Naming conventions:**
|
||||
|
||||
- Prefix all metric names with `onyx_`
|
||||
- Counters: `_total` suffix (e.g. `onyx_api_slow_requests_total`)
|
||||
- Histograms: `_seconds` or `_bytes` suffix for durations/sizes
|
||||
@@ -108,26 +107,26 @@ These metrics are exposed at `GET /metrics` on the API server.
|
||||
|
||||
### Built-in (via `prometheus-fastapi-instrumentator`)
|
||||
|
||||
| Metric | Type | Labels | Description |
|
||||
| ------------------------------------- | --------- | ----------------------------- | ------------------------------------------------- |
|
||||
| `http_requests_total` | Counter | `method`, `status`, `handler` | Total request count |
|
||||
| `http_request_duration_highr_seconds` | Histogram | _(none)_ | High-resolution latency (many buckets, no labels) |
|
||||
| `http_request_duration_seconds` | Histogram | `method`, `handler` | Latency by handler (custom buckets for P95/P99) |
|
||||
| `http_request_size_bytes` | Summary | `handler` | Incoming request content length |
|
||||
| `http_response_size_bytes` | Summary | `handler` | Outgoing response content length |
|
||||
| `http_requests_inprogress` | Gauge | `method`, `handler` | Currently in-flight requests |
|
||||
| Metric | Type | Labels | Description |
|
||||
|--------|------|--------|-------------|
|
||||
| `http_requests_total` | Counter | `method`, `status`, `handler` | Total request count |
|
||||
| `http_request_duration_highr_seconds` | Histogram | _(none)_ | High-resolution latency (many buckets, no labels) |
|
||||
| `http_request_duration_seconds` | Histogram | `method`, `handler` | Latency by handler (custom buckets for P95/P99) |
|
||||
| `http_request_size_bytes` | Summary | `handler` | Incoming request content length |
|
||||
| `http_response_size_bytes` | Summary | `handler` | Outgoing response content length |
|
||||
| `http_requests_inprogress` | Gauge | `method`, `handler` | Currently in-flight requests |
|
||||
|
||||
### Custom (via `onyx.server.metrics`)
|
||||
|
||||
| Metric | Type | Labels | Description |
|
||||
| ------------------------------ | ------- | ----------------------------- | ---------------------------------------------------------------- |
|
||||
| Metric | Type | Labels | Description |
|
||||
|--------|------|--------|-------------|
|
||||
| `onyx_api_slow_requests_total` | Counter | `method`, `handler`, `status` | Requests exceeding `SLOW_REQUEST_THRESHOLD_SECONDS` (default 1s) |
|
||||
|
||||
### Configuration
|
||||
|
||||
| Env Var | Default | Description |
|
||||
| -------------------------------- | ------- | -------------------------------------------- |
|
||||
| `SLOW_REQUEST_THRESHOLD_SECONDS` | `1.0` | Duration threshold for slow request counting |
|
||||
| Env Var | Default | Description |
|
||||
|---------|---------|-------------|
|
||||
| `SLOW_REQUEST_THRESHOLD_SECONDS` | `1.0` | Duration threshold for slow request counting |
|
||||
|
||||
### Instrumentator Settings
|
||||
|
||||
@@ -142,188 +141,44 @@ These metrics provide visibility into SQLAlchemy connection pool state across al
|
||||
|
||||
### Pool State (via custom Prometheus collector — snapshot on each scrape)
|
||||
|
||||
| Metric | Type | Labels | Description |
|
||||
| -------------------------- | ----- | -------- | ----------------------------------------------- |
|
||||
| `onyx_db_pool_checked_out` | Gauge | `engine` | Currently checked-out connections |
|
||||
| `onyx_db_pool_checked_in` | Gauge | `engine` | Idle connections available in the pool |
|
||||
| `onyx_db_pool_overflow` | Gauge | `engine` | Current overflow connections beyond `pool_size` |
|
||||
| `onyx_db_pool_size` | Gauge | `engine` | Configured pool size (constant) |
|
||||
| Metric | Type | Labels | Description |
|
||||
|--------|------|--------|-------------|
|
||||
| `onyx_db_pool_checked_out` | Gauge | `engine` | Currently checked-out connections |
|
||||
| `onyx_db_pool_checked_in` | Gauge | `engine` | Idle connections available in the pool |
|
||||
| `onyx_db_pool_overflow` | Gauge | `engine` | Current overflow connections beyond `pool_size` |
|
||||
| `onyx_db_pool_size` | Gauge | `engine` | Configured pool size (constant) |
|
||||
|
||||
### Pool Lifecycle (via SQLAlchemy pool event listeners)
|
||||
|
||||
| Metric | Type | Labels | Description |
|
||||
| ---------------------------------------- | ------- | -------- | ---------------------------------------- |
|
||||
| `onyx_db_pool_checkout_total` | Counter | `engine` | Total connection checkouts from the pool |
|
||||
| `onyx_db_pool_checkin_total` | Counter | `engine` | Total connection checkins to the pool |
|
||||
| `onyx_db_pool_connections_created_total` | Counter | `engine` | Total new database connections created |
|
||||
| `onyx_db_pool_invalidations_total` | Counter | `engine` | Total connection invalidations |
|
||||
| `onyx_db_pool_checkout_timeout_total` | Counter | `engine` | Total connection checkout timeouts |
|
||||
| Metric | Type | Labels | Description |
|
||||
|--------|------|--------|-------------|
|
||||
| `onyx_db_pool_checkout_total` | Counter | `engine` | Total connection checkouts from the pool |
|
||||
| `onyx_db_pool_checkin_total` | Counter | `engine` | Total connection checkins to the pool |
|
||||
| `onyx_db_pool_connections_created_total` | Counter | `engine` | Total new database connections created |
|
||||
| `onyx_db_pool_invalidations_total` | Counter | `engine` | Total connection invalidations |
|
||||
| `onyx_db_pool_checkout_timeout_total` | Counter | `engine` | Total connection checkout timeouts |
|
||||
|
||||
### Per-Endpoint Attribution (via pool events + endpoint context middleware)
|
||||
|
||||
| Metric | Type | Labels | Description |
|
||||
| -------------------------------------- | --------- | ------------------- | ----------------------------------------------- |
|
||||
| `onyx_db_connections_held_by_endpoint` | Gauge | `handler`, `engine` | DB connections currently held, by endpoint |
|
||||
| `onyx_db_connection_hold_seconds` | Histogram | `handler`, `engine` | Duration a DB connection is held by an endpoint |
|
||||
| Metric | Type | Labels | Description |
|
||||
|--------|------|--------|-------------|
|
||||
| `onyx_db_connections_held_by_endpoint` | Gauge | `handler`, `engine` | DB connections currently held, by endpoint |
|
||||
| `onyx_db_connection_hold_seconds` | Histogram | `handler`, `engine` | Duration a DB connection is held by an endpoint |
|
||||
|
||||
Engine label values: `sync` (main read-write), `async` (async sessions), `readonly` (read-only user).
|
||||
|
||||
Connections from background tasks (Celery) or boot-time warmup appear as `handler="unknown"`.
|
||||
|
||||
## Celery Worker Metrics
|
||||
|
||||
Celery workers expose metrics via a standalone Prometheus HTTP server (separate from the API server's `/metrics` endpoint). Each worker type runs its own server on a dedicated port.
|
||||
|
||||
### Metrics Server (`onyx.server.metrics.metrics_server`)
|
||||
|
||||
| Env Var | Default | Description |
|
||||
| ---------------------------- | ------------------- | ----------------------------------------------------- |
|
||||
| `PROMETHEUS_METRICS_PORT` | _(per worker type)_ | Override the default port for this worker |
|
||||
| `PROMETHEUS_METRICS_ENABLED` | `true` | Set to `false` to disable the metrics server entirely |
|
||||
|
||||
Default ports:
|
||||
|
||||
| Worker | Port |
|
||||
| --------------- | ---- |
|
||||
| `docfetching` | 9092 |
|
||||
| `docprocessing` | 9093 |
|
||||
| `monitoring` | 9096 |
|
||||
|
||||
Workers without a default port and no `PROMETHEUS_METRICS_PORT` env var will skip starting the server.
|
||||
|
||||
### Generic Task Lifecycle Metrics (`onyx.server.metrics.celery_task_metrics`)
|
||||
|
||||
Push-based metrics that fire on Celery signals for all tasks on the worker.
|
||||
|
||||
| Metric | Type | Labels | Description |
|
||||
| ----------------------------------- | --------- | ------------------------------- | ----------------------------------------------------------------------------- |
|
||||
| `onyx_celery_task_started_total` | Counter | `task_name`, `queue` | Total tasks started |
|
||||
| `onyx_celery_task_completed_total` | Counter | `task_name`, `queue`, `outcome` | Total tasks completed (`outcome`: `success` or `failure`) |
|
||||
| `onyx_celery_task_duration_seconds` | Histogram | `task_name`, `queue` | Task execution duration. Buckets: 1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600 |
|
||||
| `onyx_celery_tasks_active` | Gauge | `task_name`, `queue` | Currently executing tasks |
|
||||
| `onyx_celery_task_retried_total` | Counter | `task_name`, `queue` | Total task retries |
|
||||
| `onyx_celery_task_revoked_total` | Counter | `task_name` | Total tasks revoked (cancelled) |
|
||||
| `onyx_celery_task_rejected_total` | Counter | `task_name` | Total tasks rejected by worker |
|
||||
|
||||
Stale start-time entries (tasks killed via SIGTERM/OOM where `task_postrun` never fires) are evicted after 1 hour.
|
||||
|
||||
### Per-Connector Indexing Metrics (`onyx.server.metrics.indexing_task_metrics`)
|
||||
|
||||
Enriches docfetching and docprocessing tasks with connector-level labels. Silently no-ops for all other tasks.
|
||||
|
||||
| Metric | Type | Labels | Description |
|
||||
| ------------------------------------- | --------- | ----------------------------------------------------------- | ---------------------------------------- |
|
||||
| `onyx_indexing_task_started_total` | Counter | `task_name`, `source`, `tenant_id`, `cc_pair_id` | Indexing tasks started per connector |
|
||||
| `onyx_indexing_task_completed_total` | Counter | `task_name`, `source`, `tenant_id`, `cc_pair_id`, `outcome` | Indexing tasks completed per connector |
|
||||
| `onyx_indexing_task_duration_seconds` | Histogram | `task_name`, `source`, `tenant_id` | Indexing task duration by connector type |
|
||||
|
||||
`connector_name` is intentionally excluded from these push-based counters to avoid unbounded cardinality (it's a free-form user string). The pull-based collectors on the monitoring worker include it since they have bounded cardinality (one series per connector).
|
||||
|
||||
### Pull-Based Collectors (`onyx.server.metrics.indexing_pipeline`)
|
||||
|
||||
Registered only in the **Monitoring** worker. Collectors query Redis/Postgres at scrape time with a 30-second TTL cache.
|
||||
|
||||
| Metric | Type | Labels | Description |
|
||||
| ------------------------------------ | ----- | ------- | ----------------------------------- |
|
||||
| `onyx_queue_depth` | Gauge | `queue` | Celery queue length |
|
||||
| `onyx_queue_unacked` | Gauge | `queue` | Unacknowledged messages per queue |
|
||||
| `onyx_queue_oldest_task_age_seconds` | Gauge | `queue` | Age of the oldest task in the queue |
|
||||
|
||||
Plus additional connector health, index attempt, and worker heartbeat metrics — see `indexing_pipeline.py` for the full list.
|
||||
|
||||
### Adding Metrics to a Worker
|
||||
|
||||
Currently only the docfetching and docprocessing workers have push-based task metrics wired up. To add metrics to another worker (e.g. heavy, light, primary):
|
||||
|
||||
**1. Import and call the generic handlers from the worker's signal handlers:**
|
||||
|
||||
```python
|
||||
from onyx.server.metrics.celery_task_metrics import (
|
||||
on_celery_task_prerun,
|
||||
on_celery_task_postrun,
|
||||
on_celery_task_retry,
|
||||
on_celery_task_revoked,
|
||||
on_celery_task_rejected,
|
||||
)
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(sender, task_id, task, args, kwargs, **kwds):
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
on_celery_task_prerun(task_id, task)
|
||||
```
|
||||
|
||||
Do the same for `task_postrun`, `task_retry`, `task_revoked`, and `task_rejected` — see `apps/docfetching.py` for the complete example.
|
||||
|
||||
**2. Start the metrics server on `worker_ready`:**
|
||||
|
||||
```python
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender, **kwargs):
|
||||
start_metrics_server("your_worker_type")
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
```
|
||||
|
||||
Add a default port for your worker type in `metrics_server.py`'s `_DEFAULT_PORTS` dict, or set `PROMETHEUS_METRICS_PORT` in the environment.
|
||||
|
||||
**3. (Optional) Add domain-specific enrichment:**
|
||||
|
||||
If your tasks need richer labels beyond `task_name`/`queue`, create a new module in `server/metrics/` following `indexing_task_metrics.py`:
|
||||
|
||||
- Define Counters/Histograms with your domain labels
|
||||
- Write `on_<domain>_task_prerun` / `on_<domain>_task_postrun` handlers that filter by task name and no-op for others
|
||||
- Call them from the worker's signal handlers alongside the generic ones
|
||||
|
||||
**Cardinality warning:** Never use user-defined free-form strings as metric labels — they create unbounded cardinality. Use IDs or enum values. If you need free-form labels, use pull-based collectors (monitoring worker) where cardinality is naturally bounded.
|
||||
|
||||
### Current Worker Integration Status
|
||||
|
||||
| Worker | Generic Task Metrics | Domain Metrics | Metrics Server |
|
||||
| -------------------- | -------------------- | -------------- | ------------------------------------ |
|
||||
| Docfetching | ✓ | ✓ (indexing) | ✓ (port 9092) |
|
||||
| Docprocessing | ✓ | ✓ (indexing) | ✓ (port 9093) |
|
||||
| Monitoring | — | — | ✓ (port 9096, pull-based collectors) |
|
||||
| Primary | — | — | — |
|
||||
| Light | — | — | — |
|
||||
| Heavy | — | — | — |
|
||||
| User File Processing | — | — | — |
|
||||
| KG Processing | — | — | — |
|
||||
|
||||
### Example PromQL Queries (Celery)
|
||||
|
||||
```promql
|
||||
# Task completion rate by worker queue
|
||||
sum by (queue) (rate(onyx_celery_task_completed_total[5m]))
|
||||
|
||||
# P95 task duration for pruning tasks
|
||||
histogram_quantile(0.95,
|
||||
sum by (le) (rate(onyx_celery_task_duration_seconds_bucket{task_name=~".*pruning.*"}[5m])))
|
||||
|
||||
# Task failure rate
|
||||
sum by (task_name) (rate(onyx_celery_task_completed_total{outcome="failure"}[5m]))
|
||||
/ sum by (task_name) (rate(onyx_celery_task_completed_total[5m]))
|
||||
|
||||
# Active tasks per queue
|
||||
sum by (queue) (onyx_celery_tasks_active)
|
||||
|
||||
# Indexing throughput by source type
|
||||
sum by (source) (rate(onyx_indexing_task_completed_total{outcome="success"}[5m]))
|
||||
|
||||
# Queue depth — are tasks backing up?
|
||||
onyx_queue_depth > 100
|
||||
```
|
||||
|
||||
## OpenSearch Search Metrics
|
||||
|
||||
These metrics track OpenSearch search latency and throughput. Collected via `onyx.server.metrics.opensearch_search`.
|
||||
|
||||
| Metric | Type | Labels | Description |
|
||||
| ------------------------------------------------ | --------- | ------------- | --------------------------------------------------------------------------- |
|
||||
| Metric | Type | Labels | Description |
|
||||
|--------|------|--------|-------------|
|
||||
| `onyx_opensearch_search_client_duration_seconds` | Histogram | `search_type` | Client-side end-to-end latency (network + serialization + server execution) |
|
||||
| `onyx_opensearch_search_server_duration_seconds` | Histogram | `search_type` | Server-side execution time from OpenSearch `took` field |
|
||||
| `onyx_opensearch_search_total` | Counter | `search_type` | Total search requests sent to OpenSearch |
|
||||
| `onyx_opensearch_searches_in_progress` | Gauge | `search_type` | Currently in-flight OpenSearch searches |
|
||||
| `onyx_opensearch_search_server_duration_seconds` | Histogram | `search_type` | Server-side execution time from OpenSearch `took` field |
|
||||
| `onyx_opensearch_search_total` | Counter | `search_type` | Total search requests sent to OpenSearch |
|
||||
| `onyx_opensearch_searches_in_progress` | Gauge | `search_type` | Currently in-flight OpenSearch searches |
|
||||
|
||||
Search type label values: See `OpenSearchSearchType`.
|
||||
|
||||
|
||||
@@ -70,10 +70,6 @@ backend = [
|
||||
"lazy_imports==1.0.1",
|
||||
"lxml==5.3.0",
|
||||
"Mako==1.2.4",
|
||||
# NOTE: Do not update without understanding the patching behavior in
|
||||
# get_markitdown_converter in
|
||||
# backend/onyx/file_processing/extract_file_text.py and what impacts
|
||||
# updating might have on this behavior.
|
||||
"markitdown[pdf, docx, pptx, xlsx, xls]==0.1.2",
|
||||
"mcp[cli]==1.26.0",
|
||||
"msal==1.34.0",
|
||||
|
||||
@@ -1,22 +1,18 @@
|
||||
import { Card } from "@opal/components/cards/card/components";
|
||||
import { Content, SizePreset } from "@opal/layouts";
|
||||
import { Content } from "@opal/layouts";
|
||||
import { SvgEmpty } from "@opal/icons";
|
||||
import type {
|
||||
IconFunctionComponent,
|
||||
PaddingVariants,
|
||||
RichStr,
|
||||
} from "@opal/types";
|
||||
import type { IconFunctionComponent, PaddingVariants } from "@opal/types";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type EmptyMessageCardBaseProps = {
|
||||
type EmptyMessageCardProps = {
|
||||
/** Icon displayed alongside the title. */
|
||||
icon?: IconFunctionComponent;
|
||||
|
||||
/** Primary message text. */
|
||||
title: string | RichStr;
|
||||
title: string;
|
||||
|
||||
/** Padding preset for the card. @default "md" */
|
||||
padding?: PaddingVariants;
|
||||
@@ -25,30 +21,16 @@ type EmptyMessageCardBaseProps = {
|
||||
ref?: React.Ref<HTMLDivElement>;
|
||||
};
|
||||
|
||||
type EmptyMessageCardProps =
|
||||
| (EmptyMessageCardBaseProps & {
|
||||
/** @default "secondary" */
|
||||
sizePreset?: "secondary";
|
||||
})
|
||||
| (EmptyMessageCardBaseProps & {
|
||||
sizePreset: "main-ui";
|
||||
/** Description text. Only supported when `sizePreset` is `"main-ui"`. */
|
||||
description?: string | RichStr;
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// EmptyMessageCard
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function EmptyMessageCard(props: EmptyMessageCardProps) {
|
||||
const {
|
||||
sizePreset = "secondary",
|
||||
icon = SvgEmpty,
|
||||
title,
|
||||
padding = "md",
|
||||
ref,
|
||||
} = props;
|
||||
|
||||
function EmptyMessageCard({
|
||||
icon = SvgEmpty,
|
||||
title,
|
||||
padding = "md",
|
||||
ref,
|
||||
}: EmptyMessageCardProps) {
|
||||
return (
|
||||
<Card
|
||||
ref={ref}
|
||||
@@ -57,23 +39,13 @@ function EmptyMessageCard(props: EmptyMessageCardProps) {
|
||||
padding={padding}
|
||||
rounding="md"
|
||||
>
|
||||
{sizePreset === "secondary" ? (
|
||||
<Content
|
||||
icon={icon}
|
||||
title={title}
|
||||
sizePreset="secondary"
|
||||
variant="body"
|
||||
prominence="muted"
|
||||
/>
|
||||
) : (
|
||||
<Content
|
||||
icon={icon}
|
||||
title={title}
|
||||
description={"description" in props ? props.description : undefined}
|
||||
sizePreset={sizePreset}
|
||||
variant="section"
|
||||
/>
|
||||
)}
|
||||
<Content
|
||||
icon={icon}
|
||||
title={title}
|
||||
sizePreset="secondary"
|
||||
variant="body"
|
||||
prominence="muted"
|
||||
/>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,16 +1,41 @@
|
||||
"use client";
|
||||
|
||||
import "@opal/core/animations/styles.css";
|
||||
import React from "react";
|
||||
import React, { createContext, useContext, useState, useCallback } from "react";
|
||||
import { cn } from "@opal/utils";
|
||||
import type { WithoutStyles, ExtremaSizeVariants } from "@opal/types";
|
||||
import { widthVariants } from "@opal/shared";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// Context-per-group registry
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type HoverableInteraction = "rest" | "hover";
|
||||
/**
|
||||
* Lazily-created map of group names to React contexts.
|
||||
*
|
||||
* Each group gets its own `React.Context<boolean | null>` so that a
|
||||
* `Hoverable.Item` only re-renders when its *own* group's hover state
|
||||
* changes — not when any unrelated group changes.
|
||||
*
|
||||
* The default value is `null` (no provider found), which lets
|
||||
* `Hoverable.Item` distinguish "no Root ancestor" from "Root says
|
||||
* not hovered" and throw when `group` was explicitly specified.
|
||||
*/
|
||||
const contextMap = new Map<string, React.Context<boolean | null>>();
|
||||
|
||||
function getOrCreateContext(group: string): React.Context<boolean | null> {
|
||||
let ctx = contextMap.get(group);
|
||||
if (!ctx) {
|
||||
ctx = createContext<boolean | null>(null);
|
||||
ctx.displayName = `HoverableContext(${group})`;
|
||||
contextMap.set(group, ctx);
|
||||
}
|
||||
return ctx;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface HoverableRootProps
|
||||
extends WithoutStyles<React.HTMLAttributes<HTMLDivElement>> {
|
||||
@@ -18,17 +43,6 @@ interface HoverableRootProps
|
||||
group: string;
|
||||
/** Width preset. @default "auto" */
|
||||
widthVariant?: ExtremaSizeVariants;
|
||||
/**
|
||||
* JS-controllable interaction state override.
|
||||
*
|
||||
* - `"rest"` (default): items are shown/hidden by CSS `:hover`.
|
||||
* - `"hover"`: forces items visible regardless of hover state. Useful when
|
||||
* a hoverable action opens a modal — set `interaction="hover"` while the
|
||||
* modal is open so the user can see which element they're interacting with.
|
||||
*
|
||||
* @default "rest"
|
||||
*/
|
||||
interaction?: HoverableInteraction;
|
||||
/** Ref forwarded to the root `<div>`. */
|
||||
ref?: React.Ref<HTMLDivElement>;
|
||||
}
|
||||
@@ -51,10 +65,12 @@ interface HoverableItemProps
|
||||
/**
|
||||
* Hover-tracking container for a named group.
|
||||
*
|
||||
* Uses a `data-hover-group` attribute and CSS `:hover` to control
|
||||
* descendant `Hoverable.Item` visibility. No React state or context —
|
||||
* the browser natively removes `:hover` when modals/portals steal
|
||||
* pointer events, preventing stale hover state.
|
||||
* Wraps children in a `<div>` that tracks mouse-enter / mouse-leave and
|
||||
* provides the hover state via a per-group React context.
|
||||
*
|
||||
* Nesting works because each `Hoverable.Root` creates a **new** context
|
||||
* provider that shadows the parent — so an inner `Hoverable.Item group="b"`
|
||||
* reads from the inner provider, not the outer `group="a"` provider.
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
@@ -71,20 +87,70 @@ function HoverableRoot({
|
||||
group,
|
||||
children,
|
||||
widthVariant = "full",
|
||||
interaction = "rest",
|
||||
ref,
|
||||
onMouseEnter: consumerMouseEnter,
|
||||
onMouseLeave: consumerMouseLeave,
|
||||
onFocusCapture: consumerFocusCapture,
|
||||
onBlurCapture: consumerBlurCapture,
|
||||
...props
|
||||
}: HoverableRootProps) {
|
||||
const [hovered, setHovered] = useState(false);
|
||||
const [focused, setFocused] = useState(false);
|
||||
|
||||
const onMouseEnter = useCallback(
|
||||
(e: React.MouseEvent<HTMLDivElement>) => {
|
||||
setHovered(true);
|
||||
consumerMouseEnter?.(e);
|
||||
},
|
||||
[consumerMouseEnter]
|
||||
);
|
||||
|
||||
const onMouseLeave = useCallback(
|
||||
(e: React.MouseEvent<HTMLDivElement>) => {
|
||||
setHovered(false);
|
||||
consumerMouseLeave?.(e);
|
||||
},
|
||||
[consumerMouseLeave]
|
||||
);
|
||||
|
||||
const onFocusCapture = useCallback(
|
||||
(e: React.FocusEvent<HTMLDivElement>) => {
|
||||
setFocused(true);
|
||||
consumerFocusCapture?.(e);
|
||||
},
|
||||
[consumerFocusCapture]
|
||||
);
|
||||
|
||||
const onBlurCapture = useCallback(
|
||||
(e: React.FocusEvent<HTMLDivElement>) => {
|
||||
if (
|
||||
!(e.relatedTarget instanceof Node) ||
|
||||
!e.currentTarget.contains(e.relatedTarget)
|
||||
) {
|
||||
setFocused(false);
|
||||
}
|
||||
consumerBlurCapture?.(e);
|
||||
},
|
||||
[consumerBlurCapture]
|
||||
);
|
||||
|
||||
const active = hovered || focused;
|
||||
const GroupContext = getOrCreateContext(group);
|
||||
|
||||
return (
|
||||
<div
|
||||
{...props}
|
||||
ref={ref}
|
||||
className={cn(widthVariants[widthVariant])}
|
||||
data-hover-group={group}
|
||||
data-interaction={interaction !== "rest" ? interaction : undefined}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
<GroupContext.Provider value={active}>
|
||||
<div
|
||||
{...props}
|
||||
ref={ref}
|
||||
className={cn(widthVariants[widthVariant])}
|
||||
onMouseEnter={onMouseEnter}
|
||||
onMouseLeave={onMouseLeave}
|
||||
onFocusCapture={onFocusCapture}
|
||||
onBlurCapture={onBlurCapture}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
</GroupContext.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -96,10 +162,13 @@ function HoverableRoot({
|
||||
* An element whose visibility is controlled by hover state.
|
||||
*
|
||||
* **Local mode** (`group` omitted): the item handles hover on its own
|
||||
* element via CSS `:hover`.
|
||||
* element via CSS `:hover`. This is the core abstraction.
|
||||
*
|
||||
* **Group mode** (`group` provided): visibility is driven by CSS `:hover`
|
||||
* on the nearest `Hoverable.Root` ancestor via `[data-hover-group]:hover`.
|
||||
* **Group mode** (`group` provided): visibility is driven by a matching
|
||||
* `Hoverable.Root` ancestor's hover state via React context. If no
|
||||
* matching Root is found, an error is thrown.
|
||||
*
|
||||
* Uses data-attributes for variant styling (see `styles.css`).
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
@@ -115,6 +184,8 @@ function HoverableRoot({
|
||||
* </Hoverable.Item>
|
||||
* </Hoverable.Root>
|
||||
* ```
|
||||
*
|
||||
* @throws If `group` is specified but no matching `Hoverable.Root` ancestor exists.
|
||||
*/
|
||||
function HoverableItem({
|
||||
group,
|
||||
@@ -123,6 +194,17 @@ function HoverableItem({
|
||||
ref,
|
||||
...props
|
||||
}: HoverableItemProps) {
|
||||
const contextValue = useContext(
|
||||
group ? getOrCreateContext(group) : NOOP_CONTEXT
|
||||
);
|
||||
|
||||
if (group && contextValue === null) {
|
||||
throw new Error(
|
||||
`Hoverable.Item group="${group}" has no matching Hoverable.Root ancestor. ` +
|
||||
`Either wrap it in <Hoverable.Root group="${group}"> or remove the group prop for local hover.`
|
||||
);
|
||||
}
|
||||
|
||||
const isLocal = group === undefined;
|
||||
|
||||
return (
|
||||
@@ -131,6 +213,9 @@ function HoverableItem({
|
||||
ref={ref}
|
||||
className={cn("hoverable-item")}
|
||||
data-hoverable-variant={variant}
|
||||
data-hoverable-active={
|
||||
isLocal ? undefined : contextValue ? "true" : undefined
|
||||
}
|
||||
data-hoverable-local={isLocal ? "true" : undefined}
|
||||
>
|
||||
{children}
|
||||
@@ -138,6 +223,9 @@ function HoverableItem({
|
||||
);
|
||||
}
|
||||
|
||||
/** Stable context used when no group is specified (local mode). */
|
||||
const NOOP_CONTEXT = createContext<boolean | null>(null);
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Compound export
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -145,16 +233,18 @@ function HoverableItem({
|
||||
/**
|
||||
* Hoverable compound component for hover-to-reveal patterns.
|
||||
*
|
||||
* Entirely CSS-driven — no React state or context. The browser's native
|
||||
* `:hover` pseudo-class handles all state, which means hover is
|
||||
* automatically cleared when modals/portals steal pointer events.
|
||||
* Provides two sub-components:
|
||||
*
|
||||
* - `Hoverable.Root` — Container with `data-hover-group`. CSS `:hover`
|
||||
* on this element reveals descendant `Hoverable.Item` elements.
|
||||
* - `Hoverable.Root` — A container that tracks hover state for a named group
|
||||
* and provides it via React context.
|
||||
*
|
||||
* - `Hoverable.Item` — Hidden by default. In group mode, revealed when
|
||||
* the ancestor Root is hovered. In local mode (no `group`), revealed
|
||||
* when the item itself is hovered.
|
||||
* - `Hoverable.Item` — The core abstraction. On its own (no `group`), it
|
||||
* applies local CSS `:hover` for the variant effect. When `group` is
|
||||
* specified, it reads hover state from the nearest matching
|
||||
* `Hoverable.Root` — and throws if no matching Root is found.
|
||||
*
|
||||
* Supports nesting: a child `Hoverable.Root` shadows the parent's context,
|
||||
* so each group's items only respond to their own root's hover.
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
@@ -186,5 +276,4 @@ export {
|
||||
type HoverableRootProps,
|
||||
type HoverableItemProps,
|
||||
type HoverableItemVariant,
|
||||
type HoverableInteraction,
|
||||
};
|
||||
|
||||
@@ -7,20 +7,8 @@
|
||||
opacity: 0;
|
||||
}
|
||||
|
||||
/* Group mode — Root :hover controls descendant item visibility via CSS.
|
||||
Exclude local-mode items so they aren't revealed by an ancestor root. */
|
||||
[data-hover-group]:hover
|
||||
.hoverable-item[data-hoverable-variant="opacity-on-hover"]:not(
|
||||
[data-hoverable-local]
|
||||
) {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
/* Interaction override — force items visible via JS */
|
||||
[data-hover-group][data-interaction="hover"]
|
||||
.hoverable-item[data-hoverable-variant="opacity-on-hover"]:not(
|
||||
[data-hoverable-local]
|
||||
) {
|
||||
/* Group mode — Root controls visibility via React context */
|
||||
.hoverable-item[data-hoverable-variant="opacity-on-hover"][data-hoverable-active="true"] {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
@@ -29,16 +17,7 @@
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
/* Group focus — any focusable descendant of the Root receives keyboard focus,
|
||||
revealing all group items (same behavior as hover). */
|
||||
[data-hover-group]:focus-within
|
||||
.hoverable-item[data-hoverable-variant="opacity-on-hover"]:not(
|
||||
[data-hoverable-local]
|
||||
) {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
/* Local focus — item (or a focusable descendant) receives keyboard focus */
|
||||
/* Focus — item (or a focusable descendant) receives keyboard focus */
|
||||
.hoverable-item[data-hoverable-variant="opacity-on-hover"]:has(:focus-visible) {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
6
web/package-lock.json
generated
6
web/package-lock.json
generated
@@ -18122,9 +18122,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/vite": {
|
||||
"version": "6.4.2",
|
||||
"resolved": "https://registry.npmjs.org/vite/-/vite-6.4.2.tgz",
|
||||
"integrity": "sha512-2N/55r4JDJ4gdrCvGgINMy+HH3iRpNIz8K6SFwVsA+JbQScLiC+clmAxBgwiSPgcG9U15QmvqCGWzMbqda5zGQ==",
|
||||
"version": "6.4.1",
|
||||
"resolved": "https://registry.npmjs.org/vite/-/vite-6.4.1.tgz",
|
||||
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
|
||||
@@ -1 +1 @@
|
||||
export { default } from "@/refresh-pages/admin/LLMProviderConfigurationPage";
|
||||
export { default } from "@/refresh-pages/admin/LLMConfigurationPage";
|
||||
|
||||
@@ -32,10 +32,8 @@ import {
|
||||
OpenRouterFetchParams,
|
||||
LiteLLMProxyFetchParams,
|
||||
BifrostFetchParams,
|
||||
OpenAICompatibleFetchParams,
|
||||
OpenAICompatibleModelResponse,
|
||||
} from "@/interfaces/llm";
|
||||
import { SvgAws, SvgBifrost, SvgOpenrouter, SvgPlug } from "@opal/icons";
|
||||
import { SvgAws, SvgBifrost, SvgOpenrouter } from "@opal/icons";
|
||||
|
||||
// Aggregator providers that host models from multiple vendors
|
||||
export const AGGREGATOR_PROVIDERS = new Set([
|
||||
@@ -46,7 +44,6 @@ export const AGGREGATOR_PROVIDERS = new Set([
|
||||
"lm_studio",
|
||||
"litellm_proxy",
|
||||
"bifrost",
|
||||
"openai_compatible",
|
||||
"vertex_ai",
|
||||
]);
|
||||
|
||||
@@ -85,7 +82,6 @@ export const getProviderIcon = (
|
||||
openrouter: SvgOpenrouter,
|
||||
litellm_proxy: LiteLLMIcon,
|
||||
bifrost: SvgBifrost,
|
||||
openai_compatible: SvgPlug,
|
||||
vertex_ai: GeminiIcon,
|
||||
};
|
||||
|
||||
@@ -415,64 +411,6 @@ export const fetchBifrostModels = async (
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches models from a generic OpenAI-compatible server.
|
||||
* Uses snake_case params to match API structure.
|
||||
*/
|
||||
export const fetchOpenAICompatibleModels = async (
|
||||
params: OpenAICompatibleFetchParams
|
||||
): Promise<{ models: ModelConfiguration[]; error?: string }> => {
|
||||
const apiBase = params.api_base;
|
||||
if (!apiBase) {
|
||||
return { models: [], error: "API Base is required" };
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(
|
||||
"/api/admin/llm/openai-compatible/available-models",
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
api_base: apiBase,
|
||||
api_key: params.api_key,
|
||||
provider_name: params.provider_name,
|
||||
}),
|
||||
signal: params.signal,
|
||||
}
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
let errorMessage = "Failed to fetch models";
|
||||
try {
|
||||
const errorData = await response.json();
|
||||
errorMessage = errorData.detail || errorData.message || errorMessage;
|
||||
} catch {
|
||||
// ignore JSON parsing errors
|
||||
}
|
||||
return { models: [], error: errorMessage };
|
||||
}
|
||||
|
||||
const data: OpenAICompatibleModelResponse[] = await response.json();
|
||||
const models: ModelConfiguration[] = data.map((modelData) => ({
|
||||
name: modelData.name,
|
||||
display_name: modelData.display_name,
|
||||
is_visible: true,
|
||||
max_input_tokens: modelData.max_input_tokens,
|
||||
supports_image_input: modelData.supports_image_input,
|
||||
supports_reasoning: modelData.supports_reasoning,
|
||||
}));
|
||||
|
||||
return { models };
|
||||
} catch (error) {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : "Unknown error";
|
||||
return { models: [], error: errorMessage };
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches LiteLLM Proxy models directly without any form state dependencies.
|
||||
* Uses snake_case params to match API structure.
|
||||
@@ -593,13 +531,6 @@ export const fetchModels = async (
|
||||
provider_name: formValues.name,
|
||||
signal,
|
||||
});
|
||||
case LLMProviderName.OPENAI_COMPATIBLE:
|
||||
return fetchOpenAICompatibleModels({
|
||||
api_base: formValues.api_base,
|
||||
api_key: formValues.api_key,
|
||||
provider_name: formValues.name,
|
||||
signal,
|
||||
});
|
||||
default:
|
||||
return { models: [], error: `Unknown provider: ${providerName}` };
|
||||
}
|
||||
@@ -614,7 +545,6 @@ export function canProviderFetchModels(providerName?: string) {
|
||||
case LLMProviderName.OPENROUTER:
|
||||
case LLMProviderName.LITELLM_PROXY:
|
||||
case LLMProviderName.BIFROST:
|
||||
case LLMProviderName.OPENAI_COMPATIBLE:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
||||
@@ -64,50 +64,50 @@ export default function CreateRateLimitModal({
|
||||
title="Create a Token Rate Limit"
|
||||
onClose={() => setIsOpen(false)}
|
||||
/>
|
||||
<Formik
|
||||
initialValues={{
|
||||
enabled: true,
|
||||
period_hours: "",
|
||||
token_budget: "",
|
||||
target_scope: forSpecificScope || Scope.GLOBAL,
|
||||
user_group_id: forSpecificUserGroup,
|
||||
}}
|
||||
validationSchema={Yup.object().shape({
|
||||
period_hours: Yup.number()
|
||||
.required("Time Window is a required field")
|
||||
.min(1, "Time Window must be at least 1 hour"),
|
||||
token_budget: Yup.number()
|
||||
.required("Token Budget is a required field")
|
||||
.min(1, "Token Budget must be at least 1"),
|
||||
target_scope: Yup.string().required(
|
||||
"Target Scope is a required field"
|
||||
),
|
||||
user_group_id: Yup.string().test(
|
||||
"user_group_id",
|
||||
"User Group is a required field",
|
||||
(value, context) => {
|
||||
return (
|
||||
context.parent.target_scope !== "user_group" ||
|
||||
(context.parent.target_scope === "user_group" &&
|
||||
value !== undefined)
|
||||
);
|
||||
}
|
||||
),
|
||||
})}
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
formikHelpers.setSubmitting(true);
|
||||
onSubmit(
|
||||
values.target_scope,
|
||||
Number(values.period_hours),
|
||||
Number(values.token_budget),
|
||||
Number(values.user_group_id)
|
||||
);
|
||||
return formikHelpers.setSubmitting(false);
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting, values, setFieldValue }) => (
|
||||
<Form className="flex flex-col h-full min-h-0 overflow-visible">
|
||||
<Modal.Body>
|
||||
<Modal.Body>
|
||||
<Formik
|
||||
initialValues={{
|
||||
enabled: true,
|
||||
period_hours: "",
|
||||
token_budget: "",
|
||||
target_scope: forSpecificScope || Scope.GLOBAL,
|
||||
user_group_id: forSpecificUserGroup,
|
||||
}}
|
||||
validationSchema={Yup.object().shape({
|
||||
period_hours: Yup.number()
|
||||
.required("Time Window is a required field")
|
||||
.min(1, "Time Window must be at least 1 hour"),
|
||||
token_budget: Yup.number()
|
||||
.required("Token Budget is a required field")
|
||||
.min(1, "Token Budget must be at least 1"),
|
||||
target_scope: Yup.string().required(
|
||||
"Target Scope is a required field"
|
||||
),
|
||||
user_group_id: Yup.string().test(
|
||||
"user_group_id",
|
||||
"User Group is a required field",
|
||||
(value, context) => {
|
||||
return (
|
||||
context.parent.target_scope !== "user_group" ||
|
||||
(context.parent.target_scope === "user_group" &&
|
||||
value !== undefined)
|
||||
);
|
||||
}
|
||||
),
|
||||
})}
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
formikHelpers.setSubmitting(true);
|
||||
onSubmit(
|
||||
values.target_scope,
|
||||
Number(values.period_hours),
|
||||
Number(values.token_budget),
|
||||
Number(values.user_group_id)
|
||||
);
|
||||
return formikHelpers.setSubmitting(false);
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting, values, setFieldValue }) => (
|
||||
<Form className="overflow-visible px-2">
|
||||
{!forSpecificScope && (
|
||||
<SelectorFormField
|
||||
name="target_scope"
|
||||
@@ -147,15 +147,13 @@ export default function CreateRateLimitModal({
|
||||
type="number"
|
||||
placeholder=""
|
||||
/>
|
||||
</Modal.Body>
|
||||
<Modal.Footer>
|
||||
<Button disabled={isSubmitting} type="submit">
|
||||
Create
|
||||
</Button>
|
||||
</Modal.Footer>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
</Modal.Body>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
);
|
||||
|
||||
@@ -1,126 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useCallback } from "react";
|
||||
import { Button } from "@opal/components";
|
||||
import { Text } from "@opal/components";
|
||||
import { ContentAction } from "@opal/layouts";
|
||||
import { SvgEyeOff, SvgX } from "@opal/icons";
|
||||
import { getProviderIcon } from "@/app/admin/configuration/llm/utils";
|
||||
import AgentMessage, {
|
||||
AgentMessageProps,
|
||||
} from "@/app/app/message/messageComponents/AgentMessage";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { markdown } from "@opal/utils";
|
||||
|
||||
export interface MultiModelPanelProps {
|
||||
/** Provider name for icon lookup */
|
||||
provider: string;
|
||||
/** Model name for icon lookup and display */
|
||||
modelName: string;
|
||||
/** Display-friendly model name */
|
||||
displayName: string;
|
||||
/** Whether this panel is the preferred/selected response */
|
||||
isPreferred: boolean;
|
||||
/** Whether this panel is currently hidden */
|
||||
isHidden: boolean;
|
||||
/** Whether this is a non-preferred panel in selection mode (pushed off-screen) */
|
||||
isNonPreferredInSelection: boolean;
|
||||
/** Callback when user clicks this panel to select as preferred */
|
||||
onSelect: () => void;
|
||||
/** Callback to hide/show this panel */
|
||||
onToggleVisibility: () => void;
|
||||
/** Props to pass through to AgentMessage */
|
||||
agentMessageProps: AgentMessageProps;
|
||||
}
|
||||
|
||||
/**
|
||||
* A single model's response panel within the multi-model view.
|
||||
*
|
||||
* Renders in two states:
|
||||
* - **Hidden** — compact header strip only (provider icon + strikethrough name + show button).
|
||||
* - **Visible** — full header plus `AgentMessage` body. Clicking anywhere on a
|
||||
* visible non-preferred panel marks it as preferred.
|
||||
*
|
||||
* The `isNonPreferredInSelection` flag disables pointer events on the body and
|
||||
* hides the footer so the panel acts as a passive comparison surface.
|
||||
*/
|
||||
export default function MultiModelPanel({
|
||||
provider,
|
||||
modelName,
|
||||
displayName,
|
||||
isPreferred,
|
||||
isHidden,
|
||||
isNonPreferredInSelection,
|
||||
onSelect,
|
||||
onToggleVisibility,
|
||||
agentMessageProps,
|
||||
}: MultiModelPanelProps) {
|
||||
const ProviderIcon = getProviderIcon(provider, modelName);
|
||||
|
||||
const handlePanelClick = useCallback(() => {
|
||||
if (!isHidden && !isPreferred) onSelect();
|
||||
}, [isHidden, isPreferred, onSelect]);
|
||||
|
||||
const header = (
|
||||
<div
|
||||
className={cn(
|
||||
"rounded-12",
|
||||
isPreferred ? "bg-background-tint-02" : "bg-background-tint-00"
|
||||
)}
|
||||
>
|
||||
<ContentAction
|
||||
sizePreset="main-ui"
|
||||
variant="body"
|
||||
paddingVariant="lg"
|
||||
icon={ProviderIcon}
|
||||
title={isHidden ? markdown(`~~${displayName}~~`) : displayName}
|
||||
rightChildren={
|
||||
<div className="flex items-center gap-1 px-2">
|
||||
{isPreferred && (
|
||||
<span className="text-action-link-05 shrink-0">
|
||||
<Text font="secondary-body" color="inherit" nowrap>
|
||||
Preferred Response
|
||||
</Text>
|
||||
</span>
|
||||
)}
|
||||
{!isPreferred && (
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
icon={isHidden ? SvgEyeOff : SvgX}
|
||||
size="md"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
onToggleVisibility();
|
||||
}}
|
||||
tooltip={isHidden ? "Show response" : "Hide response"}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
// Hidden/collapsed panel — just the header row
|
||||
if (isHidden) {
|
||||
return header;
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex flex-col gap-3 min-w-0 rounded-16 transition-colors",
|
||||
!isPreferred && "cursor-pointer hover:bg-background-tint-02"
|
||||
)}
|
||||
onClick={handlePanelClick}
|
||||
>
|
||||
{header}
|
||||
<div className={cn(isNonPreferredInSelection && "pointer-events-none")}>
|
||||
<AgentMessage
|
||||
{...agentMessageProps}
|
||||
hideFooter={isNonPreferredInSelection}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,372 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useCallback, useMemo, useEffect, useRef } from "react";
|
||||
import { FullChatState } from "@/app/app/message/messageComponents/interfaces";
|
||||
import { Message } from "@/app/app/interfaces";
|
||||
import { LlmManager } from "@/lib/hooks";
|
||||
import { RegenerationFactory } from "@/app/app/message/messageComponents/AgentMessage";
|
||||
import MultiModelPanel from "@/app/app/message/MultiModelPanel";
|
||||
import { MultiModelResponse } from "@/app/app/message/interfaces";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export interface MultiModelResponseViewProps {
|
||||
responses: MultiModelResponse[];
|
||||
chatState: FullChatState;
|
||||
llmManager: LlmManager | null;
|
||||
onRegenerate?: RegenerationFactory;
|
||||
parentMessage?: Message | null;
|
||||
otherMessagesCanSwitchTo?: number[];
|
||||
onMessageSelection?: (nodeId: number) => void;
|
||||
/** Called whenever the set of hidden panel indices changes */
|
||||
onHiddenPanelsChange?: (hidden: Set<number>) => void;
|
||||
}
|
||||
|
||||
// How many pixels of a non-preferred panel are visible at the viewport edge
|
||||
const PEEK_W = 64;
|
||||
// Uniform panel width used in the selection-mode carousel
|
||||
const SELECTION_PANEL_W = 400;
|
||||
// Compact width for hidden panels in the carousel track
|
||||
const HIDDEN_PANEL_W = 220;
|
||||
// Generation-mode panel widths (from Figma)
|
||||
const GEN_PANEL_W_2 = 640; // 2 panels side-by-side
|
||||
const GEN_PANEL_W_3 = 436; // 3 panels side-by-side
|
||||
// Gap between panels — matches CSS gap-6 (24px)
|
||||
const PANEL_GAP = 24;
|
||||
// Minimum panel width before horizontal scroll kicks in
|
||||
const MIN_PANEL_W = 300;
|
||||
|
||||
/**
|
||||
* Renders N model responses side-by-side with two layout modes:
|
||||
*
|
||||
* **Generation mode** — equal-width panels in a horizontally-scrollable row.
|
||||
* Panel width is determined by the number of visible (non-hidden) panels.
|
||||
*
|
||||
* **Selection mode** — activated when the user clicks a panel to mark it as
|
||||
* preferred. All panels (including hidden ones) sit in a fixed-width carousel
|
||||
* track. A CSS `translateX` transform slides the track so the preferred panel
|
||||
* is centered in the viewport; the other panels peek in from the edges through
|
||||
* a mask gradient. Non-preferred visible panels are height-capped to the
|
||||
* preferred panel's measured height, dimmed at 50% opacity, and receive a
|
||||
* bottom fade-out overlay.
|
||||
*
|
||||
* Hidden panels render as a compact header-only strip at `HIDDEN_PANEL_W` in
|
||||
* both modes and are excluded from layout width calculations.
|
||||
*/
|
||||
export default function MultiModelResponseView({
|
||||
responses,
|
||||
chatState,
|
||||
llmManager,
|
||||
onRegenerate,
|
||||
parentMessage,
|
||||
otherMessagesCanSwitchTo,
|
||||
onMessageSelection,
|
||||
onHiddenPanelsChange,
|
||||
}: MultiModelResponseViewProps) {
|
||||
const [preferredIndex, setPreferredIndex] = useState<number | null>(null);
|
||||
const [hiddenPanels, setHiddenPanels] = useState<Set<number>>(new Set());
|
||||
// Controls animation: false = panels at start position, true = panels at peek position
|
||||
const [selectionEntered, setSelectionEntered] = useState(false);
|
||||
// Measures the overflow-hidden carousel container for responsive preferred-panel sizing.
|
||||
const [trackContainerW, setTrackContainerW] = useState(0);
|
||||
const roRef = useRef<ResizeObserver | null>(null);
|
||||
const trackContainerRef = useCallback((el: HTMLDivElement | null) => {
|
||||
if (roRef.current) {
|
||||
roRef.current.disconnect();
|
||||
roRef.current = null;
|
||||
}
|
||||
if (!el) return;
|
||||
const ro = new ResizeObserver(([entry]) => {
|
||||
setTrackContainerW(entry?.contentRect.width ?? 0);
|
||||
});
|
||||
ro.observe(el);
|
||||
setTrackContainerW(el.offsetWidth);
|
||||
roRef.current = ro;
|
||||
}, []);
|
||||
|
||||
// Measures the preferred panel's height to cap non-preferred panels in selection mode.
|
||||
const [preferredPanelHeight, setPreferredPanelHeight] = useState<
|
||||
number | null
|
||||
>(null);
|
||||
const preferredRoRef = useRef<ResizeObserver | null>(null);
|
||||
// Tracks which non-preferred panels overflow the preferred height cap
|
||||
const [overflowingPanels, setOverflowingPanels] = useState<Set<number>>(
|
||||
new Set()
|
||||
);
|
||||
|
||||
const preferredPanelRef = useCallback((el: HTMLDivElement | null) => {
|
||||
if (preferredRoRef.current) {
|
||||
preferredRoRef.current.disconnect();
|
||||
preferredRoRef.current = null;
|
||||
}
|
||||
if (!el) {
|
||||
setPreferredPanelHeight(null);
|
||||
return;
|
||||
}
|
||||
const ro = new ResizeObserver(([entry]) => {
|
||||
setPreferredPanelHeight(entry?.contentRect.height ?? 0);
|
||||
});
|
||||
ro.observe(el);
|
||||
setPreferredPanelHeight(el.offsetHeight);
|
||||
preferredRoRef.current = ro;
|
||||
}, []);
|
||||
|
||||
const isGenerating = useMemo(
|
||||
() => responses.some((r) => r.isGenerating),
|
||||
[responses]
|
||||
);
|
||||
|
||||
// Non-hidden responses — used for layout width decisions and selection-mode gating
|
||||
const visibleResponses = useMemo(
|
||||
() => responses.filter((r) => !hiddenPanels.has(r.modelIndex)),
|
||||
[responses, hiddenPanels]
|
||||
);
|
||||
|
||||
const toggleVisibility = useCallback(
|
||||
(modelIndex: number) => {
|
||||
setHiddenPanels((prev) => {
|
||||
const next = new Set(prev);
|
||||
if (next.has(modelIndex)) {
|
||||
next.delete(modelIndex);
|
||||
} else {
|
||||
// Don't hide the last visible panel
|
||||
const visibleCount = responses.length - next.size;
|
||||
if (visibleCount <= 1) return prev;
|
||||
next.add(modelIndex);
|
||||
}
|
||||
onHiddenPanelsChange?.(next);
|
||||
return next;
|
||||
});
|
||||
},
|
||||
[responses.length, onHiddenPanelsChange]
|
||||
);
|
||||
|
||||
const handleSelectPreferred = useCallback(
|
||||
(modelIndex: number) => {
|
||||
if (isGenerating) return;
|
||||
setPreferredIndex(modelIndex);
|
||||
const response = responses.find((r) => r.modelIndex === modelIndex);
|
||||
if (!response) return;
|
||||
if (onMessageSelection) {
|
||||
onMessageSelection(response.nodeId);
|
||||
}
|
||||
},
|
||||
[isGenerating, responses, onMessageSelection]
|
||||
);
|
||||
|
||||
// Clear preferred selection when generation starts
|
||||
useEffect(() => {
|
||||
if (isGenerating) {
|
||||
setPreferredIndex(null);
|
||||
}
|
||||
}, [isGenerating]);
|
||||
|
||||
// Find preferred panel position — used for both the selection guard and carousel layout
|
||||
const preferredIdx = responses.findIndex(
|
||||
(r) => r.modelIndex === preferredIndex
|
||||
);
|
||||
|
||||
// Selection mode when preferred is set, found in responses, not generating, and at least 2 visible panels
|
||||
const showSelectionMode =
|
||||
preferredIndex !== null &&
|
||||
preferredIdx !== -1 &&
|
||||
!isGenerating &&
|
||||
visibleResponses.length > 1;
|
||||
|
||||
// Trigger the slide-out animation one frame after entering selection mode
|
||||
useEffect(() => {
|
||||
if (!showSelectionMode) {
|
||||
setSelectionEntered(false);
|
||||
return;
|
||||
}
|
||||
const raf = requestAnimationFrame(() => setSelectionEntered(true));
|
||||
return () => cancelAnimationFrame(raf);
|
||||
}, [showSelectionMode]);
|
||||
|
||||
// Build panel props — isHidden reflects actual hidden state
|
||||
const buildPanelProps = useCallback(
|
||||
(response: MultiModelResponse, isNonPreferred: boolean) => ({
|
||||
provider: response.provider,
|
||||
modelName: response.modelName,
|
||||
displayName: response.displayName,
|
||||
isPreferred: preferredIndex === response.modelIndex,
|
||||
isHidden: hiddenPanels.has(response.modelIndex),
|
||||
isNonPreferredInSelection: isNonPreferred,
|
||||
onSelect: () => handleSelectPreferred(response.modelIndex),
|
||||
onToggleVisibility: () => toggleVisibility(response.modelIndex),
|
||||
agentMessageProps: {
|
||||
rawPackets: response.packets,
|
||||
packetCount: response.packetCount,
|
||||
chatState,
|
||||
nodeId: response.nodeId,
|
||||
messageId: response.messageId,
|
||||
currentFeedback: response.currentFeedback,
|
||||
llmManager,
|
||||
otherMessagesCanSwitchTo,
|
||||
onMessageSelection,
|
||||
onRegenerate,
|
||||
parentMessage,
|
||||
},
|
||||
}),
|
||||
[
|
||||
preferredIndex,
|
||||
hiddenPanels,
|
||||
handleSelectPreferred,
|
||||
toggleVisibility,
|
||||
chatState,
|
||||
llmManager,
|
||||
otherMessagesCanSwitchTo,
|
||||
onMessageSelection,
|
||||
onRegenerate,
|
||||
parentMessage,
|
||||
]
|
||||
);
|
||||
|
||||
if (showSelectionMode) {
|
||||
// ── Selection Layout (transform-based carousel) ──
|
||||
//
|
||||
// All panels (including hidden) sit in the track at their original A/B/C positions.
|
||||
// Hidden panels use HIDDEN_PANEL_W; non-preferred use SELECTION_PANEL_W;
|
||||
// preferred uses dynamicPrefW (up to GEN_PANEL_W_2).
|
||||
const n = responses.length;
|
||||
|
||||
const dynamicPrefW =
|
||||
trackContainerW > 0
|
||||
? Math.min(trackContainerW - 2 * (PEEK_W + PANEL_GAP), GEN_PANEL_W_2)
|
||||
: GEN_PANEL_W_2;
|
||||
|
||||
const selectionWidths = responses.map((r, i) => {
|
||||
if (hiddenPanels.has(r.modelIndex)) return HIDDEN_PANEL_W;
|
||||
if (i === preferredIdx) return dynamicPrefW;
|
||||
return SELECTION_PANEL_W;
|
||||
});
|
||||
|
||||
const panelLeftEdges = selectionWidths.reduce<number[]>((acc, w, i) => {
|
||||
acc.push(i === 0 ? 0 : acc[i - 1]! + selectionWidths[i - 1]! + PANEL_GAP);
|
||||
return acc;
|
||||
}, []);
|
||||
|
||||
const preferredCenterInTrack =
|
||||
panelLeftEdges[preferredIdx]! + selectionWidths[preferredIdx]! / 2;
|
||||
|
||||
// Start position: hidden panels at HIDDEN_PANEL_W, visible at SELECTION_PANEL_W
|
||||
const uniformTrackW =
|
||||
responses.reduce(
|
||||
(sum, r) =>
|
||||
sum +
|
||||
(hiddenPanels.has(r.modelIndex) ? HIDDEN_PANEL_W : SELECTION_PANEL_W),
|
||||
0
|
||||
) +
|
||||
(n - 1) * PANEL_GAP;
|
||||
|
||||
const trackTransform = selectionEntered
|
||||
? `translateX(${trackContainerW / 2 - preferredCenterInTrack}px)`
|
||||
: `translateX(${(trackContainerW - uniformTrackW) / 2}px)`;
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={trackContainerRef}
|
||||
className="w-full overflow-hidden"
|
||||
style={{
|
||||
maskImage: `linear-gradient(to right, transparent 0px, black ${PEEK_W}px, black calc(100% - ${PEEK_W}px), transparent 100%)`,
|
||||
WebkitMaskImage: `linear-gradient(to right, transparent 0px, black ${PEEK_W}px, black calc(100% - ${PEEK_W}px), transparent 100%)`,
|
||||
}}
|
||||
>
|
||||
<div
|
||||
className="flex items-start"
|
||||
style={{
|
||||
gap: `${PANEL_GAP}px`,
|
||||
transition: selectionEntered
|
||||
? "transform 0.45s cubic-bezier(0.2, 0, 0, 1)"
|
||||
: "none",
|
||||
transform: trackTransform,
|
||||
}}
|
||||
>
|
||||
{responses.map((r, i) => {
|
||||
const isHidden = hiddenPanels.has(r.modelIndex);
|
||||
const isPref = r.modelIndex === preferredIndex;
|
||||
const isNonPref = !isHidden && !isPref;
|
||||
const finalW = selectionWidths[i]!;
|
||||
const startW = isHidden ? HIDDEN_PANEL_W : SELECTION_PANEL_W;
|
||||
const capped = isNonPref && preferredPanelHeight != null;
|
||||
const overflows = capped && overflowingPanels.has(r.modelIndex);
|
||||
return (
|
||||
<div
|
||||
key={r.modelIndex}
|
||||
ref={(el) => {
|
||||
if (isPref) preferredPanelRef(el);
|
||||
if (capped && el) {
|
||||
const doesOverflow = el.scrollHeight > el.clientHeight;
|
||||
setOverflowingPanels((prev) => {
|
||||
const had = prev.has(r.modelIndex);
|
||||
if (doesOverflow === had) return prev;
|
||||
const next = new Set(prev);
|
||||
if (doesOverflow) next.add(r.modelIndex);
|
||||
else next.delete(r.modelIndex);
|
||||
return next;
|
||||
});
|
||||
}
|
||||
}}
|
||||
style={{
|
||||
width: `${selectionEntered ? finalW : startW}px`,
|
||||
flexShrink: 0,
|
||||
transition: selectionEntered
|
||||
? "width 0.45s cubic-bezier(0.2, 0, 0, 1)"
|
||||
: "none",
|
||||
maxHeight: capped ? preferredPanelHeight : undefined,
|
||||
overflow: capped ? "hidden" : undefined,
|
||||
position: capped ? "relative" : undefined,
|
||||
}}
|
||||
>
|
||||
<div className={cn(isNonPref && "opacity-50")}>
|
||||
<MultiModelPanel {...buildPanelProps(r, isNonPref)} />
|
||||
</div>
|
||||
{overflows && (
|
||||
<div
|
||||
className="absolute inset-x-0 bottom-0 h-24 pointer-events-none"
|
||||
style={{
|
||||
background:
|
||||
"linear-gradient(to top, var(--background-tint-01) 0%, transparent 100%)",
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ── Generation Layout (equal panels side-by-side) ──
|
||||
// Panel width based on number of visible (non-hidden) panels.
|
||||
const panelWidth =
|
||||
visibleResponses.length <= 2 ? GEN_PANEL_W_2 : GEN_PANEL_W_3;
|
||||
|
||||
return (
|
||||
<div className="overflow-x-auto">
|
||||
<div className="flex gap-6 items-start w-fit mx-auto">
|
||||
{responses.map((r) => {
|
||||
const isHidden = hiddenPanels.has(r.modelIndex);
|
||||
return (
|
||||
<div
|
||||
key={r.modelIndex}
|
||||
style={
|
||||
isHidden
|
||||
? {
|
||||
width: HIDDEN_PANEL_W,
|
||||
minWidth: HIDDEN_PANEL_W,
|
||||
maxWidth: HIDDEN_PANEL_W,
|
||||
flexShrink: 0,
|
||||
overflow: "hidden" as const,
|
||||
}
|
||||
: { width: panelWidth, minWidth: MIN_PANEL_W }
|
||||
}
|
||||
>
|
||||
<MultiModelPanel {...buildPanelProps(r, false)} />
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
import { Packet } from "@/app/app/services/streamingModels";
|
||||
import { FeedbackType } from "@/app/app/interfaces";
|
||||
|
||||
export interface MultiModelResponse {
|
||||
modelIndex: number;
|
||||
provider: string;
|
||||
modelName: string;
|
||||
displayName: string;
|
||||
packets: Packet[];
|
||||
packetCount: number;
|
||||
nodeId: number;
|
||||
messageId?: number;
|
||||
|
||||
currentFeedback?: FeedbackType | null;
|
||||
isGenerating?: boolean;
|
||||
}
|
||||
@@ -49,8 +49,6 @@ export interface AgentMessageProps {
|
||||
parentMessage?: Message | null;
|
||||
// Duration in seconds for processing this message (agent messages only)
|
||||
processingDurationSeconds?: number;
|
||||
/** Hide the feedback/toolbar footer (used in multi-model non-preferred panels) */
|
||||
hideFooter?: boolean;
|
||||
}
|
||||
|
||||
// TODO: Consider more robust comparisons:
|
||||
@@ -78,8 +76,7 @@ function arePropsEqual(
|
||||
prev.parentMessage?.messageId === next.parentMessage?.messageId &&
|
||||
prev.llmManager?.isLoadingProviders ===
|
||||
next.llmManager?.isLoadingProviders &&
|
||||
prev.processingDurationSeconds === next.processingDurationSeconds &&
|
||||
prev.hideFooter === next.hideFooter
|
||||
prev.processingDurationSeconds === next.processingDurationSeconds
|
||||
// Skip: chatState.regenerate, chatState.setPresentingDocument,
|
||||
// most of llmManager, onMessageSelection (function/object props)
|
||||
);
|
||||
@@ -98,7 +95,6 @@ const AgentMessage = React.memo(function AgentMessage({
|
||||
onRegenerate,
|
||||
parentMessage,
|
||||
processingDurationSeconds,
|
||||
hideFooter,
|
||||
}: AgentMessageProps) {
|
||||
const markdownRef = useRef<HTMLDivElement>(null);
|
||||
const finalAnswerRef = useRef<HTMLDivElement>(null);
|
||||
@@ -330,7 +326,7 @@ const AgentMessage = React.memo(function AgentMessage({
|
||||
</div>
|
||||
|
||||
{/* Feedback buttons - only show when streaming and rendering complete */}
|
||||
{isComplete && !hideFooter && (
|
||||
{isComplete && (
|
||||
<MessageToolbar
|
||||
nodeId={nodeId}
|
||||
messageId={messageId}
|
||||
|
||||
@@ -24,6 +24,7 @@ import {
|
||||
} from "@/app/craft/onboarding/constants";
|
||||
import { LLMProviderDescriptor } from "@/interfaces/llm";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "@/lib/llmConfig/constants";
|
||||
import { buildOnboardingInitialValues as buildInitialValues } from "@/sections/modals/llmConfig/utils";
|
||||
import { testApiKeyHelper } from "@/sections/modals/llmConfig/svc";
|
||||
import OnboardingInfoPages from "@/app/craft/onboarding/components/OnboardingInfoPages";
|
||||
import OnboardingUserInfo from "@/app/craft/onboarding/components/OnboardingUserInfo";
|
||||
@@ -220,8 +221,10 @@ export default function BuildOnboardingModal({
|
||||
setConnectionStatus("testing");
|
||||
setErrorMessage("");
|
||||
|
||||
const baseValues = buildInitialValues();
|
||||
const providerName = `build-mode-${currentProviderConfig.providerName}`;
|
||||
const payload = {
|
||||
...baseValues,
|
||||
name: providerName,
|
||||
provider: currentProviderConfig.providerName,
|
||||
api_key: apiKey,
|
||||
|
||||
@@ -133,7 +133,7 @@ async function createFederatedConnector(
|
||||
|
||||
async function updateFederatedConnector(
|
||||
id: number,
|
||||
credentials: CredentialForm | null,
|
||||
credentials: CredentialForm,
|
||||
config?: ConfigForm
|
||||
): Promise<{ success: boolean; message: string }> {
|
||||
try {
|
||||
@@ -143,7 +143,7 @@ async function updateFederatedConnector(
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
credentials: credentials ?? undefined,
|
||||
credentials,
|
||||
config: config || {},
|
||||
}),
|
||||
});
|
||||
@@ -201,9 +201,7 @@ export function FederatedConnectorForm({
|
||||
const isEditMode = connectorId !== undefined;
|
||||
|
||||
const [formState, setFormState] = useState<FormState>({
|
||||
// In edit mode, don't populate credentials with masked values from the API.
|
||||
// Masked values (e.g. "••••••••••••") would be saved back and corrupt the real credentials.
|
||||
credentials: isEditMode ? {} : preloadedConnectorData?.credentials || {},
|
||||
credentials: preloadedConnectorData?.credentials || {},
|
||||
config: preloadedConnectorData?.config || {},
|
||||
schema: preloadedCredentialSchema?.credentials || null,
|
||||
configurationSchema: null,
|
||||
@@ -211,7 +209,6 @@ export function FederatedConnectorForm({
|
||||
configurationSchemaError: null,
|
||||
connectorError: null,
|
||||
});
|
||||
const [credentialsModified, setCredentialsModified] = useState(false);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [submitMessage, setSubmitMessage] = useState<string | null>(null);
|
||||
const [submitSuccess, setSubmitSuccess] = useState<boolean | null>(null);
|
||||
@@ -336,7 +333,6 @@ export function FederatedConnectorForm({
|
||||
}
|
||||
|
||||
const handleCredentialChange = (key: string, value: string) => {
|
||||
setCredentialsModified(true);
|
||||
setFormState((prev) => ({
|
||||
...prev,
|
||||
credentials: {
|
||||
@@ -358,11 +354,6 @@ export function FederatedConnectorForm({
|
||||
|
||||
const handleValidateCredentials = async () => {
|
||||
if (!formState.schema) return;
|
||||
if (isEditMode && !credentialsModified) {
|
||||
setSubmitMessage("Enter new credential values before validating.");
|
||||
setSubmitSuccess(false);
|
||||
return;
|
||||
}
|
||||
|
||||
setIsValidating(true);
|
||||
setSubmitMessage(null);
|
||||
@@ -420,10 +411,8 @@ export function FederatedConnectorForm({
|
||||
setSubmitSuccess(null);
|
||||
|
||||
try {
|
||||
const shouldValidateCredentials = !isEditMode || credentialsModified;
|
||||
|
||||
// Validate required fields (skip for credentials in edit mode when unchanged)
|
||||
if (formState.schema && shouldValidateCredentials) {
|
||||
// Validate required fields
|
||||
if (formState.schema) {
|
||||
const missingRequired = Object.entries(formState.schema)
|
||||
.filter(
|
||||
([key, field]) => field.required && !formState.credentials[key]
|
||||
@@ -453,20 +442,16 @@ export function FederatedConnectorForm({
|
||||
}
|
||||
setConfigValidationErrors({});
|
||||
|
||||
// Validate credentials before creating/updating (skip in edit mode when unchanged)
|
||||
if (shouldValidateCredentials) {
|
||||
const validation = await validateCredentials(
|
||||
connector,
|
||||
formState.credentials
|
||||
);
|
||||
if (!validation.success) {
|
||||
setSubmitMessage(
|
||||
`Credential validation failed: ${validation.message}`
|
||||
);
|
||||
setSubmitSuccess(false);
|
||||
setIsSubmitting(false);
|
||||
return;
|
||||
}
|
||||
// Validate credentials before creating/updating
|
||||
const validation = await validateCredentials(
|
||||
connector,
|
||||
formState.credentials
|
||||
);
|
||||
if (!validation.success) {
|
||||
setSubmitMessage(`Credential validation failed: ${validation.message}`);
|
||||
setSubmitSuccess(false);
|
||||
setIsSubmitting(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Create or update the connector
|
||||
@@ -474,7 +459,7 @@ export function FederatedConnectorForm({
|
||||
isEditMode && connectorId
|
||||
? await updateFederatedConnector(
|
||||
connectorId,
|
||||
credentialsModified ? formState.credentials : null,
|
||||
formState.credentials,
|
||||
formState.config
|
||||
)
|
||||
: await createFederatedConnector(
|
||||
@@ -553,16 +538,14 @@ export function FederatedConnectorForm({
|
||||
id={fieldKey}
|
||||
type={fieldSpec.secret ? "password" : "text"}
|
||||
placeholder={
|
||||
isEditMode && !credentialsModified
|
||||
? "•••••••• (leave blank to keep current value)"
|
||||
: fieldSpec.example
|
||||
? String(fieldSpec.example)
|
||||
: fieldSpec.description
|
||||
fieldSpec.example
|
||||
? String(fieldSpec.example)
|
||||
: fieldSpec.description
|
||||
}
|
||||
value={formState.credentials[fieldKey] || ""}
|
||||
onChange={(e) => handleCredentialChange(fieldKey, e.target.value)}
|
||||
className="w-96"
|
||||
required={!isEditMode && fieldSpec.required}
|
||||
required={fieldSpec.required}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
|
||||
@@ -1,10 +1,25 @@
|
||||
"use client";
|
||||
import { LLMProviderDescriptor } from "@/interfaces/llm";
|
||||
import React, { createContext, useContext, useCallback } from "react";
|
||||
import {
|
||||
WellKnownLLMProviderDescriptor,
|
||||
LLMProviderDescriptor,
|
||||
} from "@/interfaces/llm";
|
||||
import React, {
|
||||
createContext,
|
||||
useContext,
|
||||
useState,
|
||||
useEffect,
|
||||
useCallback,
|
||||
} from "react";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import { useLLMProviders } from "@/hooks/useLLMProviders";
|
||||
import { useLLMProviderOptions } from "@/lib/hooks/useLLMProviderOptions";
|
||||
import { testDefaultProvider as testDefaultProviderSvc } from "@/lib/llmConfig/svc";
|
||||
|
||||
interface ProviderContextType {
|
||||
shouldShowConfigurationNeeded: boolean;
|
||||
providerOptions: WellKnownLLMProviderDescriptor[];
|
||||
refreshProviderInfo: () => Promise<void>;
|
||||
// Expose configured provider instances for components that need it (e.g., onboarding)
|
||||
llmProviders: LLMProviderDescriptor[] | undefined;
|
||||
isLoadingProviders: boolean;
|
||||
hasProviders: boolean;
|
||||
@@ -14,26 +29,79 @@ const ProviderContext = createContext<ProviderContextType | undefined>(
|
||||
undefined
|
||||
);
|
||||
|
||||
const DEFAULT_LLM_PROVIDER_TEST_COMPLETE_KEY = "defaultLlmProviderTestComplete";
|
||||
|
||||
function checkDefaultLLMProviderTestComplete() {
|
||||
if (typeof window === "undefined") return true;
|
||||
return (
|
||||
localStorage.getItem(DEFAULT_LLM_PROVIDER_TEST_COMPLETE_KEY) === "true"
|
||||
);
|
||||
}
|
||||
|
||||
function setDefaultLLMProviderTestComplete() {
|
||||
if (typeof window === "undefined") return;
|
||||
localStorage.setItem(DEFAULT_LLM_PROVIDER_TEST_COMPLETE_KEY, "true");
|
||||
}
|
||||
|
||||
export function ProviderContextProvider({
|
||||
children,
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
const { user } = useUser();
|
||||
|
||||
// Use SWR hooks instead of raw fetch
|
||||
const {
|
||||
llmProviders,
|
||||
isLoading: isLoadingProviders,
|
||||
refetch: refetchProviders,
|
||||
} = useLLMProviders();
|
||||
const { llmProviderOptions: providerOptions, refetch: refetchOptions } =
|
||||
useLLMProviderOptions();
|
||||
|
||||
const [defaultCheckSuccessful, setDefaultCheckSuccessful] =
|
||||
useState<boolean>(true);
|
||||
|
||||
// Test the default provider - only runs if test hasn't passed yet
|
||||
const testDefaultProvider = useCallback(async () => {
|
||||
const shouldCheck =
|
||||
!checkDefaultLLMProviderTestComplete() &&
|
||||
(!user || user.role === "admin");
|
||||
|
||||
if (shouldCheck) {
|
||||
const success = await testDefaultProviderSvc();
|
||||
setDefaultCheckSuccessful(success);
|
||||
if (success) {
|
||||
setDefaultLLMProviderTestComplete();
|
||||
}
|
||||
}
|
||||
}, [user]);
|
||||
|
||||
// Test default provider on mount
|
||||
useEffect(() => {
|
||||
testDefaultProvider();
|
||||
}, [testDefaultProvider]);
|
||||
|
||||
const hasProviders = (llmProviders?.length ?? 0) > 0;
|
||||
const validProviderExists = hasProviders && defaultCheckSuccessful;
|
||||
|
||||
const shouldShowConfigurationNeeded =
|
||||
!validProviderExists && (providerOptions?.length ?? 0) > 0;
|
||||
|
||||
const refreshProviderInfo = useCallback(async () => {
|
||||
await refetchProviders();
|
||||
}, [refetchProviders]);
|
||||
// Refetch provider lists and re-test default provider if needed
|
||||
await Promise.all([
|
||||
refetchProviders(),
|
||||
refetchOptions(),
|
||||
testDefaultProvider(),
|
||||
]);
|
||||
}, [refetchProviders, refetchOptions, testDefaultProvider]);
|
||||
|
||||
return (
|
||||
<ProviderContext.Provider
|
||||
value={{
|
||||
shouldShowConfigurationNeeded,
|
||||
providerOptions: providerOptions ?? [],
|
||||
refreshProviderInfo,
|
||||
llmProviders,
|
||||
isLoadingProviders,
|
||||
|
||||
@@ -17,6 +17,7 @@ const mockProviderStatus = {
|
||||
llmProviders: [] as unknown[],
|
||||
isLoadingProviders: false,
|
||||
hasProviders: false,
|
||||
providerOptions: [],
|
||||
refreshProviderInfo: jest.fn(),
|
||||
};
|
||||
|
||||
@@ -70,6 +71,7 @@ describe("useShowOnboarding", () => {
|
||||
mockProviderStatus.llmProviders = [];
|
||||
mockProviderStatus.isLoadingProviders = false;
|
||||
mockProviderStatus.hasProviders = false;
|
||||
mockProviderStatus.providerOptions = [];
|
||||
});
|
||||
|
||||
it("returns showOnboarding=false while providers are loading", () => {
|
||||
@@ -196,6 +198,7 @@ describe("useShowOnboarding", () => {
|
||||
OnboardingStep.Welcome
|
||||
);
|
||||
expect(result.current.onboardingActions).toBeDefined();
|
||||
expect(result.current.llmDescriptors).toEqual([]);
|
||||
});
|
||||
|
||||
describe("localStorage persistence", () => {
|
||||
|
||||
@@ -5,7 +5,6 @@ import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { SWR_KEYS } from "@/lib/swr-keys";
|
||||
import {
|
||||
LLMProviderDescriptor,
|
||||
LLMProviderName,
|
||||
LLMProviderResponse,
|
||||
LLMProviderView,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
@@ -139,12 +138,12 @@ export function useAdminLLMProviders() {
|
||||
* Used inside individual provider modals to pre-populate model lists
|
||||
* before the user has entered credentials.
|
||||
*
|
||||
* @param providerName - The provider's API endpoint name (e.g. "openai", "anthropic").
|
||||
* @param providerEndpoint - The provider's API endpoint name (e.g. "openai", "anthropic").
|
||||
* Pass `null` to suppress the request.
|
||||
*/
|
||||
export function useWellKnownLLMProvider(providerName: LLMProviderName) {
|
||||
export function useWellKnownLLMProvider(providerEndpoint: string | null) {
|
||||
const { data, error, isLoading } = useSWR<WellKnownLLMProviderDescriptor>(
|
||||
providerName ? SWR_KEYS.wellKnownLlmProvider(providerName) : null,
|
||||
providerEndpoint ? SWR_KEYS.wellKnownLlmProvider(providerEndpoint) : null,
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
|
||||
@@ -1,192 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useCallback, useEffect, useMemo } from "react";
|
||||
import {
|
||||
MAX_MODELS,
|
||||
SelectedModel,
|
||||
} from "@/refresh-components/popovers/ModelSelector";
|
||||
import { LLMOverride } from "@/app/app/services/lib";
|
||||
import { LlmManager } from "@/lib/hooks";
|
||||
import { buildLlmOptions } from "@/refresh-components/popovers/LLMPopover";
|
||||
|
||||
export interface UseMultiModelChatReturn {
|
||||
/** Currently selected models for multi-model comparison. */
|
||||
selectedModels: SelectedModel[];
|
||||
/** Whether multi-model mode is active (>1 model selected). */
|
||||
isMultiModelActive: boolean;
|
||||
/** Add a model to the selection. */
|
||||
addModel: (model: SelectedModel) => void;
|
||||
/** Remove a model by index. */
|
||||
removeModel: (index: number) => void;
|
||||
/** Replace a model at a specific index with a new one. */
|
||||
replaceModel: (index: number, model: SelectedModel) => void;
|
||||
/** Clear all selected models. */
|
||||
clearModels: () => void;
|
||||
/** Build the LLMOverride[] array from selectedModels. */
|
||||
buildLlmOverrides: () => LLMOverride[];
|
||||
/**
|
||||
* Restore multi-model selection from model version strings (e.g. from chat history).
|
||||
* Matches against available llmOptions to reconstruct full SelectedModel objects.
|
||||
*/
|
||||
restoreFromModelNames: (modelNames: string[]) => void;
|
||||
/**
|
||||
* Switch to a single model by name (after user picks a preferred response).
|
||||
* Matches against llmOptions to find the full SelectedModel.
|
||||
*/
|
||||
selectSingleModel: (modelName: string) => void;
|
||||
}
|
||||
|
||||
export default function useMultiModelChat(
|
||||
llmManager: LlmManager
|
||||
): UseMultiModelChatReturn {
|
||||
const [selectedModels, setSelectedModels] = useState<SelectedModel[]>([]);
|
||||
const [defaultInitialized, setDefaultInitialized] = useState(false);
|
||||
|
||||
// Initialize with the default model from llmManager once providers load
|
||||
const llmOptions = useMemo(
|
||||
() =>
|
||||
llmManager.llmProviders ? buildLlmOptions(llmManager.llmProviders) : [],
|
||||
[llmManager.llmProviders]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (defaultInitialized) return;
|
||||
if (llmOptions.length === 0) return;
|
||||
const { currentLlm } = llmManager;
|
||||
// Don't initialize if currentLlm hasn't loaded yet
|
||||
if (!currentLlm.modelName) return;
|
||||
const match = llmOptions.find(
|
||||
(opt) =>
|
||||
opt.provider === currentLlm.provider &&
|
||||
opt.modelName === currentLlm.modelName
|
||||
);
|
||||
if (match) {
|
||||
setSelectedModels([
|
||||
{
|
||||
name: match.name,
|
||||
provider: match.provider,
|
||||
modelName: match.modelName,
|
||||
displayName: match.displayName,
|
||||
},
|
||||
]);
|
||||
setDefaultInitialized(true);
|
||||
}
|
||||
}, [llmOptions, llmManager.currentLlm, defaultInitialized]);
|
||||
|
||||
const isMultiModelActive = selectedModels.length > 1;
|
||||
|
||||
const addModel = useCallback((model: SelectedModel) => {
|
||||
setSelectedModels((prev) => {
|
||||
if (prev.length >= MAX_MODELS) return prev;
|
||||
if (
|
||||
prev.some(
|
||||
(m) =>
|
||||
m.provider === model.provider && m.modelName === model.modelName
|
||||
)
|
||||
) {
|
||||
return prev;
|
||||
}
|
||||
return [...prev, model];
|
||||
});
|
||||
}, []);
|
||||
|
||||
const removeModel = useCallback((index: number) => {
|
||||
setSelectedModels((prev) => prev.filter((_, i) => i !== index));
|
||||
}, []);
|
||||
|
||||
const replaceModel = useCallback((index: number, model: SelectedModel) => {
|
||||
setSelectedModels((prev) => {
|
||||
// Don't replace with a model that's already selected elsewhere
|
||||
if (
|
||||
prev.some(
|
||||
(m, i) =>
|
||||
i !== index &&
|
||||
m.provider === model.provider &&
|
||||
m.modelName === model.modelName
|
||||
)
|
||||
) {
|
||||
return prev;
|
||||
}
|
||||
const next = [...prev];
|
||||
next[index] = model;
|
||||
return next;
|
||||
});
|
||||
}, []);
|
||||
|
||||
const clearModels = useCallback(() => {
|
||||
setSelectedModels([]);
|
||||
}, []);
|
||||
|
||||
const restoreFromModelNames = useCallback(
|
||||
(modelNames: string[]) => {
|
||||
if (modelNames.length < 2 || llmOptions.length === 0) return;
|
||||
const restored: SelectedModel[] = [];
|
||||
for (const name of modelNames) {
|
||||
// Try matching by modelName (raw version string like "claude-opus-4-6")
|
||||
// or by displayName (friendly name like "Claude Opus 4.6")
|
||||
const match = llmOptions.find(
|
||||
(opt) =>
|
||||
opt.modelName === name ||
|
||||
opt.displayName === name ||
|
||||
opt.name === name
|
||||
);
|
||||
if (match) {
|
||||
restored.push({
|
||||
name: match.name,
|
||||
provider: match.provider,
|
||||
modelName: match.modelName,
|
||||
displayName: match.displayName,
|
||||
});
|
||||
}
|
||||
}
|
||||
if (restored.length >= 2) {
|
||||
setSelectedModels(restored.slice(0, MAX_MODELS));
|
||||
setDefaultInitialized(true);
|
||||
}
|
||||
},
|
||||
[llmOptions]
|
||||
);
|
||||
|
||||
const selectSingleModel = useCallback(
|
||||
(modelName: string) => {
|
||||
if (llmOptions.length === 0) return;
|
||||
const match = llmOptions.find(
|
||||
(opt) =>
|
||||
opt.modelName === modelName ||
|
||||
opt.displayName === modelName ||
|
||||
opt.name === modelName
|
||||
);
|
||||
if (match) {
|
||||
setSelectedModels([
|
||||
{
|
||||
name: match.name,
|
||||
provider: match.provider,
|
||||
modelName: match.modelName,
|
||||
displayName: match.displayName,
|
||||
},
|
||||
]);
|
||||
}
|
||||
},
|
||||
[llmOptions]
|
||||
);
|
||||
|
||||
const buildLlmOverrides = useCallback((): LLMOverride[] => {
|
||||
return selectedModels.map((m) => ({
|
||||
model_provider: m.provider,
|
||||
model_version: m.modelName,
|
||||
display_name: m.displayName,
|
||||
}));
|
||||
}, [selectedModels]);
|
||||
|
||||
return {
|
||||
selectedModels,
|
||||
isMultiModelActive,
|
||||
addModel,
|
||||
removeModel,
|
||||
replaceModel,
|
||||
clearModels,
|
||||
buildLlmOverrides,
|
||||
restoreFromModelNames,
|
||||
selectSingleModel,
|
||||
};
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
OnboardingState,
|
||||
OnboardingStep,
|
||||
} from "@/interfaces/onboarding";
|
||||
import { WellKnownLLMProviderDescriptor } from "@/interfaces/llm";
|
||||
import { updateUserPersonalization } from "@/lib/userSettings";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import { MinimalPersonaSnapshot } from "@/app/admin/agents/interfaces";
|
||||
@@ -21,6 +22,7 @@ function getOnboardingCompletedKey(userId: string): string {
|
||||
|
||||
function useOnboardingState(liveAgent?: MinimalPersonaSnapshot): {
|
||||
state: OnboardingState;
|
||||
llmDescriptors: WellKnownLLMProviderDescriptor[];
|
||||
actions: OnboardingActions;
|
||||
isLoading: boolean;
|
||||
hasProviders: boolean;
|
||||
@@ -33,6 +35,7 @@ function useOnboardingState(liveAgent?: MinimalPersonaSnapshot): {
|
||||
llmProviders,
|
||||
isLoadingProviders,
|
||||
hasProviders: hasLlmProviders,
|
||||
providerOptions,
|
||||
refreshProviderInfo,
|
||||
} = useProviderStatus();
|
||||
|
||||
@@ -40,6 +43,7 @@ function useOnboardingState(liveAgent?: MinimalPersonaSnapshot): {
|
||||
const { refetch: refreshPersonaProviders } = useLLMProviders(liveAgent?.id);
|
||||
|
||||
const userName = user?.personalization?.name;
|
||||
const llmDescriptors = providerOptions;
|
||||
|
||||
const nameUpdateTimeoutRef = useRef<ReturnType<typeof setTimeout> | null>(
|
||||
null
|
||||
@@ -231,6 +235,7 @@ function useOnboardingState(liveAgent?: MinimalPersonaSnapshot): {
|
||||
|
||||
return {
|
||||
state,
|
||||
llmDescriptors,
|
||||
actions: {
|
||||
nextStep,
|
||||
prevStep,
|
||||
@@ -275,6 +280,7 @@ export function useShowOnboarding({
|
||||
const {
|
||||
state: onboardingState,
|
||||
actions: onboardingActions,
|
||||
llmDescriptors,
|
||||
isLoading: isLoadingOnboarding,
|
||||
hasProviders: hasAnyProvider,
|
||||
} = useOnboardingState(liveAgent);
|
||||
@@ -344,6 +350,7 @@ export function useShowOnboarding({
|
||||
onboardingDismissed,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptors,
|
||||
isLoadingOnboarding,
|
||||
hideOnboarding,
|
||||
finishOnboarding,
|
||||
|
||||
@@ -15,7 +15,6 @@ export enum LLMProviderName {
|
||||
LITELLM = "litellm",
|
||||
LITELLM_PROXY = "litellm_proxy",
|
||||
BIFROST = "bifrost",
|
||||
OPENAI_COMPATIBLE = "openai_compatible",
|
||||
CUSTOM = "custom",
|
||||
}
|
||||
|
||||
@@ -123,12 +122,16 @@ export interface LLMProviderFormProps {
|
||||
variant?: LLMModalVariant;
|
||||
existingLlmProvider?: LLMProviderView;
|
||||
shouldMarkAsDefault?: boolean;
|
||||
open?: boolean;
|
||||
onOpenChange?: (open: boolean) => void;
|
||||
/** Called after successful provider creation/update. */
|
||||
onSuccess?: () => void | Promise<void>;
|
||||
|
||||
/** The current default model name for this provider (from the global default). */
|
||||
defaultModelName?: string;
|
||||
|
||||
// Onboarding-specific (only when variant === "onboarding")
|
||||
onboardingState?: OnboardingState;
|
||||
onboardingActions?: OnboardingActions;
|
||||
llmDescriptor?: WellKnownLLMProviderDescriptor;
|
||||
}
|
||||
|
||||
// Param types for model fetching functions - use snake_case to match API structure
|
||||
@@ -179,21 +182,6 @@ export interface BifrostModelResponse {
|
||||
supports_reasoning: boolean;
|
||||
}
|
||||
|
||||
export interface OpenAICompatibleFetchParams {
|
||||
api_base?: string;
|
||||
api_key?: string;
|
||||
provider_name?: string;
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
|
||||
export interface OpenAICompatibleModelResponse {
|
||||
name: string;
|
||||
display_name: string;
|
||||
max_input_tokens: number | null;
|
||||
supports_image_input: boolean;
|
||||
supports_reasoning: boolean;
|
||||
}
|
||||
|
||||
export interface VertexAIFetchParams {
|
||||
model_configurations?: ModelConfiguration[];
|
||||
}
|
||||
@@ -212,6 +200,5 @@ export type FetchModelsParams =
|
||||
| OpenRouterFetchParams
|
||||
| LiteLLMProxyFetchParams
|
||||
| BifrostFetchParams
|
||||
| OpenAICompatibleFetchParams
|
||||
| VertexAIFetchParams
|
||||
| LMStudioFetchParams;
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
"use client";
|
||||
|
||||
import type { RichStr, WithoutStyles } from "@opal/types";
|
||||
import type { RichStr } from "@opal/types";
|
||||
import { resolveStr } from "@opal/components/text/InlineMarkdown";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import { SvgXOctagon, SvgAlertCircle } from "@opal/icons";
|
||||
import { useField, useFormikContext } from "formik";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
@@ -230,27 +229,9 @@ function ErrorTextLayout({ children, type = "error" }: ErrorTextLayoutProps) {
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* FieldSeparator - A horizontal rule with inline padding, used to visually separate field groups.
|
||||
*/
|
||||
function FieldSeparator() {
|
||||
return <Separator noPadding className="p-2" />;
|
||||
}
|
||||
|
||||
/**
|
||||
* FieldPadder - Wraps a field in standard horizontal + vertical padding (`p-2 w-full`).
|
||||
*/
|
||||
type FieldPadderProps = WithoutStyles<React.HTMLAttributes<HTMLDivElement>>;
|
||||
function FieldPadder(props: FieldPadderProps) {
|
||||
return <div {...props} className="p-2 w-full" />;
|
||||
}
|
||||
|
||||
export {
|
||||
VerticalInputLayout as Vertical,
|
||||
HorizontalInputLayout as Horizontal,
|
||||
ErrorLayout as Error,
|
||||
ErrorTextLayout,
|
||||
FieldSeparator,
|
||||
FieldPadder,
|
||||
type FieldPadderProps,
|
||||
};
|
||||
|
||||
@@ -7,7 +7,6 @@ import {
|
||||
SvgOllama,
|
||||
SvgAws,
|
||||
SvgOpenrouter,
|
||||
SvgPlug,
|
||||
SvgServer,
|
||||
SvgAzure,
|
||||
SvgGemini,
|
||||
@@ -28,7 +27,6 @@ const PROVIDER_ICONS: Record<string, IconFunctionComponent> = {
|
||||
[LLMProviderName.OPENROUTER]: SvgOpenrouter,
|
||||
[LLMProviderName.LM_STUDIO]: SvgLmStudio,
|
||||
[LLMProviderName.BIFROST]: SvgBifrost,
|
||||
[LLMProviderName.OPENAI_COMPATIBLE]: SvgPlug,
|
||||
|
||||
// fallback
|
||||
[LLMProviderName.CUSTOM]: SvgServer,
|
||||
@@ -46,7 +44,6 @@ const PROVIDER_PRODUCT_NAMES: Record<string, string> = {
|
||||
[LLMProviderName.OPENROUTER]: "OpenRouter",
|
||||
[LLMProviderName.LM_STUDIO]: "LM Studio",
|
||||
[LLMProviderName.BIFROST]: "Bifrost",
|
||||
[LLMProviderName.OPENAI_COMPATIBLE]: "OpenAI Compatible",
|
||||
|
||||
// fallback
|
||||
[LLMProviderName.CUSTOM]: "Custom Models",
|
||||
@@ -64,7 +61,6 @@ const PROVIDER_DISPLAY_NAMES: Record<string, string> = {
|
||||
[LLMProviderName.OPENROUTER]: "OpenRouter",
|
||||
[LLMProviderName.LM_STUDIO]: "LM Studio",
|
||||
[LLMProviderName.BIFROST]: "Bifrost",
|
||||
[LLMProviderName.OPENAI_COMPATIBLE]: "OpenAI Compatible",
|
||||
|
||||
// fallback
|
||||
[LLMProviderName.CUSTOM]: "Other providers or self-hosted",
|
||||
|
||||
@@ -80,11 +80,6 @@ export const SWR_KEYS = {
|
||||
// ── Users ─────────────────────────────────────────────────────────────────
|
||||
acceptedUsers: "/api/manage/users/accepted/all",
|
||||
invitedUsers: "/api/manage/users/invited",
|
||||
// Curator-accessible listing of all users (and optionally service-account
|
||||
// entries when `?include_api_keys=true`). Used by group create/edit pages so
|
||||
// global curators — who cannot hit the admin-only `/accepted/all` and
|
||||
// `/invited` endpoints — can still load the member picker.
|
||||
groupMemberCandidates: "/api/manage/users?include_api_keys=true",
|
||||
pendingTenantUsers: "/api/tenants/users/pending",
|
||||
userCounts: "/api/manage/users/counts",
|
||||
|
||||
|
||||
@@ -89,6 +89,18 @@ export const KeyWideLayout: Story = {
|
||||
},
|
||||
};
|
||||
|
||||
export const Disabled: Story = {
|
||||
render: () => (
|
||||
<KeyValueInput
|
||||
keyTitle="Key"
|
||||
valueTitle="Value"
|
||||
items={[{ key: "LOCKED", value: "cannot-edit" }]}
|
||||
onChange={() => {}}
|
||||
disabled
|
||||
/>
|
||||
),
|
||||
};
|
||||
|
||||
export const EmptyLineMode: Story = {
|
||||
render: function EmptyStory() {
|
||||
const [items, setItems] = React.useState<KeyValue[]>([]);
|
||||
|
||||
@@ -68,13 +68,21 @@
|
||||
* ```
|
||||
*/
|
||||
|
||||
import React, { useCallback, useEffect, useMemo, useRef } from "react";
|
||||
import React, {
|
||||
useCallback,
|
||||
useContext,
|
||||
useEffect,
|
||||
useMemo,
|
||||
useId,
|
||||
useRef,
|
||||
} from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import InputTypeIn from "./InputTypeIn";
|
||||
import { Button, EmptyMessageCard } from "@opal/components";
|
||||
import type { WithoutStyles } from "@opal/types";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { ErrorTextLayout } from "@/layouts/input-layouts";
|
||||
import { FieldContext } from "../form/FieldContext";
|
||||
import { FieldMessage } from "../messages/FieldMessage";
|
||||
import { SvgMinusCircle, SvgPlusCircle } from "@opal/icons";
|
||||
|
||||
export type KeyValue = { key: string; value: string };
|
||||
@@ -99,50 +107,82 @@ const GRID_COLS = {
|
||||
interface KeyValueInputItemProps {
|
||||
item: KeyValue;
|
||||
onChange: (next: KeyValue) => void;
|
||||
disabled?: boolean;
|
||||
onRemove: () => void;
|
||||
keyPlaceholder?: string;
|
||||
valuePlaceholder?: string;
|
||||
error?: KeyValueError;
|
||||
canRemove: boolean;
|
||||
index: number;
|
||||
fieldId: string;
|
||||
}
|
||||
|
||||
function KeyValueInputItem({
|
||||
item,
|
||||
onChange,
|
||||
disabled,
|
||||
onRemove,
|
||||
keyPlaceholder,
|
||||
valuePlaceholder,
|
||||
error,
|
||||
canRemove,
|
||||
index,
|
||||
fieldId,
|
||||
}: KeyValueInputItemProps) {
|
||||
return (
|
||||
<>
|
||||
<div className="flex flex-col gap-y-0.5">
|
||||
<InputTypeIn
|
||||
placeholder={keyPlaceholder}
|
||||
placeholder={keyPlaceholder || "Key"}
|
||||
value={item.key}
|
||||
onChange={(e) => onChange({ ...item, key: e.target.value })}
|
||||
aria-label={`${keyPlaceholder || "Key"} ${index + 1}`}
|
||||
aria-invalid={!!error?.key}
|
||||
aria-describedby={
|
||||
error?.key ? `${fieldId}-key-error-${index}` : undefined
|
||||
}
|
||||
variant={disabled ? "disabled" : undefined}
|
||||
showClearButton={false}
|
||||
/>
|
||||
{error?.key && <ErrorTextLayout>{error.key}</ErrorTextLayout>}
|
||||
{error?.key && (
|
||||
<FieldMessage variant="error" className="ml-0.5">
|
||||
<FieldMessage.Content
|
||||
id={`${fieldId}-key-error-${index}`}
|
||||
role="alert"
|
||||
className="ml-0.5"
|
||||
>
|
||||
{error.key}
|
||||
</FieldMessage.Content>
|
||||
</FieldMessage>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex flex-col gap-y-0.5">
|
||||
<InputTypeIn
|
||||
placeholder={valuePlaceholder}
|
||||
placeholder={valuePlaceholder || "Value"}
|
||||
value={item.value}
|
||||
onChange={(e) => onChange({ ...item, value: e.target.value })}
|
||||
aria-label={`${valuePlaceholder || "Value"} ${index + 1}`}
|
||||
aria-invalid={!!error?.value}
|
||||
aria-describedby={
|
||||
error?.value ? `${fieldId}-value-error-${index}` : undefined
|
||||
}
|
||||
variant={disabled ? "disabled" : undefined}
|
||||
showClearButton={false}
|
||||
/>
|
||||
{error?.value && <ErrorTextLayout>{error.value}</ErrorTextLayout>}
|
||||
{error?.value && (
|
||||
<FieldMessage variant="error" className="ml-0.5">
|
||||
<FieldMessage.Content
|
||||
id={`${fieldId}-value-error-${index}`}
|
||||
role="alert"
|
||||
className="ml-0.5"
|
||||
>
|
||||
{error.value}
|
||||
</FieldMessage.Content>
|
||||
</FieldMessage>
|
||||
)}
|
||||
</div>
|
||||
<Button
|
||||
disabled={!canRemove}
|
||||
disabled={disabled || !canRemove}
|
||||
prominence="tertiary"
|
||||
icon={SvgMinusCircle}
|
||||
onClick={onRemove}
|
||||
@@ -158,31 +198,46 @@ export interface KeyValueInputProps
|
||||
> {
|
||||
/** Title for the key column */
|
||||
keyTitle?: string;
|
||||
|
||||
/** Title for the value column */
|
||||
valueTitle?: string;
|
||||
|
||||
/** Placeholder for the key input */
|
||||
keyPlaceholder?: string;
|
||||
|
||||
/** Placeholder for the value input */
|
||||
valuePlaceholder?: string;
|
||||
|
||||
/** Array of key-value pairs */
|
||||
items: KeyValue[];
|
||||
|
||||
/** Callback when items change */
|
||||
onChange: (nextItems: KeyValue[]) => void;
|
||||
|
||||
/** Custom add handler */
|
||||
onAdd?: () => void;
|
||||
/** Custom remove handler */
|
||||
onRemove?: (index: number) => void;
|
||||
/** Disabled state */
|
||||
disabled?: boolean;
|
||||
/** Mode: 'line' allows removing all items, 'fixed-line' requires at least one item */
|
||||
mode?: "line" | "fixed-line";
|
||||
|
||||
/** Layout: 'equal' - both inputs same width, 'key-wide' - key input is wider (60/40 split) */
|
||||
layout?: "equal" | "key-wide";
|
||||
|
||||
/** Callback when validation state changes */
|
||||
onValidationChange?: (isValid: boolean, errors: KeyValueError[]) => void;
|
||||
/** Callback to handle validation errors - integrates with Formik or custom error handling. Called with error message when invalid, null when valid */
|
||||
onValidationError?: (errorMessage: string | null) => void;
|
||||
|
||||
/** Optional custom validator for the key field. Return { isValid, message } */
|
||||
onKeyValidate?: (
|
||||
key: string,
|
||||
index: number,
|
||||
item: KeyValue,
|
||||
items: KeyValue[]
|
||||
) => { isValid: boolean; message?: string };
|
||||
/** Optional custom validator for the value field. Return { isValid, message } */
|
||||
onValueValidate?: (
|
||||
value: string,
|
||||
index: number,
|
||||
item: KeyValue,
|
||||
items: KeyValue[]
|
||||
) => { isValid: boolean; message?: string };
|
||||
/** Whether to validate for duplicate keys */
|
||||
validateDuplicateKeys?: boolean;
|
||||
/** Whether to validate for empty keys */
|
||||
validateEmptyKeys?: boolean;
|
||||
/** Optional name for the field (for accessibility) */
|
||||
name?: string;
|
||||
/** Custom label for the add button (defaults to "Add Line") */
|
||||
addButtonLabel?: string;
|
||||
}
|
||||
@@ -190,16 +245,26 @@ export interface KeyValueInputProps
|
||||
export default function KeyValueInput({
|
||||
keyTitle = "Key",
|
||||
valueTitle = "Value",
|
||||
keyPlaceholder,
|
||||
valuePlaceholder,
|
||||
items = [],
|
||||
onChange,
|
||||
onAdd,
|
||||
onRemove,
|
||||
disabled = false,
|
||||
mode = "line",
|
||||
layout = "equal",
|
||||
onValidationChange,
|
||||
onValidationError,
|
||||
onKeyValidate,
|
||||
onValueValidate,
|
||||
validateDuplicateKeys = true,
|
||||
validateEmptyKeys = true,
|
||||
name,
|
||||
addButtonLabel = "Add Line",
|
||||
...rest
|
||||
}: KeyValueInputProps) {
|
||||
// Try to get field context if used within FormField (safe access)
|
||||
const fieldContext = useContext(FieldContext);
|
||||
|
||||
// Validation logic
|
||||
const errors = useMemo((): KeyValueError[] => {
|
||||
if (!items || items.length === 0) return [];
|
||||
@@ -208,8 +273,12 @@ export default function KeyValueInput({
|
||||
const keyCount = new Map<string, number[]>();
|
||||
|
||||
items.forEach((item, index) => {
|
||||
// Validate empty keys
|
||||
if (item.key.trim() === "" && item.value.trim() !== "") {
|
||||
// Validate empty keys - only if value is filled (user is actively working on this row)
|
||||
if (
|
||||
validateEmptyKeys &&
|
||||
item.key.trim() === "" &&
|
||||
item.value.trim() !== ""
|
||||
) {
|
||||
const error = errorsList[index];
|
||||
if (error) {
|
||||
error.key = "Key cannot be empty";
|
||||
@@ -222,22 +291,56 @@ export default function KeyValueInput({
|
||||
existing.push(index);
|
||||
keyCount.set(item.key, existing);
|
||||
}
|
||||
});
|
||||
|
||||
// Validate duplicate keys
|
||||
keyCount.forEach((indices, key) => {
|
||||
if (indices.length > 1) {
|
||||
indices.forEach((index) => {
|
||||
// Custom key validation
|
||||
if (onKeyValidate) {
|
||||
const result = onKeyValidate(item.key, index, item, items);
|
||||
if (result && result.isValid === false) {
|
||||
const error = errorsList[index];
|
||||
if (error) {
|
||||
error.key = "Duplicate key";
|
||||
error.key = result.message || "Invalid key";
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Custom value validation
|
||||
if (onValueValidate) {
|
||||
const result = onValueValidate(item.value, index, item, items);
|
||||
if (result && result.isValid === false) {
|
||||
const error = errorsList[index];
|
||||
if (error) {
|
||||
error.value = result.message || "Invalid value";
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Validate duplicate keys
|
||||
if (validateDuplicateKeys) {
|
||||
keyCount.forEach((indices, key) => {
|
||||
if (indices.length > 1) {
|
||||
indices.forEach((index) => {
|
||||
const error = errorsList[index];
|
||||
if (error) {
|
||||
error.key = "Duplicate key";
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return errorsList;
|
||||
}, [items]);
|
||||
}, [
|
||||
items,
|
||||
validateDuplicateKeys,
|
||||
validateEmptyKeys,
|
||||
onKeyValidate,
|
||||
onValueValidate,
|
||||
]);
|
||||
|
||||
const isValid = useMemo(() => {
|
||||
return errors.every((error) => !error.key && !error.value);
|
||||
}, [errors]);
|
||||
|
||||
const hasAnyError = useMemo(() => {
|
||||
return errors.some((error) => error.key || error.value);
|
||||
@@ -268,12 +371,21 @@ export default function KeyValueInput({
|
||||
}, [hasAnyError, errors]);
|
||||
|
||||
// Notify parent of validation changes
|
||||
const onValidationChangeRef = useRef(onValidationChange);
|
||||
const onValidationErrorRef = useRef(onValidationError);
|
||||
|
||||
useEffect(() => {
|
||||
onValidationChangeRef.current = onValidationChange;
|
||||
}, [onValidationChange]);
|
||||
|
||||
useEffect(() => {
|
||||
onValidationErrorRef.current = onValidationError;
|
||||
}, [onValidationError]);
|
||||
|
||||
useEffect(() => {
|
||||
onValidationChangeRef.current?.(isValid, errors);
|
||||
}, [isValid, errors]);
|
||||
|
||||
// Notify parent of error state for form library integration
|
||||
useEffect(() => {
|
||||
onValidationErrorRef.current?.(errorMessage);
|
||||
@@ -282,17 +394,25 @@ export default function KeyValueInput({
|
||||
const canRemoveItems = mode === "line" || items.length > 1;
|
||||
|
||||
const handleAdd = useCallback(() => {
|
||||
if (onAdd) {
|
||||
onAdd();
|
||||
return;
|
||||
}
|
||||
onChange([...(items || []), { key: "", value: "" }]);
|
||||
}, [onChange, items]);
|
||||
}, [onAdd, onChange, items]);
|
||||
|
||||
const handleRemove = useCallback(
|
||||
(index: number) => {
|
||||
if (!canRemoveItems && items.length === 1) return;
|
||||
|
||||
if (onRemove) {
|
||||
onRemove(index);
|
||||
return;
|
||||
}
|
||||
const next = (items || []).filter((_, i) => i !== index);
|
||||
onChange(next);
|
||||
},
|
||||
[canRemoveItems, items, onChange]
|
||||
[canRemoveItems, items, onRemove, onChange]
|
||||
);
|
||||
|
||||
const handleItemChange = useCallback(
|
||||
@@ -311,6 +431,8 @@ export default function KeyValueInput({
|
||||
}
|
||||
}, [mode]); // Only run on mode change
|
||||
|
||||
const autoId = useId();
|
||||
const fieldId = fieldContext?.baseId || name || `key-value-input-${autoId}`;
|
||||
const gridCols = GRID_COLS[layout];
|
||||
|
||||
return (
|
||||
@@ -338,24 +460,23 @@ export default function KeyValueInput({
|
||||
key={index}
|
||||
item={item}
|
||||
onChange={(next) => handleItemChange(index, next)}
|
||||
disabled={disabled}
|
||||
onRemove={() => handleRemove(index)}
|
||||
keyPlaceholder={keyPlaceholder}
|
||||
valuePlaceholder={valuePlaceholder}
|
||||
keyPlaceholder={keyTitle}
|
||||
valuePlaceholder={valueTitle}
|
||||
error={errors[index]}
|
||||
canRemove={canRemoveItems}
|
||||
index={index}
|
||||
fieldId={fieldId}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
) : (
|
||||
<EmptyMessageCard
|
||||
title="No items added yet."
|
||||
padding="sm"
|
||||
sizePreset="secondary"
|
||||
/>
|
||||
<EmptyMessageCard title="No items added yet." />
|
||||
)}
|
||||
|
||||
<Button
|
||||
disabled={disabled}
|
||||
prominence="secondary"
|
||||
onClick={handleAdd}
|
||||
icon={SvgPlusCircle}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect, useCallback, useMemo, useRef } from "react";
|
||||
import Popover from "@/refresh-components/Popover";
|
||||
import Popover, { PopoverMenu } from "@/refresh-components/Popover";
|
||||
import { LlmDescriptor, LlmManager } from "@/lib/hooks";
|
||||
import { structureValue } from "@/lib/llmConfig/utils";
|
||||
import {
|
||||
@@ -11,11 +11,25 @@ import {
|
||||
import { LLMProviderDescriptor } from "@/interfaces/llm";
|
||||
import { Slider } from "@/components/ui/slider";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import LineItem from "@/refresh-components/buttons/LineItem";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { SvgRefreshCw } from "@opal/icons";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import {
|
||||
Accordion,
|
||||
AccordionContent,
|
||||
AccordionItem,
|
||||
AccordionTrigger,
|
||||
} from "@/components/ui/accordion";
|
||||
import {
|
||||
SvgCheck,
|
||||
SvgChevronDown,
|
||||
SvgChevronRight,
|
||||
SvgRefreshCw,
|
||||
} from "@opal/icons";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { OpenButton } from "@opal/components";
|
||||
import { LLMOption, LLMOptionGroup } from "./interfaces";
|
||||
import ModelListContent from "./ModelListContent";
|
||||
|
||||
export interface LLMPopoverProps {
|
||||
llmManager: LlmManager;
|
||||
@@ -136,6 +150,7 @@ export default function LLMPopover({
|
||||
const isLoadingProviders = llmManager.isLoadingProviders;
|
||||
|
||||
const [open, setOpen] = useState(false);
|
||||
const [searchQuery, setSearchQuery] = useState("");
|
||||
const { user } = useUser();
|
||||
|
||||
const [localTemperature, setLocalTemperature] = useState(
|
||||
@@ -146,7 +161,9 @@ export default function LLMPopover({
|
||||
setLocalTemperature(llmManager.temperature ?? 0.5);
|
||||
}, [llmManager.temperature]);
|
||||
|
||||
const searchInputRef = useRef<HTMLInputElement>(null);
|
||||
const scrollContainerRef = useRef<HTMLDivElement>(null);
|
||||
const selectedItemRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const handleGlobalTemperatureChange = useCallback((value: number[]) => {
|
||||
const value_0 = value[0];
|
||||
@@ -165,28 +182,39 @@ export default function LLMPopover({
|
||||
[llmManager]
|
||||
);
|
||||
|
||||
const isSelected = useCallback(
|
||||
(option: LLMOption) =>
|
||||
option.modelName === llmManager.currentLlm.modelName &&
|
||||
option.provider === llmManager.currentLlm.provider,
|
||||
[llmManager.currentLlm.modelName, llmManager.currentLlm.provider]
|
||||
const llmOptions = useMemo(
|
||||
() => buildLlmOptions(llmProviders, currentModelName),
|
||||
[llmProviders, currentModelName]
|
||||
);
|
||||
|
||||
const handleSelectModel = useCallback(
|
||||
(option: LLMOption) => {
|
||||
llmManager.updateCurrentLlm({
|
||||
modelName: option.modelName,
|
||||
provider: option.provider,
|
||||
name: option.name,
|
||||
} as LlmDescriptor);
|
||||
onSelect?.(
|
||||
structureValue(option.name, option.provider, option.modelName)
|
||||
// Filter options by vision capability (when images are uploaded) and search query
|
||||
const filteredOptions = useMemo(() => {
|
||||
let result = llmOptions;
|
||||
if (requiresImageInput) {
|
||||
result = result.filter((opt) => opt.supportsImageInput);
|
||||
}
|
||||
if (searchQuery.trim()) {
|
||||
const query = searchQuery.toLowerCase();
|
||||
result = result.filter(
|
||||
(opt) =>
|
||||
opt.displayName.toLowerCase().includes(query) ||
|
||||
opt.modelName.toLowerCase().includes(query) ||
|
||||
(opt.vendor && opt.vendor.toLowerCase().includes(query))
|
||||
);
|
||||
setOpen(false);
|
||||
},
|
||||
[llmManager, onSelect]
|
||||
}
|
||||
return result;
|
||||
}, [llmOptions, searchQuery, requiresImageInput]);
|
||||
|
||||
// Group options by provider using backend-provided display names and ordering
|
||||
// For aggregator providers (bedrock, openrouter, vertex_ai), flatten to "Provider/Vendor" format
|
||||
const groupedOptions = useMemo(
|
||||
() => groupLlmOptions(filteredOptions),
|
||||
[filteredOptions]
|
||||
);
|
||||
|
||||
// Get display name for the model to show in the button
|
||||
// Use currentModelName prop if provided (e.g., for regenerate showing the model used),
|
||||
// otherwise fall back to the globally selected model
|
||||
const currentLlmDisplayName = useMemo(() => {
|
||||
// Only use currentModelName if it's a non-empty string
|
||||
const currentModel =
|
||||
@@ -206,30 +234,122 @@ export default function LLMPopover({
|
||||
return currentModel;
|
||||
}, [llmProviders, currentModelName, llmManager.currentLlm.modelName]);
|
||||
|
||||
const temperatureFooter = user?.preferences?.temperature_override_enabled ? (
|
||||
<>
|
||||
<div className="border-t border-border-02 mx-2" />
|
||||
<div className="flex flex-col w-full py-2 gap-2">
|
||||
<Slider
|
||||
value={[localTemperature]}
|
||||
max={llmManager.maxTemperature}
|
||||
min={0}
|
||||
step={0.01}
|
||||
onValueChange={handleGlobalTemperatureChange}
|
||||
onValueCommit={handleGlobalTemperatureCommit}
|
||||
className="w-full"
|
||||
/>
|
||||
<div className="flex flex-row items-center justify-between">
|
||||
<Text secondaryBody text03>
|
||||
Temperature (creativity)
|
||||
</Text>
|
||||
<Text secondaryBody text03>
|
||||
{localTemperature.toFixed(1)}
|
||||
</Text>
|
||||
</div>
|
||||
// Determine which group the current model belongs to (for auto-expand)
|
||||
const currentGroupKey = useMemo(() => {
|
||||
const currentModel = llmManager.currentLlm.modelName;
|
||||
const currentProvider = llmManager.currentLlm.provider;
|
||||
// Match by both modelName AND provider to handle same model name across providers
|
||||
const option = llmOptions.find(
|
||||
(o) => o.modelName === currentModel && o.provider === currentProvider
|
||||
);
|
||||
if (!option) return "openai";
|
||||
|
||||
const provider = option.provider.toLowerCase();
|
||||
const isAggregator = AGGREGATOR_PROVIDERS.has(provider);
|
||||
|
||||
if (isAggregator && option.vendor) {
|
||||
return `${provider}/${option.vendor.toLowerCase()}`;
|
||||
}
|
||||
return provider;
|
||||
}, [
|
||||
llmOptions,
|
||||
llmManager.currentLlm.modelName,
|
||||
llmManager.currentLlm.provider,
|
||||
]);
|
||||
|
||||
// Track expanded groups - initialize with current model's group
|
||||
const [expandedGroups, setExpandedGroups] = useState<string[]>([
|
||||
currentGroupKey,
|
||||
]);
|
||||
|
||||
// Reset state when popover closes/opens
|
||||
useEffect(() => {
|
||||
if (!open) {
|
||||
setSearchQuery("");
|
||||
} else {
|
||||
// Reset expanded groups to only show the selected model's group
|
||||
setExpandedGroups([currentGroupKey]);
|
||||
}
|
||||
}, [open, currentGroupKey]);
|
||||
|
||||
// Auto-scroll to selected model when popover opens
|
||||
useEffect(() => {
|
||||
if (open) {
|
||||
// Small delay to let accordion content render
|
||||
const timer = setTimeout(() => {
|
||||
selectedItemRef.current?.scrollIntoView({
|
||||
behavior: "instant",
|
||||
block: "center",
|
||||
});
|
||||
}, 50);
|
||||
return () => clearTimeout(timer);
|
||||
}
|
||||
}, [open]);
|
||||
|
||||
const isSearching = searchQuery.trim().length > 0;
|
||||
|
||||
// Compute final expanded groups
|
||||
const effectiveExpandedGroups = useMemo(() => {
|
||||
if (isSearching) {
|
||||
// Force expand all when searching
|
||||
return groupedOptions.map((g) => g.key);
|
||||
}
|
||||
return expandedGroups;
|
||||
}, [isSearching, groupedOptions, expandedGroups]);
|
||||
|
||||
// Handler for accordion changes
|
||||
const handleAccordionChange = (value: string[]) => {
|
||||
// Only update state when not searching (force-expanding)
|
||||
if (!isSearching) {
|
||||
setExpandedGroups(value);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSelectModel = (option: LLMOption) => {
|
||||
llmManager.updateCurrentLlm({
|
||||
modelName: option.modelName,
|
||||
provider: option.provider,
|
||||
name: option.name,
|
||||
} as LlmDescriptor);
|
||||
onSelect?.(structureValue(option.name, option.provider, option.modelName));
|
||||
setOpen(false);
|
||||
};
|
||||
|
||||
const renderModelItem = (option: LLMOption) => {
|
||||
const isSelected =
|
||||
option.modelName === llmManager.currentLlm.modelName &&
|
||||
option.provider === llmManager.currentLlm.provider;
|
||||
|
||||
const capabilities: string[] = [];
|
||||
if (option.supportsReasoning) {
|
||||
capabilities.push("Reasoning");
|
||||
}
|
||||
if (option.supportsImageInput) {
|
||||
capabilities.push("Vision");
|
||||
}
|
||||
const description =
|
||||
capabilities.length > 0 ? capabilities.join(", ") : undefined;
|
||||
|
||||
return (
|
||||
<div
|
||||
key={`${option.name}-${option.modelName}`}
|
||||
ref={isSelected ? selectedItemRef : undefined}
|
||||
>
|
||||
<LineItem
|
||||
selected={isSelected}
|
||||
description={description}
|
||||
onClick={() => handleSelectModel(option)}
|
||||
rightChildren={
|
||||
isSelected ? (
|
||||
<SvgCheck className="h-4 w-4 stroke-action-link-05 shrink-0" />
|
||||
) : null
|
||||
}
|
||||
>
|
||||
{option.displayName}
|
||||
</LineItem>
|
||||
</div>
|
||||
</>
|
||||
) : undefined;
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<Popover open={open} onOpenChange={setOpen}>
|
||||
@@ -253,16 +373,129 @@ export default function LLMPopover({
|
||||
</div>
|
||||
|
||||
<Popover.Content side="top" align="end" width="xl">
|
||||
<ModelListContent
|
||||
llmProviders={llmProviders}
|
||||
currentModelName={currentModelName}
|
||||
requiresImageInput={requiresImageInput}
|
||||
isLoading={isLoadingProviders}
|
||||
onSelect={handleSelectModel}
|
||||
isSelected={isSelected}
|
||||
scrollContainerRef={scrollContainerRef}
|
||||
footer={temperatureFooter}
|
||||
/>
|
||||
<Section gap={0.5}>
|
||||
{/* Search Input */}
|
||||
<InputTypeIn
|
||||
ref={searchInputRef}
|
||||
leftSearchIcon
|
||||
variant="internal"
|
||||
value={searchQuery}
|
||||
onChange={(e) => setSearchQuery(e.target.value)}
|
||||
placeholder="Search models..."
|
||||
/>
|
||||
|
||||
{/* Model List with Vendor Groups */}
|
||||
<PopoverMenu scrollContainerRef={scrollContainerRef}>
|
||||
{isLoadingProviders
|
||||
? [
|
||||
<div key="loading" className="flex items-center gap-2 py-3">
|
||||
<SimpleLoader />
|
||||
<Text secondaryBody text03>
|
||||
Loading models...
|
||||
</Text>
|
||||
</div>,
|
||||
]
|
||||
: groupedOptions.length === 0
|
||||
? [
|
||||
<div key="empty" className="py-3">
|
||||
<Text secondaryBody text03>
|
||||
No models found
|
||||
</Text>
|
||||
</div>,
|
||||
]
|
||||
: groupedOptions.length === 1
|
||||
? // Single provider - show models directly without accordion
|
||||
[
|
||||
<div
|
||||
key="single-provider"
|
||||
className="flex flex-col gap-1"
|
||||
>
|
||||
{groupedOptions[0]!.options.map(renderModelItem)}
|
||||
</div>,
|
||||
]
|
||||
: // Multiple providers - show accordion with groups
|
||||
[
|
||||
<Accordion
|
||||
key="accordion"
|
||||
type="multiple"
|
||||
value={effectiveExpandedGroups}
|
||||
onValueChange={handleAccordionChange}
|
||||
className="w-full flex flex-col"
|
||||
>
|
||||
{groupedOptions.map((group) => {
|
||||
const isExpanded = effectiveExpandedGroups.includes(
|
||||
group.key
|
||||
);
|
||||
return (
|
||||
<AccordionItem
|
||||
key={group.key}
|
||||
value={group.key}
|
||||
className="border-none pt-1"
|
||||
>
|
||||
{/* Group Header */}
|
||||
<AccordionTrigger className="flex items-center rounded-08 hover:no-underline hover:bg-background-tint-02 group [&>svg]:hidden w-full py-1">
|
||||
<div className="flex items-center gap-1 shrink-0">
|
||||
<div className="flex items-center justify-center size-5 shrink-0">
|
||||
<group.Icon size={16} />
|
||||
</div>
|
||||
<Text
|
||||
secondaryBody
|
||||
text03
|
||||
nowrap
|
||||
className="px-0.5"
|
||||
>
|
||||
{group.displayName}
|
||||
</Text>
|
||||
</div>
|
||||
<div className="flex-1" />
|
||||
<div className="flex items-center justify-center size-6 shrink-0">
|
||||
{isExpanded ? (
|
||||
<SvgChevronDown className="h-4 w-4 stroke-text-04 shrink-0" />
|
||||
) : (
|
||||
<SvgChevronRight className="h-4 w-4 stroke-text-04 shrink-0" />
|
||||
)}
|
||||
</div>
|
||||
</AccordionTrigger>
|
||||
|
||||
{/* Model Items - full width highlight */}
|
||||
<AccordionContent className="pb-0 pt-0">
|
||||
<div className="flex flex-col gap-1">
|
||||
{group.options.map(renderModelItem)}
|
||||
</div>
|
||||
</AccordionContent>
|
||||
</AccordionItem>
|
||||
);
|
||||
})}
|
||||
</Accordion>,
|
||||
]}
|
||||
</PopoverMenu>
|
||||
|
||||
{/* Global Temperature Slider (shown if enabled in user prefs) */}
|
||||
{user?.preferences?.temperature_override_enabled && (
|
||||
<>
|
||||
<div className="border-t border-border-02 mx-2" />
|
||||
<div className="flex flex-col w-full py-2 gap-2">
|
||||
<Slider
|
||||
value={[localTemperature]}
|
||||
max={llmManager.maxTemperature}
|
||||
min={0}
|
||||
step={0.01}
|
||||
onValueChange={handleGlobalTemperatureChange}
|
||||
onValueCommit={handleGlobalTemperatureCommit}
|
||||
className="w-full"
|
||||
/>
|
||||
<div className="flex flex-row items-center justify-between">
|
||||
<Text secondaryBody text03>
|
||||
Temperature (creativity)
|
||||
</Text>
|
||||
<Text secondaryBody text03>
|
||||
{localTemperature.toFixed(1)}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</Section>
|
||||
</Popover.Content>
|
||||
</Popover>
|
||||
);
|
||||
|
||||
@@ -1,200 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useMemo, useRef, useEffect } from "react";
|
||||
import { PopoverMenu } from "@/refresh-components/Popover";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import { Text } from "@opal/components";
|
||||
import { SvgCheck, SvgChevronDown, SvgChevronRight } from "@opal/icons";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { LLMOption } from "./interfaces";
|
||||
import { buildLlmOptions, groupLlmOptions } from "./LLMPopover";
|
||||
import LineItem from "@/refresh-components/buttons/LineItem";
|
||||
import { LLMProviderDescriptor } from "@/interfaces/llm";
|
||||
import {
|
||||
Collapsible,
|
||||
CollapsibleContent,
|
||||
CollapsibleTrigger,
|
||||
} from "@/refresh-components/Collapsible";
|
||||
|
||||
export interface ModelListContentProps {
|
||||
llmProviders: LLMProviderDescriptor[] | undefined;
|
||||
currentModelName?: string;
|
||||
requiresImageInput?: boolean;
|
||||
onSelect: (option: LLMOption) => void;
|
||||
isSelected: (option: LLMOption) => boolean;
|
||||
isDisabled?: (option: LLMOption) => boolean;
|
||||
scrollContainerRef?: React.RefObject<HTMLDivElement | null>;
|
||||
isLoading?: boolean;
|
||||
footer?: React.ReactNode;
|
||||
}
|
||||
|
||||
export default function ModelListContent({
|
||||
llmProviders,
|
||||
currentModelName,
|
||||
requiresImageInput,
|
||||
onSelect,
|
||||
isSelected,
|
||||
isDisabled,
|
||||
scrollContainerRef: externalScrollRef,
|
||||
isLoading,
|
||||
footer,
|
||||
}: ModelListContentProps) {
|
||||
const [searchQuery, setSearchQuery] = useState("");
|
||||
const internalScrollRef = useRef<HTMLDivElement>(null);
|
||||
const scrollContainerRef = externalScrollRef ?? internalScrollRef;
|
||||
|
||||
const llmOptions = useMemo(
|
||||
() => buildLlmOptions(llmProviders, currentModelName),
|
||||
[llmProviders, currentModelName]
|
||||
);
|
||||
|
||||
const filteredOptions = useMemo(() => {
|
||||
let result = llmOptions;
|
||||
if (requiresImageInput) {
|
||||
result = result.filter((opt) => opt.supportsImageInput);
|
||||
}
|
||||
if (searchQuery.trim()) {
|
||||
const query = searchQuery.toLowerCase();
|
||||
result = result.filter(
|
||||
(opt) =>
|
||||
opt.displayName.toLowerCase().includes(query) ||
|
||||
opt.modelName.toLowerCase().includes(query) ||
|
||||
(opt.vendor && opt.vendor.toLowerCase().includes(query))
|
||||
);
|
||||
}
|
||||
return result;
|
||||
}, [llmOptions, searchQuery, requiresImageInput]);
|
||||
|
||||
const groupedOptions = useMemo(
|
||||
() => groupLlmOptions(filteredOptions),
|
||||
[filteredOptions]
|
||||
);
|
||||
|
||||
// Find which group contains a currently-selected model (for auto-expand)
|
||||
const defaultGroupKey = useMemo(() => {
|
||||
for (const group of groupedOptions) {
|
||||
if (group.options.some((opt) => isSelected(opt))) {
|
||||
return group.key;
|
||||
}
|
||||
}
|
||||
return groupedOptions[0]?.key ?? "";
|
||||
}, [groupedOptions, isSelected]);
|
||||
|
||||
const [expandedGroups, setExpandedGroups] = useState<Set<string>>(
|
||||
new Set([defaultGroupKey])
|
||||
);
|
||||
|
||||
// Reset expanded groups when default changes (e.g. popover re-opens)
|
||||
useEffect(() => {
|
||||
setExpandedGroups(new Set([defaultGroupKey]));
|
||||
}, [defaultGroupKey]);
|
||||
|
||||
const isSearching = searchQuery.trim().length > 0;
|
||||
|
||||
const toggleGroup = (key: string) => {
|
||||
if (isSearching) return;
|
||||
setExpandedGroups((prev) => {
|
||||
const next = new Set(prev);
|
||||
if (next.has(key)) next.delete(key);
|
||||
else next.add(key);
|
||||
return next;
|
||||
});
|
||||
};
|
||||
|
||||
const isGroupOpen = (key: string) => isSearching || expandedGroups.has(key);
|
||||
|
||||
const renderModelItem = (option: LLMOption) => {
|
||||
const selected = isSelected(option);
|
||||
const disabled = isDisabled?.(option) ?? false;
|
||||
|
||||
const capabilities: string[] = [];
|
||||
if (option.supportsReasoning) capabilities.push("Reasoning");
|
||||
if (option.supportsImageInput) capabilities.push("Vision");
|
||||
const description =
|
||||
capabilities.length > 0 ? capabilities.join(", ") : undefined;
|
||||
|
||||
return (
|
||||
<LineItem
|
||||
key={`${option.provider}:${option.modelName}`}
|
||||
selected={selected}
|
||||
disabled={disabled}
|
||||
description={description}
|
||||
onClick={() => onSelect(option)}
|
||||
rightChildren={
|
||||
selected ? (
|
||||
<SvgCheck className="h-4 w-4 stroke-action-link-05 shrink-0" />
|
||||
) : null
|
||||
}
|
||||
>
|
||||
{option.displayName}
|
||||
</LineItem>
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<Section gap={0.5}>
|
||||
<InputTypeIn
|
||||
leftSearchIcon
|
||||
variant="internal"
|
||||
value={searchQuery}
|
||||
onChange={(e) => setSearchQuery(e.target.value)}
|
||||
placeholder="Search models..."
|
||||
/>
|
||||
|
||||
<PopoverMenu scrollContainerRef={scrollContainerRef}>
|
||||
{isLoading
|
||||
? [
|
||||
<Text key="loading" font="secondary-body" color="text-03">
|
||||
Loading models...
|
||||
</Text>,
|
||||
]
|
||||
: groupedOptions.length === 0
|
||||
? [
|
||||
<Text key="empty" font="secondary-body" color="text-03">
|
||||
No models found
|
||||
</Text>,
|
||||
]
|
||||
: groupedOptions.length === 1
|
||||
? [
|
||||
<Section key="single-provider" gap={0.25}>
|
||||
{groupedOptions[0]!.options.map(renderModelItem)}
|
||||
</Section>,
|
||||
]
|
||||
: groupedOptions.map((group) => {
|
||||
const open = isGroupOpen(group.key);
|
||||
return (
|
||||
<Collapsible
|
||||
key={group.key}
|
||||
open={open}
|
||||
onOpenChange={() => toggleGroup(group.key)}
|
||||
>
|
||||
<CollapsibleTrigger asChild>
|
||||
<LineItem
|
||||
muted
|
||||
icon={group.Icon}
|
||||
rightChildren={
|
||||
open ? (
|
||||
<SvgChevronDown className="h-4 w-4 stroke-text-04 shrink-0" />
|
||||
) : (
|
||||
<SvgChevronRight className="h-4 w-4 stroke-text-04 shrink-0" />
|
||||
)
|
||||
}
|
||||
>
|
||||
{group.displayName}
|
||||
</LineItem>
|
||||
</CollapsibleTrigger>
|
||||
|
||||
<CollapsibleContent>
|
||||
<Section gap={0.25}>
|
||||
{group.options.map(renderModelItem)}
|
||||
</Section>
|
||||
</CollapsibleContent>
|
||||
</Collapsible>
|
||||
);
|
||||
})}
|
||||
</PopoverMenu>
|
||||
|
||||
{footer}
|
||||
</Section>
|
||||
);
|
||||
}
|
||||
@@ -1,230 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useMemo, useRef } from "react";
|
||||
import Popover from "@/refresh-components/Popover";
|
||||
import { LlmManager } from "@/lib/hooks";
|
||||
import { getProviderIcon } from "@/app/admin/configuration/llm/utils";
|
||||
import { Button, SelectButton, OpenButton } from "@opal/components";
|
||||
import { SvgPlusCircle, SvgX } from "@opal/icons";
|
||||
import { LLMOption } from "@/refresh-components/popovers/interfaces";
|
||||
import ModelListContent from "@/refresh-components/popovers/ModelListContent";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
|
||||
export const MAX_MODELS = 3;
|
||||
|
||||
export interface SelectedModel {
|
||||
name: string;
|
||||
provider: string;
|
||||
modelName: string;
|
||||
displayName: string;
|
||||
}
|
||||
|
||||
export interface ModelSelectorProps {
|
||||
llmManager: LlmManager;
|
||||
selectedModels: SelectedModel[];
|
||||
onAdd: (model: SelectedModel) => void;
|
||||
onRemove: (index: number) => void;
|
||||
onReplace: (index: number, model: SelectedModel) => void;
|
||||
}
|
||||
|
||||
function modelKey(provider: string, modelName: string): string {
|
||||
return `${provider}:${modelName}`;
|
||||
}
|
||||
|
||||
export default function ModelSelector({
|
||||
llmManager,
|
||||
selectedModels,
|
||||
onAdd,
|
||||
onRemove,
|
||||
onReplace,
|
||||
}: ModelSelectorProps) {
|
||||
const [open, setOpen] = useState(false);
|
||||
// null = add mode (via + button), number = replace mode (via pill click)
|
||||
const [replacingIndex, setReplacingIndex] = useState<number | null>(null);
|
||||
// Virtual anchor ref — points to the clicked pill so the popover positions above it
|
||||
const anchorRef = useRef<HTMLElement | null>(null);
|
||||
|
||||
const isMultiModel = selectedModels.length > 1;
|
||||
const atMax = selectedModels.length >= MAX_MODELS;
|
||||
|
||||
const selectedKeys = useMemo(
|
||||
() => new Set(selectedModels.map((m) => modelKey(m.provider, m.modelName))),
|
||||
[selectedModels]
|
||||
);
|
||||
|
||||
const otherSelectedKeys = useMemo(() => {
|
||||
if (replacingIndex === null) return new Set<string>();
|
||||
return new Set(
|
||||
selectedModels
|
||||
.filter((_, i) => i !== replacingIndex)
|
||||
.map((m) => modelKey(m.provider, m.modelName))
|
||||
);
|
||||
}, [selectedModels, replacingIndex]);
|
||||
|
||||
const replacingKey =
|
||||
replacingIndex !== null
|
||||
? (() => {
|
||||
const m = selectedModels[replacingIndex];
|
||||
return m ? modelKey(m.provider, m.modelName) : null;
|
||||
})()
|
||||
: null;
|
||||
|
||||
const isSelected = (option: LLMOption) => {
|
||||
const key = modelKey(option.provider, option.modelName);
|
||||
if (replacingIndex !== null) return key === replacingKey;
|
||||
return selectedKeys.has(key);
|
||||
};
|
||||
|
||||
const isDisabled = (option: LLMOption) => {
|
||||
const key = modelKey(option.provider, option.modelName);
|
||||
if (replacingIndex !== null) return otherSelectedKeys.has(key);
|
||||
return !selectedKeys.has(key) && atMax;
|
||||
};
|
||||
|
||||
const handleSelect = (option: LLMOption) => {
|
||||
const model: SelectedModel = {
|
||||
name: option.name,
|
||||
provider: option.provider,
|
||||
modelName: option.modelName,
|
||||
displayName: option.displayName,
|
||||
};
|
||||
|
||||
if (replacingIndex !== null) {
|
||||
onReplace(replacingIndex, model);
|
||||
setOpen(false);
|
||||
setReplacingIndex(null);
|
||||
return;
|
||||
}
|
||||
|
||||
const key = modelKey(option.provider, option.modelName);
|
||||
const existingIndex = selectedModels.findIndex(
|
||||
(m) => modelKey(m.provider, m.modelName) === key
|
||||
);
|
||||
if (existingIndex >= 0) {
|
||||
onRemove(existingIndex);
|
||||
} else if (!atMax) {
|
||||
onAdd(model);
|
||||
}
|
||||
};
|
||||
|
||||
const handleOpenChange = (nextOpen: boolean) => {
|
||||
setOpen(nextOpen);
|
||||
if (!nextOpen) setReplacingIndex(null);
|
||||
};
|
||||
|
||||
const handlePillClick = (index: number, element: HTMLElement) => {
|
||||
anchorRef.current = element;
|
||||
setReplacingIndex(index);
|
||||
setOpen(true);
|
||||
};
|
||||
|
||||
return (
|
||||
<Popover open={open} onOpenChange={handleOpenChange}>
|
||||
<div className="flex items-center justify-end gap-1 p-1">
|
||||
{!atMax && (
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
icon={SvgPlusCircle}
|
||||
size="sm"
|
||||
tooltip="Add Model"
|
||||
onClick={(e: React.MouseEvent) => {
|
||||
anchorRef.current = e.currentTarget as HTMLElement;
|
||||
setReplacingIndex(null);
|
||||
setOpen(true);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
|
||||
<Popover.Anchor
|
||||
virtualRef={anchorRef as React.RefObject<HTMLElement>}
|
||||
/>
|
||||
{selectedModels.length > 0 && (
|
||||
<>
|
||||
{!atMax && (
|
||||
<Separator
|
||||
orientation="vertical"
|
||||
paddingXRem={0.5}
|
||||
paddingYRem={0.5}
|
||||
/>
|
||||
)}
|
||||
<div className="flex items-center">
|
||||
{selectedModels.map((model, index) => {
|
||||
const ProviderIcon = getProviderIcon(
|
||||
model.provider,
|
||||
model.modelName
|
||||
);
|
||||
|
||||
if (!isMultiModel) {
|
||||
return (
|
||||
<OpenButton
|
||||
key={modelKey(model.provider, model.modelName)}
|
||||
icon={ProviderIcon}
|
||||
onClick={(e: React.MouseEvent) =>
|
||||
handlePillClick(index, e.currentTarget as HTMLElement)
|
||||
}
|
||||
>
|
||||
{model.displayName}
|
||||
</OpenButton>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
key={modelKey(model.provider, model.modelName)}
|
||||
className="flex items-center"
|
||||
>
|
||||
{index > 0 && (
|
||||
<Separator
|
||||
orientation="vertical"
|
||||
paddingXRem={0.5}
|
||||
className="h-5"
|
||||
/>
|
||||
)}
|
||||
<SelectButton
|
||||
icon={ProviderIcon}
|
||||
rightIcon={SvgX}
|
||||
state="empty"
|
||||
variant="select-tinted"
|
||||
interaction="hover"
|
||||
size="lg"
|
||||
onClick={(e: React.MouseEvent) => {
|
||||
const target = e.target as HTMLElement;
|
||||
const btn = e.currentTarget as HTMLElement;
|
||||
const icons = btn.querySelectorAll(
|
||||
".interactive-foreground-icon"
|
||||
);
|
||||
const lastIcon = icons[icons.length - 1];
|
||||
if (lastIcon && lastIcon.contains(target)) {
|
||||
onRemove(index);
|
||||
} else {
|
||||
handlePillClick(index, btn);
|
||||
}
|
||||
}}
|
||||
>
|
||||
{model.displayName}
|
||||
</SelectButton>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<Popover.Content
|
||||
side="top"
|
||||
align="start"
|
||||
width="lg"
|
||||
avoidCollisions={false}
|
||||
>
|
||||
<ModelListContent
|
||||
llmProviders={llmManager.llmProviders}
|
||||
isLoading={llmManager.isLoadingProviders}
|
||||
onSelect={handleSelect}
|
||||
isSelected={isSelected}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
</Popover.Content>
|
||||
</Popover>
|
||||
);
|
||||
}
|
||||
@@ -232,6 +232,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
onboardingDismissed,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptors,
|
||||
isLoadingOnboarding,
|
||||
finishOnboarding,
|
||||
hideOnboarding,
|
||||
@@ -811,6 +812,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
handleFinishOnboarding={finishOnboarding}
|
||||
state={onboardingState}
|
||||
actions={onboardingActions}
|
||||
llmDescriptors={llmDescriptors}
|
||||
/>
|
||||
)}
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useMemo, useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import useSWR from "swr";
|
||||
import { Table, Button } from "@opal/components";
|
||||
import { IllustrationContent } from "@opal/layouts";
|
||||
import { SvgUsers } from "@opal/icons";
|
||||
@@ -13,14 +14,17 @@ import Text from "@/refresh-components/texts/Text";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import useGroupMemberCandidates from "./useGroupMemberCandidates";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import useAdminUsers from "@/hooks/useAdminUsers";
|
||||
import { SWR_KEYS } from "@/lib/swr-keys";
|
||||
import type { ApiKeyDescriptor, MemberRow } from "./interfaces";
|
||||
import {
|
||||
createGroup,
|
||||
updateAgentGroupSharing,
|
||||
updateDocSetGroupSharing,
|
||||
saveTokenLimits,
|
||||
} from "./svc";
|
||||
import { memberTableColumns, PAGE_SIZE } from "./shared";
|
||||
import { apiKeyToMemberRow, memberTableColumns, PAGE_SIZE } from "./shared";
|
||||
import SharedGroupResources from "@/refresh-pages/admin/GroupsPage/SharedGroupResources";
|
||||
import TokenLimitSection from "./TokenLimitSection";
|
||||
import type { TokenLimit } from "./TokenLimitSection";
|
||||
@@ -38,7 +42,22 @@ function CreateGroupPage() {
|
||||
{ tokenBudget: null, periodHours: null },
|
||||
]);
|
||||
|
||||
const { rows: allRows, isLoading, error } = useGroupMemberCandidates();
|
||||
const { users, isLoading: usersLoading, error: usersError } = useAdminUsers();
|
||||
|
||||
const {
|
||||
data: apiKeys,
|
||||
isLoading: apiKeysLoading,
|
||||
error: apiKeysError,
|
||||
} = useSWR<ApiKeyDescriptor[]>(SWR_KEYS.adminApiKeys, errorHandlingFetcher);
|
||||
|
||||
const isLoading = usersLoading || apiKeysLoading;
|
||||
const error = usersError ?? apiKeysError;
|
||||
|
||||
const allRows: MemberRow[] = useMemo(() => {
|
||||
const activeUsers = users.filter((u) => u.is_active);
|
||||
const serviceAccountRows = (apiKeys ?? []).map(apiKeyToMemberRow);
|
||||
return [...activeUsers, ...serviceAccountRows];
|
||||
}, [users, apiKeys]);
|
||||
|
||||
async function handleCreate() {
|
||||
const trimmed = groupName.trim();
|
||||
@@ -115,11 +134,11 @@ function CreateGroupPage() {
|
||||
{/* Members table */}
|
||||
{isLoading && <SimpleLoader />}
|
||||
|
||||
{error ? (
|
||||
{error && (
|
||||
<Text as="p" secondaryBody text03>
|
||||
Failed to load users.
|
||||
</Text>
|
||||
) : null}
|
||||
)}
|
||||
|
||||
{!isLoading && !error && (
|
||||
<Section
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import useSWR, { useSWRConfig } from "swr";
|
||||
import useGroupMemberCandidates from "./useGroupMemberCandidates";
|
||||
import { Table, Button } from "@opal/components";
|
||||
import { IllustrationContent } from "@opal/layouts";
|
||||
import { SvgUsers, SvgTrash, SvgMinusCircle, SvgPlusCircle } from "@opal/icons";
|
||||
@@ -20,9 +19,20 @@ import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationMo
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import useAdminUsers from "@/hooks/useAdminUsers";
|
||||
import type { UserGroup } from "@/lib/types";
|
||||
import type { MemberRow, TokenRateLimitDisplay } from "./interfaces";
|
||||
import { baseColumns, memberTableColumns, tc, PAGE_SIZE } from "./shared";
|
||||
import type {
|
||||
ApiKeyDescriptor,
|
||||
MemberRow,
|
||||
TokenRateLimitDisplay,
|
||||
} from "./interfaces";
|
||||
import {
|
||||
apiKeyToMemberRow,
|
||||
baseColumns,
|
||||
memberTableColumns,
|
||||
tc,
|
||||
PAGE_SIZE,
|
||||
} from "./shared";
|
||||
import {
|
||||
renameGroup,
|
||||
updateGroup,
|
||||
@@ -94,15 +104,18 @@ function EditGroupPage({ groupId }: EditGroupPageProps) {
|
||||
const initialAgentIdsRef = useRef<number[]>([]);
|
||||
const initialDocSetIdsRef = useRef<number[]>([]);
|
||||
|
||||
// Users + service accounts (curator-accessible — see hook docs).
|
||||
const {
|
||||
rows: allRows,
|
||||
isLoading: candidatesLoading,
|
||||
error: candidatesError,
|
||||
} = useGroupMemberCandidates();
|
||||
// Users and API keys
|
||||
const { users, isLoading: usersLoading, error: usersError } = useAdminUsers();
|
||||
|
||||
const isLoading = groupLoading || candidatesLoading || tokenLimitsLoading;
|
||||
const error = groupError ?? candidatesError;
|
||||
const {
|
||||
data: apiKeys,
|
||||
isLoading: apiKeysLoading,
|
||||
error: apiKeysError,
|
||||
} = useSWR<ApiKeyDescriptor[]>(SWR_KEYS.adminApiKeys, errorHandlingFetcher);
|
||||
|
||||
const isLoading =
|
||||
groupLoading || usersLoading || apiKeysLoading || tokenLimitsLoading;
|
||||
const error = groupError ?? usersError ?? apiKeysError;
|
||||
|
||||
// Pre-populate form when group data loads
|
||||
useEffect(() => {
|
||||
@@ -132,6 +145,12 @@ function EditGroupPage({ groupId }: EditGroupPageProps) {
|
||||
}
|
||||
}, [tokenRateLimits]);
|
||||
|
||||
const allRows = useMemo(() => {
|
||||
const activeUsers = users.filter((u) => u.is_active);
|
||||
const serviceAccountRows = (apiKeys ?? []).map(apiKeyToMemberRow);
|
||||
return [...activeUsers, ...serviceAccountRows];
|
||||
}, [users, apiKeys]);
|
||||
|
||||
const memberRows = useMemo(() => {
|
||||
const selected = new Set(selectedUserIds);
|
||||
return allRows.filter((r) => selected.has(r.id ?? r.email));
|
||||
|
||||
@@ -1,147 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useMemo } from "react";
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { SWR_KEYS } from "@/lib/swr-keys";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import { AccountType, UserStatus, type UserRole } from "@/lib/types";
|
||||
import type {
|
||||
UserGroupInfo,
|
||||
UserRow,
|
||||
} from "@/refresh-pages/admin/UsersPage/interfaces";
|
||||
import type { ApiKeyDescriptor, MemberRow } from "./interfaces";
|
||||
|
||||
// Backend response shape for `/api/manage/users?include_api_keys=true`. The
|
||||
// existing `AllUsersResponse` in `lib/types.ts` types `accepted` as `User[]`,
|
||||
// which is missing fields the table needs (`personal_name`, `account_type`,
|
||||
// `groups`, etc.), so we declare an accurate local type here.
|
||||
interface FullUserSnapshot {
|
||||
id: string;
|
||||
email: string;
|
||||
role: UserRole;
|
||||
account_type: AccountType;
|
||||
is_active: boolean;
|
||||
password_configured: boolean;
|
||||
personal_name: string | null;
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
groups: UserGroupInfo[];
|
||||
is_scim_synced: boolean;
|
||||
}
|
||||
|
||||
interface ManageUsersResponse {
|
||||
accepted: FullUserSnapshot[];
|
||||
invited: { email: string }[];
|
||||
slack_users: FullUserSnapshot[];
|
||||
accepted_pages: number;
|
||||
invited_pages: number;
|
||||
slack_users_pages: number;
|
||||
}
|
||||
|
||||
function snapshotToMemberRow(snapshot: FullUserSnapshot): MemberRow {
|
||||
return {
|
||||
id: snapshot.id,
|
||||
email: snapshot.email,
|
||||
role: snapshot.role,
|
||||
status: snapshot.is_active ? UserStatus.ACTIVE : UserStatus.INACTIVE,
|
||||
is_active: snapshot.is_active,
|
||||
is_scim_synced: snapshot.is_scim_synced,
|
||||
personal_name: snapshot.personal_name,
|
||||
created_at: snapshot.created_at,
|
||||
updated_at: snapshot.updated_at,
|
||||
groups: snapshot.groups,
|
||||
};
|
||||
}
|
||||
|
||||
function serviceAccountToMemberRow(
|
||||
snapshot: FullUserSnapshot,
|
||||
apiKey: ApiKeyDescriptor | undefined
|
||||
): MemberRow {
|
||||
return {
|
||||
id: snapshot.id,
|
||||
email: "Service Account",
|
||||
role: apiKey?.api_key_role ?? snapshot.role,
|
||||
status: UserStatus.ACTIVE,
|
||||
is_active: true,
|
||||
is_scim_synced: false,
|
||||
personal_name:
|
||||
apiKey?.api_key_name ?? snapshot.personal_name ?? "Unnamed Key",
|
||||
created_at: null,
|
||||
updated_at: null,
|
||||
groups: [],
|
||||
api_key_display: apiKey?.api_key_display,
|
||||
};
|
||||
}
|
||||
|
||||
interface UseGroupMemberCandidatesResult {
|
||||
/** Active users + service-account rows, in the order the table expects. */
|
||||
rows: MemberRow[];
|
||||
/** Subset of `rows` representing real (non-service-account) users. */
|
||||
userRows: MemberRow[];
|
||||
isLoading: boolean;
|
||||
error: unknown;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the candidate list for the group create/edit member pickers.
|
||||
*
|
||||
* Hits `/api/manage/users?include_api_keys=true`, which is gated by
|
||||
* `current_curator_or_admin_user` on the backend, so this works for both
|
||||
* admins and global curators (the admin-only `/accepted/all` and `/invited`
|
||||
* endpoints used to be called here, which 403'd for global curators and broke
|
||||
* the Edit Group page entirely).
|
||||
*
|
||||
* For admins, we additionally fetch `/admin/api-key` to enrich service-account
|
||||
* rows with the masked api-key display string. That call is admin-only and is
|
||||
* skipped for curators; its failure is non-fatal.
|
||||
*/
|
||||
export default function useGroupMemberCandidates(): UseGroupMemberCandidatesResult {
|
||||
const { isAdmin } = useUser();
|
||||
|
||||
const {
|
||||
data: usersData,
|
||||
isLoading: usersLoading,
|
||||
error: usersError,
|
||||
} = useSWR<ManageUsersResponse>(
|
||||
SWR_KEYS.groupMemberCandidates,
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
||||
const { data: apiKeys, isLoading: apiKeysLoading } = useSWR<
|
||||
ApiKeyDescriptor[]
|
||||
>(isAdmin ? SWR_KEYS.adminApiKeys : null, errorHandlingFetcher);
|
||||
|
||||
const apiKeysByUserId = useMemo(() => {
|
||||
const map = new Map<string, ApiKeyDescriptor>();
|
||||
for (const key of apiKeys ?? []) map.set(key.user_id, key);
|
||||
return map;
|
||||
}, [apiKeys]);
|
||||
|
||||
const { rows, userRows } = useMemo(() => {
|
||||
const accepted = usersData?.accepted ?? [];
|
||||
const userRowsLocal: MemberRow[] = [];
|
||||
const serviceAccountRows: MemberRow[] = [];
|
||||
for (const snapshot of accepted) {
|
||||
if (!snapshot.is_active) continue;
|
||||
if (snapshot.account_type === AccountType.SERVICE_ACCOUNT) {
|
||||
serviceAccountRows.push(
|
||||
serviceAccountToMemberRow(snapshot, apiKeysByUserId.get(snapshot.id))
|
||||
);
|
||||
} else {
|
||||
userRowsLocal.push(snapshotToMemberRow(snapshot));
|
||||
}
|
||||
}
|
||||
return {
|
||||
rows: [...userRowsLocal, ...serviceAccountRows],
|
||||
userRows: userRowsLocal,
|
||||
};
|
||||
}, [usersData, apiKeysByUserId]);
|
||||
|
||||
return {
|
||||
rows,
|
||||
userRows,
|
||||
isLoading: usersLoading || (isAdmin && apiKeysLoading),
|
||||
error: usersError,
|
||||
};
|
||||
}
|
||||
@@ -31,7 +31,6 @@ import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationMo
|
||||
import { useCreateModal } from "@/refresh-components/contexts/ModalContext";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import {
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "@/interfaces/llm";
|
||||
@@ -44,10 +43,9 @@ import BedrockModal from "@/sections/modals/llmConfig/BedrockModal";
|
||||
import VertexAIModal from "@/sections/modals/llmConfig/VertexAIModal";
|
||||
import OpenRouterModal from "@/sections/modals/llmConfig/OpenRouterModal";
|
||||
import CustomModal from "@/sections/modals/llmConfig/CustomModal";
|
||||
import LMStudioModal from "@/sections/modals/llmConfig/LMStudioModal";
|
||||
import LMStudioForm from "@/sections/modals/llmConfig/LMStudioForm";
|
||||
import LiteLLMProxyModal from "@/sections/modals/llmConfig/LiteLLMProxyModal";
|
||||
import BifrostModal from "@/sections/modals/llmConfig/BifrostModal";
|
||||
import OpenAICompatibleModal from "@/sections/modals/llmConfig/OpenAICompatibleModal";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
|
||||
const route = ADMIN_ROUTES.LLM_MODELS;
|
||||
@@ -59,60 +57,93 @@ const route = ADMIN_ROUTES.LLM_MODELS;
|
||||
// Client-side ordering for the "Add Provider" cards. The backend may return
|
||||
// wellKnownLLMProviders in an arbitrary order, so we sort explicitly here.
|
||||
const PROVIDER_DISPLAY_ORDER: string[] = [
|
||||
LLMProviderName.OPENAI,
|
||||
LLMProviderName.ANTHROPIC,
|
||||
LLMProviderName.VERTEX_AI,
|
||||
LLMProviderName.BEDROCK,
|
||||
LLMProviderName.AZURE,
|
||||
LLMProviderName.LITELLM,
|
||||
LLMProviderName.LITELLM_PROXY,
|
||||
LLMProviderName.OLLAMA_CHAT,
|
||||
LLMProviderName.OPENROUTER,
|
||||
LLMProviderName.LM_STUDIO,
|
||||
LLMProviderName.BIFROST,
|
||||
LLMProviderName.OPENAI_COMPATIBLE,
|
||||
"openai",
|
||||
"anthropic",
|
||||
"vertex_ai",
|
||||
"bedrock",
|
||||
"azure",
|
||||
"litellm_proxy",
|
||||
"ollama_chat",
|
||||
"openrouter",
|
||||
"lm_studio",
|
||||
"bifrost",
|
||||
];
|
||||
|
||||
const PROVIDER_MODAL_MAP: Record<
|
||||
string,
|
||||
(
|
||||
shouldMarkAsDefault: boolean,
|
||||
open: boolean,
|
||||
onOpenChange: (open: boolean) => void
|
||||
) => React.ReactNode
|
||||
> = {
|
||||
openai: (d, onOpenChange) => (
|
||||
<OpenAIModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
anthropic: (d, onOpenChange) => (
|
||||
<AnthropicModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
ollama_chat: (d, onOpenChange) => (
|
||||
<OllamaModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
azure: (d, onOpenChange) => (
|
||||
<AzureModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
bedrock: (d, onOpenChange) => (
|
||||
<BedrockModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
vertex_ai: (d, onOpenChange) => (
|
||||
<VertexAIModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
openrouter: (d, onOpenChange) => (
|
||||
<OpenRouterModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
lm_studio: (d, onOpenChange) => (
|
||||
<LMStudioModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
litellm_proxy: (d, onOpenChange) => (
|
||||
<LiteLLMProxyModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
bifrost: (d, onOpenChange) => (
|
||||
<BifrostModal shouldMarkAsDefault={d} onOpenChange={onOpenChange} />
|
||||
),
|
||||
openai_compatible: (d, onOpenChange) => (
|
||||
<OpenAICompatibleModal
|
||||
openai: (d, open, onOpenChange) => (
|
||||
<OpenAIModal
|
||||
shouldMarkAsDefault={d}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
/>
|
||||
),
|
||||
anthropic: (d, open, onOpenChange) => (
|
||||
<AnthropicModal
|
||||
shouldMarkAsDefault={d}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
/>
|
||||
),
|
||||
ollama_chat: (d, open, onOpenChange) => (
|
||||
<OllamaModal
|
||||
shouldMarkAsDefault={d}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
/>
|
||||
),
|
||||
azure: (d, open, onOpenChange) => (
|
||||
<AzureModal
|
||||
shouldMarkAsDefault={d}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
/>
|
||||
),
|
||||
bedrock: (d, open, onOpenChange) => (
|
||||
<BedrockModal
|
||||
shouldMarkAsDefault={d}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
/>
|
||||
),
|
||||
vertex_ai: (d, open, onOpenChange) => (
|
||||
<VertexAIModal
|
||||
shouldMarkAsDefault={d}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
/>
|
||||
),
|
||||
openrouter: (d, open, onOpenChange) => (
|
||||
<OpenRouterModal
|
||||
shouldMarkAsDefault={d}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
/>
|
||||
),
|
||||
lm_studio: (d, open, onOpenChange) => (
|
||||
<LMStudioForm
|
||||
shouldMarkAsDefault={d}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
/>
|
||||
),
|
||||
litellm_proxy: (d, open, onOpenChange) => (
|
||||
<LiteLLMProxyModal
|
||||
shouldMarkAsDefault={d}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
/>
|
||||
),
|
||||
bifrost: (d, open, onOpenChange) => (
|
||||
<BifrostModal
|
||||
shouldMarkAsDefault={d}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
/>
|
||||
),
|
||||
@@ -179,10 +210,7 @@ function ExistingProviderCard({
|
||||
</ConfirmationModalLayout>
|
||||
)}
|
||||
|
||||
<Hoverable.Root
|
||||
group="ExistingProviderCard"
|
||||
interaction={deleteModal.isOpen ? "hover" : "rest"}
|
||||
>
|
||||
<Hoverable.Root group="ExistingProviderCard">
|
||||
<SelectCard
|
||||
state="filled"
|
||||
padding="sm"
|
||||
@@ -224,8 +252,12 @@ function ExistingProviderCard({
|
||||
</div>
|
||||
}
|
||||
/>
|
||||
{isOpen &&
|
||||
getModalForExistingProvider(provider, setIsOpen, defaultModelName)}
|
||||
{getModalForExistingProvider(
|
||||
provider,
|
||||
isOpen,
|
||||
setIsOpen,
|
||||
defaultModelName
|
||||
)}
|
||||
</SelectCard>
|
||||
</Hoverable.Root>
|
||||
</>
|
||||
@@ -241,6 +273,7 @@ interface NewProviderCardProps {
|
||||
isFirstProvider: boolean;
|
||||
formFn: (
|
||||
shouldMarkAsDefault: boolean,
|
||||
open: boolean,
|
||||
onOpenChange: (open: boolean) => void
|
||||
) => React.ReactNode;
|
||||
}
|
||||
@@ -278,7 +311,7 @@ function NewProviderCard({
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
{isOpen && formFn(isFirstProvider, setIsOpen)}
|
||||
{formFn(isFirstProvider, isOpen, setIsOpen)}
|
||||
</SelectCard>
|
||||
);
|
||||
}
|
||||
@@ -322,12 +355,11 @@ function NewCustomProviderCard({
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
{isOpen && (
|
||||
<CustomModal
|
||||
shouldMarkAsDefault={isFirstProvider}
|
||||
onOpenChange={setIsOpen}
|
||||
/>
|
||||
)}
|
||||
<CustomModal
|
||||
shouldMarkAsDefault={isFirstProvider}
|
||||
open={isOpen}
|
||||
onOpenChange={setIsOpen}
|
||||
/>
|
||||
</SelectCard>
|
||||
);
|
||||
}
|
||||
@@ -336,7 +368,7 @@ function NewCustomProviderCard({
|
||||
// LLMConfigurationPage — main page component
|
||||
// ============================================================================
|
||||
|
||||
export default function LLMProviderConfigurationPage() {
|
||||
export default function LLMConfigurationPage() {
|
||||
const { mutate } = useSWRConfig();
|
||||
const { llmProviders: existingLlmProviders, defaultText } =
|
||||
useAdminLLMProviders();
|
||||
@@ -1,99 +1,177 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { LLMProviderFormProps, LLMProviderName } from "@/interfaces/llm";
|
||||
import { Formik } from "formik";
|
||||
import { LLMProviderFormProps } from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import {
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import {
|
||||
APIKeyField,
|
||||
ModelSelectionField,
|
||||
ModelsField,
|
||||
DisplayNameField,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
ModelsAccessField,
|
||||
FieldSeparator,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
|
||||
const ANTHROPIC_PROVIDER_NAME = "anthropic";
|
||||
const DEFAULT_DEFAULT_MODEL_NAME = "claude-sonnet-4-5";
|
||||
|
||||
export default function AnthropicModal({
|
||||
variant = "llm-configuration",
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
onSuccess,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
}: LLMProviderFormProps) {
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
|
||||
ANTHROPIC_PROVIDER_NAME
|
||||
);
|
||||
|
||||
if (open === false) return null;
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const initialValues = useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.ANTHROPIC,
|
||||
existingLlmProvider
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
apiKey: true,
|
||||
});
|
||||
const initialValues = isOnboarding
|
||||
? {
|
||||
...buildOnboardingInitialValues(),
|
||||
name: ANTHROPIC_PROVIDER_NAME,
|
||||
provider: ANTHROPIC_PROVIDER_NAME,
|
||||
api_key: "",
|
||||
default_model_name: DEFAULT_DEFAULT_MODEL_NAME,
|
||||
}
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
api_base: existingLlmProvider?.api_base ?? undefined,
|
||||
default_model_name:
|
||||
(defaultModelName &&
|
||||
modelConfigurations.some((m) => m.name === defaultModelName)
|
||||
? defaultModelName
|
||||
: undefined) ??
|
||||
wellKnownLLMProvider?.recommended_default_model?.name ??
|
||||
DEFAULT_DEFAULT_MODEL_NAME,
|
||||
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
});
|
||||
|
||||
return (
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.ANTHROPIC}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
<Formik
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.ANTHROPIC,
|
||||
values,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
(wellKnownLLMProvider ?? llmDescriptor)?.known_models ?? [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: ANTHROPIC_PROVIDER_NAME,
|
||||
payload: {
|
||||
...values,
|
||||
model_configurations: modelConfigsToUse,
|
||||
is_auto_mode:
|
||||
values.default_model_name === DEFAULT_DEFAULT_MODEL_NAME,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: ANTHROPIC_PROVIDER_NAME,
|
||||
values,
|
||||
initialValues,
|
||||
modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
<APIKeyField providerName="Anthropic" />
|
||||
{(formikProps) => (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={ANTHROPIC_PROVIDER_NAME}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
>
|
||||
<APIKeyField providerName="Anthropic" />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<FieldSeparator />
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. claude-sonnet-4-5" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={modelConfigurations}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={
|
||||
wellKnownLLMProvider?.recommended_default_model ?? null
|
||||
}
|
||||
shouldShowAutoUpdateToggle={true}
|
||||
/>
|
||||
)}
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
</>
|
||||
)}
|
||||
</LLMConfigurationModalWrapper>
|
||||
)}
|
||||
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField shouldShowAutoUpdateToggle={true} />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</ModalWrapper>
|
||||
</Formik>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,35 +1,45 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { useFormikContext } from "formik";
|
||||
import { Formik } from "formik";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
} from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import {
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
BaseLLMFormValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import {
|
||||
APIKeyField,
|
||||
DisplayNameField,
|
||||
ModelAccessField,
|
||||
ModelSelectionField,
|
||||
ModalWrapper,
|
||||
FieldSeparator,
|
||||
FieldWrapper,
|
||||
ModelsAccessField,
|
||||
ModelsField,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import {
|
||||
isValidAzureTargetUri,
|
||||
parseAzureTargetUri,
|
||||
} from "@/lib/azureTargetUri";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
|
||||
const AZURE_PROVIDER_NAME = "azure";
|
||||
|
||||
interface AzureModalValues extends BaseLLMFormValues {
|
||||
api_key: string;
|
||||
@@ -39,33 +49,6 @@ interface AzureModalValues extends BaseLLMFormValues {
|
||||
deployment_name?: string;
|
||||
}
|
||||
|
||||
function AzureModelSelection() {
|
||||
const formikProps = useFormikContext<AzureModalValues>();
|
||||
return (
|
||||
<ModelSelectionField
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onAddModel={(modelName) => {
|
||||
const current = formikProps.values.model_configurations;
|
||||
if (current.some((m) => m.name === modelName)) return;
|
||||
const updated = [
|
||||
...current,
|
||||
{
|
||||
name: modelName,
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
];
|
||||
formikProps.setFieldValue("model_configurations", updated);
|
||||
if (!formikProps.values.test_model_name) {
|
||||
formikProps.setFieldValue("test_model_name", modelName);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
function buildTargetUri(existingLlmProvider?: LLMProviderView): string {
|
||||
if (!existingLlmProvider?.api_base || !existingLlmProvider?.api_version) {
|
||||
return "";
|
||||
@@ -100,104 +83,193 @@ export default function AzureModal({
|
||||
variant = "llm-configuration",
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
onSuccess,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
}: LLMProviderFormProps) {
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } = useWellKnownLLMProvider(AZURE_PROVIDER_NAME);
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
const [addedModels, setAddedModels] = useState<ModelConfiguration[]>([]);
|
||||
|
||||
const initialValues: AzureModalValues = {
|
||||
...useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.AZURE,
|
||||
existingLlmProvider
|
||||
),
|
||||
target_uri: buildTargetUri(existingLlmProvider),
|
||||
} as AzureModalValues;
|
||||
if (open === false) return null;
|
||||
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
apiKey: true,
|
||||
extra: {
|
||||
target_uri: Yup.string()
|
||||
.required("Target URI is required")
|
||||
.test(
|
||||
"valid-target-uri",
|
||||
"Target URI must be a valid URL with api-version query parameter and either a deployment name in the path or /openai/responses",
|
||||
(value) => (value ? isValidAzureTargetUri(value) : false)
|
||||
const onClose = () => {
|
||||
setAddedModels([]);
|
||||
onOpenChange?.(false);
|
||||
};
|
||||
|
||||
const baseModelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
// Merge base models with any user-added models (dedup by name)
|
||||
const existingNames = new Set(baseModelConfigurations.map((m) => m.name));
|
||||
const modelConfigurations = [
|
||||
...baseModelConfigurations,
|
||||
...addedModels.filter((m) => !existingNames.has(m.name)),
|
||||
];
|
||||
|
||||
const initialValues: AzureModalValues = isOnboarding
|
||||
? ({
|
||||
...buildOnboardingInitialValues(),
|
||||
name: AZURE_PROVIDER_NAME,
|
||||
provider: AZURE_PROVIDER_NAME,
|
||||
api_key: "",
|
||||
target_uri: "",
|
||||
default_model_name: "",
|
||||
} as AzureModalValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
},
|
||||
});
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
target_uri: buildTargetUri(existingLlmProvider),
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
target_uri: Yup.string()
|
||||
.required("Target URI is required")
|
||||
.test(
|
||||
"valid-target-uri",
|
||||
"Target URI must be a valid URL with api-version query parameter and either a deployment name in the path or /openai/responses",
|
||||
(value) => (value ? isValidAzureTargetUri(value) : false)
|
||||
),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
target_uri: Yup.string()
|
||||
.required("Target URI is required")
|
||||
.test(
|
||||
"valid-target-uri",
|
||||
"Target URI must be a valid URL with api-version query parameter and either a deployment name in the path or /openai/responses",
|
||||
(value) => (value ? isValidAzureTargetUri(value) : false)
|
||||
),
|
||||
});
|
||||
|
||||
return (
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.AZURE}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
<Formik
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
const processedValues = processValues(values);
|
||||
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.AZURE,
|
||||
values: processedValues,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
(wellKnownLLMProvider ?? llmDescriptor)?.known_models ?? [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: AZURE_PROVIDER_NAME,
|
||||
payload: {
|
||||
...processedValues,
|
||||
model_configurations: modelConfigsToUse,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: AZURE_PROVIDER_NAME,
|
||||
values: processedValues,
|
||||
initialValues,
|
||||
modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
<InputLayouts.FieldPadder>
|
||||
<InputLayouts.Vertical
|
||||
name="target_uri"
|
||||
title="Target URI"
|
||||
subDescription="Paste your endpoint target URI from Azure OpenAI (including API endpoint base, deployment name, and API version)."
|
||||
{(formikProps) => (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={AZURE_PROVIDER_NAME}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
>
|
||||
<InputTypeInField
|
||||
name="target_uri"
|
||||
placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview"
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</InputLayouts.FieldPadder>
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="target_uri"
|
||||
title="Target URI"
|
||||
subDescription="Paste your endpoint target URI from Azure OpenAI (including API endpoint base, deployment name, and API version)."
|
||||
>
|
||||
<InputTypeInField
|
||||
name="target_uri"
|
||||
placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview"
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
|
||||
<APIKeyField providerName="Azure" />
|
||||
<APIKeyField providerName="Azure" />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<FieldSeparator />
|
||||
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. gpt-4o" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={modelConfigurations}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={null}
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onAddModel={(modelName) => {
|
||||
const newModel: ModelConfiguration = {
|
||||
name: modelName,
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
};
|
||||
setAddedModels((prev) => [...prev, newModel]);
|
||||
const currentSelected =
|
||||
formikProps.values.selected_model_names ?? [];
|
||||
formikProps.setFieldValue("selected_model_names", [
|
||||
...currentSelected,
|
||||
modelName,
|
||||
]);
|
||||
if (!formikProps.values.default_model_name) {
|
||||
formikProps.setFieldValue("default_model_name", modelName);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
</>
|
||||
)}
|
||||
</LLMConfigurationModalWrapper>
|
||||
)}
|
||||
|
||||
<InputLayouts.FieldSeparator />
|
||||
<AzureModelSelection />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</ModalWrapper>
|
||||
</Formik>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect } from "react";
|
||||
import { useState, useEffect } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { useFormikContext } from "formik";
|
||||
import { Formik, FormikProps } from "formik";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import InputSelectField from "@/refresh-components/form/InputSelectField";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
@@ -10,22 +10,30 @@ import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
} from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import {
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
BaseLLMFormValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
ModelSelectionField,
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import {
|
||||
ModelsField,
|
||||
DisplayNameField,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
FieldSeparator,
|
||||
FieldWrapper,
|
||||
ModelsAccessField,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import { fetchBedrockModels } from "@/app/admin/configuration/llm/utils";
|
||||
import { Card } from "@opal/components";
|
||||
@@ -33,9 +41,9 @@ import { Section } from "@/layouts/general-layouts";
|
||||
import { SvgAlertCircle } from "@opal/icons";
|
||||
import { Content } from "@opal/layouts";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
import useOnMount from "@/hooks/useOnMount";
|
||||
|
||||
const BEDROCK_PROVIDER_NAME = "bedrock";
|
||||
const AWS_REGION_OPTIONS = [
|
||||
{ name: "us-east-1", value: "us-east-1" },
|
||||
{ name: "us-east-2", value: "us-east-2" },
|
||||
@@ -71,15 +79,26 @@ interface BedrockModalValues extends BaseLLMFormValues {
|
||||
}
|
||||
|
||||
interface BedrockModalInternalsProps {
|
||||
formikProps: FormikProps<BedrockModalValues>;
|
||||
existingLlmProvider: LLMProviderView | undefined;
|
||||
fetchedModels: ModelConfiguration[];
|
||||
setFetchedModels: (models: ModelConfiguration[]) => void;
|
||||
modelConfigurations: ModelConfiguration[];
|
||||
isTesting: boolean;
|
||||
onClose: () => void;
|
||||
isOnboarding: boolean;
|
||||
}
|
||||
|
||||
function BedrockModalInternals({
|
||||
formikProps,
|
||||
existingLlmProvider,
|
||||
fetchedModels,
|
||||
setFetchedModels,
|
||||
modelConfigurations,
|
||||
isTesting,
|
||||
onClose,
|
||||
isOnboarding,
|
||||
}: BedrockModalInternalsProps) {
|
||||
const formikProps = useFormikContext<BedrockModalValues>();
|
||||
const authMethod = formikProps.values.custom_config?.BEDROCK_AUTH_METHOD;
|
||||
|
||||
useEffect(() => {
|
||||
@@ -96,6 +115,11 @@ function BedrockModalInternals({
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [authMethod]);
|
||||
|
||||
const currentModels =
|
||||
fetchedModels.length > 0
|
||||
? fetchedModels
|
||||
: existingLlmProvider?.model_configurations || modelConfigurations;
|
||||
|
||||
const isAuthComplete =
|
||||
authMethod === AUTH_METHOD_IAM ||
|
||||
(authMethod === AUTH_METHOD_ACCESS_KEY &&
|
||||
@@ -115,12 +139,12 @@ function BedrockModalInternals({
|
||||
formikProps.values.custom_config?.AWS_SECRET_ACCESS_KEY,
|
||||
aws_bearer_token_bedrock:
|
||||
formikProps.values.custom_config?.AWS_BEARER_TOKEN_BEDROCK,
|
||||
provider_name: LLMProviderName.BEDROCK,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
formikProps.setFieldValue("model_configurations", models);
|
||||
setFetchedModels(models);
|
||||
};
|
||||
|
||||
// Auto-fetch models on initial load when editing an existing provider
|
||||
@@ -135,8 +159,16 @@ function BedrockModalInternals({
|
||||
});
|
||||
|
||||
return (
|
||||
<>
|
||||
<InputLayouts.FieldPadder>
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={BEDROCK_PROVIDER_NAME}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
>
|
||||
<FieldWrapper>
|
||||
<Section gap={1}>
|
||||
<InputLayouts.Vertical
|
||||
name={FIELD_AWS_REGION_NAME}
|
||||
@@ -190,7 +222,7 @@ function BedrockModalInternals({
|
||||
</InputSelect>
|
||||
</InputLayouts.Vertical>
|
||||
</Section>
|
||||
</InputLayouts.FieldPadder>
|
||||
</FieldWrapper>
|
||||
|
||||
{authMethod === AUTH_METHOD_ACCESS_KEY && (
|
||||
<Card background="light" border="none" padding="sm">
|
||||
@@ -218,7 +250,7 @@ function BedrockModalInternals({
|
||||
)}
|
||||
|
||||
{authMethod === AUTH_METHOD_IAM && (
|
||||
<InputLayouts.FieldPadder>
|
||||
<FieldWrapper>
|
||||
<Card background="none" border="solid" padding="sm">
|
||||
<Content
|
||||
icon={SvgAlertCircle}
|
||||
@@ -227,7 +259,7 @@ function BedrockModalInternals({
|
||||
sizePreset="main-ui"
|
||||
/>
|
||||
</Card>
|
||||
</InputLayouts.FieldPadder>
|
||||
</FieldWrapper>
|
||||
)}
|
||||
|
||||
{authMethod === AUTH_METHOD_LONG_TERM_API_KEY && (
|
||||
@@ -248,24 +280,32 @@ function BedrockModalInternals({
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
<FieldSeparator />
|
||||
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. us.anthropic.claude-sonnet-4-5-v1" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={currentModels}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={null}
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
)}
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
</LLMConfigurationModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -273,53 +313,88 @@ export default function BedrockModal({
|
||||
variant = "llm-configuration",
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
onSuccess,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
|
||||
BEDROCK_PROVIDER_NAME
|
||||
);
|
||||
|
||||
if (open === false) return null;
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const initialValues: BedrockModalValues = {
|
||||
...useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.BEDROCK,
|
||||
existingLlmProvider
|
||||
),
|
||||
custom_config: {
|
||||
AWS_REGION_NAME:
|
||||
(existingLlmProvider?.custom_config?.AWS_REGION_NAME as string) ?? "",
|
||||
BEDROCK_AUTH_METHOD:
|
||||
(existingLlmProvider?.custom_config?.BEDROCK_AUTH_METHOD as string) ??
|
||||
"access_key",
|
||||
AWS_ACCESS_KEY_ID:
|
||||
(existingLlmProvider?.custom_config?.AWS_ACCESS_KEY_ID as string) ?? "",
|
||||
AWS_SECRET_ACCESS_KEY:
|
||||
(existingLlmProvider?.custom_config?.AWS_SECRET_ACCESS_KEY as string) ??
|
||||
"",
|
||||
AWS_BEARER_TOKEN_BEDROCK:
|
||||
(existingLlmProvider?.custom_config
|
||||
?.AWS_BEARER_TOKEN_BEDROCK as string) ?? "",
|
||||
},
|
||||
} as BedrockModalValues;
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
extra: {
|
||||
custom_config: Yup.object({
|
||||
AWS_REGION_NAME: Yup.string().required("AWS Region is required"),
|
||||
}),
|
||||
},
|
||||
});
|
||||
const initialValues: BedrockModalValues = isOnboarding
|
||||
? ({
|
||||
...buildOnboardingInitialValues(),
|
||||
name: BEDROCK_PROVIDER_NAME,
|
||||
provider: BEDROCK_PROVIDER_NAME,
|
||||
default_model_name: "",
|
||||
custom_config: {
|
||||
AWS_REGION_NAME: "",
|
||||
BEDROCK_AUTH_METHOD: "access_key",
|
||||
AWS_ACCESS_KEY_ID: "",
|
||||
AWS_SECRET_ACCESS_KEY: "",
|
||||
AWS_BEARER_TOKEN_BEDROCK: "",
|
||||
},
|
||||
} as BedrockModalValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
custom_config: {
|
||||
AWS_REGION_NAME:
|
||||
(existingLlmProvider?.custom_config?.AWS_REGION_NAME as string) ??
|
||||
"",
|
||||
BEDROCK_AUTH_METHOD:
|
||||
(existingLlmProvider?.custom_config
|
||||
?.BEDROCK_AUTH_METHOD as string) ?? "access_key",
|
||||
AWS_ACCESS_KEY_ID:
|
||||
(existingLlmProvider?.custom_config?.AWS_ACCESS_KEY_ID as string) ??
|
||||
"",
|
||||
AWS_SECRET_ACCESS_KEY:
|
||||
(existingLlmProvider?.custom_config
|
||||
?.AWS_SECRET_ACCESS_KEY as string) ?? "",
|
||||
AWS_BEARER_TOKEN_BEDROCK:
|
||||
(existingLlmProvider?.custom_config
|
||||
?.AWS_BEARER_TOKEN_BEDROCK as string) ?? "",
|
||||
},
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
custom_config: Yup.object({
|
||||
AWS_REGION_NAME: Yup.string().required("AWS Region is required"),
|
||||
}),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
custom_config: Yup.object({
|
||||
AWS_REGION_NAME: Yup.string().required("AWS Region is required"),
|
||||
}),
|
||||
});
|
||||
|
||||
return (
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.BEDROCK}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
<Formik
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
const filteredCustomConfig = Object.fromEntries(
|
||||
Object.entries(values.custom_config || {}).filter(([, v]) => v !== "")
|
||||
);
|
||||
@@ -332,37 +407,51 @@ export default function BedrockModal({
|
||||
: undefined,
|
||||
};
|
||||
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.BEDROCK,
|
||||
values: submitValues,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
fetchedModels.length > 0 ? fetchedModels : [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: BEDROCK_PROVIDER_NAME,
|
||||
payload: {
|
||||
...submitValues,
|
||||
model_configurations: modelConfigsToUse,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: BEDROCK_PROVIDER_NAME,
|
||||
values: submitValues,
|
||||
initialValues,
|
||||
modelConfigurations:
|
||||
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
<BedrockModalInternals
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
</ModalWrapper>
|
||||
{(formikProps) => (
|
||||
<BedrockModalInternals
|
||||
formikProps={formikProps}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
fetchedModels={fetchedModels}
|
||||
setFetchedModels={setFetchedModels}
|
||||
modelConfigurations={modelConfigurations}
|
||||
isTesting={isTesting}
|
||||
onClose={onClose}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
)}
|
||||
</Formik>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,33 +1,45 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect } from "react";
|
||||
import { useState, useEffect } from "react";
|
||||
import { markdown } from "@opal/utils";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { useFormikContext } from "formik";
|
||||
import { Formik, FormikProps } from "formik";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
} from "@/interfaces/llm";
|
||||
import { fetchBifrostModels } from "@/app/admin/configuration/llm/utils";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import {
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
BaseLLMFormValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
APIBaseField,
|
||||
APIKeyField,
|
||||
ModelSelectionField,
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import {
|
||||
ModelsField,
|
||||
DisplayNameField,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
ModelsAccessField,
|
||||
FieldSeparator,
|
||||
FieldWrapper,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
|
||||
const BIFROST_PROVIDER_NAME = LLMProviderName.BIFROST;
|
||||
const DEFAULT_API_BASE = "";
|
||||
|
||||
interface BifrostModalValues extends BaseLLMFormValues {
|
||||
api_key: string;
|
||||
@@ -35,15 +47,30 @@ interface BifrostModalValues extends BaseLLMFormValues {
|
||||
}
|
||||
|
||||
interface BifrostModalInternalsProps {
|
||||
formikProps: FormikProps<BifrostModalValues>;
|
||||
existingLlmProvider: LLMProviderView | undefined;
|
||||
fetchedModels: ModelConfiguration[];
|
||||
setFetchedModels: (models: ModelConfiguration[]) => void;
|
||||
modelConfigurations: ModelConfiguration[];
|
||||
isTesting: boolean;
|
||||
onClose: () => void;
|
||||
isOnboarding: boolean;
|
||||
}
|
||||
|
||||
function BifrostModalInternals({
|
||||
formikProps,
|
||||
existingLlmProvider,
|
||||
fetchedModels,
|
||||
setFetchedModels,
|
||||
modelConfigurations,
|
||||
isTesting,
|
||||
onClose,
|
||||
isOnboarding,
|
||||
}: BifrostModalInternalsProps) {
|
||||
const formikProps = useFormikContext<BifrostModalValues>();
|
||||
const currentModels =
|
||||
fetchedModels.length > 0
|
||||
? fetchedModels
|
||||
: existingLlmProvider?.model_configurations || modelConfigurations;
|
||||
|
||||
const isFetchDisabled = !formikProps.values.api_base;
|
||||
|
||||
@@ -51,12 +78,12 @@ function BifrostModalInternals({
|
||||
const { models, error } = await fetchBifrostModels({
|
||||
api_base: formikProps.values.api_base,
|
||||
api_key: formikProps.values.api_key || undefined,
|
||||
provider_name: LLMProviderName.BIFROST,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
formikProps.setFieldValue("model_configurations", models);
|
||||
setFetchedModels(models);
|
||||
};
|
||||
|
||||
// Auto-fetch models on initial load when editing an existing provider
|
||||
@@ -73,39 +100,69 @@ function BifrostModalInternals({
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<>
|
||||
<APIBaseField
|
||||
subDescription="Paste your Bifrost gateway endpoint URL (including API version)."
|
||||
placeholder="https://your-bifrost-gateway.com/v1"
|
||||
/>
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={LLMProviderName.BIFROST}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
>
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="api_base"
|
||||
title="API Base URL"
|
||||
subDescription="Paste your Bifrost gateway endpoint URL (including API version)."
|
||||
>
|
||||
<InputTypeInField
|
||||
name="api_base"
|
||||
placeholder="https://your-bifrost-gateway.com/v1"
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
|
||||
<APIKeyField
|
||||
optional
|
||||
subDescription={markdown(
|
||||
"Paste your API key from [Bifrost](https://docs.getbifrost.ai/overview) to access your models."
|
||||
)}
|
||||
/>
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="api_key"
|
||||
title="API Key"
|
||||
suffix="optional"
|
||||
subDescription={markdown(
|
||||
"Paste your API key from [Bifrost](https://docs.getbifrost.ai/overview) to access your models."
|
||||
)}
|
||||
>
|
||||
<PasswordInputTypeInField name="api_key" placeholder="API Key" />
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
<FieldSeparator />
|
||||
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. anthropic/claude-sonnet-4-6" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={currentModels}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={null}
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
)}
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
</LLMConfigurationModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -113,63 +170,109 @@ export default function BifrostModal({
|
||||
variant = "llm-configuration",
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
onSuccess,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
|
||||
BIFROST_PROVIDER_NAME
|
||||
);
|
||||
|
||||
if (open === false) return null;
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const initialValues: BifrostModalValues = useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.BIFROST,
|
||||
existingLlmProvider
|
||||
) as BifrostModalValues;
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
apiBase: true,
|
||||
});
|
||||
const initialValues: BifrostModalValues = isOnboarding
|
||||
? ({
|
||||
...buildOnboardingInitialValues(),
|
||||
name: BIFROST_PROVIDER_NAME,
|
||||
provider: BIFROST_PROVIDER_NAME,
|
||||
api_key: "",
|
||||
api_base: DEFAULT_API_BASE,
|
||||
default_model_name: "",
|
||||
} as BifrostModalValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
});
|
||||
|
||||
return (
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.BIFROST}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
<Formik
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.BIFROST,
|
||||
values,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
fetchedModels.length > 0 ? fetchedModels : [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: BIFROST_PROVIDER_NAME,
|
||||
payload: {
|
||||
...values,
|
||||
model_configurations: modelConfigsToUse,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: BIFROST_PROVIDER_NAME,
|
||||
values,
|
||||
initialValues,
|
||||
modelConfigurations:
|
||||
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
<BifrostModalInternals
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
</ModalWrapper>
|
||||
{(formikProps) => (
|
||||
<BifrostModalInternals
|
||||
formikProps={formikProps}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
fetchedModels={fetchedModels}
|
||||
setFetchedModels={setFetchedModels}
|
||||
modelConfigurations={modelConfigurations}
|
||||
isTesting={isTesting}
|
||||
onClose={onClose}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
)}
|
||||
</Formik>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -70,9 +70,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
}
|
||||
) {
|
||||
const nameInput = screen.getByPlaceholderText("Display Name");
|
||||
const providerInput = screen.getByPlaceholderText(
|
||||
"Provider Name as shown on LiteLLM"
|
||||
);
|
||||
const providerInput = screen.getByPlaceholderText("Provider Name");
|
||||
|
||||
await user.type(nameInput, options.name);
|
||||
await user.type(providerInput, options.provider);
|
||||
@@ -101,7 +99,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
render(<CustomModal onOpenChange={() => {}} />);
|
||||
render(<CustomModal open={true} onOpenChange={() => {}} />);
|
||||
|
||||
await fillBasicFields(user, {
|
||||
name: "My Custom Provider",
|
||||
@@ -168,7 +166,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
json: async () => ({ detail: "Invalid API key" }),
|
||||
} as Response);
|
||||
|
||||
render(<CustomModal onOpenChange={() => {}} />);
|
||||
render(<CustomModal open={true} onOpenChange={() => {}} />);
|
||||
|
||||
await fillBasicFields(user, {
|
||||
name: "Bad Provider",
|
||||
@@ -246,6 +244,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
render(
|
||||
<CustomModal
|
||||
existingLlmProvider={existingProvider}
|
||||
open={true}
|
||||
onOpenChange={() => {}}
|
||||
/>
|
||||
);
|
||||
@@ -340,6 +339,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
render(
|
||||
<CustomModal
|
||||
existingLlmProvider={existingProvider}
|
||||
open={true}
|
||||
onOpenChange={() => {}}
|
||||
/>
|
||||
);
|
||||
@@ -406,7 +406,13 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
json: async () => ({}),
|
||||
} as Response);
|
||||
|
||||
render(<CustomModal shouldMarkAsDefault={true} onOpenChange={() => {}} />);
|
||||
render(
|
||||
<CustomModal
|
||||
shouldMarkAsDefault={true}
|
||||
open={true}
|
||||
onOpenChange={() => {}}
|
||||
/>
|
||||
);
|
||||
|
||||
await fillBasicFields(user, {
|
||||
name: "New Default Provider",
|
||||
@@ -451,7 +457,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
json: async () => ({ detail: "Database error" }),
|
||||
} as Response);
|
||||
|
||||
render(<CustomModal onOpenChange={() => {}} />);
|
||||
render(<CustomModal open={true} onOpenChange={() => {}} />);
|
||||
|
||||
await fillBasicFields(user, {
|
||||
name: "Test Provider",
|
||||
@@ -486,15 +492,13 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
json: async () => ({ id: 1, name: "Provider with Custom Config" }),
|
||||
} as Response);
|
||||
|
||||
render(<CustomModal onOpenChange={() => {}} />);
|
||||
render(<CustomModal open={true} onOpenChange={() => {}} />);
|
||||
|
||||
// Fill basic fields
|
||||
const nameInput = screen.getByPlaceholderText("Display Name");
|
||||
await user.type(nameInput, "Cloudflare Provider");
|
||||
|
||||
const providerInput = screen.getByPlaceholderText(
|
||||
"Provider Name as shown on LiteLLM"
|
||||
);
|
||||
const providerInput = screen.getByPlaceholderText("Provider Name");
|
||||
await user.type(providerInput, "cloudflare");
|
||||
|
||||
// Click "Add Line" button for custom config (aria-label from KeyValueInput)
|
||||
@@ -504,8 +508,8 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
await user.click(addLineButton);
|
||||
|
||||
// Fill in custom config key-value pair
|
||||
const keyInputs = screen.getAllByRole("textbox", { name: /Key \d+/ });
|
||||
const valueInputs = screen.getAllByRole("textbox", { name: /Value \d+/ });
|
||||
const keyInputs = screen.getAllByPlaceholderText("Key");
|
||||
const valueInputs = screen.getAllByPlaceholderText("Value");
|
||||
|
||||
await user.type(keyInputs[0]!, "CLOUDFLARE_ACCOUNT_ID");
|
||||
await user.type(valueInputs[0]!, "my-account-id-123");
|
||||
|
||||
@@ -1,22 +1,24 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { useFormikContext } from "formik";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
ModelConfiguration,
|
||||
} from "@/interfaces/llm";
|
||||
import { Formik, FormikProps } from "formik";
|
||||
import { LLMProviderFormProps, ModelConfiguration } from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import { useInitialValues } from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
APIKeyField,
|
||||
APIBaseField,
|
||||
buildDefaultInitialValues,
|
||||
buildOnboardingInitialValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import {
|
||||
DisplayNameField,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
FieldSeparator,
|
||||
ModelsAccessField,
|
||||
LLMConfigurationModalWrapper,
|
||||
FieldWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
@@ -28,9 +30,7 @@ import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Button, Card, EmptyMessageCard } from "@opal/components";
|
||||
import { SvgMinusCircle, SvgPlusCircle } from "@opal/icons";
|
||||
import { markdown } from "@opal/utils";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
import { Content } from "@opal/layouts";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
|
||||
@@ -107,10 +107,13 @@ function ModelConfigurationItem({
|
||||
);
|
||||
}
|
||||
|
||||
function ModelConfigurationList() {
|
||||
const formikProps = useFormikContext<{
|
||||
interface ModelConfigurationListProps {
|
||||
formikProps: FormikProps<{
|
||||
model_configurations: CustomModelConfiguration[];
|
||||
}>();
|
||||
}>;
|
||||
}
|
||||
|
||||
function ModelConfigurationList({ formikProps }: ModelConfigurationListProps) {
|
||||
const models = formikProps.values.model_configurations;
|
||||
|
||||
function handleChange(index: number, next: CustomModelConfiguration) {
|
||||
@@ -176,68 +179,55 @@ function ModelConfigurationList() {
|
||||
);
|
||||
}
|
||||
|
||||
function CustomConfigKeyValue() {
|
||||
const formikProps = useFormikContext<{ custom_config_list: KeyValue[] }>();
|
||||
return (
|
||||
<KeyValueInput
|
||||
items={formikProps.values.custom_config_list}
|
||||
onChange={(items) =>
|
||||
formikProps.setFieldValue("custom_config_list", items)
|
||||
}
|
||||
addButtonLabel="Add Line"
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Custom Config Processing ─────────────────────────────────────────────────
|
||||
|
||||
function keyValueListToDict(items: KeyValue[]): Record<string, string> {
|
||||
const result: Record<string, string> = {};
|
||||
for (const { key, value } of items) {
|
||||
if (key.trim() !== "") {
|
||||
result[key] = value;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
function customConfigProcessing(items: KeyValue[]) {
|
||||
const customConfig: { [key: string]: string } = {};
|
||||
items.forEach(({ key, value }) => {
|
||||
customConfig[key] = value;
|
||||
});
|
||||
return customConfig;
|
||||
}
|
||||
|
||||
export default function CustomModal({
|
||||
variant = "llm-configuration",
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
onSuccess,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
}: LLMProviderFormProps) {
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const { mutate } = useSWRConfig();
|
||||
|
||||
if (open === false) return null;
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const initialValues = {
|
||||
...useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.CUSTOM,
|
||||
existingLlmProvider
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
undefined,
|
||||
defaultModelName
|
||||
),
|
||||
...(isOnboarding ? buildOnboardingInitialValues() : {}),
|
||||
provider: existingLlmProvider?.provider ?? "",
|
||||
api_version: existingLlmProvider?.api_version ?? "",
|
||||
model_configurations: existingLlmProvider?.model_configurations.map(
|
||||
(mc) => ({
|
||||
name: mc.name,
|
||||
display_name: mc.display_name ?? "",
|
||||
is_visible: mc.is_visible,
|
||||
max_input_tokens: mc.max_input_tokens ?? null,
|
||||
supports_image_input: mc.supports_image_input,
|
||||
supports_reasoning: mc.supports_reasoning,
|
||||
})
|
||||
) ?? [
|
||||
{
|
||||
name: "",
|
||||
display_name: "",
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
],
|
||||
custom_config_list: existingLlmProvider?.custom_config
|
||||
@@ -269,13 +259,11 @@ export default function CustomModal({
|
||||
});
|
||||
|
||||
return (
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.CUSTOM}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
<Formik
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
setSubmitting(true);
|
||||
|
||||
const modelConfigurations = values.model_configurations
|
||||
@@ -295,127 +283,127 @@ export default function CustomModal({
|
||||
return;
|
||||
}
|
||||
|
||||
// Always send custom_config as a dict (even empty) so the backend
|
||||
// preserves it as non-null — this is the signal that the provider was
|
||||
// created via CustomModal.
|
||||
const customConfig = keyValueListToDict(values.custom_config_list);
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
await submitOnboardingProvider({
|
||||
providerName: values.provider,
|
||||
payload: {
|
||||
...values,
|
||||
model_configurations: modelConfigurations,
|
||||
custom_config: customConfigProcessing(values.custom_config_list),
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: true,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
const selectedModelNames = modelConfigurations.map(
|
||||
(config) => config.name
|
||||
);
|
||||
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: (values as Record<string, unknown>).provider as string,
|
||||
values: {
|
||||
...values,
|
||||
model_configurations: modelConfigurations,
|
||||
custom_config: customConfig,
|
||||
},
|
||||
initialValues: {
|
||||
...initialValues,
|
||||
custom_config: keyValueListToDict(initialValues.custom_config_list),
|
||||
},
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
isCustomProvider: true,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
await submitLLMProvider({
|
||||
providerName: values.provider,
|
||||
values: {
|
||||
...values,
|
||||
selected_model_names: selectedModelNames,
|
||||
custom_config: customConfigProcessing(values.custom_config_list),
|
||||
},
|
||||
initialValues: {
|
||||
...initialValues,
|
||||
custom_config: customConfigProcessing(
|
||||
initialValues.custom_config_list
|
||||
),
|
||||
},
|
||||
modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
{!isOnboarding && (
|
||||
<InputLayouts.FieldPadder>
|
||||
<InputLayouts.Vertical
|
||||
name="provider"
|
||||
title="Provider Name"
|
||||
subDescription={markdown(
|
||||
"Should be one of the providers listed at [LiteLLM](https://docs.litellm.ai/docs/providers)."
|
||||
)}
|
||||
>
|
||||
<InputTypeInField
|
||||
name="provider"
|
||||
placeholder="Provider Name as shown on LiteLLM"
|
||||
variant={existingLlmProvider ? "disabled" : undefined}
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</InputLayouts.FieldPadder>
|
||||
)}
|
||||
|
||||
<APIBaseField optional />
|
||||
|
||||
<InputLayouts.FieldPadder>
|
||||
<InputLayouts.Vertical
|
||||
name="api_version"
|
||||
title="API Version"
|
||||
suffix="optional"
|
||||
{(formikProps) => (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint="custom"
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
>
|
||||
<InputTypeInField name="api_version" />
|
||||
</InputLayouts.Vertical>
|
||||
</InputLayouts.FieldPadder>
|
||||
{!isOnboarding && (
|
||||
<Section gap={0}>
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
|
||||
<APIKeyField
|
||||
optional
|
||||
subDescription="Paste your API key if your model provider requires authentication."
|
||||
/>
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="provider"
|
||||
title="Provider Name"
|
||||
subDescription="Should be one of the providers listed at https://docs.litellm.ai/docs/providers."
|
||||
>
|
||||
<InputTypeInField
|
||||
name="provider"
|
||||
placeholder="Provider Name"
|
||||
variant={existingLlmProvider ? "disabled" : undefined}
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
</Section>
|
||||
)}
|
||||
|
||||
<InputLayouts.FieldPadder>
|
||||
<Section gap={0.75}>
|
||||
<Content
|
||||
title="Additional Configs"
|
||||
description={markdown(
|
||||
"Add extra properties as needed by the model provider. These are passed to LiteLLM's `completion()` call as [environment variables](https://docs.litellm.ai/docs/set_keys#environment-variables). See [documentation](https://docs.onyx.app/admins/ai_models/custom_inference_provider) for more instructions."
|
||||
)}
|
||||
widthVariant="full"
|
||||
variant="section"
|
||||
sizePreset="main-content"
|
||||
/>
|
||||
<FieldSeparator />
|
||||
|
||||
<CustomConfigKeyValue />
|
||||
</Section>
|
||||
</InputLayouts.FieldPadder>
|
||||
<FieldWrapper>
|
||||
<Section gap={0.75}>
|
||||
<Content
|
||||
title="Provider Configs"
|
||||
description="Add properties as needed by the model provider. This is passed to LiteLLM completion() call as arguments in the environment variable. See LiteLLM documentation for more instructions."
|
||||
widthVariant="full"
|
||||
variant="section"
|
||||
sizePreset="main-content"
|
||||
/>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
<KeyValueInput
|
||||
items={formikProps.values.custom_config_list}
|
||||
onChange={(items) =>
|
||||
formikProps.setFieldValue("custom_config_list", items)
|
||||
}
|
||||
addButtonLabel="Add Line"
|
||||
/>
|
||||
</Section>
|
||||
</FieldWrapper>
|
||||
|
||||
<FieldSeparator />
|
||||
|
||||
<Section gap={0.5}>
|
||||
<FieldWrapper>
|
||||
<Content
|
||||
title="Models"
|
||||
description="List LLM models you wish to use and their configurations for this provider. See full list of models at LiteLLM."
|
||||
variant="section"
|
||||
sizePreset="main-content"
|
||||
widthVariant="full"
|
||||
/>
|
||||
</FieldWrapper>
|
||||
|
||||
<Card padding="sm">
|
||||
<ModelConfigurationList formikProps={formikProps as any} />
|
||||
</Card>
|
||||
</Section>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
</>
|
||||
)}
|
||||
</LLMConfigurationModalWrapper>
|
||||
)}
|
||||
|
||||
<InputLayouts.FieldSeparator />
|
||||
<Section gap={0.5}>
|
||||
<InputLayouts.FieldPadder>
|
||||
<Content
|
||||
title="Models"
|
||||
description="List LLM models you wish to use and their configurations for this provider. See full list of models at LiteLLM."
|
||||
variant="section"
|
||||
sizePreset="main-content"
|
||||
widthVariant="full"
|
||||
/>
|
||||
</InputLayouts.FieldPadder>
|
||||
|
||||
<Card padding="sm">
|
||||
<ModelConfigurationList />
|
||||
</Card>
|
||||
</Section>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</ModalWrapper>
|
||||
</Formik>
|
||||
);
|
||||
}
|
||||
|
||||
315
web/src/sections/modals/llmConfig/LMStudioForm.tsx
Normal file
315
web/src/sections/modals/llmConfig/LMStudioForm.tsx
Normal file
@@ -0,0 +1,315 @@
|
||||
"use client";
|
||||
|
||||
import { useCallback, useEffect, useMemo, useState } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { Formik, FormikProps } from "formik";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
} from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import {
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
BaseLLMFormValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import {
|
||||
ModelsField,
|
||||
DisplayNameField,
|
||||
ModelsAccessField,
|
||||
FieldSeparator,
|
||||
FieldWrapper,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import { fetchModels } from "@/app/admin/configuration/llm/utils";
|
||||
import debounce from "lodash/debounce";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
|
||||
const DEFAULT_API_BASE = "http://localhost:1234";
|
||||
|
||||
interface LMStudioFormValues extends BaseLLMFormValues {
|
||||
api_base: string;
|
||||
custom_config: {
|
||||
LM_STUDIO_API_KEY?: string;
|
||||
};
|
||||
}
|
||||
|
||||
interface LMStudioFormInternalsProps {
|
||||
formikProps: FormikProps<LMStudioFormValues>;
|
||||
existingLlmProvider: LLMProviderView | undefined;
|
||||
fetchedModels: ModelConfiguration[];
|
||||
setFetchedModels: (models: ModelConfiguration[]) => void;
|
||||
isTesting: boolean;
|
||||
onClose: () => void;
|
||||
isOnboarding: boolean;
|
||||
}
|
||||
|
||||
function LMStudioFormInternals({
|
||||
formikProps,
|
||||
existingLlmProvider,
|
||||
fetchedModels,
|
||||
setFetchedModels,
|
||||
isTesting,
|
||||
onClose,
|
||||
isOnboarding,
|
||||
}: LMStudioFormInternalsProps) {
|
||||
const initialApiKey =
|
||||
(existingLlmProvider?.custom_config?.LM_STUDIO_API_KEY as string) ?? "";
|
||||
|
||||
const doFetchModels = useCallback(
|
||||
(apiBase: string, apiKey: string | undefined, signal: AbortSignal) => {
|
||||
fetchModels(
|
||||
LLMProviderName.LM_STUDIO,
|
||||
{
|
||||
api_base: apiBase,
|
||||
custom_config: apiKey ? { LM_STUDIO_API_KEY: apiKey } : {},
|
||||
api_key_changed: apiKey !== initialApiKey,
|
||||
name: existingLlmProvider?.name,
|
||||
},
|
||||
signal
|
||||
).then((data) => {
|
||||
if (signal.aborted) return;
|
||||
if (data.error) {
|
||||
toast.error(data.error);
|
||||
setFetchedModels([]);
|
||||
return;
|
||||
}
|
||||
setFetchedModels(data.models);
|
||||
});
|
||||
},
|
||||
[existingLlmProvider?.name, initialApiKey, setFetchedModels]
|
||||
);
|
||||
|
||||
const debouncedFetchModels = useMemo(
|
||||
() => debounce(doFetchModels, 500),
|
||||
[doFetchModels]
|
||||
);
|
||||
|
||||
const apiBase = formikProps.values.api_base;
|
||||
const apiKey = formikProps.values.custom_config?.LM_STUDIO_API_KEY;
|
||||
|
||||
useEffect(() => {
|
||||
if (apiBase) {
|
||||
const controller = new AbortController();
|
||||
debouncedFetchModels(apiBase, apiKey, controller.signal);
|
||||
return () => {
|
||||
debouncedFetchModels.cancel();
|
||||
controller.abort();
|
||||
};
|
||||
} else {
|
||||
setFetchedModels([]);
|
||||
}
|
||||
}, [apiBase, apiKey, debouncedFetchModels, setFetchedModels]);
|
||||
|
||||
const currentModels =
|
||||
fetchedModels.length > 0
|
||||
? fetchedModels
|
||||
: existingLlmProvider?.model_configurations || [];
|
||||
|
||||
return (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={LLMProviderName.LM_STUDIO}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
>
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="api_base"
|
||||
title="API Base URL"
|
||||
subDescription="The base URL for your LM Studio server."
|
||||
>
|
||||
<InputTypeInField
|
||||
name="api_base"
|
||||
placeholder="Your LM Studio API base URL"
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="custom_config.LM_STUDIO_API_KEY"
|
||||
title="API Key"
|
||||
subDescription="Optional API key if your LM Studio server requires authentication."
|
||||
suffix="optional"
|
||||
>
|
||||
<PasswordInputTypeInField
|
||||
name="custom_config.LM_STUDIO_API_KEY"
|
||||
placeholder="API Key"
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<FieldSeparator />
|
||||
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. llama3.1" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={currentModels}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={null}
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
/>
|
||||
)}
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
</>
|
||||
)}
|
||||
</LLMConfigurationModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
export default function LMStudioForm({
|
||||
variant = "llm-configuration",
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
|
||||
LLMProviderName.LM_STUDIO
|
||||
);
|
||||
|
||||
if (open === false) return null;
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const initialValues: LMStudioFormValues = isOnboarding
|
||||
? ({
|
||||
...buildOnboardingInitialValues(),
|
||||
name: LLMProviderName.LM_STUDIO,
|
||||
provider: LLMProviderName.LM_STUDIO,
|
||||
api_base: DEFAULT_API_BASE,
|
||||
default_model_name: "",
|
||||
custom_config: {
|
||||
LM_STUDIO_API_KEY: "",
|
||||
},
|
||||
} as LMStudioFormValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
custom_config: {
|
||||
LM_STUDIO_API_KEY:
|
||||
(existingLlmProvider?.custom_config?.LM_STUDIO_API_KEY as string) ??
|
||||
"",
|
||||
},
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
});
|
||||
|
||||
return (
|
||||
<Formik
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
const filteredCustomConfig = Object.fromEntries(
|
||||
Object.entries(values.custom_config || {}).filter(([, v]) => v !== "")
|
||||
);
|
||||
|
||||
const submitValues = {
|
||||
...values,
|
||||
custom_config:
|
||||
Object.keys(filteredCustomConfig).length > 0
|
||||
? filteredCustomConfig
|
||||
: undefined,
|
||||
};
|
||||
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
fetchedModels.length > 0 ? fetchedModels : [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: LLMProviderName.LM_STUDIO,
|
||||
payload: {
|
||||
...submitValues,
|
||||
model_configurations: modelConfigsToUse,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: LLMProviderName.LM_STUDIO,
|
||||
values: submitValues,
|
||||
initialValues,
|
||||
modelConfigurations:
|
||||
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
{(formikProps) => (
|
||||
<LMStudioFormInternals
|
||||
formikProps={formikProps}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
fetchedModels={fetchedModels}
|
||||
setFetchedModels={setFetchedModels}
|
||||
isTesting={isTesting}
|
||||
onClose={onClose}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
)}
|
||||
</Formik>
|
||||
);
|
||||
}
|
||||
@@ -1,213 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useCallback, useEffect, useMemo } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { useFormikContext } from "formik";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
} from "@/interfaces/llm";
|
||||
import {
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
BaseLLMFormValues as BaseLLMModalValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
APIKeyField,
|
||||
APIBaseField,
|
||||
ModelSelectionField,
|
||||
DisplayNameField,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import { fetchModels } from "@/app/admin/configuration/llm/utils";
|
||||
import debounce from "lodash/debounce";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
|
||||
const DEFAULT_API_BASE = "http://localhost:1234";
|
||||
|
||||
interface LMStudioModalValues extends BaseLLMModalValues {
|
||||
api_base: string;
|
||||
custom_config: {
|
||||
LM_STUDIO_API_KEY?: string;
|
||||
};
|
||||
}
|
||||
|
||||
interface LMStudioModalInternalsProps {
|
||||
existingLlmProvider: LLMProviderView | undefined;
|
||||
isOnboarding: boolean;
|
||||
}
|
||||
|
||||
function LMStudioModalInternals({
|
||||
existingLlmProvider,
|
||||
isOnboarding,
|
||||
}: LMStudioModalInternalsProps) {
|
||||
const formikProps = useFormikContext<LMStudioModalValues>();
|
||||
const initialApiKey = existingLlmProvider?.custom_config?.LM_STUDIO_API_KEY;
|
||||
|
||||
const doFetchModels = useCallback(
|
||||
(apiBase: string, apiKey: string | undefined, signal: AbortSignal) => {
|
||||
fetchModels(
|
||||
LLMProviderName.LM_STUDIO,
|
||||
{
|
||||
api_base: apiBase,
|
||||
custom_config: apiKey ? { LM_STUDIO_API_KEY: apiKey } : {},
|
||||
api_key_changed: apiKey !== initialApiKey,
|
||||
name: existingLlmProvider?.name,
|
||||
},
|
||||
signal
|
||||
).then((data) => {
|
||||
if (signal.aborted) return;
|
||||
if (data.error) {
|
||||
toast.error(data.error);
|
||||
formikProps.setFieldValue("model_configurations", []);
|
||||
return;
|
||||
}
|
||||
formikProps.setFieldValue("model_configurations", data.models);
|
||||
});
|
||||
},
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
[existingLlmProvider?.name, initialApiKey]
|
||||
);
|
||||
|
||||
const debouncedFetchModels = useMemo(
|
||||
() => debounce(doFetchModels, 500),
|
||||
[doFetchModels]
|
||||
);
|
||||
|
||||
const apiBase = formikProps.values.api_base;
|
||||
const apiKey = formikProps.values.custom_config?.LM_STUDIO_API_KEY;
|
||||
|
||||
useEffect(() => {
|
||||
if (apiBase) {
|
||||
const controller = new AbortController();
|
||||
debouncedFetchModels(apiBase, apiKey, controller.signal);
|
||||
return () => {
|
||||
debouncedFetchModels.cancel();
|
||||
controller.abort();
|
||||
};
|
||||
} else {
|
||||
formikProps.setFieldValue("model_configurations", []);
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [apiBase, apiKey, debouncedFetchModels]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<APIBaseField
|
||||
subDescription="The base URL for your LM Studio server."
|
||||
placeholder="Your LM Studio API base URL"
|
||||
/>
|
||||
|
||||
<APIKeyField
|
||||
optional
|
||||
subDescription="Optional API key if your LM Studio server requires authentication."
|
||||
/>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField shouldShowAutoUpdateToggle={false} />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export default function LMStudioModal({
|
||||
variant = "llm-configuration",
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
onOpenChange,
|
||||
onSuccess,
|
||||
}: LLMProviderFormProps) {
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const { mutate } = useSWRConfig();
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const initialValues: LMStudioModalValues = {
|
||||
...useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.LM_STUDIO,
|
||||
existingLlmProvider
|
||||
),
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
custom_config: {
|
||||
LM_STUDIO_API_KEY: existingLlmProvider?.custom_config?.LM_STUDIO_API_KEY,
|
||||
},
|
||||
} as LMStudioModalValues;
|
||||
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
apiBase: true,
|
||||
});
|
||||
|
||||
return (
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.LM_STUDIO}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
const filteredCustomConfig = Object.fromEntries(
|
||||
Object.entries(values.custom_config || {}).filter(([, v]) => v !== "")
|
||||
);
|
||||
|
||||
const submitValues = {
|
||||
...values,
|
||||
custom_config:
|
||||
Object.keys(filteredCustomConfig).length > 0
|
||||
? filteredCustomConfig
|
||||
: undefined,
|
||||
};
|
||||
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.LM_STUDIO,
|
||||
values: submitValues,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
<LMStudioModalInternals
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
</ModalWrapper>
|
||||
);
|
||||
}
|
||||
@@ -1,32 +1,41 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect } from "react";
|
||||
import { useState, useEffect } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { useFormikContext } from "formik";
|
||||
import { Formik, FormikProps } from "formik";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
} from "@/interfaces/llm";
|
||||
import { fetchLiteLLMProxyModels } from "@/app/admin/configuration/llm/utils";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import {
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
BaseLLMFormValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import {
|
||||
APIKeyField,
|
||||
APIBaseField,
|
||||
ModelSelectionField,
|
||||
ModelsField,
|
||||
DisplayNameField,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
ModelsAccessField,
|
||||
FieldSeparator,
|
||||
FieldWrapper,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
|
||||
const DEFAULT_API_BASE = "http://localhost:4000";
|
||||
|
||||
@@ -36,15 +45,30 @@ interface LiteLLMProxyModalValues extends BaseLLMFormValues {
|
||||
}
|
||||
|
||||
interface LiteLLMProxyModalInternalsProps {
|
||||
formikProps: FormikProps<LiteLLMProxyModalValues>;
|
||||
existingLlmProvider: LLMProviderView | undefined;
|
||||
fetchedModels: ModelConfiguration[];
|
||||
setFetchedModels: (models: ModelConfiguration[]) => void;
|
||||
modelConfigurations: ModelConfiguration[];
|
||||
isTesting: boolean;
|
||||
onClose: () => void;
|
||||
isOnboarding: boolean;
|
||||
}
|
||||
|
||||
function LiteLLMProxyModalInternals({
|
||||
formikProps,
|
||||
existingLlmProvider,
|
||||
fetchedModels,
|
||||
setFetchedModels,
|
||||
modelConfigurations,
|
||||
isTesting,
|
||||
onClose,
|
||||
isOnboarding,
|
||||
}: LiteLLMProxyModalInternalsProps) {
|
||||
const formikProps = useFormikContext<LiteLLMProxyModalValues>();
|
||||
const currentModels =
|
||||
fetchedModels.length > 0
|
||||
? fetchedModels
|
||||
: existingLlmProvider?.model_configurations || modelConfigurations;
|
||||
|
||||
const isFetchDisabled =
|
||||
!formikProps.values.api_base || !formikProps.values.api_key;
|
||||
@@ -53,12 +77,12 @@ function LiteLLMProxyModalInternals({
|
||||
const { models, error } = await fetchLiteLLMProxyModels({
|
||||
api_base: formikProps.values.api_base,
|
||||
api_key: formikProps.values.api_key,
|
||||
provider_name: LLMProviderName.LITELLM_PROXY,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
formikProps.setFieldValue("model_configurations", models);
|
||||
setFetchedModels(models);
|
||||
};
|
||||
|
||||
// Auto-fetch models on initial load when editing an existing provider
|
||||
@@ -74,34 +98,58 @@ function LiteLLMProxyModalInternals({
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<>
|
||||
<APIBaseField
|
||||
subDescription="The base URL for your LiteLLM Proxy server."
|
||||
placeholder="https://your-litellm-proxy.com"
|
||||
/>
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={LLMProviderName.LITELLM_PROXY}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
>
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="api_base"
|
||||
title="API Base URL"
|
||||
subDescription="The base URL for your LiteLLM Proxy server."
|
||||
>
|
||||
<InputTypeInField
|
||||
name="api_base"
|
||||
placeholder="https://your-litellm-proxy.com"
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
|
||||
<APIKeyField providerName="LiteLLM Proxy" />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
<FieldSeparator />
|
||||
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. gpt-4o" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={currentModels}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={null}
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
)}
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
</LLMConfigurationModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -109,67 +157,111 @@ export default function LiteLLMProxyModal({
|
||||
variant = "llm-configuration",
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
onSuccess,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
|
||||
LLMProviderName.LITELLM_PROXY
|
||||
);
|
||||
|
||||
if (open === false) return null;
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const initialValues: LiteLLMProxyModalValues = {
|
||||
...useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.LITELLM_PROXY,
|
||||
existingLlmProvider
|
||||
),
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
} as LiteLLMProxyModalValues;
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
apiKey: true,
|
||||
apiBase: true,
|
||||
});
|
||||
const initialValues: LiteLLMProxyModalValues = isOnboarding
|
||||
? ({
|
||||
...buildOnboardingInitialValues(),
|
||||
name: LLMProviderName.LITELLM_PROXY,
|
||||
provider: LLMProviderName.LITELLM_PROXY,
|
||||
api_key: "",
|
||||
api_base: DEFAULT_API_BASE,
|
||||
default_model_name: "",
|
||||
} as LiteLLMProxyModalValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
});
|
||||
|
||||
return (
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.LITELLM_PROXY}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
<Formik
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.LITELLM_PROXY,
|
||||
values,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
fetchedModels.length > 0 ? fetchedModels : [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: LLMProviderName.LITELLM_PROXY,
|
||||
payload: {
|
||||
...values,
|
||||
model_configurations: modelConfigsToUse,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: LLMProviderName.LITELLM_PROXY,
|
||||
values,
|
||||
initialValues,
|
||||
modelConfigurations:
|
||||
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
<LiteLLMProxyModalInternals
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
</ModalWrapper>
|
||||
{(formikProps) => (
|
||||
<LiteLLMProxyModalInternals
|
||||
formikProps={formikProps}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
fetchedModels={fetchedModels}
|
||||
setFetchedModels={setFetchedModels}
|
||||
modelConfigurations={modelConfigurations}
|
||||
isTesting={isTesting}
|
||||
onClose={onClose}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
)}
|
||||
</Formik>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,44 +1,47 @@
|
||||
"use client";
|
||||
|
||||
import * as Yup from "yup";
|
||||
import { Dispatch, SetStateAction, useMemo, useState } from "react";
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { useFormikContext } from "formik";
|
||||
import { Formik, FormikProps } from "formik";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
} from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import {
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
BaseLLMFormValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
ModelSelectionField,
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import {
|
||||
ModelsField,
|
||||
DisplayNameField,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
ModelsAccessField,
|
||||
FieldSeparator,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import { fetchOllamaModels } from "@/app/admin/configuration/llm/utils";
|
||||
import debounce from "lodash/debounce";
|
||||
import Tabs from "@/refresh-components/Tabs";
|
||||
import { Card } from "@opal/components";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import useOnMount from "@/hooks/useOnMount";
|
||||
|
||||
const OLLAMA_PROVIDER_NAME = "ollama_chat";
|
||||
const DEFAULT_API_BASE = "http://127.0.0.1:11434";
|
||||
const CLOUD_API_BASE = "https://ollama.com";
|
||||
|
||||
enum Tab {
|
||||
TAB_SELF_HOSTED = "self-hosted",
|
||||
TAB_CLOUD = "cloud",
|
||||
}
|
||||
const TAB_SELF_HOSTED = "self-hosted";
|
||||
const TAB_CLOUD = "cloud";
|
||||
|
||||
interface OllamaModalValues extends BaseLLMFormValues {
|
||||
api_base: string;
|
||||
@@ -48,67 +51,104 @@ interface OllamaModalValues extends BaseLLMFormValues {
|
||||
}
|
||||
|
||||
interface OllamaModalInternalsProps {
|
||||
formikProps: FormikProps<OllamaModalValues>;
|
||||
existingLlmProvider: LLMProviderView | undefined;
|
||||
fetchedModels: ModelConfiguration[];
|
||||
setFetchedModels: (models: ModelConfiguration[]) => void;
|
||||
isTesting: boolean;
|
||||
onClose: () => void;
|
||||
isOnboarding: boolean;
|
||||
tab: Tab;
|
||||
setTab: Dispatch<SetStateAction<Tab>>;
|
||||
}
|
||||
|
||||
function OllamaModalInternals({
|
||||
formikProps,
|
||||
existingLlmProvider,
|
||||
fetchedModels,
|
||||
setFetchedModels,
|
||||
isTesting,
|
||||
onClose,
|
||||
isOnboarding,
|
||||
tab,
|
||||
setTab,
|
||||
}: OllamaModalInternalsProps) {
|
||||
const formikProps = useFormikContext<OllamaModalValues>();
|
||||
const isInitialMount = useRef(true);
|
||||
|
||||
const isFetchDisabled = useMemo(
|
||||
() =>
|
||||
tab === Tab.TAB_SELF_HOSTED
|
||||
? !formikProps.values.api_base
|
||||
: !formikProps.values.custom_config.OLLAMA_API_KEY,
|
||||
[tab, formikProps]
|
||||
const doFetchModels = useCallback(
|
||||
(apiBase: string, signal: AbortSignal) => {
|
||||
fetchOllamaModels({
|
||||
api_base: apiBase,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
signal,
|
||||
}).then((data) => {
|
||||
if (signal.aborted) return;
|
||||
if (data.error) {
|
||||
toast.error(data.error);
|
||||
setFetchedModels([]);
|
||||
return;
|
||||
}
|
||||
setFetchedModels(data.models);
|
||||
});
|
||||
},
|
||||
[existingLlmProvider?.name, setFetchedModels]
|
||||
);
|
||||
|
||||
const handleFetchModels = async (signal?: AbortSignal) => {
|
||||
// Only Ollama cloud accepts API key
|
||||
const apiBase = formikProps.values.custom_config?.OLLAMA_API_KEY
|
||||
? CLOUD_API_BASE
|
||||
: formikProps.values.api_base;
|
||||
const { models, error } = await fetchOllamaModels({
|
||||
api_base: apiBase,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
signal,
|
||||
});
|
||||
if (signal?.aborted) return;
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
formikProps.setFieldValue("model_configurations", models);
|
||||
};
|
||||
const debouncedFetchModels = useMemo(
|
||||
() => debounce(doFetchModels, 500),
|
||||
[doFetchModels]
|
||||
);
|
||||
|
||||
// Auto-fetch models on initial load when editing an existing provider
|
||||
useOnMount(() => {
|
||||
if (existingLlmProvider) {
|
||||
handleFetchModels().catch((err) => {
|
||||
toast.error(
|
||||
err instanceof Error ? err.message : "Failed to fetch models"
|
||||
);
|
||||
});
|
||||
// Skip the initial fetch for new providers — api_base starts with a default
|
||||
// value, which would otherwise trigger a fetch before the user has done
|
||||
// anything. Existing providers should still auto-fetch on mount.
|
||||
useEffect(() => {
|
||||
if (isInitialMount.current) {
|
||||
isInitialMount.current = false;
|
||||
if (!existingLlmProvider) return;
|
||||
}
|
||||
});
|
||||
|
||||
if (formikProps.values.api_base) {
|
||||
const controller = new AbortController();
|
||||
debouncedFetchModels(formikProps.values.api_base, controller.signal);
|
||||
return () => {
|
||||
debouncedFetchModels.cancel();
|
||||
controller.abort();
|
||||
};
|
||||
} else {
|
||||
setFetchedModels([]);
|
||||
}
|
||||
}, [
|
||||
formikProps.values.api_base,
|
||||
debouncedFetchModels,
|
||||
setFetchedModels,
|
||||
existingLlmProvider,
|
||||
]);
|
||||
|
||||
const currentModels =
|
||||
fetchedModels.length > 0
|
||||
? fetchedModels
|
||||
: existingLlmProvider?.model_configurations || [];
|
||||
|
||||
const hasApiKey = !!formikProps.values.custom_config?.OLLAMA_API_KEY;
|
||||
const defaultTab =
|
||||
existingLlmProvider && hasApiKey ? TAB_CLOUD : TAB_SELF_HOSTED;
|
||||
|
||||
return (
|
||||
<>
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={OLLAMA_PROVIDER_NAME}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
>
|
||||
<Card background="light" border="none" padding="sm">
|
||||
<Tabs value={tab} onValueChange={(value) => setTab(value as Tab)}>
|
||||
<Tabs defaultValue={defaultTab}>
|
||||
<Tabs.List>
|
||||
<Tabs.Trigger value={Tab.TAB_SELF_HOSTED}>
|
||||
<Tabs.Trigger value={TAB_SELF_HOSTED}>
|
||||
Self-hosted Ollama
|
||||
</Tabs.Trigger>
|
||||
<Tabs.Trigger value={Tab.TAB_CLOUD}>Ollama Cloud</Tabs.Trigger>
|
||||
<Tabs.Trigger value={TAB_CLOUD}>Ollama Cloud</Tabs.Trigger>
|
||||
</Tabs.List>
|
||||
<Tabs.Content value={Tab.TAB_SELF_HOSTED} padding={0}>
|
||||
<Tabs.Content value={TAB_SELF_HOSTED}>
|
||||
<InputLayouts.Vertical
|
||||
name="api_base"
|
||||
title="API Base URL"
|
||||
@@ -121,7 +161,7 @@ function OllamaModalInternals({
|
||||
</InputLayouts.Vertical>
|
||||
</Tabs.Content>
|
||||
|
||||
<Tabs.Content value={Tab.TAB_CLOUD}>
|
||||
<Tabs.Content value={TAB_CLOUD}>
|
||||
<InputLayouts.Vertical
|
||||
name="custom_config.OLLAMA_API_KEY"
|
||||
title="API Key"
|
||||
@@ -138,24 +178,31 @@ function OllamaModalInternals({
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
<FieldSeparator />
|
||||
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. llama3.1" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={currentModels}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={null}
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
/>
|
||||
)}
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
</LLMConfigurationModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -163,102 +210,125 @@ export default function OllamaModal({
|
||||
variant = "llm-configuration",
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
onSuccess,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const { mutate } = useSWRConfig();
|
||||
const apiKey = existingLlmProvider?.custom_config?.OLLAMA_API_KEY;
|
||||
const defaultTab =
|
||||
existingLlmProvider && !!apiKey ? Tab.TAB_CLOUD : Tab.TAB_SELF_HOSTED;
|
||||
const [tab, setTab] = useState<Tab>(defaultTab);
|
||||
const { wellKnownLLMProvider } =
|
||||
useWellKnownLLMProvider(OLLAMA_PROVIDER_NAME);
|
||||
|
||||
if (open === false) return null;
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const initialValues: OllamaModalValues = {
|
||||
...useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.OLLAMA_CHAT,
|
||||
existingLlmProvider
|
||||
),
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
custom_config: {
|
||||
OLLAMA_API_KEY: apiKey,
|
||||
},
|
||||
} as OllamaModalValues;
|
||||
|
||||
const validationSchema = useMemo(
|
||||
() =>
|
||||
buildValidationSchema(isOnboarding, {
|
||||
apiBase: tab === Tab.TAB_SELF_HOSTED,
|
||||
extra:
|
||||
tab === Tab.TAB_CLOUD
|
||||
? {
|
||||
custom_config: Yup.object({
|
||||
OLLAMA_API_KEY: Yup.string().required("API Key is required"),
|
||||
}),
|
||||
}
|
||||
: undefined,
|
||||
}),
|
||||
[tab, isOnboarding]
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const initialValues: OllamaModalValues = isOnboarding
|
||||
? ({
|
||||
...buildOnboardingInitialValues(),
|
||||
name: OLLAMA_PROVIDER_NAME,
|
||||
provider: OLLAMA_PROVIDER_NAME,
|
||||
api_base: DEFAULT_API_BASE,
|
||||
default_model_name: "",
|
||||
custom_config: {
|
||||
OLLAMA_API_KEY: "",
|
||||
},
|
||||
} as OllamaModalValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
custom_config: {
|
||||
OLLAMA_API_KEY:
|
||||
(existingLlmProvider?.custom_config?.OLLAMA_API_KEY as string) ??
|
||||
"",
|
||||
},
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
});
|
||||
|
||||
return (
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.OLLAMA_CHAT}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
<Formik
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
const filteredCustomConfig = Object.fromEntries(
|
||||
Object.entries(values.custom_config || {}).filter(([, v]) => v !== "")
|
||||
);
|
||||
|
||||
const submitValues = {
|
||||
...values,
|
||||
api_base: filteredCustomConfig.OLLAMA_API_KEY
|
||||
? CLOUD_API_BASE
|
||||
: values.api_base,
|
||||
custom_config:
|
||||
Object.keys(filteredCustomConfig).length > 0
|
||||
? filteredCustomConfig
|
||||
: undefined,
|
||||
};
|
||||
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.OLLAMA_CHAT,
|
||||
values: submitValues,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
fetchedModels.length > 0 ? fetchedModels : [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: OLLAMA_PROVIDER_NAME,
|
||||
payload: {
|
||||
...submitValues,
|
||||
model_configurations: modelConfigsToUse,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: OLLAMA_PROVIDER_NAME,
|
||||
values: submitValues,
|
||||
initialValues,
|
||||
modelConfigurations:
|
||||
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
<OllamaModalInternals
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
isOnboarding={isOnboarding}
|
||||
tab={tab}
|
||||
setTab={setTab}
|
||||
/>
|
||||
</ModalWrapper>
|
||||
{(formikProps) => (
|
||||
<OllamaModalInternals
|
||||
formikProps={formikProps}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
fetchedModels={fetchedModels}
|
||||
setFetchedModels={setFetchedModels}
|
||||
isTesting={isTesting}
|
||||
onClose={onClose}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
)}
|
||||
</Formik>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,174 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect } from "react";
|
||||
import { markdown } from "@opal/utils";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { useFormikContext } from "formik";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
} from "@/interfaces/llm";
|
||||
import { fetchOpenAICompatibleModels } from "@/app/admin/configuration/llm/utils";
|
||||
import {
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
BaseLLMFormValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
APIBaseField,
|
||||
APIKeyField,
|
||||
ModelSelectionField,
|
||||
DisplayNameField,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
|
||||
interface OpenAICompatibleModalValues extends BaseLLMFormValues {
|
||||
api_key: string;
|
||||
api_base: string;
|
||||
}
|
||||
|
||||
interface OpenAICompatibleModalInternalsProps {
|
||||
existingLlmProvider: LLMProviderView | undefined;
|
||||
isOnboarding: boolean;
|
||||
}
|
||||
|
||||
function OpenAICompatibleModalInternals({
|
||||
existingLlmProvider,
|
||||
isOnboarding,
|
||||
}: OpenAICompatibleModalInternalsProps) {
|
||||
const formikProps = useFormikContext<OpenAICompatibleModalValues>();
|
||||
|
||||
const isFetchDisabled = !formikProps.values.api_base;
|
||||
|
||||
const handleFetchModels = async () => {
|
||||
const { models, error } = await fetchOpenAICompatibleModels({
|
||||
api_base: formikProps.values.api_base,
|
||||
api_key: formikProps.values.api_key || undefined,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
formikProps.setFieldValue("model_configurations", models);
|
||||
};
|
||||
|
||||
// Auto-fetch models on initial load when editing an existing provider
|
||||
useEffect(() => {
|
||||
if (existingLlmProvider && !isFetchDisabled) {
|
||||
handleFetchModels().catch((err) => {
|
||||
toast.error(
|
||||
err instanceof Error ? err.message : "Failed to fetch models"
|
||||
);
|
||||
});
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<>
|
||||
<APIBaseField
|
||||
subDescription="The base URL of your OpenAI-compatible server."
|
||||
placeholder="http://localhost:8000/v1"
|
||||
/>
|
||||
|
||||
<APIKeyField
|
||||
optional
|
||||
subDescription={markdown(
|
||||
"Provide an API key if your server requires authentication."
|
||||
)}
|
||||
/>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export default function OpenAICompatibleModal({
|
||||
variant = "llm-configuration",
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
onOpenChange,
|
||||
onSuccess,
|
||||
}: LLMProviderFormProps) {
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const { mutate } = useSWRConfig();
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const initialValues = useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.OPENAI_COMPATIBLE,
|
||||
existingLlmProvider
|
||||
) as OpenAICompatibleModalValues;
|
||||
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
apiBase: true,
|
||||
});
|
||||
|
||||
return (
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.OPENAI_COMPATIBLE}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.OPENAI_COMPATIBLE,
|
||||
values,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
<OpenAICompatibleModalInternals
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
</ModalWrapper>
|
||||
);
|
||||
}
|
||||
@@ -1,99 +1,175 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { LLMProviderFormProps, LLMProviderName } from "@/interfaces/llm";
|
||||
import { Formik } from "formik";
|
||||
import { LLMProviderFormProps } from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import {
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import {
|
||||
APIKeyField,
|
||||
ModelSelectionField,
|
||||
ModelsField,
|
||||
DisplayNameField,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
FieldSeparator,
|
||||
ModelsAccessField,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
|
||||
const OPENAI_PROVIDER_NAME = "openai";
|
||||
const DEFAULT_DEFAULT_MODEL_NAME = "gpt-5.2";
|
||||
|
||||
export default function OpenAIModal({
|
||||
variant = "llm-configuration",
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
onSuccess,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
}: LLMProviderFormProps) {
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } =
|
||||
useWellKnownLLMProvider(OPENAI_PROVIDER_NAME);
|
||||
|
||||
if (open === false) return null;
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const initialValues = useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.OPENAI,
|
||||
existingLlmProvider
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
apiKey: true,
|
||||
});
|
||||
const initialValues = isOnboarding
|
||||
? {
|
||||
...buildOnboardingInitialValues(),
|
||||
name: OPENAI_PROVIDER_NAME,
|
||||
provider: OPENAI_PROVIDER_NAME,
|
||||
api_key: "",
|
||||
default_model_name: DEFAULT_DEFAULT_MODEL_NAME,
|
||||
}
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
default_model_name:
|
||||
(defaultModelName &&
|
||||
modelConfigurations.some((m) => m.name === defaultModelName)
|
||||
? defaultModelName
|
||||
: undefined) ??
|
||||
wellKnownLLMProvider?.recommended_default_model?.name ??
|
||||
DEFAULT_DEFAULT_MODEL_NAME,
|
||||
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
});
|
||||
|
||||
return (
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.OPENAI}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
<Formik
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.OPENAI,
|
||||
values,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
(wellKnownLLMProvider ?? llmDescriptor)?.known_models ?? [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: OPENAI_PROVIDER_NAME,
|
||||
payload: {
|
||||
...values,
|
||||
model_configurations: modelConfigsToUse,
|
||||
is_auto_mode:
|
||||
values.default_model_name === DEFAULT_DEFAULT_MODEL_NAME,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: OPENAI_PROVIDER_NAME,
|
||||
values,
|
||||
initialValues,
|
||||
modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
<APIKeyField providerName="OpenAI" />
|
||||
{(formikProps) => (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={OPENAI_PROVIDER_NAME}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
>
|
||||
<APIKeyField providerName="OpenAI" />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<FieldSeparator />
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. gpt-5.2" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={modelConfigurations}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={
|
||||
wellKnownLLMProvider?.recommended_default_model ?? null
|
||||
}
|
||||
shouldShowAutoUpdateToggle={true}
|
||||
/>
|
||||
)}
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
</>
|
||||
)}
|
||||
</LLMConfigurationModalWrapper>
|
||||
)}
|
||||
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField shouldShowAutoUpdateToggle={true} />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</ModalWrapper>
|
||||
</Formik>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,50 +1,73 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect } from "react";
|
||||
import { useState, useEffect } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { useFormikContext } from "formik";
|
||||
import { Formik, FormikProps } from "formik";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
} from "@/interfaces/llm";
|
||||
import { fetchOpenRouterModels } from "@/app/admin/configuration/llm/utils";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import {
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
BaseLLMFormValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import {
|
||||
APIKeyField,
|
||||
APIBaseField,
|
||||
ModelSelectionField,
|
||||
ModelsField,
|
||||
DisplayNameField,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
ModelsAccessField,
|
||||
FieldSeparator,
|
||||
FieldWrapper,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
|
||||
const OPENROUTER_PROVIDER_NAME = "openrouter";
|
||||
const DEFAULT_API_BASE = "https://openrouter.ai/api/v1";
|
||||
|
||||
interface OpenRouterModalValues extends BaseLLMFormValues {
|
||||
api_key: string;
|
||||
api_base: string;
|
||||
}
|
||||
|
||||
interface OpenRouterModalInternalsProps {
|
||||
formikProps: FormikProps<OpenRouterModalValues>;
|
||||
existingLlmProvider: LLMProviderView | undefined;
|
||||
fetchedModels: ModelConfiguration[];
|
||||
setFetchedModels: (models: ModelConfiguration[]) => void;
|
||||
modelConfigurations: ModelConfiguration[];
|
||||
isTesting: boolean;
|
||||
onClose: () => void;
|
||||
isOnboarding: boolean;
|
||||
}
|
||||
|
||||
function OpenRouterModalInternals({
|
||||
formikProps,
|
||||
existingLlmProvider,
|
||||
fetchedModels,
|
||||
setFetchedModels,
|
||||
modelConfigurations,
|
||||
isTesting,
|
||||
onClose,
|
||||
isOnboarding,
|
||||
}: OpenRouterModalInternalsProps) {
|
||||
const formikProps = useFormikContext<OpenRouterModalValues>();
|
||||
const currentModels =
|
||||
fetchedModels.length > 0
|
||||
? fetchedModels
|
||||
: existingLlmProvider?.model_configurations || modelConfigurations;
|
||||
|
||||
const isFetchDisabled =
|
||||
!formikProps.values.api_base || !formikProps.values.api_key;
|
||||
@@ -53,12 +76,12 @@ function OpenRouterModalInternals({
|
||||
const { models, error } = await fetchOpenRouterModels({
|
||||
api_base: formikProps.values.api_base,
|
||||
api_key: formikProps.values.api_key,
|
||||
provider_name: LLMProviderName.OPENROUTER,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
});
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
formikProps.setFieldValue("model_configurations", models);
|
||||
setFetchedModels(models);
|
||||
};
|
||||
|
||||
// Auto-fetch models on initial load when editing an existing provider
|
||||
@@ -74,34 +97,58 @@ function OpenRouterModalInternals({
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<>
|
||||
<APIBaseField
|
||||
subDescription="Paste your OpenRouter-compatible endpoint URL or use OpenRouter API directly."
|
||||
placeholder="Your OpenRouter base URL"
|
||||
/>
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={OPENROUTER_PROVIDER_NAME}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
>
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="api_base"
|
||||
title="API Base URL"
|
||||
subDescription="Paste your OpenRouter-compatible endpoint URL or use OpenRouter API directly."
|
||||
>
|
||||
<InputTypeInField
|
||||
name="api_base"
|
||||
placeholder="Your OpenRouter base URL"
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
|
||||
<APIKeyField providerName="OpenRouter" />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
)}
|
||||
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
<FieldSeparator />
|
||||
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. openai/gpt-4o" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={currentModels}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={null}
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
onRefetch={isFetchDisabled ? undefined : handleFetchModels}
|
||||
/>
|
||||
)}
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
<FieldSeparator />
|
||||
<ModelsAccessField formikProps={formikProps} />
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
</LLMConfigurationModalWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -109,67 +156,111 @@ export default function OpenRouterModal({
|
||||
variant = "llm-configuration",
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
onSuccess,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
|
||||
OPENROUTER_PROVIDER_NAME
|
||||
);
|
||||
|
||||
if (open === false) return null;
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const initialValues: OpenRouterModalValues = {
|
||||
...useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.OPENROUTER,
|
||||
existingLlmProvider
|
||||
),
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
} as OpenRouterModalValues;
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
apiKey: true,
|
||||
apiBase: true,
|
||||
});
|
||||
const initialValues: OpenRouterModalValues = isOnboarding
|
||||
? ({
|
||||
...buildOnboardingInitialValues(),
|
||||
name: OPENROUTER_PROVIDER_NAME,
|
||||
provider: OPENROUTER_PROVIDER_NAME,
|
||||
api_key: "",
|
||||
api_base: DEFAULT_API_BASE,
|
||||
default_model_name: "",
|
||||
} as OpenRouterModalValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
});
|
||||
|
||||
return (
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.OPENROUTER}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
<Formik
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.OPENROUTER,
|
||||
values,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
fetchedModels.length > 0 ? fetchedModels : [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: OPENROUTER_PROVIDER_NAME,
|
||||
payload: {
|
||||
...values,
|
||||
model_configurations: modelConfigsToUse,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: OPENROUTER_PROVIDER_NAME,
|
||||
values,
|
||||
initialValues,
|
||||
modelConfigurations:
|
||||
fetchedModels.length > 0 ? fetchedModels : modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
<OpenRouterModalInternals
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
</ModalWrapper>
|
||||
{(formikProps) => (
|
||||
<OpenRouterModalInternals
|
||||
formikProps={formikProps}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
fetchedModels={fetchedModels}
|
||||
setFetchedModels={setFetchedModels}
|
||||
modelConfigurations={modelConfigurations}
|
||||
isTesting={isTesting}
|
||||
onClose={onClose}
|
||||
isOnboarding={isOnboarding}
|
||||
/>
|
||||
)}
|
||||
</Formik>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,27 +1,38 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { Formik } from "formik";
|
||||
import { FileUploadFormField } from "@/components/Field";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import { LLMProviderFormProps, LLMProviderName } from "@/interfaces/llm";
|
||||
import { LLMProviderFormProps } from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import {
|
||||
useInitialValues,
|
||||
buildValidationSchema,
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
buildOnboardingInitialValues,
|
||||
BaseLLMFormValues,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
import { submitProvider } from "@/sections/modals/llmConfig/svc";
|
||||
import { LLMProviderConfiguredSource } from "@/lib/analytics";
|
||||
import {
|
||||
ModelSelectionField,
|
||||
submitLLMProvider,
|
||||
submitOnboardingProvider,
|
||||
} from "@/sections/modals/llmConfig/svc";
|
||||
import {
|
||||
ModelsField,
|
||||
DisplayNameField,
|
||||
ModelAccessField,
|
||||
ModalWrapper,
|
||||
FieldSeparator,
|
||||
FieldWrapper,
|
||||
ModelsAccessField,
|
||||
SingleDefaultModelField,
|
||||
LLMConfigurationModalWrapper,
|
||||
} from "@/sections/modals/llmConfig/shared";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
|
||||
const VERTEXAI_PROVIDER_NAME = "vertex_ai";
|
||||
const VERTEXAI_DISPLAY_NAME = "Google Cloud Vertex AI";
|
||||
const VERTEXAI_DEFAULT_MODEL = "gemini-2.5-pro";
|
||||
const VERTEXAI_DEFAULT_LOCATION = "global";
|
||||
|
||||
interface VertexAIModalValues extends BaseLLMFormValues {
|
||||
@@ -35,49 +46,89 @@ export default function VertexAIModal({
|
||||
variant = "llm-configuration",
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
onSuccess,
|
||||
defaultModelName,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
llmDescriptor,
|
||||
}: LLMProviderFormProps) {
|
||||
const isOnboarding = variant === "onboarding";
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const { mutate } = useSWRConfig();
|
||||
const { wellKnownLLMProvider } = useWellKnownLLMProvider(
|
||||
VERTEXAI_PROVIDER_NAME
|
||||
);
|
||||
|
||||
if (open === false) return null;
|
||||
|
||||
const onClose = () => onOpenChange?.(false);
|
||||
|
||||
const initialValues: VertexAIModalValues = {
|
||||
...useInitialValues(
|
||||
isOnboarding,
|
||||
LLMProviderName.VERTEX_AI,
|
||||
existingLlmProvider
|
||||
),
|
||||
custom_config: {
|
||||
vertex_credentials:
|
||||
(existingLlmProvider?.custom_config?.vertex_credentials as string) ??
|
||||
"",
|
||||
vertex_location:
|
||||
(existingLlmProvider?.custom_config?.vertex_location as string) ??
|
||||
VERTEXAI_DEFAULT_LOCATION,
|
||||
},
|
||||
} as VertexAIModalValues;
|
||||
const modelConfigurations = buildAvailableModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? llmDescriptor
|
||||
);
|
||||
|
||||
const validationSchema = buildValidationSchema(isOnboarding, {
|
||||
extra: {
|
||||
custom_config: Yup.object({
|
||||
vertex_credentials: Yup.string().required(
|
||||
"Credentials file is required"
|
||||
const initialValues: VertexAIModalValues = isOnboarding
|
||||
? ({
|
||||
...buildOnboardingInitialValues(),
|
||||
name: VERTEXAI_PROVIDER_NAME,
|
||||
provider: VERTEXAI_PROVIDER_NAME,
|
||||
default_model_name: VERTEXAI_DEFAULT_MODEL,
|
||||
custom_config: {
|
||||
vertex_credentials: "",
|
||||
vertex_location: VERTEXAI_DEFAULT_LOCATION,
|
||||
},
|
||||
} as VertexAIModalValues)
|
||||
: {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations,
|
||||
defaultModelName
|
||||
),
|
||||
vertex_location: Yup.string(),
|
||||
}),
|
||||
},
|
||||
});
|
||||
default_model_name:
|
||||
(defaultModelName &&
|
||||
modelConfigurations.some((m) => m.name === defaultModelName)
|
||||
? defaultModelName
|
||||
: undefined) ??
|
||||
wellKnownLLMProvider?.recommended_default_model?.name ??
|
||||
VERTEXAI_DEFAULT_MODEL,
|
||||
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,
|
||||
custom_config: {
|
||||
vertex_credentials:
|
||||
(existingLlmProvider?.custom_config
|
||||
?.vertex_credentials as string) ?? "",
|
||||
vertex_location:
|
||||
(existingLlmProvider?.custom_config?.vertex_location as string) ??
|
||||
VERTEXAI_DEFAULT_LOCATION,
|
||||
},
|
||||
};
|
||||
|
||||
const validationSchema = isOnboarding
|
||||
? Yup.object().shape({
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
custom_config: Yup.object({
|
||||
vertex_credentials: Yup.string().required(
|
||||
"Credentials file is required"
|
||||
),
|
||||
vertex_location: Yup.string(),
|
||||
}),
|
||||
})
|
||||
: buildDefaultValidationSchema().shape({
|
||||
custom_config: Yup.object({
|
||||
vertex_credentials: Yup.string().required(
|
||||
"Credentials file is required"
|
||||
),
|
||||
vertex_location: Yup.string(),
|
||||
}),
|
||||
});
|
||||
|
||||
return (
|
||||
<ModalWrapper
|
||||
providerName={LLMProviderName.VERTEX_AI}
|
||||
llmProvider={existingLlmProvider}
|
||||
onClose={onClose}
|
||||
<Formik
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
onSubmit={async (values, { setSubmitting, setStatus }) => {
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
const filteredCustomConfig = Object.fromEntries(
|
||||
Object.entries(values.custom_config || {}).filter(
|
||||
([key, v]) => key === "vertex_credentials" || v !== ""
|
||||
@@ -92,75 +143,101 @@ export default function VertexAIModal({
|
||||
: undefined,
|
||||
};
|
||||
|
||||
await submitProvider({
|
||||
analyticsSource: isOnboarding
|
||||
? LLMProviderConfiguredSource.CHAT_ONBOARDING
|
||||
: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
providerName: LLMProviderName.VERTEX_AI,
|
||||
values: submitValues,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess: async () => {
|
||||
if (onSuccess) {
|
||||
await onSuccess();
|
||||
} else {
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
toast.success(
|
||||
existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!"
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
if (isOnboarding && onboardingState && onboardingActions) {
|
||||
const modelConfigsToUse =
|
||||
(wellKnownLLMProvider ?? llmDescriptor)?.known_models ?? [];
|
||||
|
||||
await submitOnboardingProvider({
|
||||
providerName: VERTEXAI_PROVIDER_NAME,
|
||||
payload: {
|
||||
...submitValues,
|
||||
model_configurations: modelConfigsToUse,
|
||||
is_auto_mode:
|
||||
values.default_model_name === VERTEXAI_DEFAULT_MODEL,
|
||||
},
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider: false,
|
||||
onClose,
|
||||
setIsSubmitting: setSubmitting,
|
||||
});
|
||||
} else {
|
||||
await submitLLMProvider({
|
||||
providerName: VERTEXAI_PROVIDER_NAME,
|
||||
values: submitValues,
|
||||
initialValues,
|
||||
modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
<InputLayouts.FieldPadder>
|
||||
<InputLayouts.Vertical
|
||||
name="custom_config.vertex_location"
|
||||
title="Google Cloud Region Name"
|
||||
subDescription="Region where your Google Vertex AI models are hosted. See full list of regions supported at Google Cloud."
|
||||
{(formikProps) => (
|
||||
<LLMConfigurationModalWrapper
|
||||
providerEndpoint={VERTEXAI_PROVIDER_NAME}
|
||||
providerName={VERTEXAI_DISPLAY_NAME}
|
||||
existingProviderName={existingLlmProvider?.name}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
isDirty={formikProps.dirty}
|
||||
isTesting={isTesting}
|
||||
isSubmitting={formikProps.isSubmitting}
|
||||
>
|
||||
<InputTypeInField
|
||||
name="custom_config.vertex_location"
|
||||
placeholder={VERTEXAI_DEFAULT_LOCATION}
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</InputLayouts.FieldPadder>
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="custom_config.vertex_location"
|
||||
title="Google Cloud Region Name"
|
||||
subDescription="Region where your Google Vertex AI models are hosted. See full list of regions supported at Google Cloud."
|
||||
>
|
||||
<InputTypeInField
|
||||
name="custom_config.vertex_location"
|
||||
placeholder={VERTEXAI_DEFAULT_LOCATION}
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
|
||||
<InputLayouts.FieldPadder>
|
||||
<InputLayouts.Vertical
|
||||
name="custom_config.vertex_credentials"
|
||||
title="API Key"
|
||||
subDescription="Attach your API key JSON from Google Cloud to access your models."
|
||||
>
|
||||
<FileUploadFormField
|
||||
name="custom_config.vertex_credentials"
|
||||
label=""
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</InputLayouts.FieldPadder>
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="custom_config.vertex_credentials"
|
||||
title="API Key"
|
||||
subDescription="Attach your API key JSON from Google Cloud to access your models."
|
||||
>
|
||||
<FileUploadFormField
|
||||
name="custom_config.vertex_credentials"
|
||||
label=""
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</FieldWrapper>
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
</>
|
||||
<FieldSeparator />
|
||||
|
||||
{!isOnboarding && (
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
)}
|
||||
|
||||
<FieldSeparator />
|
||||
|
||||
{isOnboarding ? (
|
||||
<SingleDefaultModelField placeholder="E.g. gemini-2.5-pro" />
|
||||
) : (
|
||||
<ModelsField
|
||||
modelConfigurations={modelConfigurations}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={
|
||||
wellKnownLLMProvider?.recommended_default_model ?? null
|
||||
}
|
||||
shouldShowAutoUpdateToggle={true}
|
||||
/>
|
||||
)}
|
||||
|
||||
{!isOnboarding && <ModelsAccessField formikProps={formikProps} />}
|
||||
</LLMConfigurationModalWrapper>
|
||||
)}
|
||||
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelSelectionField shouldShowAutoUpdateToggle={true} />
|
||||
|
||||
{!isOnboarding && (
|
||||
<>
|
||||
<InputLayouts.FieldSeparator />
|
||||
<ModelAccessField />
|
||||
</>
|
||||
)}
|
||||
</ModalWrapper>
|
||||
</Formik>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -7,68 +7,58 @@ import VertexAIModal from "@/sections/modals/llmConfig/VertexAIModal";
|
||||
import OpenRouterModal from "@/sections/modals/llmConfig/OpenRouterModal";
|
||||
import CustomModal from "@/sections/modals/llmConfig/CustomModal";
|
||||
import BedrockModal from "@/sections/modals/llmConfig/BedrockModal";
|
||||
import LMStudioModal from "@/sections/modals/llmConfig/LMStudioModal";
|
||||
import LMStudioForm from "@/sections/modals/llmConfig/LMStudioForm";
|
||||
import LiteLLMProxyModal from "@/sections/modals/llmConfig/LiteLLMProxyModal";
|
||||
import BifrostModal from "@/sections/modals/llmConfig/BifrostModal";
|
||||
import OpenAICompatibleModal from "@/sections/modals/llmConfig/OpenAICompatibleModal";
|
||||
|
||||
function detectIfRealOpenAIProvider(provider: LLMProviderView) {
|
||||
return (
|
||||
provider.provider === LLMProviderName.OPENAI &&
|
||||
provider.api_key &&
|
||||
!provider.api_base &&
|
||||
Object.keys(provider.custom_config || {}).length === 0
|
||||
);
|
||||
}
|
||||
|
||||
export function getModalForExistingProvider(
|
||||
provider: LLMProviderView,
|
||||
open?: boolean,
|
||||
onOpenChange?: (open: boolean) => void,
|
||||
defaultModelName?: string
|
||||
) {
|
||||
const props = {
|
||||
existingLlmProvider: provider,
|
||||
open,
|
||||
onOpenChange,
|
||||
defaultModelName,
|
||||
};
|
||||
|
||||
const hasCustomConfig = provider.custom_config != null;
|
||||
|
||||
switch (provider.provider) {
|
||||
// These providers don't use custom_config themselves, so a non-null
|
||||
// custom_config means the provider was created via CustomModal.
|
||||
case LLMProviderName.OPENAI:
|
||||
return hasCustomConfig ? (
|
||||
<CustomModal {...props} />
|
||||
) : (
|
||||
<OpenAIModal {...props} />
|
||||
);
|
||||
// "openai" as a provider name can be used for litellm proxy / any OpenAI-compatible provider
|
||||
if (detectIfRealOpenAIProvider(provider)) {
|
||||
return <OpenAIModal {...props} />;
|
||||
} else {
|
||||
return <CustomModal {...props} />;
|
||||
}
|
||||
case LLMProviderName.ANTHROPIC:
|
||||
return hasCustomConfig ? (
|
||||
<CustomModal {...props} />
|
||||
) : (
|
||||
<AnthropicModal {...props} />
|
||||
);
|
||||
case LLMProviderName.AZURE:
|
||||
return hasCustomConfig ? (
|
||||
<CustomModal {...props} />
|
||||
) : (
|
||||
<AzureModal {...props} />
|
||||
);
|
||||
case LLMProviderName.OPENROUTER:
|
||||
return hasCustomConfig ? (
|
||||
<CustomModal {...props} />
|
||||
) : (
|
||||
<OpenRouterModal {...props} />
|
||||
);
|
||||
|
||||
// These providers legitimately store settings in custom_config,
|
||||
// so always use their dedicated modals.
|
||||
return <AnthropicModal {...props} />;
|
||||
case LLMProviderName.OLLAMA_CHAT:
|
||||
return <OllamaModal {...props} />;
|
||||
case LLMProviderName.AZURE:
|
||||
return <AzureModal {...props} />;
|
||||
case LLMProviderName.VERTEX_AI:
|
||||
return <VertexAIModal {...props} />;
|
||||
case LLMProviderName.BEDROCK:
|
||||
return <BedrockModal {...props} />;
|
||||
case LLMProviderName.OPENROUTER:
|
||||
return <OpenRouterModal {...props} />;
|
||||
case LLMProviderName.LM_STUDIO:
|
||||
return <LMStudioModal {...props} />;
|
||||
return <LMStudioForm {...props} />;
|
||||
case LLMProviderName.LITELLM_PROXY:
|
||||
return <LiteLLMProxyModal {...props} />;
|
||||
case LLMProviderName.BIFROST:
|
||||
return <BifrostModal {...props} />;
|
||||
case LLMProviderName.OPENAI_COMPATIBLE:
|
||||
return <OpenAICompatibleModal {...props} />;
|
||||
default:
|
||||
return <CustomModal {...props} />;
|
||||
}
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
"use client";
|
||||
|
||||
import React, { useEffect, useRef, useState } from "react";
|
||||
import { Formik, Form, useFormikContext } from "formik";
|
||||
import type { FormikConfig } from "formik";
|
||||
import { ReactNode, useState } from "react";
|
||||
import { Form, FormikProps } from "formik";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import { useAgents } from "@/hooks/useAgents";
|
||||
import { useUserGroups } from "@/lib/hooks";
|
||||
import { LLMProviderView, ModelConfiguration } from "@/interfaces/llm";
|
||||
import { ModelConfiguration, SimpleKnownModel } from "@/interfaces/llm";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import Checkbox from "@/refresh-components/inputs/Checkbox";
|
||||
import InputTypeInField from "@/refresh-components/form/InputTypeInField";
|
||||
@@ -16,10 +15,12 @@ import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
|
||||
import Switch from "@/refresh-components/inputs/Switch";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Button, LineItemButton } from "@opal/components";
|
||||
import { Button, LineItemButton, Tag } from "@opal/components";
|
||||
import { BaseLLMFormValues } from "@/sections/modals/llmConfig/utils";
|
||||
import type { RichStr } from "@opal/types";
|
||||
import { WithoutStyles } from "@opal/types";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { Hoverable } from "@opal/core";
|
||||
import { Content } from "@opal/layouts";
|
||||
import {
|
||||
SvgArrowExchange,
|
||||
@@ -47,14 +48,27 @@ import {
|
||||
getProviderProductName,
|
||||
} from "@/lib/llmConfig/providers";
|
||||
|
||||
export function FieldSeparator() {
|
||||
return <Separator noPadding className="px-2" />;
|
||||
}
|
||||
|
||||
export type FieldWrapperProps = WithoutStyles<
|
||||
React.HTMLAttributes<HTMLDivElement>
|
||||
>;
|
||||
|
||||
export function FieldWrapper(props: FieldWrapperProps) {
|
||||
return <div {...props} className="p-2 w-full" />;
|
||||
}
|
||||
|
||||
// ─── DisplayNameField ────────────────────────────────────────────────────────
|
||||
|
||||
export interface DisplayNameFieldProps {
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
export function DisplayNameField({ disabled = false }: DisplayNameFieldProps) {
|
||||
return (
|
||||
<InputLayouts.FieldPadder>
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="name"
|
||||
title="Display Name"
|
||||
@@ -66,7 +80,7 @@ export function DisplayNameField({ disabled = false }: DisplayNameFieldProps) {
|
||||
variant={disabled ? "disabled" : undefined}
|
||||
/>
|
||||
</InputLayouts.Vertical>
|
||||
</InputLayouts.FieldPadder>
|
||||
</FieldWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -75,56 +89,47 @@ export function DisplayNameField({ disabled = false }: DisplayNameFieldProps) {
|
||||
export interface APIKeyFieldProps {
|
||||
optional?: boolean;
|
||||
providerName?: string;
|
||||
subDescription?: string | RichStr;
|
||||
}
|
||||
|
||||
export function APIKeyField({
|
||||
optional = false,
|
||||
providerName,
|
||||
subDescription,
|
||||
}: APIKeyFieldProps) {
|
||||
return (
|
||||
<InputLayouts.FieldPadder>
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Vertical
|
||||
name="api_key"
|
||||
title="API Key"
|
||||
subDescription={
|
||||
subDescription
|
||||
? subDescription
|
||||
: providerName
|
||||
? `Paste your API key from ${providerName} to access your models.`
|
||||
: "Paste your API key to access your models."
|
||||
providerName
|
||||
? `Paste your API key from ${providerName} to access your models.`
|
||||
: "Paste your API key to access your models."
|
||||
}
|
||||
suffix={optional ? "optional" : undefined}
|
||||
>
|
||||
<PasswordInputTypeInField name="api_key" />
|
||||
<PasswordInputTypeInField name="api_key" placeholder="API Key" />
|
||||
</InputLayouts.Vertical>
|
||||
</InputLayouts.FieldPadder>
|
||||
</FieldWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
// ─── APIBaseField ───────────────────────────────────────────────────────────
|
||||
// ─── SingleDefaultModelField ─────────────────────────────────────────────────
|
||||
|
||||
export interface APIBaseFieldProps {
|
||||
optional?: boolean;
|
||||
subDescription?: string | RichStr;
|
||||
export interface SingleDefaultModelFieldProps {
|
||||
placeholder?: string;
|
||||
}
|
||||
export function APIBaseField({
|
||||
optional = false,
|
||||
subDescription,
|
||||
placeholder = "https://",
|
||||
}: APIBaseFieldProps) {
|
||||
|
||||
export function SingleDefaultModelField({
|
||||
placeholder = "E.g. gpt-4o",
|
||||
}: SingleDefaultModelFieldProps) {
|
||||
return (
|
||||
<InputLayouts.FieldPadder>
|
||||
<InputLayouts.Vertical
|
||||
name="api_base"
|
||||
title="API Base URL"
|
||||
subDescription={subDescription}
|
||||
suffix={optional ? "optional" : undefined}
|
||||
>
|
||||
<InputTypeInField name="api_base" placeholder={placeholder} />
|
||||
</InputLayouts.Vertical>
|
||||
</InputLayouts.FieldPadder>
|
||||
<InputLayouts.Vertical
|
||||
name="default_model_name"
|
||||
title="Default Model"
|
||||
description="The model to use by default for this provider unless otherwise specified."
|
||||
>
|
||||
<InputTypeInField name="default_model_name" placeholder={placeholder} />
|
||||
</InputLayouts.Vertical>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -134,8 +139,13 @@ export function APIBaseField({
|
||||
const GROUP_PREFIX = "group:";
|
||||
const AGENT_PREFIX = "agent:";
|
||||
|
||||
export function ModelAccessField() {
|
||||
const formikProps = useFormikContext<BaseLLMFormValues>();
|
||||
interface ModelsAccessFieldProps<T> {
|
||||
formikProps: FormikProps<T>;
|
||||
}
|
||||
|
||||
export function ModelsAccessField<T extends BaseLLMFormValues>({
|
||||
formikProps,
|
||||
}: ModelsAccessFieldProps<T>) {
|
||||
const { agents } = useAgents();
|
||||
const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups();
|
||||
const { data: usersData } = useUsers({ includeApiKeys: false });
|
||||
@@ -218,7 +228,7 @@ export function ModelAccessField() {
|
||||
|
||||
return (
|
||||
<div className="flex flex-col w-full">
|
||||
<InputLayouts.FieldPadder>
|
||||
<FieldWrapper>
|
||||
<InputLayouts.Horizontal
|
||||
name="is_public"
|
||||
title="Models Access"
|
||||
@@ -239,7 +249,7 @@ export function ModelAccessField() {
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</InputLayouts.Horizontal>
|
||||
</InputLayouts.FieldPadder>
|
||||
</FieldWrapper>
|
||||
|
||||
{!isPublic && (
|
||||
<Card background="light" border="none" padding="sm">
|
||||
@@ -305,7 +315,7 @@ export function ModelAccessField() {
|
||||
</div>
|
||||
)}
|
||||
|
||||
<InputLayouts.FieldSeparator />
|
||||
<FieldSeparator />
|
||||
|
||||
{selectedAgentIds.length > 0 ? (
|
||||
<div className="grid grid-cols-2 gap-1 w-full">
|
||||
@@ -358,115 +368,90 @@ export function ModelAccessField() {
|
||||
);
|
||||
}
|
||||
|
||||
// ─── RefetchButton ──────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Manages an AbortController so that clicking the button cancels any
|
||||
* in-flight fetch before starting a new one. Also aborts on unmount.
|
||||
*/
|
||||
interface RefetchButtonProps {
|
||||
onRefetch: (signal: AbortSignal) => Promise<void> | void;
|
||||
}
|
||||
function RefetchButton({ onRefetch }: RefetchButtonProps) {
|
||||
const abortRef = useRef<AbortController | null>(null);
|
||||
const [isFetching, setIsFetching] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
return () => abortRef.current?.abort();
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
icon={isFetching ? SimpleLoader : SvgRefreshCw}
|
||||
onClick={async () => {
|
||||
abortRef.current?.abort();
|
||||
const controller = new AbortController();
|
||||
abortRef.current = controller;
|
||||
setIsFetching(true);
|
||||
try {
|
||||
await onRefetch(controller.signal);
|
||||
} catch (err) {
|
||||
if (err instanceof DOMException && err.name === "AbortError") return;
|
||||
toast.error(
|
||||
err instanceof Error ? err.message : "Failed to fetch models"
|
||||
);
|
||||
} finally {
|
||||
if (!controller.signal.aborted) {
|
||||
setIsFetching(false);
|
||||
}
|
||||
}
|
||||
}}
|
||||
disabled={isFetching}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// ─── ModelsField ─────────────────────────────────────────────────────
|
||||
|
||||
export interface ModelSelectionFieldProps {
|
||||
export interface ModelsFieldProps<T> {
|
||||
formikProps: FormikProps<T>;
|
||||
modelConfigurations: ModelConfiguration[];
|
||||
recommendedDefaultModel: SimpleKnownModel | null;
|
||||
shouldShowAutoUpdateToggle: boolean;
|
||||
onRefetch?: (signal: AbortSignal) => Promise<void> | void;
|
||||
/** Called when the user clicks the refresh button to re-fetch models. */
|
||||
onRefetch?: () => Promise<void> | void;
|
||||
/** Called when the user adds a custom model by name. Enables the "Add Model" input. */
|
||||
onAddModel?: (modelName: string) => void;
|
||||
}
|
||||
export function ModelSelectionField({
|
||||
|
||||
export function ModelsField<T extends BaseLLMFormValues>({
|
||||
formikProps,
|
||||
modelConfigurations,
|
||||
recommendedDefaultModel,
|
||||
shouldShowAutoUpdateToggle,
|
||||
onRefetch,
|
||||
onAddModel,
|
||||
}: ModelSelectionFieldProps) {
|
||||
const formikProps = useFormikContext<BaseLLMFormValues>();
|
||||
}: ModelsFieldProps<T>) {
|
||||
const [newModelName, setNewModelName] = useState("");
|
||||
const isAutoMode = formikProps.values.is_auto_mode;
|
||||
const models = formikProps.values.model_configurations;
|
||||
const selectedModels = formikProps.values.selected_model_names ?? [];
|
||||
const defaultModel = formikProps.values.default_model_name;
|
||||
|
||||
// Snapshot the original model visibility so we can restore it when
|
||||
// toggling auto mode back on.
|
||||
const originalModelsRef = useRef(models);
|
||||
useEffect(() => {
|
||||
if (originalModelsRef.current.length === 0 && models.length > 0) {
|
||||
originalModelsRef.current = models;
|
||||
function handleCheckboxChange(modelName: string, checked: boolean) {
|
||||
// Read current values inside the handler to avoid stale closure issues
|
||||
const currentSelected = formikProps.values.selected_model_names ?? [];
|
||||
const currentDefault = formikProps.values.default_model_name;
|
||||
|
||||
if (checked) {
|
||||
const newSelected = [...currentSelected, modelName];
|
||||
formikProps.setFieldValue("selected_model_names", newSelected);
|
||||
// If this is the first model, set it as default
|
||||
if (currentSelected.length === 0) {
|
||||
formikProps.setFieldValue("default_model_name", modelName);
|
||||
}
|
||||
} else {
|
||||
const newSelected = currentSelected.filter((name) => name !== modelName);
|
||||
formikProps.setFieldValue("selected_model_names", newSelected);
|
||||
// If removing the default, set the first remaining model as default
|
||||
if (currentDefault === modelName && newSelected.length > 0) {
|
||||
formikProps.setFieldValue("default_model_name", newSelected[0]);
|
||||
} else if (newSelected.length === 0) {
|
||||
formikProps.setFieldValue("default_model_name", undefined);
|
||||
}
|
||||
}
|
||||
}, [models]);
|
||||
}
|
||||
|
||||
// Automatically derive test_model_name from model_configurations.
|
||||
// Any change to visibility or the model list syncs this automatically.
|
||||
useEffect(() => {
|
||||
const firstVisible = models.find((m) => m.is_visible)?.name;
|
||||
if (firstVisible !== formikProps.values.test_model_name) {
|
||||
formikProps.setFieldValue("test_model_name", firstVisible);
|
||||
}
|
||||
}, [models]); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
function setVisibility(modelName: string, visible: boolean) {
|
||||
const updated = models.map((m) =>
|
||||
m.name === modelName ? { ...m, is_visible: visible } : m
|
||||
);
|
||||
formikProps.setFieldValue("model_configurations", updated);
|
||||
function handleSetDefault(modelName: string) {
|
||||
formikProps.setFieldValue("default_model_name", modelName);
|
||||
}
|
||||
|
||||
function handleToggleAutoMode(nextIsAutoMode: boolean) {
|
||||
formikProps.setFieldValue("is_auto_mode", nextIsAutoMode);
|
||||
if (nextIsAutoMode) {
|
||||
formikProps.setFieldValue(
|
||||
"model_configurations",
|
||||
originalModelsRef.current
|
||||
);
|
||||
formikProps.setFieldValue(
|
||||
"selected_model_names",
|
||||
modelConfigurations.filter((m) => m.is_visible).map((m) => m.name)
|
||||
);
|
||||
formikProps.setFieldValue(
|
||||
"default_model_name",
|
||||
recommendedDefaultModel?.name ?? undefined
|
||||
);
|
||||
}
|
||||
|
||||
const allSelected =
|
||||
modelConfigurations.length > 0 &&
|
||||
modelConfigurations.every((m) => selectedModels.includes(m.name));
|
||||
|
||||
function handleToggleSelectAll() {
|
||||
if (allSelected) {
|
||||
formikProps.setFieldValue("selected_model_names", []);
|
||||
formikProps.setFieldValue("default_model_name", undefined);
|
||||
} else {
|
||||
const allNames = modelConfigurations.map((m) => m.name);
|
||||
formikProps.setFieldValue("selected_model_names", allNames);
|
||||
if (!formikProps.values.default_model_name && allNames.length > 0) {
|
||||
formikProps.setFieldValue("default_model_name", allNames[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const allSelected = models.length > 0 && models.every((m) => m.is_visible);
|
||||
|
||||
function handleToggleSelectAll() {
|
||||
const nextVisible = !allSelected;
|
||||
const updated = models.map((m) => ({
|
||||
...m,
|
||||
is_visible: nextVisible,
|
||||
}));
|
||||
formikProps.setFieldValue("model_configurations", updated);
|
||||
}
|
||||
|
||||
const visibleModels = models.filter((m) => m.is_visible);
|
||||
const visibleModels = modelConfigurations.filter((m) => m.is_visible);
|
||||
|
||||
return (
|
||||
<Card background="light" border="none" padding="sm">
|
||||
@@ -479,45 +464,118 @@ export function ModelSelectionField({
|
||||
>
|
||||
<Section flexDirection="row" gap={0}>
|
||||
<Button
|
||||
disabled={isAutoMode || models.length === 0}
|
||||
disabled={isAutoMode || modelConfigurations.length === 0}
|
||||
prominence="tertiary"
|
||||
size="md"
|
||||
onClick={handleToggleSelectAll}
|
||||
>
|
||||
{allSelected ? "Unselect All" : "Select All"}
|
||||
</Button>
|
||||
{onRefetch && <RefetchButton onRefetch={onRefetch} />}
|
||||
{onRefetch && (
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
icon={SvgRefreshCw}
|
||||
onClick={async () => {
|
||||
try {
|
||||
await onRefetch();
|
||||
} catch (err) {
|
||||
toast.error(
|
||||
err instanceof Error
|
||||
? err.message
|
||||
: "Failed to fetch models"
|
||||
);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</Section>
|
||||
</InputLayouts.Horizontal>
|
||||
|
||||
{models.length === 0 ? (
|
||||
{modelConfigurations.length === 0 ? (
|
||||
<EmptyMessageCard title="No models available." padding="sm" />
|
||||
) : (
|
||||
<Section gap={0.25}>
|
||||
{isAutoMode
|
||||
? visibleModels.map((model) => (
|
||||
<LineItemButton
|
||||
? // Auto mode: read-only display
|
||||
visibleModels.map((model) => (
|
||||
<Hoverable.Root
|
||||
key={model.name}
|
||||
variant="section"
|
||||
sizePreset="main-ui"
|
||||
selectVariant="select-heavy"
|
||||
state="selected"
|
||||
icon={() => <Checkbox checked />}
|
||||
title={model.display_name || model.name}
|
||||
/>
|
||||
group="LLMConfigurationButton"
|
||||
widthVariant="full"
|
||||
>
|
||||
<LineItemButton
|
||||
variant="section"
|
||||
sizePreset="main-ui"
|
||||
selectVariant="select-heavy"
|
||||
state="selected"
|
||||
icon={() => <Checkbox checked />}
|
||||
title={model.display_name || model.name}
|
||||
rightChildren={
|
||||
model.name === defaultModel ? (
|
||||
<Section>
|
||||
<Tag title="Default Model" color="blue" />
|
||||
</Section>
|
||||
) : undefined
|
||||
}
|
||||
/>
|
||||
</Hoverable.Root>
|
||||
))
|
||||
: models.map((model) => (
|
||||
<LineItemButton
|
||||
key={model.name}
|
||||
variant="section"
|
||||
sizePreset="main-ui"
|
||||
selectVariant="select-heavy"
|
||||
state={model.is_visible ? "selected" : "empty"}
|
||||
icon={() => <Checkbox checked={model.is_visible} />}
|
||||
title={model.name}
|
||||
onClick={() => setVisibility(model.name, !model.is_visible)}
|
||||
/>
|
||||
))}
|
||||
: // Manual mode: checkbox selection
|
||||
modelConfigurations.map((modelConfiguration) => {
|
||||
const isSelected = selectedModels.includes(
|
||||
modelConfiguration.name
|
||||
);
|
||||
const isDefault = defaultModel === modelConfiguration.name;
|
||||
|
||||
return (
|
||||
<Hoverable.Root
|
||||
key={modelConfiguration.name}
|
||||
group="LLMConfigurationButton"
|
||||
widthVariant="full"
|
||||
>
|
||||
<LineItemButton
|
||||
variant="section"
|
||||
sizePreset="main-ui"
|
||||
selectVariant="select-heavy"
|
||||
state={isSelected ? "selected" : "empty"}
|
||||
icon={() => <Checkbox checked={isSelected} />}
|
||||
title={modelConfiguration.name}
|
||||
onClick={() =>
|
||||
handleCheckboxChange(
|
||||
modelConfiguration.name,
|
||||
!isSelected
|
||||
)
|
||||
}
|
||||
rightChildren={
|
||||
isSelected ? (
|
||||
isDefault ? (
|
||||
<Section>
|
||||
<Tag color="blue" title="Default Model" />
|
||||
</Section>
|
||||
) : (
|
||||
<Hoverable.Item
|
||||
group="LLMConfigurationButton"
|
||||
variant="opacity-on-hover"
|
||||
>
|
||||
<Button
|
||||
size="sm"
|
||||
prominence="internal"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
handleSetDefault(modelConfiguration.name);
|
||||
}}
|
||||
type="button"
|
||||
>
|
||||
Set as default
|
||||
</Button>
|
||||
</Hoverable.Item>
|
||||
)
|
||||
) : undefined
|
||||
}
|
||||
/>
|
||||
</Hoverable.Root>
|
||||
);
|
||||
})}
|
||||
</Section>
|
||||
)}
|
||||
|
||||
@@ -532,7 +590,7 @@ export function ModelSelectionField({
|
||||
if (e.key === "Enter" && newModelName.trim()) {
|
||||
e.preventDefault();
|
||||
const trimmed = newModelName.trim();
|
||||
if (!models.some((m) => m.name === trimmed)) {
|
||||
if (!modelConfigurations.some((m) => m.name === trimmed)) {
|
||||
onAddModel(trimmed);
|
||||
setNewModelName("");
|
||||
}
|
||||
@@ -547,11 +605,14 @@ export function ModelSelectionField({
|
||||
type="button"
|
||||
disabled={
|
||||
!newModelName.trim() ||
|
||||
models.some((m) => m.name === newModelName.trim())
|
||||
modelConfigurations.some((m) => m.name === newModelName.trim())
|
||||
}
|
||||
onClick={() => {
|
||||
const trimmed = newModelName.trim();
|
||||
if (trimmed && !models.some((m) => m.name === trimmed)) {
|
||||
if (
|
||||
trimmed &&
|
||||
!modelConfigurations.some((m) => m.name === trimmed)
|
||||
) {
|
||||
onAddModel(trimmed);
|
||||
setNewModelName("");
|
||||
}
|
||||
@@ -578,87 +639,41 @@ export function ModelSelectionField({
|
||||
);
|
||||
}
|
||||
|
||||
// ─── ModalWrapper ─────────────────────────────────────────────────────
|
||||
// ============================================================================
|
||||
// LLMConfigurationModalWrapper
|
||||
// ============================================================================
|
||||
|
||||
export interface ModalWrapperProps<
|
||||
T extends BaseLLMFormValues = BaseLLMFormValues,
|
||||
> {
|
||||
providerName: string;
|
||||
llmProvider?: LLMProviderView;
|
||||
interface LLMConfigurationModalWrapperProps {
|
||||
providerEndpoint: string;
|
||||
providerName?: string;
|
||||
existingProviderName?: string;
|
||||
onClose: () => void;
|
||||
initialValues: T;
|
||||
validationSchema: FormikConfig<T>["validationSchema"];
|
||||
onSubmit: FormikConfig<T>["onSubmit"];
|
||||
children: React.ReactNode;
|
||||
isFormValid: boolean;
|
||||
isDirty?: boolean;
|
||||
isTesting?: boolean;
|
||||
isSubmitting?: boolean;
|
||||
children: ReactNode;
|
||||
}
|
||||
export function ModalWrapper<T extends BaseLLMFormValues = BaseLLMFormValues>({
|
||||
|
||||
export function LLMConfigurationModalWrapper({
|
||||
providerEndpoint,
|
||||
providerName,
|
||||
llmProvider,
|
||||
existingProviderName,
|
||||
onClose,
|
||||
initialValues,
|
||||
validationSchema,
|
||||
onSubmit,
|
||||
isFormValid,
|
||||
isDirty,
|
||||
isTesting,
|
||||
isSubmitting,
|
||||
children,
|
||||
}: ModalWrapperProps<T>) {
|
||||
return (
|
||||
<Formik
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
validateOnMount
|
||||
onSubmit={onSubmit}
|
||||
>
|
||||
{() => (
|
||||
<ModalWrapperInner
|
||||
providerName={providerName}
|
||||
llmProvider={llmProvider}
|
||||
onClose={onClose}
|
||||
modelConfigurations={initialValues.model_configurations}
|
||||
>
|
||||
{children}
|
||||
</ModalWrapperInner>
|
||||
)}
|
||||
</Formik>
|
||||
);
|
||||
}
|
||||
|
||||
interface ModalWrapperInnerProps {
|
||||
providerName: string;
|
||||
llmProvider?: LLMProviderView;
|
||||
onClose: () => void;
|
||||
modelConfigurations?: ModelConfiguration[];
|
||||
children: React.ReactNode;
|
||||
}
|
||||
function ModalWrapperInner({
|
||||
providerName,
|
||||
llmProvider,
|
||||
onClose,
|
||||
modelConfigurations,
|
||||
children,
|
||||
}: ModalWrapperInnerProps) {
|
||||
const { isValid, dirty, isSubmitting, status, setFieldValue, values } =
|
||||
useFormikContext<BaseLLMFormValues>();
|
||||
|
||||
// When SWR resolves after mount, populate model_configurations if still
|
||||
// empty. test_model_name is then derived automatically by
|
||||
// ModelSelectionField's useEffect.
|
||||
useEffect(() => {
|
||||
if (
|
||||
modelConfigurations &&
|
||||
modelConfigurations.length > 0 &&
|
||||
values.model_configurations.length === 0
|
||||
) {
|
||||
setFieldValue("model_configurations", modelConfigurations);
|
||||
}
|
||||
}, [modelConfigurations]); // eslint-disable-line react-hooks/exhaustive-deps
|
||||
|
||||
const isTesting = status?.isTesting === true;
|
||||
}: LLMConfigurationModalWrapperProps) {
|
||||
const busy = isTesting || isSubmitting;
|
||||
const providerIcon = getProviderIcon(providerName);
|
||||
const providerDisplayName = getProviderDisplayName(providerName);
|
||||
const providerProductName = getProviderProductName(providerName);
|
||||
const providerIcon = getProviderIcon(providerEndpoint);
|
||||
const providerDisplayName =
|
||||
providerName ?? getProviderDisplayName(providerEndpoint);
|
||||
const providerProductName = getProviderProductName(providerEndpoint);
|
||||
|
||||
const title = llmProvider
|
||||
? `Configure "${llmProvider.name}"`
|
||||
const title = existingProviderName
|
||||
? `Configure "${existingProviderName}"`
|
||||
: `Set up ${providerProductName}`;
|
||||
const description = `Connect to ${providerDisplayName} and set up your ${providerProductName} models.`;
|
||||
|
||||
@@ -674,7 +689,7 @@ function ModalWrapperInner({
|
||||
description={description}
|
||||
onClose={onClose}
|
||||
/>
|
||||
<Modal.Body padding={0.5} gap={0}>
|
||||
<Modal.Body padding={0.5} gap={0.5}>
|
||||
{children}
|
||||
</Modal.Body>
|
||||
<Modal.Footer>
|
||||
@@ -682,11 +697,13 @@ function ModalWrapperInner({
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
disabled={!isValid || !dirty || busy}
|
||||
disabled={
|
||||
!isFormValid || busy || (!!existingProviderName && !isDirty)
|
||||
}
|
||||
type="submit"
|
||||
icon={busy ? SimpleLoader : undefined}
|
||||
>
|
||||
{llmProvider?.name
|
||||
{existingProviderName
|
||||
? busy
|
||||
? "Updating"
|
||||
: "Update"
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
import { LLMProviderName, LLMProviderView } from "@/interfaces/llm";
|
||||
import {
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
} from "@/interfaces/llm";
|
||||
import {
|
||||
LLM_ADMIN_URL,
|
||||
LLM_PROVIDERS_ADMIN_URL,
|
||||
} from "@/lib/llmConfig/constants";
|
||||
import { refreshLlmProviderCaches } from "@/lib/llmConfig/cache";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import isEqual from "lodash/isEqual";
|
||||
import { parseAzureTargetUri } from "@/lib/azureTargetUri";
|
||||
@@ -13,11 +18,13 @@ import {
|
||||
} from "@/lib/analytics";
|
||||
import {
|
||||
BaseLLMFormValues,
|
||||
SubmitLLMProviderParams,
|
||||
SubmitOnboardingProviderParams,
|
||||
TestApiKeyResult,
|
||||
filterModelConfigurations,
|
||||
getAutoModeModelConfigurations,
|
||||
} from "@/sections/modals/llmConfig/utils";
|
||||
|
||||
// ─── Test helpers ─────────────────────────────────────────────────────────
|
||||
|
||||
const submitLlmTestRequest = async (
|
||||
payload: Record<string, unknown>,
|
||||
fallbackErrorMessage: string
|
||||
@@ -43,6 +50,161 @@ const submitLlmTestRequest = async (
|
||||
}
|
||||
};
|
||||
|
||||
export const submitLLMProvider = async <T extends BaseLLMFormValues>({
|
||||
providerName,
|
||||
values,
|
||||
initialValues,
|
||||
modelConfigurations,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
hideSuccess,
|
||||
setIsTesting,
|
||||
mutate,
|
||||
onClose,
|
||||
setSubmitting,
|
||||
}: SubmitLLMProviderParams<T>): Promise<void> => {
|
||||
setSubmitting(true);
|
||||
|
||||
const { selected_model_names: visibleModels, api_key, ...rest } = values;
|
||||
|
||||
// In auto mode, use recommended models from descriptor
|
||||
// In manual mode, use user's selection
|
||||
let filteredModelConfigurations: ModelConfiguration[];
|
||||
let finalDefaultModelName = rest.default_model_name;
|
||||
|
||||
if (values.is_auto_mode) {
|
||||
filteredModelConfigurations =
|
||||
getAutoModeModelConfigurations(modelConfigurations);
|
||||
|
||||
// In auto mode, use the first recommended model as default if current default isn't in the list
|
||||
const visibleModelNames = new Set(
|
||||
filteredModelConfigurations.map((m) => m.name)
|
||||
);
|
||||
if (
|
||||
finalDefaultModelName &&
|
||||
!visibleModelNames.has(finalDefaultModelName)
|
||||
) {
|
||||
finalDefaultModelName = filteredModelConfigurations[0]?.name ?? "";
|
||||
}
|
||||
} else {
|
||||
filteredModelConfigurations = filterModelConfigurations(
|
||||
modelConfigurations,
|
||||
visibleModels,
|
||||
rest.default_model_name as string | undefined
|
||||
);
|
||||
}
|
||||
|
||||
const customConfigChanged = !isEqual(
|
||||
values.custom_config,
|
||||
initialValues.custom_config
|
||||
);
|
||||
|
||||
const normalizedApiBase =
|
||||
typeof rest.api_base === "string" && rest.api_base.trim() === ""
|
||||
? undefined
|
||||
: rest.api_base;
|
||||
|
||||
const finalValues = {
|
||||
...rest,
|
||||
api_base: normalizedApiBase,
|
||||
default_model_name: finalDefaultModelName,
|
||||
api_key,
|
||||
api_key_changed: api_key !== (initialValues.api_key as string | undefined),
|
||||
custom_config_changed: customConfigChanged,
|
||||
model_configurations: filteredModelConfigurations,
|
||||
};
|
||||
|
||||
// Test the configuration
|
||||
if (!isEqual(finalValues, initialValues)) {
|
||||
setIsTesting(true);
|
||||
|
||||
const response = await fetch("/api/admin/llm/test", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
provider: providerName,
|
||||
...finalValues,
|
||||
model: finalDefaultModelName,
|
||||
id: existingLlmProvider?.id,
|
||||
}),
|
||||
});
|
||||
setIsTesting(false);
|
||||
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
toast.error(errorMsg);
|
||||
setSubmitting(false);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const response = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}${
|
||||
existingLlmProvider ? "" : "?is_creation=true"
|
||||
}`,
|
||||
{
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
provider: providerName,
|
||||
...finalValues,
|
||||
id: existingLlmProvider?.id,
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
const fullErrorMsg = existingLlmProvider
|
||||
? `Failed to update provider: ${errorMsg}`
|
||||
: `Failed to enable provider: ${errorMsg}`;
|
||||
toast.error(fullErrorMsg);
|
||||
return;
|
||||
}
|
||||
|
||||
if (shouldMarkAsDefault) {
|
||||
const newLlmProvider = (await response.json()) as LLMProviderView;
|
||||
const setDefaultResponse = await fetch(`${LLM_ADMIN_URL}/default`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
provider_id: newLlmProvider.id,
|
||||
model_name: finalDefaultModelName,
|
||||
}),
|
||||
});
|
||||
if (!setDefaultResponse.ok) {
|
||||
const errorMsg = (await setDefaultResponse.json()).detail;
|
||||
toast.error(`Failed to set provider as default: ${errorMsg}`);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
await refreshLlmProviderCaches(mutate);
|
||||
onClose();
|
||||
|
||||
if (!hideSuccess) {
|
||||
const successMsg = existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!";
|
||||
toast.success(successMsg);
|
||||
}
|
||||
|
||||
const knownProviders = new Set<string>(Object.values(LLMProviderName));
|
||||
track(AnalyticsEvent.CONFIGURED_LLM_PROVIDER, {
|
||||
provider: knownProviders.has(providerName) ? providerName : "custom",
|
||||
is_creation: !existingLlmProvider,
|
||||
source: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
});
|
||||
|
||||
setSubmitting(false);
|
||||
};
|
||||
|
||||
export const testApiKeyHelper = async (
|
||||
providerName: string,
|
||||
formValues: Record<string, unknown>,
|
||||
@@ -79,7 +241,7 @@ export const testApiKeyHelper = async (
|
||||
...((formValues?.custom_config as Record<string, unknown>) ?? {}),
|
||||
...(customConfigOverride ?? {}),
|
||||
},
|
||||
model: modelName ?? (formValues?.test_model_name as string) ?? "",
|
||||
model: modelName ?? (formValues?.default_model_name as string) ?? "",
|
||||
};
|
||||
|
||||
return await submitLlmTestRequest(
|
||||
@@ -97,148 +259,96 @@ export const testCustomProvider = async (
|
||||
);
|
||||
};
|
||||
|
||||
// ─── Submit provider ──────────────────────────────────────────────────────
|
||||
|
||||
export interface SubmitProviderParams<
|
||||
T extends BaseLLMFormValues = BaseLLMFormValues,
|
||||
> {
|
||||
providerName: string;
|
||||
values: T;
|
||||
initialValues: T;
|
||||
existingLlmProvider?: LLMProviderView;
|
||||
shouldMarkAsDefault?: boolean;
|
||||
isCustomProvider?: boolean;
|
||||
setStatus: (status: Record<string, unknown>) => void;
|
||||
setSubmitting: (submitting: boolean) => void;
|
||||
onClose: () => void;
|
||||
/** Called after successful create/update + set-default. Use for cache refresh, state updates, toasts, etc. */
|
||||
onSuccess?: () => void | Promise<void>;
|
||||
/** Analytics source for tracking. @default LLMProviderConfiguredSource.ADMIN_PAGE */
|
||||
analyticsSource?: LLMProviderConfiguredSource;
|
||||
}
|
||||
|
||||
export async function submitProvider<T extends BaseLLMFormValues>({
|
||||
export const submitOnboardingProvider = async ({
|
||||
providerName,
|
||||
values,
|
||||
initialValues,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
payload,
|
||||
onboardingState,
|
||||
onboardingActions,
|
||||
isCustomProvider,
|
||||
setStatus,
|
||||
setSubmitting,
|
||||
onClose,
|
||||
onSuccess,
|
||||
analyticsSource = LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
}: SubmitProviderParams<T>): Promise<void> {
|
||||
setSubmitting(true);
|
||||
setIsSubmitting,
|
||||
}: SubmitOnboardingProviderParams): Promise<void> => {
|
||||
setIsSubmitting(true);
|
||||
|
||||
const { test_model_name, api_key, ...rest } = values;
|
||||
const testModelName =
|
||||
test_model_name ||
|
||||
values.model_configurations.find((m) => m.is_visible)?.name ||
|
||||
"";
|
||||
|
||||
// ── Test credentials ────────────────────────────────────────────────
|
||||
const customConfigChanged = !isEqual(
|
||||
values.custom_config,
|
||||
initialValues.custom_config
|
||||
);
|
||||
|
||||
const normalizedApiBase =
|
||||
typeof rest.api_base === "string" && rest.api_base.trim() === ""
|
||||
? undefined
|
||||
: rest.api_base;
|
||||
|
||||
const finalValues = {
|
||||
...rest,
|
||||
api_base: normalizedApiBase,
|
||||
api_key,
|
||||
api_key_changed: api_key !== (initialValues.api_key as string | undefined),
|
||||
custom_config_changed: customConfigChanged,
|
||||
};
|
||||
|
||||
if (!isEqual(finalValues, initialValues)) {
|
||||
setStatus({ isTesting: true });
|
||||
|
||||
const testResult = await submitLlmTestRequest(
|
||||
{
|
||||
provider: providerName,
|
||||
...finalValues,
|
||||
model: testModelName,
|
||||
id: existingLlmProvider?.id,
|
||||
},
|
||||
"An error occurred while testing the provider."
|
||||
);
|
||||
setStatus({ isTesting: false });
|
||||
|
||||
if (!testResult.ok) {
|
||||
toast.error(testResult.errorMessage);
|
||||
setSubmitting(false);
|
||||
return;
|
||||
}
|
||||
// Test credentials
|
||||
let result: TestApiKeyResult;
|
||||
if (isCustomProvider) {
|
||||
result = await testCustomProvider(payload);
|
||||
} else {
|
||||
result = await testApiKeyHelper(providerName, payload);
|
||||
}
|
||||
|
||||
// ── Create/update provider ──────────────────────────────────────────
|
||||
const response = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}${
|
||||
existingLlmProvider ? "" : "?is_creation=true"
|
||||
}`,
|
||||
{
|
||||
method: "PUT",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
provider: providerName,
|
||||
...finalValues,
|
||||
id: existingLlmProvider?.id,
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
const fullErrorMsg = existingLlmProvider
|
||||
? `Failed to update provider: ${errorMsg}`
|
||||
: `Failed to enable provider: ${errorMsg}`;
|
||||
toast.error(fullErrorMsg);
|
||||
setSubmitting(false);
|
||||
if (!result.ok) {
|
||||
toast.error(result.errorMessage);
|
||||
setIsSubmitting(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// ── Set as default ──────────────────────────────────────────────────
|
||||
if (shouldMarkAsDefault && testModelName) {
|
||||
// Create provider
|
||||
const response = await fetch(`${LLM_PROVIDERS_ADMIN_URL}?is_creation=true`, {
|
||||
method: "PUT",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
toast.error(errorMsg);
|
||||
setIsSubmitting(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Set as default if first provider
|
||||
if (
|
||||
onboardingState?.data?.llmProviders == null ||
|
||||
onboardingState.data.llmProviders.length === 0
|
||||
) {
|
||||
try {
|
||||
const newLlmProvider = await response.json();
|
||||
if (newLlmProvider?.id != null) {
|
||||
const setDefaultResponse = await fetch(`${LLM_ADMIN_URL}/default`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
provider_id: newLlmProvider.id,
|
||||
model_name: testModelName,
|
||||
}),
|
||||
});
|
||||
if (!setDefaultResponse.ok) {
|
||||
const err = await setDefaultResponse.json().catch(() => ({}));
|
||||
toast.error(err?.detail ?? "Failed to set provider as default");
|
||||
setSubmitting(false);
|
||||
return;
|
||||
const defaultModelName =
|
||||
(payload as Record<string, string>).default_model_name ??
|
||||
(payload as Record<string, ModelConfiguration[]>)
|
||||
.model_configurations?.[0]?.name ??
|
||||
"";
|
||||
|
||||
if (defaultModelName) {
|
||||
const setDefaultResponse = await fetch(`${LLM_ADMIN_URL}/default`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
provider_id: newLlmProvider.id,
|
||||
model_name: defaultModelName,
|
||||
}),
|
||||
});
|
||||
if (!setDefaultResponse.ok) {
|
||||
const err = await setDefaultResponse.json().catch(() => ({}));
|
||||
toast.error(err?.detail ?? "Failed to set provider as default");
|
||||
setIsSubmitting(false);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
} catch (_e) {
|
||||
toast.error("Failed to set new provider as default");
|
||||
}
|
||||
}
|
||||
|
||||
// ── Post-success ────────────────────────────────────────────────────
|
||||
const knownProviders = new Set<string>(Object.values(LLMProviderName));
|
||||
track(AnalyticsEvent.CONFIGURED_LLM_PROVIDER, {
|
||||
provider: knownProviders.has(providerName) ? providerName : "custom",
|
||||
is_creation: !existingLlmProvider,
|
||||
source: analyticsSource,
|
||||
provider: isCustomProvider ? "custom" : providerName,
|
||||
is_creation: true,
|
||||
source: LLMProviderConfiguredSource.CHAT_ONBOARDING,
|
||||
});
|
||||
|
||||
if (onSuccess) await onSuccess();
|
||||
// Update onboarding state
|
||||
onboardingActions.updateData({
|
||||
llmProviders: [
|
||||
...(onboardingState?.data.llmProviders ?? []),
|
||||
isCustomProvider ? "custom" : providerName,
|
||||
],
|
||||
});
|
||||
onboardingActions.setButtonActive(true);
|
||||
|
||||
setSubmitting(false);
|
||||
setIsSubmitting(false);
|
||||
onClose();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,130 +1,197 @@
|
||||
import {
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import { useWellKnownLLMProvider } from "@/hooks/useLLMProviders";
|
||||
import { ScopedMutator } from "swr";
|
||||
import { OnboardingActions, OnboardingState } from "@/interfaces/onboarding";
|
||||
|
||||
// ─── useInitialValues ─────────────────────────────────────────────────────
|
||||
// Common class names for the Form component across all LLM provider forms
|
||||
export const LLM_FORM_CLASS_NAME = "flex flex-col gap-y-4 items-stretch mt-6";
|
||||
|
||||
/** Builds the merged model list from existing + well-known, deduped by name. */
|
||||
function buildModelConfigurations(
|
||||
export const buildDefaultInitialValues = (
|
||||
existingLlmProvider?: LLMProviderView,
|
||||
wellKnownLLMProvider?: WellKnownLLMProviderDescriptor
|
||||
): ModelConfiguration[] {
|
||||
const existingModels = existingLlmProvider?.model_configurations ?? [];
|
||||
const wellKnownModels = wellKnownLLMProvider?.known_models ?? [];
|
||||
modelConfigurations?: ModelConfiguration[],
|
||||
currentDefaultModelName?: string
|
||||
) => {
|
||||
const defaultModelName =
|
||||
(currentDefaultModelName &&
|
||||
existingLlmProvider?.model_configurations?.some(
|
||||
(m) => m.name === currentDefaultModelName
|
||||
)
|
||||
? currentDefaultModelName
|
||||
: undefined) ??
|
||||
existingLlmProvider?.model_configurations?.[0]?.name ??
|
||||
modelConfigurations?.[0]?.name ??
|
||||
"";
|
||||
|
||||
const modelMap = new Map<string, ModelConfiguration>();
|
||||
wellKnownModels.forEach((m) => modelMap.set(m.name, m));
|
||||
existingModels.forEach((m) => modelMap.set(m.name, m));
|
||||
|
||||
return Array.from(modelMap.values());
|
||||
}
|
||||
|
||||
/** Shared initial values for all LLM provider forms (both onboarding and admin). */
|
||||
export function useInitialValues(
|
||||
isOnboarding: boolean,
|
||||
providerName: LLMProviderName,
|
||||
existingLlmProvider?: LLMProviderView
|
||||
) {
|
||||
const { wellKnownLLMProvider } = useWellKnownLLMProvider(providerName);
|
||||
|
||||
const modelConfigurations = buildModelConfigurations(
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider ?? undefined
|
||||
);
|
||||
|
||||
const testModelName =
|
||||
modelConfigurations.find((m) => m.is_visible)?.name ??
|
||||
wellKnownLLMProvider?.recommended_default_model?.name;
|
||||
// Auto mode must be explicitly enabled by the user
|
||||
// Default to false for new providers, preserve existing value when editing
|
||||
const isAutoMode = existingLlmProvider?.is_auto_mode ?? false;
|
||||
|
||||
return {
|
||||
provider: existingLlmProvider?.provider ?? providerName,
|
||||
name: isOnboarding ? providerName : existingLlmProvider?.name ?? "",
|
||||
api_key: existingLlmProvider?.api_key ?? undefined,
|
||||
api_base: existingLlmProvider?.api_base ?? undefined,
|
||||
name: existingLlmProvider?.name || "",
|
||||
default_model_name: defaultModelName,
|
||||
is_public: existingLlmProvider?.is_public ?? true,
|
||||
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,
|
||||
is_auto_mode: isAutoMode,
|
||||
groups: existingLlmProvider?.groups ?? [],
|
||||
personas: existingLlmProvider?.personas ?? [],
|
||||
model_configurations: modelConfigurations,
|
||||
test_model_name: testModelName,
|
||||
selected_model_names: existingLlmProvider
|
||||
? existingLlmProvider.model_configurations
|
||||
.filter((modelConfiguration) => modelConfiguration.is_visible)
|
||||
.map((modelConfiguration) => modelConfiguration.name)
|
||||
: modelConfigurations
|
||||
?.filter((modelConfiguration) => modelConfiguration.is_visible)
|
||||
.map((modelConfiguration) => modelConfiguration.name) ?? [],
|
||||
};
|
||||
}
|
||||
|
||||
// ─── buildValidationSchema ────────────────────────────────────────────────
|
||||
|
||||
interface ValidationSchemaOptions {
|
||||
apiKey?: boolean;
|
||||
apiBase?: boolean;
|
||||
extra?: Yup.ObjectShape;
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds the validation schema for a modal.
|
||||
*
|
||||
* @param isOnboarding — controls the base schema:
|
||||
* - `true`: minimal (only `test_model_name`).
|
||||
* - `false`: full admin schema (display name, access, models, etc.).
|
||||
* @param options.apiKey — require `api_key`.
|
||||
* @param options.apiBase — require `api_base`.
|
||||
* @param options.extra — arbitrary Yup fields for provider-specific validation.
|
||||
*/
|
||||
export function buildValidationSchema(
|
||||
isOnboarding: boolean,
|
||||
{ apiKey, apiBase, extra }: ValidationSchemaOptions = {}
|
||||
) {
|
||||
const providerFields: Yup.ObjectShape = {
|
||||
...(apiKey && {
|
||||
api_key: Yup.string().required("API Key is required"),
|
||||
}),
|
||||
...(apiBase && {
|
||||
api_base: Yup.string().required("API Base URL is required"),
|
||||
}),
|
||||
...extra,
|
||||
};
|
||||
|
||||
if (isOnboarding) {
|
||||
return Yup.object().shape({
|
||||
test_model_name: Yup.string().required("Model name is required"),
|
||||
...providerFields,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
export const buildDefaultValidationSchema = () => {
|
||||
return Yup.object({
|
||||
name: Yup.string().required("Display Name is required"),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
is_public: Yup.boolean().required(),
|
||||
is_auto_mode: Yup.boolean().required(),
|
||||
groups: Yup.array().of(Yup.number()),
|
||||
personas: Yup.array().of(Yup.number()),
|
||||
test_model_name: Yup.string().required("Model name is required"),
|
||||
...providerFields,
|
||||
selected_model_names: Yup.array().of(Yup.string()),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// ─── Form value types ─────────────────────────────────────────────────────
|
||||
export const buildAvailableModelConfigurations = (
|
||||
existingLlmProvider?: LLMProviderView,
|
||||
wellKnownLLMProvider?: WellKnownLLMProviderDescriptor
|
||||
): ModelConfiguration[] => {
|
||||
const existingModels = existingLlmProvider?.model_configurations ?? [];
|
||||
const wellKnownModels = wellKnownLLMProvider?.known_models ?? [];
|
||||
|
||||
/** Base form values that all provider forms share. */
|
||||
// Create a map to deduplicate by model name, preferring existing models
|
||||
const modelMap = new Map<string, ModelConfiguration>();
|
||||
|
||||
// Add well-known models first
|
||||
wellKnownModels.forEach((model) => {
|
||||
modelMap.set(model.name, model);
|
||||
});
|
||||
|
||||
// Override with existing models (they take precedence)
|
||||
existingModels.forEach((model) => {
|
||||
modelMap.set(model.name, model);
|
||||
});
|
||||
|
||||
return Array.from(modelMap.values());
|
||||
};
|
||||
|
||||
// Base form values that all provider forms share
|
||||
export interface BaseLLMFormValues {
|
||||
name: string;
|
||||
api_key?: string;
|
||||
api_base?: string;
|
||||
/** Model name used for the test request — automatically derived. */
|
||||
test_model_name?: string;
|
||||
default_model_name?: string;
|
||||
is_public: boolean;
|
||||
is_auto_mode: boolean;
|
||||
groups: number[];
|
||||
personas: number[];
|
||||
/** The full model list with is_visible set directly by user interaction. */
|
||||
model_configurations: ModelConfiguration[];
|
||||
selected_model_names: string[];
|
||||
custom_config?: Record<string, string>;
|
||||
}
|
||||
|
||||
// ─── Misc ─────────────────────────────────────────────────────────────────
|
||||
export interface SubmitLLMProviderParams<
|
||||
T extends BaseLLMFormValues = BaseLLMFormValues,
|
||||
> {
|
||||
providerName: string;
|
||||
values: T;
|
||||
initialValues: T;
|
||||
modelConfigurations: ModelConfiguration[];
|
||||
existingLlmProvider?: LLMProviderView;
|
||||
shouldMarkAsDefault?: boolean;
|
||||
hideSuccess?: boolean;
|
||||
setIsTesting: (testing: boolean) => void;
|
||||
mutate: ScopedMutator;
|
||||
onClose: () => void;
|
||||
setSubmitting: (submitting: boolean) => void;
|
||||
}
|
||||
|
||||
export const filterModelConfigurations = (
|
||||
currentModelConfigurations: ModelConfiguration[],
|
||||
visibleModels: string[],
|
||||
defaultModelName?: string
|
||||
): ModelConfiguration[] => {
|
||||
return currentModelConfigurations
|
||||
.map(
|
||||
(modelConfiguration): ModelConfiguration => ({
|
||||
name: modelConfiguration.name,
|
||||
is_visible: visibleModels.includes(modelConfiguration.name),
|
||||
max_input_tokens: modelConfiguration.max_input_tokens ?? null,
|
||||
supports_image_input: modelConfiguration.supports_image_input,
|
||||
supports_reasoning: modelConfiguration.supports_reasoning,
|
||||
display_name: modelConfiguration.display_name,
|
||||
})
|
||||
)
|
||||
.filter(
|
||||
(modelConfiguration) =>
|
||||
modelConfiguration.name === defaultModelName ||
|
||||
modelConfiguration.is_visible
|
||||
);
|
||||
};
|
||||
|
||||
// Helper to get model configurations for auto mode
|
||||
// In auto mode, we include ALL models but preserve their visibility status
|
||||
// Models in the auto config are visible, others are created but not visible
|
||||
export const getAutoModeModelConfigurations = (
|
||||
modelConfigurations: ModelConfiguration[]
|
||||
): ModelConfiguration[] => {
|
||||
return modelConfigurations.map(
|
||||
(modelConfiguration): ModelConfiguration => ({
|
||||
name: modelConfiguration.name,
|
||||
is_visible: modelConfiguration.is_visible,
|
||||
max_input_tokens: modelConfiguration.max_input_tokens ?? null,
|
||||
supports_image_input: modelConfiguration.supports_image_input,
|
||||
supports_reasoning: modelConfiguration.supports_reasoning,
|
||||
display_name: modelConfiguration.display_name,
|
||||
})
|
||||
);
|
||||
};
|
||||
|
||||
export type TestApiKeyResult =
|
||||
| { ok: true }
|
||||
| { ok: false; errorMessage: string };
|
||||
|
||||
export const getModelOptions = (
|
||||
fetchedModelConfigurations: Array<{ name: string }>
|
||||
) => {
|
||||
return fetchedModelConfigurations.map((model) => ({
|
||||
label: model.name,
|
||||
value: model.name,
|
||||
}));
|
||||
};
|
||||
|
||||
/** Initial values used by onboarding forms (flat shape, always creating new). */
|
||||
export const buildOnboardingInitialValues = () => ({
|
||||
name: "",
|
||||
provider: "",
|
||||
api_key: "",
|
||||
api_base: "",
|
||||
api_version: "",
|
||||
default_model_name: "",
|
||||
model_configurations: [] as ModelConfiguration[],
|
||||
custom_config: {} as Record<string, string>,
|
||||
api_key_changed: true,
|
||||
groups: [] as number[],
|
||||
is_public: true,
|
||||
is_auto_mode: false,
|
||||
personas: [] as number[],
|
||||
selected_model_names: [] as string[],
|
||||
deployment_name: "",
|
||||
target_uri: "",
|
||||
});
|
||||
|
||||
export interface SubmitOnboardingProviderParams {
|
||||
providerName: string;
|
||||
payload: Record<string, unknown>;
|
||||
onboardingState: OnboardingState;
|
||||
onboardingActions: OnboardingActions;
|
||||
isCustomProvider: boolean;
|
||||
onClose: () => void;
|
||||
setIsSubmitting: (submitting: boolean) => void;
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user