mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-08 08:22:42 +00:00
Compare commits
105 Commits
v3.1.2
...
refactor/l
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fcb2e063c5 | ||
|
|
4328e432fe | ||
|
|
939acac4db | ||
|
|
5bdcb2a260 | ||
|
|
87b7ca7aac | ||
|
|
c02ac4f2c5 | ||
|
|
ce2aac7b49 | ||
|
|
98fbb7cd83 | ||
|
|
ccef373f7e | ||
|
|
6fc7e67c4a | ||
|
|
a6a28e7c57 | ||
|
|
d928a685b4 | ||
|
|
ceece8e006 | ||
|
|
d823de9bd8 | ||
|
|
6b2dd45cb9 | ||
|
|
c3568be783 | ||
|
|
6206e69a77 | ||
|
|
04f786b86b | ||
|
|
0ab8a75da8 | ||
|
|
c51ffcdf3d | ||
|
|
c6eb6756bf | ||
|
|
2051153a05 | ||
|
|
b796d0b36a | ||
|
|
299d421de1 | ||
|
|
8b35a0b1a2 | ||
|
|
c6e7c12960 | ||
|
|
8cb6ef58e5 | ||
|
|
f2170e07f4 | ||
|
|
4dc7d871ff | ||
|
|
4a7fe33cd4 | ||
|
|
e54408b365 | ||
|
|
d7a5e541e5 | ||
|
|
a9dece99bd | ||
|
|
e0a6dc5126 | ||
|
|
179a0759d7 | ||
|
|
e665ff96f4 | ||
|
|
a99e275940 | ||
|
|
9ddbbb4e6b | ||
|
|
6533bc61cb | ||
|
|
f31d99a2bd | ||
|
|
80eb544287 | ||
|
|
c85b0095ea | ||
|
|
3e84454cbc | ||
|
|
ca5dff6400 | ||
|
|
72146c27c1 | ||
|
|
0af55a9247 | ||
|
|
677f1749ad | ||
|
|
9d353cb108 | ||
|
|
fc998c1034 | ||
|
|
cc4008a916 | ||
|
|
142c3ac9cf | ||
|
|
3a46b94ca0 | ||
|
|
94633698c3 | ||
|
|
6ae15589cd | ||
|
|
c24a8bb228 | ||
|
|
01945abd86 | ||
|
|
658632195f | ||
|
|
ec6fd01ba4 | ||
|
|
148e6fb97d | ||
|
|
6598c1a48d | ||
|
|
497ce43bd8 | ||
|
|
8634cb0446 | ||
|
|
8d56fd3dc6 | ||
|
|
a7579a99d0 | ||
|
|
3533c10da4 | ||
|
|
7b0414bf0d | ||
|
|
b500ea537a | ||
|
|
abd6d55add | ||
|
|
f15b6b8034 | ||
|
|
fb40485f25 | ||
|
|
22e85f1f28 | ||
|
|
2ef7c3e6f3 | ||
|
|
92a471ed2b | ||
|
|
d1b7e529a4 | ||
|
|
95c3579264 | ||
|
|
8802e5cad3 | ||
|
|
a41b4bbc82 | ||
|
|
c026c077b5 | ||
|
|
3eee539a86 | ||
|
|
143e7a0d72 | ||
|
|
4572358038 | ||
|
|
1753f94c11 | ||
|
|
120ddf2ef6 | ||
|
|
2cce5bc58f | ||
|
|
383a6001d2 | ||
|
|
3a6f45bfca | ||
|
|
e06b5ef202 | ||
|
|
c13ce816fa | ||
|
|
39f3e872ec | ||
|
|
b033c00217 | ||
|
|
6d47c5f21a | ||
|
|
0645540e24 | ||
|
|
a2c0fc4df0 | ||
|
|
7dccc88b35 | ||
|
|
ac617a51ce | ||
|
|
339a111a8f | ||
|
|
09b7e6fc9b | ||
|
|
135238014f | ||
|
|
303e37bf53 | ||
|
|
6a888e9900 | ||
|
|
e90a7767c6 | ||
|
|
1ded3af63c | ||
|
|
c53546c000 | ||
|
|
9afa12edda | ||
|
|
32046de962 |
@@ -114,8 +114,10 @@ jobs:
|
||||
|
||||
- name: Mark workflow as failed if cherry-pick failed
|
||||
if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure'
|
||||
env:
|
||||
CHERRY_PICK_REASON: ${{ steps.run_cherry_pick.outputs.reason }}
|
||||
run: |
|
||||
echo "::error::Automated cherry-pick failed (${{ steps.run_cherry_pick.outputs.reason }})."
|
||||
echo "::error::Automated cherry-pick failed (${CHERRY_PICK_REASON})."
|
||||
exit 1
|
||||
|
||||
notify-slack-on-cherry-pick-failure:
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
"""LLMProvider deprecated fields are nullable
|
||||
|
||||
Revision ID: 001984c88745
|
||||
Revises: 7616121f6e97
|
||||
Create Date: 2026-02-01 22:24:34.171100
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "001984c88745"
|
||||
down_revision = "7616121f6e97"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Make default_model_name nullable (was NOT NULL)
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"default_model_name",
|
||||
existing_type=sa.String(),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Remove server_default from is_default_vision_provider (was server_default=false())
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"is_default_vision_provider",
|
||||
existing_type=sa.Boolean(),
|
||||
server_default=None,
|
||||
)
|
||||
|
||||
# is_default_provider and default_vision_model are already nullable with no server_default
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Restore default_model_name to NOT NULL (set empty string for any NULLs first)
|
||||
op.execute(
|
||||
"UPDATE llm_provider SET default_model_name = '' WHERE default_model_name IS NULL"
|
||||
)
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"default_model_name",
|
||||
existing_type=sa.String(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Restore server_default for is_default_vision_provider
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"is_default_vision_provider",
|
||||
existing_type=sa.Boolean(),
|
||||
server_default=sa.false(),
|
||||
)
|
||||
@@ -123,9 +123,21 @@ def _seed_llms(
|
||||
upsert_llm_provider(llm_upsert_request, db_session)
|
||||
for llm_upsert_request in llm_upsert_requests
|
||||
]
|
||||
update_default_provider(
|
||||
provider_id=seeded_providers[0].id, db_session=db_session
|
||||
)
|
||||
|
||||
if len(seeded_providers[0].model_configurations) > 0:
|
||||
default_model = next(
|
||||
(
|
||||
mc
|
||||
for mc in seeded_providers[0].model_configurations
|
||||
if mc.is_visible
|
||||
),
|
||||
seeded_providers[0].model_configurations[0],
|
||||
).name
|
||||
update_default_provider(
|
||||
provider_id=seeded_providers[0].id,
|
||||
model_name=default_model,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:
|
||||
|
||||
@@ -302,12 +302,12 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
|
||||
has_set_default_provider = False
|
||||
|
||||
def _upsert(request: LLMProviderUpsertRequest) -> None:
|
||||
def _upsert(request: LLMProviderUpsertRequest, default_model: str) -> None:
|
||||
nonlocal has_set_default_provider
|
||||
try:
|
||||
provider = upsert_llm_provider(request, db_session)
|
||||
if not has_set_default_provider:
|
||||
update_default_provider(provider.id, db_session)
|
||||
update_default_provider(provider.id, default_model, db_session)
|
||||
has_set_default_provider = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure {request.provider} provider: {e}")
|
||||
@@ -325,14 +325,13 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
OPENAI_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openai_provider)
|
||||
_upsert(openai_provider, default_model_name)
|
||||
|
||||
# Create default image generation config using the OpenAI API key
|
||||
try:
|
||||
@@ -361,14 +360,13 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="Anthropic",
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
ANTHROPIC_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(anthropic_provider)
|
||||
_upsert(anthropic_provider, default_model_name)
|
||||
else:
|
||||
logger.info(
|
||||
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
|
||||
@@ -393,14 +391,13 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="Google Vertex AI",
|
||||
provider=VERTEXAI_PROVIDER_NAME,
|
||||
custom_config=custom_config,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
VERTEXAI_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(vertexai_provider)
|
||||
_upsert(vertexai_provider, default_model_name)
|
||||
else:
|
||||
logger.info(
|
||||
"VERTEXAI_DEFAULT_CREDENTIALS not set, skipping Vertex AI provider configuration"
|
||||
@@ -432,12 +429,11 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="OpenRouter",
|
||||
provider=OPENROUTER_PROVIDER_NAME,
|
||||
api_key=OPENROUTER_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=model_configurations,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openrouter_provider)
|
||||
_upsert(openrouter_provider, default_model_name)
|
||||
else:
|
||||
logger.info(
|
||||
"OPENROUTER_DEFAULT_API_KEY not set, skipping OpenRouter provider configuration"
|
||||
|
||||
@@ -213,8 +213,12 @@ def upsert_llm_provider(
|
||||
llm_provider_upsert_request: LLMProviderUpsertRequest,
|
||||
db_session: Session,
|
||||
) -> LLMProviderView:
|
||||
existing_llm_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
existing_llm_provider = (
|
||||
fetch_existing_llm_provider_by_id(
|
||||
id=llm_provider_upsert_request.id, db_session=db_session
|
||||
)
|
||||
if llm_provider_upsert_request.id
|
||||
else None
|
||||
)
|
||||
|
||||
if not existing_llm_provider:
|
||||
@@ -238,11 +242,6 @@ def upsert_llm_provider(
|
||||
existing_llm_provider.api_base = api_base
|
||||
existing_llm_provider.api_version = llm_provider_upsert_request.api_version
|
||||
existing_llm_provider.custom_config = custom_config
|
||||
# TODO: Remove default model name on api change
|
||||
# Needed due to /provider/{id}/default endpoint not disclosing the default model name
|
||||
existing_llm_provider.default_model_name = (
|
||||
llm_provider_upsert_request.default_model_name
|
||||
)
|
||||
existing_llm_provider.is_public = llm_provider_upsert_request.is_public
|
||||
existing_llm_provider.is_auto_mode = llm_provider_upsert_request.is_auto_mode
|
||||
existing_llm_provider.deployment_name = llm_provider_upsert_request.deployment_name
|
||||
@@ -251,6 +250,10 @@ def upsert_llm_provider(
|
||||
# If its not already in the db, we need to generate an ID by flushing
|
||||
db_session.flush()
|
||||
|
||||
models_to_exist = {
|
||||
mc.name for mc in llm_provider_upsert_request.model_configurations
|
||||
}
|
||||
|
||||
# Build a lookup of existing model configurations by name (single iteration)
|
||||
existing_by_name = {
|
||||
mc.name: mc for mc in existing_llm_provider.model_configurations
|
||||
@@ -306,15 +309,6 @@ def upsert_llm_provider(
|
||||
display_name=model_config.display_name,
|
||||
)
|
||||
|
||||
default_model = fetch_default_model(db_session, LLMModelFlowType.CHAT)
|
||||
if default_model and default_model.llm_provider_id == existing_llm_provider.id:
|
||||
_update_default_model(
|
||||
db_session=db_session,
|
||||
provider_id=existing_llm_provider.id,
|
||||
model=existing_llm_provider.default_model_name,
|
||||
flow_type=LLMModelFlowType.CHAT,
|
||||
)
|
||||
|
||||
# Make sure the relationship table stays up to date
|
||||
update_group_llm_provider_relationships__no_commit(
|
||||
llm_provider_id=existing_llm_provider.id,
|
||||
@@ -488,6 +482,22 @@ def fetch_existing_llm_provider(
|
||||
return provider_model
|
||||
|
||||
|
||||
def fetch_existing_llm_provider_by_id(
|
||||
id: int, db_session: Session
|
||||
) -> LLMProviderModel | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel)
|
||||
.where(LLMProviderModel.id == id)
|
||||
.options(
|
||||
selectinload(LLMProviderModel.model_configurations),
|
||||
selectinload(LLMProviderModel.groups),
|
||||
selectinload(LLMProviderModel.personas),
|
||||
)
|
||||
)
|
||||
|
||||
return provider_model
|
||||
|
||||
|
||||
def fetch_embedding_provider(
|
||||
db_session: Session, provider_type: EmbeddingProvider
|
||||
) -> CloudEmbeddingProviderModel | None:
|
||||
@@ -604,22 +614,13 @@ def remove_llm_provider__no_commit(db_session: Session, provider_id: int) -> Non
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def update_default_provider(provider_id: int, db_session: Session) -> None:
|
||||
# Attempt to get the default_model_name from the provider first
|
||||
# TODO: Remove default_model_name check
|
||||
provider = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.id == provider_id,
|
||||
)
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f"LLM Provider with id={provider_id} does not exist")
|
||||
|
||||
def update_default_provider(
|
||||
provider_id: int, model_name: str, db_session: Session
|
||||
) -> None:
|
||||
_update_default_model(
|
||||
db_session,
|
||||
provider_id,
|
||||
provider.default_model_name,
|
||||
model_name,
|
||||
LLMModelFlowType.CHAT,
|
||||
)
|
||||
|
||||
@@ -805,12 +806,6 @@ def sync_auto_mode_models(
|
||||
)
|
||||
changes += 1
|
||||
|
||||
# In Auto mode, default model is always set from GitHub config
|
||||
default_model = llm_recommendations.get_default_model(provider.provider)
|
||||
if default_model and provider.default_model_name != default_model.name:
|
||||
provider.default_model_name = default_model.name
|
||||
changes += 1
|
||||
|
||||
db_session.commit()
|
||||
return changes
|
||||
|
||||
@@ -866,7 +861,6 @@ def insert_new_model_configuration__no_commit(
|
||||
is_visible=is_visible,
|
||||
max_input_tokens=max_input_tokens,
|
||||
display_name=display_name,
|
||||
supports_image_input=LLMModelFlowType.VISION in supported_flows,
|
||||
)
|
||||
.on_conflict_do_nothing()
|
||||
.returning(ModelConfiguration.id)
|
||||
@@ -901,7 +895,6 @@ def update_model_configuration__no_commit(
|
||||
is_visible=is_visible,
|
||||
max_input_tokens=max_input_tokens,
|
||||
display_name=display_name,
|
||||
supports_image_input=LLMModelFlowType.VISION in supported_flows,
|
||||
)
|
||||
.where(ModelConfiguration.id == model_configuration_id)
|
||||
.returning(ModelConfiguration)
|
||||
|
||||
@@ -2822,14 +2822,9 @@ class LLMProvider(Base):
|
||||
custom_config: Mapped[dict[str, str] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
default_model_name: Mapped[str] = mapped_column(String)
|
||||
|
||||
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# should only be set for a single provider
|
||||
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
|
||||
is_default_vision_provider: Mapped[bool | None] = mapped_column(Boolean)
|
||||
default_vision_model: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
# EE only
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
# Auto mode: models, visibility, and defaults are managed by GitHub config
|
||||
@@ -2879,8 +2874,6 @@ class ModelConfiguration(Base):
|
||||
# - The end-user is configuring a model and chooses not to set a max-input-tokens limit.
|
||||
max_input_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
supports_image_input: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
|
||||
# Human-readable display name for the model.
|
||||
# For dynamic providers (OpenRouter, Bedrock, Ollama), this comes from the source API.
|
||||
# For static providers (OpenAI, Anthropic), this may be null and will fall back to LiteLLM.
|
||||
|
||||
@@ -97,7 +97,6 @@ def _build_llm_provider_request(
|
||||
), # Only this from source
|
||||
api_base=api_base, # From request
|
||||
api_version=api_version, # From request
|
||||
default_model_name=model_name,
|
||||
deployment_name=deployment_name, # From request
|
||||
is_public=True,
|
||||
groups=[],
|
||||
@@ -136,7 +135,6 @@ def _build_llm_provider_request(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
default_model_name=model_name,
|
||||
deployment_name=deployment_name,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
@@ -168,7 +166,6 @@ def _create_image_gen_llm_provider__no_commit(
|
||||
api_key=provider_request.api_key,
|
||||
api_base=provider_request.api_base,
|
||||
api_version=provider_request.api_version,
|
||||
default_model_name=provider_request.default_model_name,
|
||||
deployment_name=provider_request.deployment_name,
|
||||
is_public=provider_request.is_public,
|
||||
custom_config=provider_request.custom_config,
|
||||
|
||||
@@ -22,7 +22,10 @@ from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import can_user_access_llm_provider
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.llm import fetch_default_vision_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_existing_llm_provider_by_id
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_existing_models
|
||||
from onyx.db.llm import fetch_persona_with_groups
|
||||
@@ -52,11 +55,12 @@ from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
)
|
||||
from onyx.server.manage.llm.models import BedrockFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BedrockModelsRequest
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LLMCost
|
||||
from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderResponse
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OllamaModelDetails
|
||||
from onyx.server.manage.llm.models import OllamaModelsRequest
|
||||
@@ -233,12 +237,12 @@ def test_llm_configuration(
|
||||
|
||||
test_api_key = test_llm_request.api_key
|
||||
test_custom_config = test_llm_request.custom_config
|
||||
if test_llm_request.name:
|
||||
if test_llm_request.id:
|
||||
# NOTE: we are querying by name. we probably should be querying by an invariant id, but
|
||||
# as it turns out the name is not editable in the UI and other code also keys off name,
|
||||
# so we won't rock the boat just yet.
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=test_llm_request.name, db_session=db_session
|
||||
existing_provider = fetch_existing_llm_provider_by_id(
|
||||
id=test_llm_request.id, db_session=db_session
|
||||
)
|
||||
if existing_provider:
|
||||
test_custom_config = _restore_masked_custom_config_values(
|
||||
@@ -268,7 +272,7 @@ def test_llm_configuration(
|
||||
|
||||
llm = get_llm(
|
||||
provider=test_llm_request.provider,
|
||||
model=test_llm_request.default_model_name,
|
||||
model=test_llm_request.model,
|
||||
api_key=test_api_key,
|
||||
api_base=test_llm_request.api_base,
|
||||
api_version=test_llm_request.api_version,
|
||||
@@ -303,7 +307,7 @@ def list_llm_providers(
|
||||
include_image_gen: bool = Query(False),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderView]:
|
||||
) -> LLMProviderResponse[LLMProviderView]:
|
||||
start_time = datetime.now(timezone.utc)
|
||||
logger.debug("Starting to fetch LLM providers")
|
||||
|
||||
@@ -328,7 +332,25 @@ def list_llm_providers(
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
logger.debug(f"Completed fetching LLM providers in {duration:.2f} seconds")
|
||||
|
||||
return llm_provider_list
|
||||
default_model = None
|
||||
if model_config := fetch_default_llm_model(db_session):
|
||||
default_model = DefaultModel(
|
||||
provider_id=model_config.llm_provider.id,
|
||||
model_name=model_config.name,
|
||||
)
|
||||
|
||||
default_vision_model = None
|
||||
if model_config := fetch_default_vision_model(db_session):
|
||||
default_vision_model = DefaultModel(
|
||||
provider_id=model_config.llm_provider.id,
|
||||
model_name=model_config.name,
|
||||
)
|
||||
|
||||
return LLMProviderResponse[LLMProviderView].from_models(
|
||||
providers=llm_provider_list,
|
||||
default_text=default_model,
|
||||
default_vision=default_vision_model,
|
||||
)
|
||||
|
||||
|
||||
@admin_router.put("/provider")
|
||||
@@ -341,21 +363,29 @@ def put_llm_provider(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LLMProviderView:
|
||||
# NOTE: Name updating functionality currently not supported. There are many places that still
|
||||
# rely on immutable names, so this will be a larger change
|
||||
|
||||
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
|
||||
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
|
||||
# the result
|
||||
existing_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
)
|
||||
existing_provider = None
|
||||
if llm_provider_upsert_request.id:
|
||||
existing_provider = fetch_existing_llm_provider_by_id(
|
||||
id=llm_provider_upsert_request.id, db_session=db_session
|
||||
)
|
||||
|
||||
if existing_provider and is_creation:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"LLM Provider with name {llm_provider_upsert_request.name} already exists",
|
||||
detail=f"LLM Provider with name {llm_provider_upsert_request.name} and \
|
||||
id={llm_provider_upsert_request.id} already exists",
|
||||
)
|
||||
elif not existing_provider and not is_creation:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"LLM Provider with name {llm_provider_upsert_request.name} does not exist",
|
||||
detail=f"LLM Provider with name {llm_provider_upsert_request.name} and \
|
||||
id={llm_provider_upsert_request.id} does not exist",
|
||||
)
|
||||
|
||||
# SSRF Protection: Validate api_base and custom_config match stored values
|
||||
@@ -393,22 +423,6 @@ def put_llm_provider(
|
||||
deduplicated_personas.append(persona_id)
|
||||
llm_provider_upsert_request.personas = deduplicated_personas
|
||||
|
||||
default_model_found = False
|
||||
|
||||
for model_configuration in llm_provider_upsert_request.model_configurations:
|
||||
if model_configuration.name == llm_provider_upsert_request.default_model_name:
|
||||
model_configuration.is_visible = True
|
||||
default_model_found = True
|
||||
|
||||
# TODO: Remove this logic on api change
|
||||
# Believed to be a dead pathway but we want to be safe for now
|
||||
if not default_model_found:
|
||||
llm_provider_upsert_request.model_configurations.append(
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=llm_provider_upsert_request.default_model_name, is_visible=True
|
||||
)
|
||||
)
|
||||
|
||||
# the llm api key is sanitized when returned to clients, so the only time we
|
||||
# should get a real key is when it is explicitly changed
|
||||
if existing_provider and not llm_provider_upsert_request.api_key_changed:
|
||||
@@ -438,8 +452,8 @@ def put_llm_provider(
|
||||
config = fetch_llm_recommendations_from_github()
|
||||
if config and llm_provider_upsert_request.provider in config.providers:
|
||||
# Refetch the provider to get the updated model
|
||||
updated_provider = fetch_existing_llm_provider(
|
||||
name=llm_provider_upsert_request.name, db_session=db_session
|
||||
updated_provider = fetch_existing_llm_provider_by_id(
|
||||
id=result.id, db_session=db_session
|
||||
)
|
||||
if updated_provider:
|
||||
sync_auto_mode_models(
|
||||
@@ -469,28 +483,29 @@ def delete_llm_provider(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.post("/provider/{provider_id}/default")
|
||||
@admin_router.post("/default")
|
||||
def set_provider_as_default(
|
||||
provider_id: int,
|
||||
default_model_request: DefaultModel,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_default_provider(provider_id=provider_id, db_session=db_session)
|
||||
update_default_provider(
|
||||
provider_id=default_model_request.provider_id,
|
||||
model_name=default_model_request.model_name,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@admin_router.post("/provider/{provider_id}/default-vision")
|
||||
@admin_router.post("/default-vision")
|
||||
def set_provider_as_default_vision(
|
||||
provider_id: int,
|
||||
vision_model: str | None = Query(
|
||||
None, description="The default vision model to use"
|
||||
),
|
||||
default_model_request: DefaultModel,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
if vision_model is None:
|
||||
raise HTTPException(status_code=404, detail="Vision model not provided")
|
||||
update_default_vision_provider(
|
||||
provider_id=provider_id, vision_model=vision_model, db_session=db_session
|
||||
provider_id=default_model_request.provider_id,
|
||||
vision_model=default_model_request.model_name,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@@ -516,7 +531,7 @@ def get_auto_config(
|
||||
def get_vision_capable_providers(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[VisionProviderResponse]:
|
||||
) -> LLMProviderResponse[VisionProviderResponse]:
|
||||
"""Return a list of LLM providers and their models that support image input"""
|
||||
vision_models = fetch_existing_models(
|
||||
db_session=db_session, flow_types=[LLMModelFlowType.VISION]
|
||||
@@ -545,7 +560,18 @@ def get_vision_capable_providers(
|
||||
]
|
||||
|
||||
logger.debug(f"Found {len(vision_provider_response)} vision-capable providers")
|
||||
return vision_provider_response
|
||||
|
||||
default_vision_model = None
|
||||
if model_config := fetch_default_vision_model(db_session):
|
||||
default_vision_model = DefaultModel(
|
||||
provider_id=model_config.llm_provider.id,
|
||||
model_name=model_config.name,
|
||||
)
|
||||
|
||||
return LLMProviderResponse[VisionProviderResponse].from_models(
|
||||
providers=vision_provider_response,
|
||||
default_vision=default_vision_model,
|
||||
)
|
||||
|
||||
|
||||
"""Endpoints for all"""
|
||||
@@ -555,7 +581,7 @@ def get_vision_capable_providers(
|
||||
def list_llm_provider_basics(
|
||||
user: User = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
) -> LLMProviderResponse[LLMProviderDescriptor]:
|
||||
"""Get LLM providers accessible to the current user.
|
||||
|
||||
Returns:
|
||||
@@ -592,7 +618,25 @@ def list_llm_provider_basics(
|
||||
f"Completed fetching {len(accessible_providers)} user-accessible providers in {duration:.2f} seconds"
|
||||
)
|
||||
|
||||
return accessible_providers
|
||||
default_model = None
|
||||
if model_config := fetch_default_llm_model(db_session):
|
||||
default_model = DefaultModel(
|
||||
provider_id=model_config.llm_provider.id,
|
||||
model_name=model_config.name,
|
||||
)
|
||||
|
||||
default_vision_model = None
|
||||
if model_config := fetch_default_vision_model(db_session):
|
||||
default_vision_model = DefaultModel(
|
||||
provider_id=model_config.llm_provider.id,
|
||||
model_name=model_config.name,
|
||||
)
|
||||
|
||||
return LLMProviderResponse[LLMProviderDescriptor].from_models(
|
||||
providers=accessible_providers,
|
||||
default_text=default_model,
|
||||
default_vision=default_vision_model,
|
||||
)
|
||||
|
||||
|
||||
def get_valid_model_names_for_persona(
|
||||
@@ -635,7 +679,7 @@ def list_llm_providers_for_persona(
|
||||
persona_id: int,
|
||||
user: User = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
) -> LLMProviderResponse[LLMProviderDescriptor]:
|
||||
"""Get LLM providers for a specific persona.
|
||||
|
||||
Returns providers that the user can access when using this persona:
|
||||
@@ -682,7 +726,63 @@ def list_llm_providers_for_persona(
|
||||
f"Completed fetching {len(llm_provider_list)} LLM providers for persona {persona_id} in {duration:.2f} seconds"
|
||||
)
|
||||
|
||||
return llm_provider_list
|
||||
# Get the default model and vision model for the persona
|
||||
# NOTE: This should be ported over to use id as it is blocking on name mutability
|
||||
persona_default_provider = persona.llm_model_provider_override
|
||||
persona_default_model = persona.llm_model_version_override
|
||||
|
||||
default_text_model = fetch_default_llm_model(db_session)
|
||||
default_vision_model = fetch_default_vision_model(db_session)
|
||||
|
||||
# Build default_text and default_vision using persona overrides when available,
|
||||
# falling back to the global defaults.
|
||||
default_text: DefaultModel | None = (
|
||||
DefaultModel(
|
||||
provider_id=default_text_model.llm_provider.id,
|
||||
model_name=default_text_model.name,
|
||||
)
|
||||
if default_text_model
|
||||
else None
|
||||
)
|
||||
default_vision: DefaultModel | None = (
|
||||
DefaultModel(
|
||||
provider_id=default_vision_model.llm_provider.id,
|
||||
model_name=default_vision_model.name,
|
||||
)
|
||||
if default_vision_model
|
||||
else None
|
||||
)
|
||||
|
||||
if persona_default_provider:
|
||||
provider = fetch_existing_llm_provider(persona_default_provider, db_session)
|
||||
if provider:
|
||||
if persona_default_model:
|
||||
# Persona specifies both provider and model — use them directly
|
||||
default_text = DefaultModel(
|
||||
provider_id=provider.id,
|
||||
model_name=persona_default_model,
|
||||
)
|
||||
else:
|
||||
# Persona specifies only the provider — pick a visible (public) model,
|
||||
# falling back to any model on this provider
|
||||
visible_model = next(
|
||||
(mc for mc in provider.model_configurations if mc.is_visible),
|
||||
None,
|
||||
)
|
||||
fallback_model = visible_model or next(
|
||||
iter(provider.model_configurations), None
|
||||
)
|
||||
if fallback_model:
|
||||
default_text = DefaultModel(
|
||||
provider_id=provider.id,
|
||||
model_name=fallback_model.name,
|
||||
)
|
||||
|
||||
return LLMProviderResponse[LLMProviderDescriptor].from_models(
|
||||
providers=llm_provider_list,
|
||||
default_text=default_text,
|
||||
default_vision=default_vision,
|
||||
)
|
||||
|
||||
|
||||
@admin_router.get("/provider-contextual-cost")
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
@@ -21,6 +23,8 @@ if TYPE_CHECKING:
|
||||
ModelConfiguration as ModelConfigurationModel,
|
||||
)
|
||||
|
||||
T = TypeVar("T", bound="LLMProviderDescriptor | LLMProviderView")
|
||||
|
||||
|
||||
# TODO: Clear this up on api refactor
|
||||
# There is still logic that requires sending each providers default model name
|
||||
@@ -52,19 +56,18 @@ def get_default_vision_model_name(llm_provider_model: "LLMProviderModel") -> str
|
||||
|
||||
class TestLLMRequest(BaseModel):
|
||||
# provider level
|
||||
id: int | None = None
|
||||
name: str | None = None
|
||||
provider: str
|
||||
model: str
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
custom_config: dict[str, str] | None = None
|
||||
|
||||
# model level
|
||||
default_model_name: str
|
||||
deployment_name: str | None = None
|
||||
|
||||
model_configurations: list["ModelConfigurationUpsertRequest"]
|
||||
|
||||
# if try and use the existing API/custom config key
|
||||
api_key_changed: bool
|
||||
custom_config_changed: bool
|
||||
@@ -80,13 +83,10 @@ class LLMProviderDescriptor(BaseModel):
|
||||
"""A descriptor for an LLM provider that can be safely viewed by
|
||||
non-admin users. Used when giving a list of available LLMs."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
provider: str
|
||||
provider_display_name: str # Human-friendly name like "Claude (Anthropic)"
|
||||
default_model_name: str
|
||||
is_default_provider: bool | None
|
||||
is_default_vision_provider: bool | None
|
||||
default_vision_model: str | None
|
||||
model_configurations: list["ModelConfigurationView"]
|
||||
|
||||
@classmethod
|
||||
@@ -99,22 +99,12 @@ class LLMProviderDescriptor(BaseModel):
|
||||
)
|
||||
|
||||
provider = llm_provider_model.provider
|
||||
default_model_name = get_default_llm_model_name(llm_provider_model)
|
||||
default_vision_model = get_default_vision_model_name(llm_provider_model)
|
||||
|
||||
is_default_provider = bool(default_model_name)
|
||||
is_default_vision_provider = default_vision_model is not None
|
||||
|
||||
default_model_name = default_model_name or llm_provider_model.default_model_name
|
||||
|
||||
return cls(
|
||||
id=llm_provider_model.id,
|
||||
name=llm_provider_model.name,
|
||||
provider=provider,
|
||||
provider_display_name=get_provider_display_name(provider),
|
||||
default_model_name=default_model_name,
|
||||
is_default_provider=is_default_provider,
|
||||
is_default_vision_provider=is_default_vision_provider,
|
||||
default_vision_model=default_vision_model,
|
||||
model_configurations=filter_model_configurations(
|
||||
llm_provider_model.model_configurations, provider
|
||||
),
|
||||
@@ -128,18 +118,17 @@ class LLMProvider(BaseModel):
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
custom_config: dict[str, str] | None = None
|
||||
default_model_name: str
|
||||
is_public: bool = True
|
||||
is_auto_mode: bool = False
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
personas: list[int] = Field(default_factory=list)
|
||||
deployment_name: str | None = None
|
||||
default_vision_model: str | None = None
|
||||
|
||||
|
||||
class LLMProviderUpsertRequest(LLMProvider):
|
||||
# should only be used for a "custom" provider
|
||||
# for default providers, the built-in model names are used
|
||||
id: int | None = None
|
||||
api_key_changed: bool = False
|
||||
custom_config_changed: bool = False
|
||||
model_configurations: list["ModelConfigurationUpsertRequest"] = []
|
||||
@@ -155,8 +144,6 @@ class LLMProviderView(LLMProvider):
|
||||
"""Stripped down representation of LLMProvider for display / limited access info only"""
|
||||
|
||||
id: int
|
||||
is_default_provider: bool | None = None
|
||||
is_default_vision_provider: bool | None = None
|
||||
model_configurations: list["ModelConfigurationView"]
|
||||
|
||||
@classmethod
|
||||
@@ -178,14 +165,6 @@ class LLMProviderView(LLMProvider):
|
||||
|
||||
provider = llm_provider_model.provider
|
||||
|
||||
default_model_name = get_default_llm_model_name(llm_provider_model)
|
||||
default_vision_model = get_default_vision_model_name(llm_provider_model)
|
||||
|
||||
is_default_provider = bool(default_model_name)
|
||||
is_default_vision_provider = default_vision_model is not None
|
||||
|
||||
default_model_name = default_model_name or llm_provider_model.default_model_name
|
||||
|
||||
return cls(
|
||||
id=llm_provider_model.id,
|
||||
name=llm_provider_model.name,
|
||||
@@ -198,10 +177,6 @@ class LLMProviderView(LLMProvider):
|
||||
api_base=llm_provider_model.api_base,
|
||||
api_version=llm_provider_model.api_version,
|
||||
custom_config=llm_provider_model.custom_config,
|
||||
default_model_name=default_model_name,
|
||||
is_default_provider=is_default_provider,
|
||||
is_default_vision_provider=is_default_vision_provider,
|
||||
default_vision_model=default_vision_model,
|
||||
is_public=llm_provider_model.is_public,
|
||||
is_auto_mode=llm_provider_model.is_auto_mode,
|
||||
groups=groups,
|
||||
@@ -228,7 +203,8 @@ class ModelConfigurationUpsertRequest(BaseModel):
|
||||
name=model_configuration_model.name,
|
||||
is_visible=model_configuration_model.is_visible,
|
||||
max_input_tokens=model_configuration_model.max_input_tokens,
|
||||
supports_image_input=model_configuration_model.supports_image_input,
|
||||
supports_image_input=LLMModelFlowType.VISION
|
||||
in model_configuration_model.llm_model_flow_types,
|
||||
display_name=model_configuration_model.display_name,
|
||||
)
|
||||
|
||||
@@ -421,3 +397,27 @@ class OpenRouterFinalModelResponse(BaseModel):
|
||||
int | None
|
||||
) # From OpenRouter API context_length (may be missing for some models)
|
||||
supports_image_input: bool
|
||||
|
||||
|
||||
class DefaultModel(BaseModel):
|
||||
provider_id: int
|
||||
model_name: str
|
||||
|
||||
|
||||
class LLMProviderResponse(BaseModel, Generic[T]):
|
||||
providers: list[T]
|
||||
default_text: DefaultModel | None = None
|
||||
default_vision: DefaultModel | None = None
|
||||
|
||||
@classmethod
|
||||
def from_models(
|
||||
cls,
|
||||
providers: list[T],
|
||||
default_text: DefaultModel | None = None,
|
||||
default_vision: DefaultModel | None = None,
|
||||
) -> "LLMProviderResponse[T]":
|
||||
return cls(
|
||||
providers=providers,
|
||||
default_text=default_text,
|
||||
default_vision=default_vision,
|
||||
)
|
||||
|
||||
@@ -245,7 +245,11 @@ def setup_postgres(db_session: Session) -> None:
|
||||
create_initial_default_connector(db_session)
|
||||
associate_default_cc_pair(db_session)
|
||||
|
||||
if GEN_AI_API_KEY and fetch_default_llm_model(db_session) is None:
|
||||
if (
|
||||
GEN_AI_API_KEY
|
||||
and fetch_default_llm_model(db_session) is None
|
||||
and not INTEGRATION_TESTS_MODE
|
||||
):
|
||||
# Only for dev flows
|
||||
logger.notice("Setting up default OpenAI LLM for dev.")
|
||||
|
||||
@@ -257,7 +261,6 @@ def setup_postgres(db_session: Session) -> None:
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
custom_config=None,
|
||||
default_model_name=llm_model,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
model_configurations=[
|
||||
@@ -269,7 +272,9 @@ def setup_postgres(db_session: Session) -> None:
|
||||
new_llm_provider = upsert_llm_provider(
|
||||
llm_provider_upsert_request=model_req, db_session=db_session
|
||||
)
|
||||
update_default_provider(provider_id=new_llm_provider.id, db_session=db_session)
|
||||
update_default_provider(
|
||||
provider_id=new_llm_provider.id, model_name=llm_model, db_session=db_session
|
||||
)
|
||||
|
||||
|
||||
def update_default_multipass_indexing(db_session: Session) -> None:
|
||||
|
||||
@@ -17,7 +17,7 @@ def test_bedrock_llm_configuration(client: TestClient) -> None:
|
||||
# Prepare the test request payload
|
||||
test_request: dict[str, Any] = {
|
||||
"provider": LlmProviderNames.BEDROCK,
|
||||
"default_model_name": _DEFAULT_BEDROCK_MODEL,
|
||||
"model": _DEFAULT_BEDROCK_MODEL,
|
||||
"api_key": None,
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
@@ -26,7 +26,6 @@ def test_bedrock_llm_configuration(client: TestClient) -> None:
|
||||
"AWS_ACCESS_KEY_ID": os.environ.get("AWS_ACCESS_KEY_ID"),
|
||||
"AWS_SECRET_ACCESS_KEY": os.environ.get("AWS_SECRET_ACCESS_KEY"),
|
||||
},
|
||||
"model_configurations": [{"name": _DEFAULT_BEDROCK_MODEL, "is_visible": True}],
|
||||
"api_key_changed": True,
|
||||
"custom_config_changed": True,
|
||||
}
|
||||
@@ -44,7 +43,7 @@ def test_bedrock_llm_configuration_invalid_key(client: TestClient) -> None:
|
||||
# Prepare the test request payload with invalid credentials
|
||||
test_request: dict[str, Any] = {
|
||||
"provider": LlmProviderNames.BEDROCK,
|
||||
"default_model_name": _DEFAULT_BEDROCK_MODEL,
|
||||
"model": _DEFAULT_BEDROCK_MODEL,
|
||||
"api_key": None,
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
@@ -53,7 +52,6 @@ def test_bedrock_llm_configuration_invalid_key(client: TestClient) -> None:
|
||||
"AWS_ACCESS_KEY_ID": "invalid_access_key_id",
|
||||
"AWS_SECRET_ACCESS_KEY": "invalid_secret_access_key",
|
||||
},
|
||||
"model_configurations": [{"name": _DEFAULT_BEDROCK_MODEL, "is_visible": True}],
|
||||
"api_key_changed": True,
|
||||
"custom_config_changed": True,
|
||||
}
|
||||
|
||||
@@ -28,7 +28,6 @@ def ensure_default_llm_provider(db_session: Session) -> None:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=os.environ.get("OPENAI_API_KEY", "test"),
|
||||
is_public=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini",
|
||||
@@ -41,7 +40,7 @@ def ensure_default_llm_provider(db_session: Session) -> None:
|
||||
llm_provider_upsert_request=llm_provider_request,
|
||||
db_session=db_session,
|
||||
)
|
||||
update_default_provider(provider.id, db_session)
|
||||
update_default_provider(provider.id, "gpt-4o-mini", db_session)
|
||||
except Exception as exc: # pragma: no cover - only hits on duplicate setup issues
|
||||
# Rollback to clear the pending transaction state
|
||||
db_session.rollback()
|
||||
|
||||
@@ -47,7 +47,6 @@ def test_answer_with_only_anthropic_provider(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.ANTHROPIC,
|
||||
api_key=anthropic_api_key,
|
||||
default_model_name=anthropic_model,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
model_configurations=[
|
||||
@@ -59,7 +58,7 @@ def test_answer_with_only_anthropic_provider(
|
||||
)
|
||||
|
||||
try:
|
||||
update_default_provider(anthropic_provider.id, db_session)
|
||||
update_default_provider(anthropic_provider.id, anthropic_model, db_session)
|
||||
|
||||
test_user = create_test_user(db_session, email_prefix="anthropic_only")
|
||||
chat_session = create_chat_session(
|
||||
|
||||
@@ -29,6 +29,7 @@ from onyx.server.manage.llm.api import (
|
||||
test_llm_configuration as run_test_llm_configuration,
|
||||
)
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import TestLLMRequest as LLMTestRequest
|
||||
|
||||
@@ -44,15 +45,14 @@ def _create_test_provider(
|
||||
db_session: Session,
|
||||
name: str,
|
||||
api_key: str = "sk-test-key-00000000000000000000000000000000000",
|
||||
) -> None:
|
||||
) -> LLMProviderView:
|
||||
"""Helper to create a test LLM provider in the database."""
|
||||
upsert_llm_provider(
|
||||
return upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o-mini", is_visible=True)
|
||||
],
|
||||
@@ -107,12 +107,7 @@ class TestLLMConfigurationEndpoint:
|
||||
api_key="sk-new-test-key-0000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -157,12 +152,7 @@ class TestLLMConfigurationEndpoint:
|
||||
api_key="sk-invalid-key-00000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -194,7 +184,9 @@ class TestLLMConfigurationEndpoint:
|
||||
|
||||
try:
|
||||
# First, create the provider in the database
|
||||
_create_test_provider(db_session, provider_name, api_key=original_api_key)
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, api_key=original_api_key
|
||||
)
|
||||
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
|
||||
@@ -202,17 +194,13 @@ class TestLLMConfigurationEndpoint:
|
||||
# Test with api_key_changed=False - should use stored key
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
id=provider.id,
|
||||
name=provider_name, # Existing provider
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None, # Not providing a new key
|
||||
api_key_changed=False, # Using existing key
|
||||
custom_config_changed=False,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -259,12 +247,7 @@ class TestLLMConfigurationEndpoint:
|
||||
api_key=new_api_key, # Providing a new key
|
||||
api_key_changed=True, # Key is being changed
|
||||
custom_config_changed=False,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -297,7 +280,7 @@ class TestLLMConfigurationEndpoint:
|
||||
|
||||
try:
|
||||
# First, create the provider in the database with custom_config
|
||||
upsert_llm_provider(
|
||||
provider = upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
@@ -305,12 +288,6 @@ class TestLLMConfigurationEndpoint:
|
||||
api_key_changed=True,
|
||||
custom_config=original_custom_config,
|
||||
custom_config_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
@@ -321,18 +298,14 @@ class TestLLMConfigurationEndpoint:
|
||||
# Test with custom_config_changed=False - should use stored config
|
||||
run_test_llm_configuration(
|
||||
test_llm_request=LLMTestRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None,
|
||||
api_key_changed=False,
|
||||
custom_config=None, # Not providing new config
|
||||
custom_config_changed=False, # Using existing config
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -373,12 +346,7 @@ class TestLLMConfigurationEndpoint:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
default_model_name=model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
model=model_name,
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -442,7 +410,6 @@ class TestDefaultProviderEndpoint:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=provider_1_api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name=provider_1_initial_model,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
|
||||
@@ -452,7 +419,7 @@ class TestDefaultProviderEndpoint:
|
||||
)
|
||||
|
||||
# Set provider 1 as the default provider explicitly
|
||||
update_default_provider(provider_1.id, db_session)
|
||||
update_default_provider(provider_1.id, provider_1_initial_model, db_session)
|
||||
|
||||
# Step 2: Call run_test_default_provider - should use provider 1's default model
|
||||
with patch(
|
||||
@@ -472,7 +439,6 @@ class TestDefaultProviderEndpoint:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=provider_2_api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name=provider_2_default_model,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
@@ -499,11 +465,11 @@ class TestDefaultProviderEndpoint:
|
||||
# Step 5: Update provider 1's default model
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
id=provider_1.id,
|
||||
name=provider_1_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=provider_1_api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name=provider_1_updated_model, # Changed
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
|
||||
@@ -512,6 +478,9 @@ class TestDefaultProviderEndpoint:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Set provider 1's default model to the updated model
|
||||
update_default_provider(provider_1.id, provider_1_updated_model, db_session)
|
||||
|
||||
# Step 6: Call run_test_default_provider - should use new model on provider 1
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
|
||||
@@ -524,7 +493,7 @@ class TestDefaultProviderEndpoint:
|
||||
captured_llms.clear()
|
||||
|
||||
# Step 7: Change the default provider to provider 2
|
||||
update_default_provider(provider_2.id, db_session)
|
||||
update_default_provider(provider_2.id, provider_2_default_model, db_session)
|
||||
|
||||
# Step 8: Call run_test_default_provider - should use provider 2
|
||||
with patch(
|
||||
@@ -596,7 +565,6 @@ class TestDefaultProviderEndpoint:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
@@ -605,7 +573,7 @@ class TestDefaultProviderEndpoint:
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
update_default_provider(provider.id, db_session)
|
||||
update_default_provider(provider.id, "gpt-4o-mini", db_session)
|
||||
|
||||
# Test should fail
|
||||
with patch(
|
||||
|
||||
@@ -20,6 +20,7 @@ from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_llm_provider_view
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import UserRole
|
||||
@@ -49,7 +50,6 @@ def _create_test_provider(
|
||||
api_key_changed=True,
|
||||
api_base=api_base,
|
||||
custom_config=custom_config,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o-mini", is_visible=True)
|
||||
],
|
||||
@@ -91,14 +91,14 @@ class TestLLMProviderChanges:
|
||||
the API key should be blocked.
|
||||
"""
|
||||
try:
|
||||
_create_test_provider(db_session, provider_name)
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base="https://attacker.example.com",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -125,16 +125,16 @@ class TestLLMProviderChanges:
|
||||
Changing api_base IS allowed when the API key is also being changed.
|
||||
"""
|
||||
try:
|
||||
_create_test_provider(db_session, provider_name)
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-new-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
api_base="https://custom-endpoint.example.com/v1",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -159,14 +159,16 @@ class TestLLMProviderChanges:
|
||||
original_api_base = "https://original.example.com/v1"
|
||||
|
||||
try:
|
||||
_create_test_provider(db_session, provider_name, api_base=original_api_base)
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, api_base=original_api_base
|
||||
)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base=original_api_base,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -190,14 +192,14 @@ class TestLLMProviderChanges:
|
||||
changes. This allows model-only updates when provider has no custom base URL.
|
||||
"""
|
||||
try:
|
||||
_create_test_provider(db_session, provider_name, api_base=None)
|
||||
view = _create_test_provider(db_session, provider_name, api_base=None)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=view.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base="",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -223,14 +225,16 @@ class TestLLMProviderChanges:
|
||||
original_api_base = "https://original.example.com/v1"
|
||||
|
||||
try:
|
||||
_create_test_provider(db_session, provider_name, api_base=original_api_base)
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, api_base=original_api_base
|
||||
)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base=None,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -259,14 +263,14 @@ class TestLLMProviderChanges:
|
||||
users have full control over their deployment.
|
||||
"""
|
||||
try:
|
||||
_create_test_provider(db_session, provider_name)
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_base="https://custom.example.com/v1",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -297,7 +301,6 @@ class TestLLMProviderChanges:
|
||||
api_key="sk-new-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
api_base="https://custom.example.com/v1",
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -322,7 +325,7 @@ class TestLLMProviderChanges:
|
||||
redirect LLM API requests).
|
||||
"""
|
||||
try:
|
||||
_create_test_provider(
|
||||
provider = _create_test_provider(
|
||||
db_session,
|
||||
provider_name,
|
||||
custom_config={"SOME_CONFIG": "original_value"},
|
||||
@@ -330,11 +333,11 @@ class TestLLMProviderChanges:
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
custom_config={"OPENAI_API_BASE": "https://attacker.example.com"},
|
||||
custom_config_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -362,15 +365,15 @@ class TestLLMProviderChanges:
|
||||
without changing the API key.
|
||||
"""
|
||||
try:
|
||||
_create_test_provider(db_session, provider_name)
|
||||
provider = _create_test_provider(db_session, provider_name)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
custom_config={"OPENAI_API_BASE": "https://attacker.example.com"},
|
||||
custom_config_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
@@ -399,7 +402,7 @@ class TestLLMProviderChanges:
|
||||
new_config = {"AWS_REGION_NAME": "us-west-2"}
|
||||
|
||||
try:
|
||||
_create_test_provider(
|
||||
provider = _create_test_provider(
|
||||
db_session,
|
||||
provider_name,
|
||||
custom_config={"AWS_REGION_NAME": "us-east-1"},
|
||||
@@ -407,13 +410,13 @@ class TestLLMProviderChanges:
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-new-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=True,
|
||||
custom_config=new_config,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -438,17 +441,17 @@ class TestLLMProviderChanges:
|
||||
original_config = {"AWS_REGION_NAME": "us-east-1"}
|
||||
|
||||
try:
|
||||
_create_test_provider(
|
||||
provider = _create_test_provider(
|
||||
db_session, provider_name, custom_config=original_config
|
||||
)
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
custom_config=original_config,
|
||||
custom_config_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
)
|
||||
|
||||
result = put_llm_provider(
|
||||
@@ -474,7 +477,7 @@ class TestLLMProviderChanges:
|
||||
new_config = {"AWS_REGION_NAME": "eu-west-1"}
|
||||
|
||||
try:
|
||||
_create_test_provider(
|
||||
provider = _create_test_provider(
|
||||
db_session,
|
||||
provider_name,
|
||||
custom_config={"AWS_REGION_NAME": "us-east-1"},
|
||||
@@ -482,10 +485,10 @@ class TestLLMProviderChanges:
|
||||
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
|
||||
update_request = LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
custom_config=new_config,
|
||||
default_model_name="gpt-4o-mini",
|
||||
custom_config_changed=True,
|
||||
)
|
||||
|
||||
@@ -532,12 +535,7 @@ def test_upload_with_custom_config_then_change(
|
||||
LLMTestRequest(
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
model=default_model_name,
|
||||
api_key_changed=False,
|
||||
custom_config_changed=True,
|
||||
custom_config=custom_config,
|
||||
@@ -546,11 +544,10 @@ def test_upload_with_custom_config_then_change(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
put_llm_provider(
|
||||
provider = put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
default_model_name=default_model_name,
|
||||
custom_config=custom_config,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -569,14 +566,10 @@ def test_upload_with_custom_config_then_change(
|
||||
# Turn auto mode off
|
||||
run_llm_config_test(
|
||||
LLMTestRequest(
|
||||
id=provider.id,
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
model=default_model_name,
|
||||
api_key_changed=False,
|
||||
custom_config_changed=False,
|
||||
),
|
||||
@@ -586,9 +579,9 @@ def test_upload_with_custom_config_then_change(
|
||||
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
@@ -616,7 +609,9 @@ def test_upload_with_custom_config_then_change(
|
||||
)
|
||||
|
||||
# Check inside the database and check that custom_config is the same as the original
|
||||
provider = fetch_existing_llm_provider(name=name, db_session=db_session)
|
||||
provider = fetch_llm_provider_view(
|
||||
db_session=db_session, provider_name=name
|
||||
)
|
||||
if not provider:
|
||||
assert False, "Provider not found in the database"
|
||||
|
||||
@@ -642,11 +637,10 @@ def test_preserves_masked_sensitive_custom_config_on_provider_update(
|
||||
}
|
||||
|
||||
try:
|
||||
put_llm_provider(
|
||||
view = put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=provider,
|
||||
default_model_name=default_model_name,
|
||||
custom_config=original_custom_config,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -665,9 +659,9 @@ def test_preserves_masked_sensitive_custom_config_on_provider_update(
|
||||
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
id=view.id,
|
||||
name=name,
|
||||
provider=provider,
|
||||
default_model_name=default_model_name,
|
||||
custom_config={
|
||||
"vertex_credentials": _mask_string(
|
||||
original_custom_config["vertex_credentials"]
|
||||
@@ -719,11 +713,10 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
|
||||
return ""
|
||||
|
||||
try:
|
||||
put_llm_provider(
|
||||
view = put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=provider,
|
||||
default_model_name=default_model_name,
|
||||
custom_config=original_custom_config,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -742,14 +735,10 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
|
||||
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
|
||||
run_llm_config_test(
|
||||
LLMTestRequest(
|
||||
id=view.id,
|
||||
name=name,
|
||||
provider=provider,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
model=default_model_name,
|
||||
api_key_changed=False,
|
||||
custom_config_changed=True,
|
||||
custom_config={
|
||||
|
||||
@@ -18,6 +18,7 @@ from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_llm_provider_view
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import sync_auto_mode_models
|
||||
from onyx.db.llm import update_default_provider
|
||||
@@ -135,7 +136,6 @@ class TestAutoModeSyncFeature:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name=expected_default_model,
|
||||
model_configurations=[], # No model configs provided
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -163,13 +163,8 @@ class TestAutoModeSyncFeature:
|
||||
if mc.name in all_expected_models:
|
||||
assert mc.is_visible is True, f"Model '{mc.name}' should be visible"
|
||||
|
||||
# Verify the default model was set correctly
|
||||
assert (
|
||||
provider.default_model_name == expected_default_model
|
||||
), f"Default model should be '{expected_default_model}'"
|
||||
|
||||
# Step 4: Set the provider as default
|
||||
update_default_provider(provider.id, db_session)
|
||||
update_default_provider(provider.id, expected_default_model, db_session)
|
||||
|
||||
# Step 5: Fetch the default provider and verify
|
||||
default_model = fetch_default_llm_model(db_session)
|
||||
@@ -238,7 +233,6 @@ class TestAutoModeSyncFeature:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name="gpt-4o",
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -310,14 +304,13 @@ class TestAutoModeSyncFeature:
|
||||
|
||||
try:
|
||||
# Step 1: Upload provider WITHOUT auto mode, with initial models
|
||||
put_llm_provider(
|
||||
provider = put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=False, # Not in auto mode initially
|
||||
default_model_name="gpt-4",
|
||||
model_configurations=initial_models,
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -344,12 +337,12 @@ class TestAutoModeSyncFeature:
|
||||
):
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
id=provider.id,
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None, # Not changing API key
|
||||
api_key_changed=False,
|
||||
is_auto_mode=True, # Now enabling auto mode
|
||||
default_model_name=auto_mode_default,
|
||||
model_configurations=[], # Auto mode will sync from config
|
||||
),
|
||||
is_creation=False, # This is an update
|
||||
@@ -360,8 +353,8 @@ class TestAutoModeSyncFeature:
|
||||
# Step 3: Verify model visibility after auto mode transition
|
||||
# Expire session cache to force fresh fetch after sync_auto_mode_models committed
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
provider = fetch_llm_provider_view(
|
||||
db_session=db_session, provider_name=provider_name
|
||||
)
|
||||
assert provider is not None
|
||||
assert provider.is_auto_mode is True
|
||||
@@ -388,9 +381,6 @@ class TestAutoModeSyncFeature:
|
||||
model_visibility[model_name] is False
|
||||
), f"Model '{model_name}' not in auto config should NOT be visible"
|
||||
|
||||
# Verify the default model was updated
|
||||
assert provider.default_model_name == auto_mode_default
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
@@ -432,8 +422,12 @@ class TestAutoModeSyncFeature:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name="gpt-4o",
|
||||
model_configurations=[],
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o",
|
||||
is_visible=True,
|
||||
)
|
||||
],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
@@ -535,7 +529,6 @@ class TestAutoModeSyncFeature:
|
||||
api_key=provider_1_api_key,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name=provider_1_default_model,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -549,7 +542,7 @@ class TestAutoModeSyncFeature:
|
||||
name=provider_1_name, db_session=db_session
|
||||
)
|
||||
assert provider_1 is not None
|
||||
update_default_provider(provider_1.id, db_session)
|
||||
update_default_provider(provider_1.id, provider_1_default_model, db_session)
|
||||
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
@@ -563,7 +556,6 @@ class TestAutoModeSyncFeature:
|
||||
api_key=provider_2_api_key,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name=provider_2_default_model,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -584,7 +576,7 @@ class TestAutoModeSyncFeature:
|
||||
name=provider_2_name, db_session=db_session
|
||||
)
|
||||
assert provider_2 is not None
|
||||
update_default_provider(provider_2.id, db_session)
|
||||
update_default_provider(provider_2.id, provider_2_default_model, db_session)
|
||||
|
||||
# Step 5: Verify provider 2 is now the default
|
||||
db_session.expire_all()
|
||||
|
||||
@@ -64,7 +64,6 @@ def _create_provider(
|
||||
name=name,
|
||||
provider=provider,
|
||||
api_key="sk-ant-api03-...",
|
||||
default_model_name="claude-3-5-sonnet-20240620",
|
||||
is_public=is_public,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -154,7 +153,9 @@ def test_user_sends_message_to_private_provider(
|
||||
)
|
||||
_create_provider(db_session, LlmProviderNames.GOOGLE, "private-provider", False)
|
||||
|
||||
update_default_provider(public_provider_id, db_session)
|
||||
update_default_provider(
|
||||
public_provider_id, "claude-3-5-sonnet-20240620", db_session
|
||||
)
|
||||
|
||||
try:
|
||||
# Create chat session
|
||||
|
||||
@@ -434,7 +434,6 @@ class TestSlackBotFederatedSearch:
|
||||
name=f"test-llm-provider-{uuid4().hex[:8]}",
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=api_key,
|
||||
default_model_name="gpt-4o",
|
||||
is_public=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
@@ -448,7 +447,7 @@ class TestSlackBotFederatedSearch:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
update_default_provider(provider_view.id, db_session)
|
||||
update_default_provider(provider_view.id, "gpt-4o", db_session)
|
||||
|
||||
def _teardown_common_mocks(self, patches: list) -> None:
|
||||
"""Stop all patches"""
|
||||
|
||||
@@ -4,10 +4,12 @@ from uuid import uuid4
|
||||
import requests
|
||||
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -32,7 +34,6 @@ class LLMProviderManager:
|
||||
llm_provider = LLMProviderUpsertRequest(
|
||||
name=name or f"test-provider-{uuid4()}",
|
||||
provider=provider or LlmProviderNames.OPENAI,
|
||||
default_model_name=default_model_name or "gpt-4o-mini",
|
||||
api_key=api_key or os.environ["OPENAI_API_KEY"],
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
@@ -65,7 +66,6 @@ class LLMProviderManager:
|
||||
name=response_data["name"],
|
||||
provider=response_data["provider"],
|
||||
api_key=response_data["api_key"],
|
||||
default_model_name=response_data["default_model_name"],
|
||||
is_public=response_data["is_public"],
|
||||
is_auto_mode=response_data.get("is_auto_mode", False),
|
||||
groups=response_data["groups"],
|
||||
@@ -75,9 +75,20 @@ class LLMProviderManager:
|
||||
)
|
||||
|
||||
if set_as_default:
|
||||
if default_model_name is None:
|
||||
default_model_name = "gpt-4o-mini"
|
||||
|
||||
set_default_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{llm_response.json()['id']}/default",
|
||||
headers=user_performing_action.headers,
|
||||
f"{API_SERVER_URL}/admin/llm/default",
|
||||
json={
|
||||
"provider_id": response_data["id"],
|
||||
"model_name": default_model_name,
|
||||
},
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
)
|
||||
set_default_response.raise_for_status()
|
||||
|
||||
@@ -113,7 +124,12 @@ class LLMProviderManager:
|
||||
verify_deleted: bool = False,
|
||||
) -> None:
|
||||
all_llm_providers = LLMProviderManager.get_all(user_performing_action)
|
||||
default_model = LLMProviderManager.get_default_model(user_performing_action)
|
||||
for fetched_llm_provider in all_llm_providers:
|
||||
model_names = [
|
||||
model.name for model in fetched_llm_provider.model_configurations
|
||||
]
|
||||
|
||||
if llm_provider.id == fetched_llm_provider.id:
|
||||
if verify_deleted:
|
||||
raise ValueError(
|
||||
@@ -126,11 +142,25 @@ class LLMProviderManager:
|
||||
if (
|
||||
fetched_llm_groups == llm_provider_groups
|
||||
and llm_provider.provider == fetched_llm_provider.provider
|
||||
and llm_provider.default_model_name
|
||||
== fetched_llm_provider.default_model_name
|
||||
and default_model.model_name in model_names
|
||||
and llm_provider.is_public == fetched_llm_provider.is_public
|
||||
and set(fetched_llm_provider.personas) == set(llm_provider.personas)
|
||||
):
|
||||
return
|
||||
if not verify_deleted:
|
||||
raise ValueError(f"LLM Provider {llm_provider.id} not found")
|
||||
|
||||
@staticmethod
|
||||
def get_default_model(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DefaultModel:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/llm/default",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
return DefaultModel(**response.json())
|
||||
|
||||
@@ -116,7 +116,6 @@ class DATestLLMProvider(BaseModel):
|
||||
name: str
|
||||
provider: str
|
||||
api_key: str
|
||||
default_model_name: str
|
||||
is_public: bool
|
||||
is_auto_mode: bool = False
|
||||
groups: list[int]
|
||||
|
||||
@@ -72,7 +72,7 @@ def _get_provider_by_id(admin_user: DATestUser, provider_id: int) -> dict:
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
for provider in response.json():
|
||||
for provider in response.json()["providers"]:
|
||||
if provider["id"] == provider_id:
|
||||
return provider
|
||||
raise ValueError(f"Provider with id {provider_id} not found")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -9,6 +9,8 @@ from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import can_user_access_llm_provider
|
||||
from onyx.db.llm import fetch_user_group_ids
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import LLMModelFlow
|
||||
from onyx.db.models import LLMProvider as LLMProviderModel
|
||||
from onyx.db.models import LLMProvider__Persona
|
||||
@@ -20,6 +22,8 @@ from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.persona import PersonaManager
|
||||
@@ -41,24 +45,32 @@ def _create_llm_provider(
|
||||
is_public: bool,
|
||||
is_default: bool,
|
||||
) -> LLMProviderModel:
|
||||
provider = LLMProviderModel(
|
||||
name=name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None,
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
custom_config=None,
|
||||
default_model_name=default_model_name,
|
||||
deployment_name=None,
|
||||
is_public=is_public,
|
||||
# Use None instead of False to avoid unique constraint violation
|
||||
# The is_default_provider column has unique=True, so only one True and one False allowed
|
||||
is_default_provider=is_default if is_default else None,
|
||||
is_default_vision_provider=False,
|
||||
default_vision_model=None,
|
||||
_provider = upsert_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=None,
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
custom_config=None,
|
||||
is_public=is_public,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name,
|
||||
is_visible=True,
|
||||
)
|
||||
],
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.add(provider)
|
||||
db_session.flush()
|
||||
|
||||
if is_default:
|
||||
update_default_provider(_provider.id, default_model_name, db_session)
|
||||
|
||||
provider = db_session.get(LLMProviderModel, _provider.id)
|
||||
if not provider:
|
||||
raise ValueError(f"Provider {name} not found")
|
||||
|
||||
return provider
|
||||
|
||||
|
||||
@@ -321,13 +333,19 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
|
||||
persona=persona,
|
||||
user=admin_model,
|
||||
)
|
||||
assert allowed_llm.config.model_name == restricted_provider.default_model_name
|
||||
assert (
|
||||
allowed_llm.config.model_name
|
||||
== restricted_provider.model_configurations[0].name
|
||||
)
|
||||
|
||||
fallback_llm = get_llm_for_persona(
|
||||
persona=persona,
|
||||
user=basic_model,
|
||||
)
|
||||
assert fallback_llm.config.model_name == default_provider.default_model_name
|
||||
assert (
|
||||
fallback_llm.config.model_name
|
||||
== default_provider.model_configurations[0].name
|
||||
)
|
||||
|
||||
|
||||
def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
@@ -346,6 +364,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
name="public-provider",
|
||||
is_public=True,
|
||||
set_as_default=True,
|
||||
default_model_name="gpt-4o",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
@@ -365,7 +384,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
headers=basic_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
providers = response.json()
|
||||
providers = response.json()["providers"]
|
||||
provider_names = [p["name"] for p in providers]
|
||||
|
||||
# Public provider should be visible
|
||||
@@ -380,7 +399,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert admin_response.status_code == 200
|
||||
admin_providers = admin_response.json()
|
||||
admin_providers = admin_response.json()["providers"]
|
||||
admin_provider_names = [p["name"] for p in admin_providers]
|
||||
|
||||
assert public_provider.name in admin_provider_names
|
||||
@@ -396,6 +415,7 @@ def test_provider_delete_clears_persona_references(reset: None) -> None: # noqa
|
||||
name="default-provider",
|
||||
is_public=True,
|
||||
set_as_default=True,
|
||||
default_model_name="gpt-4o",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ def test_authorized_persona_access_returns_filtered_providers(
|
||||
|
||||
# Should succeed
|
||||
assert response.status_code == 200
|
||||
providers = response.json()
|
||||
providers = response.json()["providers"]
|
||||
|
||||
# Should include the restricted provider since basic_user can access the persona
|
||||
provider_names = [p["name"] for p in providers]
|
||||
@@ -140,7 +140,7 @@ def test_persona_id_zero_applies_rbac(
|
||||
|
||||
# Should succeed (persona_id=0 refers to default persona, which is public)
|
||||
assert response.status_code == 200
|
||||
providers = response.json()
|
||||
providers = response.json()["providers"]
|
||||
|
||||
# Should NOT include the restricted provider since basic_user is not in group2
|
||||
provider_names = [p["name"] for p in providers]
|
||||
@@ -182,7 +182,7 @@ def test_admin_can_query_any_persona(
|
||||
|
||||
# Should succeed - admins can access any persona
|
||||
assert response.status_code == 200
|
||||
providers = response.json()
|
||||
providers = response.json()["providers"]
|
||||
|
||||
# Should include the restricted provider
|
||||
provider_names = [p["name"] for p in providers]
|
||||
@@ -223,7 +223,7 @@ def test_public_persona_accessible_to_all(
|
||||
|
||||
# Should succeed
|
||||
assert response.status_code == 200
|
||||
providers = response.json()
|
||||
providers = response.json()["providers"]
|
||||
|
||||
# Should return the public provider
|
||||
assert len(providers) > 0
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import { ModalCreationInterface } from "@/refresh-components/contexts/ModalContext";
|
||||
import { ImageProvider } from "@/app/admin/configuration/image-generation/constants";
|
||||
import { LLMProviderView } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { LLMProviderView } from "@/interfaces/llm";
|
||||
import { ImageGenerationConfigView } from "@/lib/configuration/imageConfigurationService";
|
||||
import { getImageGenForm } from "./forms";
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import { Select } from "@/refresh-components/cards";
|
||||
import { useCreateModal } from "@/refresh-components/contexts/ModalContext";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { LLMProviderView } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { LLMProviderView } from "@/interfaces/llm";
|
||||
import {
|
||||
IMAGE_PROVIDER_GROUPS,
|
||||
ImageProvider,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { FormikProps } from "formik";
|
||||
import { ImageProvider } from "../constants";
|
||||
import { LLMProviderView } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { LLMProviderView } from "@/interfaces/llm";
|
||||
import {
|
||||
ImageGenerationConfigView,
|
||||
ImageGenerationCredentials,
|
||||
|
||||
@@ -1,84 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import useSWR from "swr";
|
||||
import { Callout } from "@/components/ui/callout";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Title from "@/components/ui/title";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { LLMProviderView } from "./interfaces";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
|
||||
import { OpenAIForm } from "./forms/OpenAIForm";
|
||||
import { AnthropicForm } from "./forms/AnthropicForm";
|
||||
import { OllamaForm } from "./forms/OllamaForm";
|
||||
import { AzureForm } from "./forms/AzureForm";
|
||||
import { BedrockForm } from "./forms/BedrockForm";
|
||||
import { VertexAIForm } from "./forms/VertexAIForm";
|
||||
import { OpenRouterForm } from "./forms/OpenRouterForm";
|
||||
import { getFormForExistingProvider } from "./forms/getForm";
|
||||
import { CustomForm } from "./forms/CustomForm";
|
||||
|
||||
export function LLMConfiguration() {
|
||||
const { data: existingLlmProviders } = useSWR<LLMProviderView[]>(
|
||||
LLM_PROVIDERS_ADMIN_URL,
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
||||
if (!existingLlmProviders) {
|
||||
return <ThreeDotsLoader />;
|
||||
}
|
||||
|
||||
const isFirstProvider = existingLlmProviders.length === 0;
|
||||
|
||||
return (
|
||||
<>
|
||||
<Title className="mb-2">Enabled LLM Providers</Title>
|
||||
|
||||
{existingLlmProviders.length > 0 ? (
|
||||
<>
|
||||
<Text as="p" className="mb-4">
|
||||
If multiple LLM providers are enabled, the default provider will be
|
||||
used for all "Default" Assistants. For user-created
|
||||
Assistants, you can select the LLM provider/model that best fits the
|
||||
use case!
|
||||
</Text>
|
||||
<div className="flex flex-col gap-y-4">
|
||||
{[...existingLlmProviders]
|
||||
.sort((a, b) => {
|
||||
if (a.is_default_provider && !b.is_default_provider) return -1;
|
||||
if (!a.is_default_provider && b.is_default_provider) return 1;
|
||||
return 0;
|
||||
})
|
||||
.map((llmProvider) => (
|
||||
<div key={llmProvider.id}>
|
||||
{getFormForExistingProvider(llmProvider)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<Callout type="warning" title="No LLM providers configured yet">
|
||||
Please set one up below in order to start using Onyx!
|
||||
</Callout>
|
||||
)}
|
||||
|
||||
<Title className="mb-2 mt-6">Add LLM Provider</Title>
|
||||
<Text as="p" className="mb-4">
|
||||
Add a new LLM provider by either selecting from one of the default
|
||||
providers or by specifying your own custom LLM provider.
|
||||
</Text>
|
||||
|
||||
<div className="flex flex-col gap-y-4">
|
||||
<OpenAIForm shouldMarkAsDefault={isFirstProvider} />
|
||||
<AnthropicForm shouldMarkAsDefault={isFirstProvider} />
|
||||
<OllamaForm shouldMarkAsDefault={isFirstProvider} />
|
||||
<AzureForm shouldMarkAsDefault={isFirstProvider} />
|
||||
<BedrockForm shouldMarkAsDefault={isFirstProvider} />
|
||||
<VertexAIForm shouldMarkAsDefault={isFirstProvider} />
|
||||
<OpenRouterForm shouldMarkAsDefault={isFirstProvider} />
|
||||
|
||||
<CustomForm shouldMarkAsDefault={isFirstProvider} />
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -1,13 +1,14 @@
|
||||
"use client";
|
||||
|
||||
import { ArrayHelpers, FieldArray, FormikProps, useField } from "formik";
|
||||
import { ModelConfiguration } from "./interfaces";
|
||||
import { ModelConfiguration } from "@/interfaces/llm";
|
||||
import { ManualErrorMessage, TextFormField } from "@/components/Field";
|
||||
import { useEffect, useState } from "react";
|
||||
import CreateButton from "@/refresh-components/buttons/CreateButton";
|
||||
import { Button } from "@opal/components";
|
||||
import { SvgX } from "@opal/icons";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
|
||||
function ModelConfigurationRow({
|
||||
name,
|
||||
index,
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
import { LLMProviderName, LLMProviderView } from "../interfaces";
|
||||
import { AnthropicForm } from "./AnthropicForm";
|
||||
import { OpenAIForm } from "./OpenAIForm";
|
||||
import { OllamaForm } from "./OllamaForm";
|
||||
import { AzureForm } from "./AzureForm";
|
||||
import { VertexAIForm } from "./VertexAIForm";
|
||||
import { OpenRouterForm } from "./OpenRouterForm";
|
||||
import { CustomForm } from "./CustomForm";
|
||||
import { BedrockForm } from "./BedrockForm";
|
||||
|
||||
export function detectIfRealOpenAIProvider(provider: LLMProviderView) {
|
||||
return (
|
||||
provider.provider === LLMProviderName.OPENAI &&
|
||||
provider.api_key &&
|
||||
!provider.api_base &&
|
||||
Object.keys(provider.custom_config || {}).length === 0
|
||||
);
|
||||
}
|
||||
|
||||
export const getFormForExistingProvider = (provider: LLMProviderView) => {
|
||||
switch (provider.provider) {
|
||||
case LLMProviderName.OPENAI:
|
||||
// "openai" as a provider name can be used for litellm proxy / any OpenAI-compatible provider
|
||||
if (detectIfRealOpenAIProvider(provider)) {
|
||||
return <OpenAIForm existingLlmProvider={provider} />;
|
||||
} else {
|
||||
return <CustomForm existingLlmProvider={provider} />;
|
||||
}
|
||||
case LLMProviderName.ANTHROPIC:
|
||||
return <AnthropicForm existingLlmProvider={provider} />;
|
||||
case LLMProviderName.OLLAMA_CHAT:
|
||||
return <OllamaForm existingLlmProvider={provider} />;
|
||||
case LLMProviderName.AZURE:
|
||||
return <AzureForm existingLlmProvider={provider} />;
|
||||
case LLMProviderName.VERTEX_AI:
|
||||
return <VertexAIForm existingLlmProvider={provider} />;
|
||||
case LLMProviderName.BEDROCK:
|
||||
return <BedrockForm existingLlmProvider={provider} />;
|
||||
case LLMProviderName.OPENROUTER:
|
||||
return <OpenRouterForm existingLlmProvider={provider} />;
|
||||
default:
|
||||
return <CustomForm existingLlmProvider={provider} />;
|
||||
}
|
||||
};
|
||||
@@ -1,14 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { LLMConfiguration } from "./LLMConfiguration";
|
||||
import { SvgCpu } from "@opal/icons";
|
||||
export default function Page() {
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle title="LLM Setup" icon={SvgCpu} />
|
||||
import LLMConfigurationPage from "@/refresh-pages/admin/LLMConfigurationPage";
|
||||
|
||||
<LLMConfiguration />
|
||||
</>
|
||||
);
|
||||
export default function Page() {
|
||||
return <LLMConfigurationPage />;
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ import {
|
||||
BedrockFetchParams,
|
||||
OllamaFetchParams,
|
||||
OpenRouterFetchParams,
|
||||
} from "./interfaces";
|
||||
} from "@/interfaces/llm";
|
||||
import { SvgAws, SvgOpenrouter } from "@opal/icons";
|
||||
|
||||
// Aggregator providers that host models from multiple vendors
|
||||
@@ -106,8 +106,8 @@ export const getProviderIcon = (
|
||||
return CPUIcon;
|
||||
};
|
||||
|
||||
export const isAnthropic = (provider: string, modelName: string) =>
|
||||
provider === "anthropic" || modelName.toLowerCase().includes("claude");
|
||||
export const isAnthropic = (provider: string, modelName?: string) =>
|
||||
provider === "anthropic" || !!modelName?.toLowerCase().includes("claude");
|
||||
|
||||
/**
|
||||
* Fetches Bedrock models directly without any form state dependencies.
|
||||
@@ -153,6 +153,7 @@ export const fetchBedrockModels = async (
|
||||
is_visible: false,
|
||||
max_input_tokens: modelData.max_input_tokens,
|
||||
supports_image_input: modelData.supports_image_input,
|
||||
supports_reasoning: false,
|
||||
}));
|
||||
|
||||
return { models };
|
||||
@@ -205,6 +206,7 @@ export const fetchOllamaModels = async (
|
||||
is_visible: true,
|
||||
max_input_tokens: modelData.max_input_tokens,
|
||||
supports_image_input: modelData.supports_image_input,
|
||||
supports_reasoning: false,
|
||||
}));
|
||||
|
||||
return { models };
|
||||
@@ -262,6 +264,7 @@ export const fetchOpenRouterModels = async (
|
||||
is_visible: true,
|
||||
max_input_tokens: modelData.max_input_tokens,
|
||||
supports_image_input: modelData.supports_image_input,
|
||||
supports_reasoning: false,
|
||||
}));
|
||||
|
||||
return { models };
|
||||
|
||||
@@ -25,7 +25,7 @@ import { ModelOption } from "@/components/embedding/ModelSelector";
|
||||
import {
|
||||
EMBEDDING_MODELS_ADMIN_URL,
|
||||
EMBEDDING_PROVIDERS_ADMIN_URL,
|
||||
} from "@/app/admin/configuration/llm/constants";
|
||||
} from "@/lib/llmConfig/constants";
|
||||
import { AdvancedSearchConfiguration } from "@/app/admin/embeddings/interfaces";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import {
|
||||
import {
|
||||
EMBEDDING_PROVIDERS_ADMIN_URL,
|
||||
LLM_PROVIDERS_ADMIN_URL,
|
||||
} from "@/app/admin/configuration/llm/constants";
|
||||
} from "@/lib/llmConfig/constants";
|
||||
import { mutate } from "swr";
|
||||
import { testEmbedding } from "@/app/admin/embeddings/pages/utils";
|
||||
import { SvgSettings } from "@opal/icons";
|
||||
|
||||
@@ -11,7 +11,7 @@ import {
|
||||
EmbeddingProvider,
|
||||
getFormattedProviderName,
|
||||
} from "@/components/embedding/interfaces";
|
||||
import { EMBEDDING_PROVIDERS_ADMIN_URL } from "@/app/admin/configuration/llm/constants";
|
||||
import { EMBEDDING_PROVIDERS_ADMIN_URL } from "@/lib/llmConfig/constants";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import { SvgSettings } from "@opal/icons";
|
||||
export interface ProviderCreationModalProps {
|
||||
|
||||
@@ -15,7 +15,7 @@ import {
|
||||
import NumberInput from "../../connectors/[connector]/pages/ConnectorInput/NumberInput";
|
||||
import { StringOrNumberOption } from "@/components/Dropdown";
|
||||
import useSWR from "swr";
|
||||
import { LLM_CONTEXTUAL_COST_ADMIN_URL } from "../../configuration/llm/constants";
|
||||
import { LLM_CONTEXTUAL_COST_ADMIN_URL } from "@/lib/llmConfig/constants";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
|
||||
|
||||
@@ -18,7 +18,7 @@ import SourceTag from "@/refresh-components/buttons/source-tag/SourceTag";
|
||||
import { citationsToSourceInfoArray } from "@/refresh-components/buttons/source-tag/sourceTagUtils";
|
||||
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
|
||||
import LLMPopover from "@/refresh-components/popovers/LLMPopover";
|
||||
import { parseLlmDescriptor } from "@/lib/llm/utils";
|
||||
import { parseLlmDescriptor } from "@/lib/llmConfig/utils";
|
||||
import { LlmManager } from "@/lib/hooks";
|
||||
import { Message } from "@/app/app/interfaces";
|
||||
import { SvgThumbsDown, SvgThumbsUp } from "@opal/icons";
|
||||
|
||||
@@ -11,7 +11,7 @@ import Text from "@/refresh-components/texts/Text";
|
||||
import Popover, { PopoverMenu } from "@/refresh-components/Popover";
|
||||
import Switch from "@/refresh-components/inputs/Switch";
|
||||
import LineItem from "@/refresh-components/buttons/LineItem";
|
||||
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { LLMProviderDescriptor } from "@/interfaces/llm";
|
||||
import {
|
||||
BuildLlmSelection,
|
||||
BUILD_MODE_PROVIDERS,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useMemo, useState, useCallback } from "react";
|
||||
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { LLMProviderDescriptor } from "@/interfaces/llm";
|
||||
import {
|
||||
BuildLlmSelection,
|
||||
getBuildLlmSelection,
|
||||
|
||||
@@ -7,7 +7,7 @@ import { usePreProvisionPolling } from "@/app/craft/hooks/usePreProvisionPolling
|
||||
import { CRAFT_SEARCH_PARAM_NAMES } from "@/app/craft/services/searchParams";
|
||||
import { CRAFT_PATH } from "@/app/craft/v1/constants";
|
||||
import { getBuildUserPersona } from "@/app/craft/onboarding/constants";
|
||||
import { useLLMProviders } from "@/lib/hooks/useLLMProviders";
|
||||
import { useLLMProviders } from "@/hooks/useLLMProviders";
|
||||
import { checkPreProvisionedSession } from "@/app/craft/services/apiServices";
|
||||
|
||||
interface UseBuildSessionControllerProps {
|
||||
|
||||
@@ -18,8 +18,8 @@ import {
|
||||
getBuildLlmSelection,
|
||||
getDefaultLlmSelection,
|
||||
} from "@/app/craft/onboarding/constants";
|
||||
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "@/app/admin/configuration/llm/constants";
|
||||
import { LLMProviderDescriptor } from "@/interfaces/llm";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "@/lib/llmConfig/constants";
|
||||
import {
|
||||
buildInitialValues,
|
||||
testApiKeyHelper,
|
||||
|
||||
@@ -5,10 +5,7 @@ import { cn } from "@/lib/utils";
|
||||
import { Disabled } from "@/refresh-components/Disabled";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import SimpleTooltip from "@/refresh-components/SimpleTooltip";
|
||||
import {
|
||||
LLMProviderName,
|
||||
LLMProviderDescriptor,
|
||||
} from "@/app/admin/configuration/llm/interfaces";
|
||||
import { LLMProviderName, LLMProviderDescriptor } from "@/interfaces/llm";
|
||||
|
||||
// Provider configurations
|
||||
export type ProviderKey = "anthropic" | "openai" | "openrouter";
|
||||
|
||||
@@ -19,13 +19,12 @@ const LLM_SELECTION_PRIORITY = [
|
||||
interface MinimalLlmProvider {
|
||||
name: string;
|
||||
provider: string;
|
||||
default_model_name: string;
|
||||
is_default_provider: boolean | null;
|
||||
model_configurations: { name: string; is_visible: boolean }[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the best default LLM selection based on available providers.
|
||||
* Priority: Anthropic > OpenAI > OpenRouter > system default > first available
|
||||
* Priority: Anthropic > OpenAI > OpenRouter > first available
|
||||
*/
|
||||
export function getDefaultLlmSelection(
|
||||
llmProviders: MinimalLlmProvider[] | undefined
|
||||
@@ -44,23 +43,16 @@ export function getDefaultLlmSelection(
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: use the default provider's default model
|
||||
const defaultProvider = llmProviders.find((p) => p.is_default_provider);
|
||||
if (defaultProvider) {
|
||||
return {
|
||||
providerName: defaultProvider.name,
|
||||
provider: defaultProvider.provider,
|
||||
modelName: defaultProvider.default_model_name,
|
||||
};
|
||||
}
|
||||
|
||||
// Final fallback: first available provider
|
||||
// Fallback: first available provider, use its first visible model
|
||||
const firstProvider = llmProviders[0];
|
||||
if (firstProvider) {
|
||||
const firstModel = firstProvider.model_configurations.find(
|
||||
(m) => m.is_visible
|
||||
);
|
||||
return {
|
||||
providerName: firstProvider.name,
|
||||
provider: firstProvider.provider,
|
||||
modelName: firstProvider.default_model_name,
|
||||
modelName: firstModel?.name ?? "",
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
import { useCallback, useState, useMemo, useEffect } from "react";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import { useLLMProviders } from "@/lib/hooks/useLLMProviders";
|
||||
import { LLMProviderName } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { useLLMProviders } from "@/hooks/useLLMProviders";
|
||||
import { LLMProviderName } from "@/interfaces/llm";
|
||||
import {
|
||||
OnboardingModalMode,
|
||||
OnboardingModalController,
|
||||
@@ -18,9 +18,7 @@ import { useBuildSessionStore } from "@/app/craft/hooks/useBuildSessionStore";
|
||||
|
||||
// Check if all 3 build mode providers are configured (anthropic, openai, openrouter)
|
||||
function checkAllProvidersConfigured(
|
||||
llmProviders:
|
||||
| import("@/app/admin/configuration/llm/interfaces").LLMProviderDescriptor[]
|
||||
| undefined
|
||||
llmProviders: import("@/interfaces/llm").LLMProviderDescriptor[] | undefined
|
||||
): boolean {
|
||||
if (!llmProviders || llmProviders.length === 0) {
|
||||
return false;
|
||||
@@ -35,9 +33,7 @@ function checkAllProvidersConfigured(
|
||||
|
||||
// Check if at least one provider is configured
|
||||
function checkHasAnyProvider(
|
||||
llmProviders:
|
||||
| import("@/app/admin/configuration/llm/interfaces").LLMProviderDescriptor[]
|
||||
| undefined
|
||||
llmProviders: import("@/interfaces/llm").LLMProviderDescriptor[] | undefined
|
||||
): boolean {
|
||||
return !!(llmProviders && llmProviders.length > 0);
|
||||
}
|
||||
|
||||
@@ -33,9 +33,7 @@ export interface OnboardingModalController {
|
||||
close: () => void;
|
||||
|
||||
// Data needed for modal
|
||||
llmProviders:
|
||||
| import("@/app/admin/configuration/llm/interfaces").LLMProviderDescriptor[]
|
||||
| undefined;
|
||||
llmProviders: import("@/interfaces/llm").LLMProviderDescriptor[] | undefined;
|
||||
initialValues: {
|
||||
firstName: string;
|
||||
lastName: string;
|
||||
@@ -54,7 +52,9 @@ export interface OnboardingModalController {
|
||||
completeUserInfo: (info: BuildUserInfo) => Promise<void>;
|
||||
completeLlmSetup: () => Promise<void>;
|
||||
refetchLlmProviders: () => Promise<
|
||||
| import("@/app/admin/configuration/llm/interfaces").LLMProviderDescriptor[]
|
||||
| import("@/interfaces/llm").LLMProviderResponse<
|
||||
import("@/interfaces/llm").LLMProviderDescriptor
|
||||
>
|
||||
| undefined
|
||||
>;
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ import Switch from "@/refresh-components/inputs/Switch";
|
||||
import SimpleTooltip from "@/refresh-components/SimpleTooltip";
|
||||
import NotAllowedModal from "@/app/craft/onboarding/components/NotAllowedModal";
|
||||
import { useOnboarding } from "@/app/craft/onboarding/BuildOnboardingProvider";
|
||||
import { useLLMProviders } from "@/lib/hooks/useLLMProviders";
|
||||
import { useLLMProviders } from "@/hooks/useLLMProviders";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import { getProviderIcon } from "@/app/admin/configuration/llm/utils";
|
||||
import {
|
||||
|
||||
@@ -27,6 +27,7 @@ const SETTINGS_LAYOUT_PREFIXES = [
|
||||
"/admin/document-index-migration",
|
||||
"/admin/discord-bot",
|
||||
"/admin/theme",
|
||||
"/admin/configuration/llm",
|
||||
];
|
||||
|
||||
export function ClientLayout({
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import {
|
||||
WellKnownLLMProviderDescriptor,
|
||||
LLMProviderDescriptor,
|
||||
} from "@/app/admin/configuration/llm/interfaces";
|
||||
} from "@/interfaces/llm";
|
||||
import React, {
|
||||
createContext,
|
||||
useContext,
|
||||
@@ -11,9 +11,9 @@ import React, {
|
||||
useCallback,
|
||||
} from "react";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import { useLLMProviders } from "@/lib/hooks/useLLMProviders";
|
||||
import { useLLMProviders } from "@/hooks/useLLMProviders";
|
||||
import { useLLMProviderOptions } from "@/lib/hooks/useLLMProviderOptions";
|
||||
import { testDefaultProvider as testDefaultProviderSvc } from "@/lib/llm/svc";
|
||||
import { testDefaultProvider as testDefaultProviderSvc } from "@/lib/llmConfig/svc";
|
||||
|
||||
interface ProviderContextType {
|
||||
shouldShowConfigurationNeeded: boolean;
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"use client";
|
||||
|
||||
import { useMemo } from "react";
|
||||
import { parseLlmDescriptor, structureValue } from "@/lib/llm/utils";
|
||||
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { parseLlmDescriptor, structureValue } from "@/lib/llmConfig/utils";
|
||||
import { DefaultModel, LLMProviderDescriptor } from "@/interfaces/llm";
|
||||
import { getProviderIcon } from "@/app/admin/configuration/llm/utils";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import { createIcon } from "@/components/icons/icons";
|
||||
@@ -23,6 +23,7 @@ export interface LLMSelectorProps {
|
||||
name?: string;
|
||||
userSettings?: boolean;
|
||||
llmProviders: LLMProviderDescriptor[];
|
||||
defaultText?: DefaultModel | null;
|
||||
currentLlm: string | null;
|
||||
onSelect: (value: string | null) => void;
|
||||
requiresImageGeneration?: boolean;
|
||||
@@ -33,6 +34,7 @@ export default function LLMSelector({
|
||||
name,
|
||||
userSettings,
|
||||
llmProviders,
|
||||
defaultText,
|
||||
currentLlm,
|
||||
onSelect,
|
||||
requiresImageGeneration,
|
||||
@@ -139,11 +141,11 @@ export default function LLMSelector({
|
||||
});
|
||||
}, [llmOptions]);
|
||||
|
||||
const defaultProvider = llmProviders.find(
|
||||
(llmProvider) => llmProvider.is_default_provider
|
||||
);
|
||||
const defaultProvider = defaultText
|
||||
? llmProviders.find((p) => p.id === defaultText.provider_id)
|
||||
: undefined;
|
||||
|
||||
const defaultModelName = defaultProvider?.default_model_name;
|
||||
const defaultModelName = defaultText?.model_name;
|
||||
const defaultModelConfig = defaultProvider?.model_configurations.find(
|
||||
(m) => m.name === defaultModelName
|
||||
);
|
||||
|
||||
@@ -42,7 +42,7 @@ import {
|
||||
getFinalLLM,
|
||||
modelSupportsImageInput,
|
||||
structureValue,
|
||||
} from "@/lib/llm/utils";
|
||||
} from "@/lib/llmConfig/utils";
|
||||
import {
|
||||
CurrentMessageFIFO,
|
||||
updateCurrentMessageFIFO,
|
||||
|
||||
150
web/src/hooks/useLLMProviders.ts
Normal file
150
web/src/hooks/useLLMProviders.ts
Normal file
@@ -0,0 +1,150 @@
|
||||
"use client";
|
||||
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import {
|
||||
LLMProviderDescriptor,
|
||||
LLMProviderResponse,
|
||||
LLMProviderView,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "@/interfaces/llm";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "@/lib/llmConfig/constants";
|
||||
|
||||
/**
|
||||
* Fetches configured LLM providers accessible to the current user.
|
||||
*
|
||||
* Hits the **non-admin** endpoints which return `LLMProviderDescriptor`
|
||||
* (no `id` or sensitive fields like `api_key`). Use this hook in
|
||||
* user-facing UI (chat, popovers, onboarding) where you need the list
|
||||
* of providers and their visible models but don't need admin-level details.
|
||||
*
|
||||
* The backend wraps the provider list in an `LLMProviderResponse` envelope
|
||||
* that also carries the global default text and vision models. This hook
|
||||
* unwraps `.providers` for convenience while still exposing the defaults.
|
||||
*
|
||||
* **Endpoints:**
|
||||
* - No `personaId` → `GET /api/llm/provider`
|
||||
* Returns all public providers plus restricted providers the user can
|
||||
* access via group membership.
|
||||
* - With `personaId` → `GET /api/llm/persona/{personaId}/providers`
|
||||
* Returns providers scoped to a specific persona, respecting RBAC
|
||||
* restrictions. Use this when displaying model options for a particular
|
||||
* assistant.
|
||||
*
|
||||
* @param personaId - Optional persona ID for RBAC-scoped providers.
|
||||
*
|
||||
* @returns
|
||||
* - `llmProviders` — The array of provider descriptors, or `undefined`
|
||||
* while loading.
|
||||
* - `defaultText` — The global (or persona-overridden) default text model.
|
||||
* - `defaultVision` — The global (or persona-overridden) default vision model.
|
||||
* - `isLoading` — `true` until the first successful response or error.
|
||||
* - `error` — The SWR error object, if any.
|
||||
* - `refetch` — SWR `mutate` function to trigger a revalidation.
|
||||
*/
|
||||
export function useLLMProviders(personaId?: number) {
|
||||
const url =
|
||||
personaId !== undefined
|
||||
? `/api/llm/persona/${personaId}/providers`
|
||||
: "/api/llm/provider";
|
||||
|
||||
const { data, error, mutate } = useSWR<
|
||||
LLMProviderResponse<LLMProviderDescriptor>
|
||||
>(url, errorHandlingFetcher, {
|
||||
revalidateOnFocus: false,
|
||||
dedupingInterval: 60000,
|
||||
});
|
||||
|
||||
return {
|
||||
llmProviders: data?.providers,
|
||||
defaultText: data?.default_text ?? null,
|
||||
defaultVision: data?.default_vision ?? null,
|
||||
isLoading: !error && !data,
|
||||
error,
|
||||
refetch: mutate,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetches configured LLM providers via the **admin** endpoint.
|
||||
*
|
||||
* Hits `GET /api/admin/llm/provider` which returns `LLMProviderView` —
|
||||
* the full provider object including `id`, `api_key` (masked),
|
||||
* group/persona assignments, and all other admin-visible fields.
|
||||
*
|
||||
* Use this hook on admin pages (e.g. the LLM Configuration page) where
|
||||
* you need provider IDs for mutations (setting defaults, editing, deleting)
|
||||
* or need to display admin-only metadata. **Do not use in user-facing UI**
|
||||
* — use `useLLMProviders` instead.
|
||||
*
|
||||
* @returns
|
||||
* - `llmProviders` — The array of full provider views, or `undefined`
|
||||
* while loading.
|
||||
* - `defaultText` — The global default text model.
|
||||
* - `defaultVision` — The global default vision model.
|
||||
* - `isLoading` — `true` until the first successful response or error.
|
||||
* - `error` — The SWR error object, if any.
|
||||
* - `refetch` — SWR `mutate` function to trigger a revalidation.
|
||||
*/
|
||||
export function useAdminLLMProviders() {
|
||||
const { data, error, mutate } = useSWR<LLMProviderResponse<LLMProviderView>>(
|
||||
LLM_PROVIDERS_ADMIN_URL,
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
dedupingInterval: 60000,
|
||||
}
|
||||
);
|
||||
|
||||
return {
|
||||
llmProviders: data?.providers,
|
||||
defaultText: data?.default_text ?? null,
|
||||
defaultVision: data?.default_vision ?? null,
|
||||
isLoading: !error && !data,
|
||||
error,
|
||||
refetch: mutate,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetches the catalog of well-known (built-in) LLM providers.
|
||||
*
|
||||
* Hits `GET /api/admin/llm/built-in/options` which returns the static
|
||||
* list of provider descriptors that Onyx ships with out of the box
|
||||
* (OpenAI, Anthropic, Vertex AI, Bedrock, Azure, Ollama, OpenRouter,
|
||||
* etc.). Each descriptor includes the provider's known models and the
|
||||
* recommended default model.
|
||||
*
|
||||
* Used primarily on the LLM Configuration page and onboarding flows
|
||||
* to show which providers are available to set up, and to pre-populate
|
||||
* model lists before the user has entered credentials.
|
||||
*
|
||||
* @returns
|
||||
* - `wellKnownLLMProviders` — The array of built-in provider descriptors,
|
||||
* or `null` while loading.
|
||||
* - `isLoading` — `true` until the first successful response or error.
|
||||
* - `error` — The SWR error object, if any.
|
||||
* - `mutate` — SWR `mutate` function to trigger a revalidation.
|
||||
*/
|
||||
export function useWellKnownLLMProviders() {
|
||||
const {
|
||||
data: wellKnownLLMProviders,
|
||||
error,
|
||||
isLoading,
|
||||
mutate,
|
||||
} = useSWR<WellKnownLLMProviderDescriptor[]>(
|
||||
"/api/admin/llm/built-in/options",
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
dedupingInterval: 60000,
|
||||
}
|
||||
);
|
||||
|
||||
return {
|
||||
wellKnownLLMProviders: wellKnownLLMProviders ?? null,
|
||||
isLoading,
|
||||
error,
|
||||
mutate,
|
||||
};
|
||||
}
|
||||
@@ -13,8 +13,8 @@ export interface ModelConfiguration {
|
||||
name: string;
|
||||
is_visible: boolean;
|
||||
max_input_tokens: number | null;
|
||||
supports_image_input: boolean | null;
|
||||
supports_reasoning?: boolean;
|
||||
supports_image_input: boolean;
|
||||
supports_reasoning: boolean;
|
||||
display_name?: string;
|
||||
provider_display_name?: string;
|
||||
vendor?: string;
|
||||
@@ -30,7 +30,6 @@ export interface SimpleKnownModel {
|
||||
export interface WellKnownLLMProviderDescriptor {
|
||||
name: string;
|
||||
known_models: ModelConfiguration[];
|
||||
|
||||
recommended_default_model: SimpleKnownModel | null;
|
||||
}
|
||||
|
||||
@@ -40,44 +39,31 @@ export interface LLMModelDescriptor {
|
||||
maxTokens: number;
|
||||
}
|
||||
|
||||
export interface LLMProvider {
|
||||
export interface LLMProviderView {
|
||||
id: number;
|
||||
name: string;
|
||||
provider: string;
|
||||
api_key: string | null;
|
||||
api_base: string | null;
|
||||
api_version: string | null;
|
||||
custom_config: { [key: string]: string } | null;
|
||||
default_model_name: string;
|
||||
is_public: boolean;
|
||||
is_auto_mode: boolean;
|
||||
groups: number[];
|
||||
personas: number[];
|
||||
deployment_name: string | null;
|
||||
default_vision_model: string | null;
|
||||
is_default_vision_provider: boolean | null;
|
||||
model_configurations: ModelConfiguration[];
|
||||
}
|
||||
|
||||
export interface LLMProviderView extends LLMProvider {
|
||||
id: number;
|
||||
is_default_provider: boolean | null;
|
||||
}
|
||||
|
||||
export interface VisionProvider extends LLMProviderView {
|
||||
vision_models: string[];
|
||||
}
|
||||
|
||||
export interface LLMProviderDescriptor {
|
||||
id: number;
|
||||
name: string;
|
||||
provider: string;
|
||||
provider_display_name?: string;
|
||||
default_model_name: string;
|
||||
is_default_provider: boolean | null;
|
||||
is_default_vision_provider?: boolean | null;
|
||||
default_vision_model?: string | null;
|
||||
is_public?: boolean;
|
||||
groups?: number[];
|
||||
personas?: number[];
|
||||
provider_display_name: string;
|
||||
model_configurations: ModelConfiguration[];
|
||||
}
|
||||
|
||||
@@ -102,9 +88,22 @@ export interface BedrockModelResponse {
|
||||
supports_image_input: boolean;
|
||||
}
|
||||
|
||||
export interface DefaultModel {
|
||||
provider_id: number;
|
||||
model_name: string;
|
||||
}
|
||||
|
||||
export interface LLMProviderResponse<T> {
|
||||
providers: T[];
|
||||
default_text: DefaultModel | null;
|
||||
default_vision: DefaultModel | null;
|
||||
}
|
||||
|
||||
export interface LLMProviderFormProps {
|
||||
existingLlmProvider?: LLMProviderView;
|
||||
shouldMarkAsDefault?: boolean;
|
||||
open?: boolean;
|
||||
onOpenChange?: (open: boolean) => void;
|
||||
}
|
||||
|
||||
// Param types for model fetching functions - use snake_case to match API structure
|
||||
@@ -2,8 +2,8 @@ import {
|
||||
getDefaultLlmDescriptor,
|
||||
getValidLlmDescriptorForProviders,
|
||||
} from "@/lib/hooks";
|
||||
import { structureValue } from "@/lib/llm/utils";
|
||||
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { structureValue } from "@/lib/llmConfig/utils";
|
||||
import { LLMProviderDescriptor } from "@/interfaces/llm";
|
||||
import { makeProvider } from "@tests/setup/llmProviderTestUtils";
|
||||
|
||||
describe("LLM resolver helpers", () => {
|
||||
@@ -11,29 +11,30 @@ describe("LLM resolver helpers", () => {
|
||||
const sharedModel = "shared-runtime-model";
|
||||
const providers: LLMProviderDescriptor[] = [
|
||||
makeProvider({
|
||||
id: 1,
|
||||
name: "OpenAI Provider",
|
||||
provider: "openai",
|
||||
default_model_name: sharedModel,
|
||||
is_default_provider: true,
|
||||
model_configurations: [
|
||||
{
|
||||
name: sharedModel,
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
],
|
||||
}),
|
||||
makeProvider({
|
||||
id: 2,
|
||||
name: "Anthropic Provider",
|
||||
provider: "anthropic",
|
||||
default_model_name: sharedModel,
|
||||
model_configurations: [
|
||||
{
|
||||
name: sharedModel,
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
],
|
||||
}),
|
||||
@@ -54,29 +55,30 @@ describe("LLM resolver helpers", () => {
|
||||
test("falls back to default provider when model is unavailable", () => {
|
||||
const providers: LLMProviderDescriptor[] = [
|
||||
makeProvider({
|
||||
id: 10,
|
||||
name: "Default OpenAI",
|
||||
provider: "openai",
|
||||
default_model_name: "gpt-4o-mini",
|
||||
is_default_provider: true,
|
||||
model_configurations: [
|
||||
{
|
||||
name: "gpt-4o-mini",
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: true,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
],
|
||||
}),
|
||||
makeProvider({
|
||||
id: 20,
|
||||
name: "Anthropic Backup",
|
||||
provider: "anthropic",
|
||||
default_model_name: "claude-3-5-sonnet",
|
||||
model_configurations: [
|
||||
{
|
||||
name: "claude-3-5-sonnet",
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: true,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
],
|
||||
}),
|
||||
@@ -97,30 +99,30 @@ describe("LLM resolver helpers", () => {
|
||||
test("uses first provider with models when no explicit default exists", () => {
|
||||
const providers: LLMProviderDescriptor[] = [
|
||||
makeProvider({
|
||||
id: 30,
|
||||
name: "First Provider",
|
||||
provider: "openai",
|
||||
default_model_name: "gpt-first",
|
||||
is_default_provider: false,
|
||||
model_configurations: [
|
||||
{
|
||||
name: "gpt-first",
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
],
|
||||
}),
|
||||
makeProvider({
|
||||
id: 40,
|
||||
name: "Second Provider",
|
||||
provider: "anthropic",
|
||||
default_model_name: "claude-second",
|
||||
is_default_provider: false,
|
||||
model_configurations: [
|
||||
{
|
||||
name: "claude-second",
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
],
|
||||
}),
|
||||
|
||||
@@ -23,23 +23,22 @@ import {
|
||||
} from "react";
|
||||
import { DateRangePickerValue } from "@/components/dateRangeSelectors/AdminDateRangeSelector";
|
||||
import { SourceMetadata } from "./search/interfaces";
|
||||
import { parseLlmDescriptor } from "./llm/utils";
|
||||
import { parseLlmDescriptor } from "./llmConfig/utils";
|
||||
import { ChatSession } from "@/app/app/interfaces";
|
||||
import { AllUsersResponse } from "./types";
|
||||
import { Credential } from "./connectors/credentials";
|
||||
import { SettingsContext } from "@/providers/SettingsProvider";
|
||||
import {
|
||||
MinimalPersonaSnapshot,
|
||||
PersonaLabel,
|
||||
} from "@/app/admin/assistants/interfaces";
|
||||
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { DefaultModel, LLMProviderDescriptor } from "@/interfaces/llm";
|
||||
import { isAnthropic } from "@/app/admin/configuration/llm/utils";
|
||||
import { getSourceMetadataForSources } from "./sources";
|
||||
import { AuthType, NEXT_PUBLIC_CLOUD_ENABLED } from "./constants";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import { SEARCH_TOOL_ID } from "@/app/app/components/tools/constants";
|
||||
import { updateTemperatureOverrideForChatSession } from "@/app/app/services/lib";
|
||||
import { useLLMProviders } from "./hooks/useLLMProviders";
|
||||
import { useLLMProviders } from "@/hooks/useLLMProviders";
|
||||
|
||||
const CREDENTIAL_URL = "/api/manage/admin/credential";
|
||||
|
||||
@@ -525,26 +524,31 @@ providing appropriate defaults for new conversations based on the available tool
|
||||
*/
|
||||
|
||||
export function getDefaultLlmDescriptor(
|
||||
llmProviders: LLMProviderDescriptor[]
|
||||
llmProviders: LLMProviderDescriptor[],
|
||||
defaultText?: DefaultModel | null
|
||||
): LlmDescriptor | null {
|
||||
const defaultProvider = llmProviders.find(
|
||||
(provider) => provider.is_default_provider
|
||||
);
|
||||
if (defaultProvider) {
|
||||
return {
|
||||
name: defaultProvider.name,
|
||||
provider: defaultProvider.provider,
|
||||
modelName: defaultProvider.default_model_name,
|
||||
};
|
||||
if (defaultText) {
|
||||
const provider = llmProviders.find((p) => p.id === defaultText.provider_id);
|
||||
if (provider) {
|
||||
return {
|
||||
name: provider.name,
|
||||
provider: provider.provider,
|
||||
modelName: defaultText.model_name,
|
||||
};
|
||||
}
|
||||
}
|
||||
// Fallback: first provider with visible models
|
||||
const firstLlmProvider = llmProviders.find(
|
||||
(provider) => provider.model_configurations.length > 0
|
||||
);
|
||||
if (firstLlmProvider) {
|
||||
const firstModel = firstLlmProvider.model_configurations.find(
|
||||
(m) => m.is_visible
|
||||
);
|
||||
return {
|
||||
name: firstLlmProvider.name,
|
||||
provider: firstLlmProvider.provider,
|
||||
modelName: firstLlmProvider.default_model_name,
|
||||
modelName: firstModel?.name ?? "",
|
||||
};
|
||||
}
|
||||
return null;
|
||||
@@ -629,19 +633,25 @@ export function useLlmManager(
|
||||
|
||||
// Get all user-accessible providers via SWR (general providers - no persona filter)
|
||||
// This includes public + all restricted providers user can access via groups
|
||||
const { llmProviders: allUserProviders, isLoading: isLoadingAllProviders } =
|
||||
useLLMProviders();
|
||||
const {
|
||||
llmProviders: allUserProviders,
|
||||
defaultText: allUserDefaultText,
|
||||
isLoading: isLoadingAllProviders,
|
||||
} = useLLMProviders();
|
||||
// Fetch persona-specific providers to enforce RBAC restrictions per assistant
|
||||
// Only fetch if we have an assistant selected
|
||||
const personaId =
|
||||
liveAssistant?.id !== undefined ? liveAssistant.id : undefined;
|
||||
const {
|
||||
llmProviders: personaProviders,
|
||||
defaultText: personaDefaultText,
|
||||
isLoading: isLoadingPersonaProviders,
|
||||
} = useLLMProviders(personaId);
|
||||
|
||||
const llmProviders =
|
||||
personaProviders !== undefined ? personaProviders : allUserProviders;
|
||||
const defaultText =
|
||||
personaProviders !== undefined ? personaDefaultText : allUserDefaultText;
|
||||
|
||||
const [userHasManuallyOverriddenLLM, setUserHasManuallyOverriddenLLM] =
|
||||
useState(false);
|
||||
@@ -700,7 +710,7 @@ export function useLlmManager(
|
||||
} else if (user?.preferences?.default_model) {
|
||||
setCurrentLlm(getValidLlmDescriptor(user.preferences.default_model));
|
||||
} else {
|
||||
const defaultLlm = getDefaultLlmDescriptor(llmProviders);
|
||||
const defaultLlm = getDefaultLlmDescriptor(llmProviders, defaultText);
|
||||
if (defaultLlm) {
|
||||
setCurrentLlm(defaultLlm);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import useSWR from "swr";
|
||||
import { WellKnownLLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { WellKnownLLMProviderDescriptor } from "@/interfaces/llm";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
|
||||
export function useLLMProviderOptions() {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import useSWR from "swr";
|
||||
import { useLLMProviders } from "./useLLMProviders";
|
||||
import { useLLMProviders } from "@/hooks/useLLMProviders";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
|
||||
jest.mock("swr", () => ({
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
import useSWR from "swr";
|
||||
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
|
||||
export function useLLMProviders(personaId?: number) {
|
||||
// personaId can be:
|
||||
// - undefined: public providers only (/api/llm/provider)
|
||||
// - number (personaId): persona-specific providers with RBAC enforcement
|
||||
|
||||
const url =
|
||||
typeof personaId === "number"
|
||||
? `/api/llm/persona/${personaId}/providers`
|
||||
: "/api/llm/provider";
|
||||
|
||||
const { data, error, mutate } = useSWR<LLMProviderDescriptor[] | undefined>(
|
||||
url,
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false, // Cache aggressively for performance
|
||||
dedupingInterval: 60000, // Dedupe requests within 1 minute
|
||||
}
|
||||
);
|
||||
|
||||
return {
|
||||
llmProviders: data,
|
||||
isLoading: !error && !data,
|
||||
error,
|
||||
refetch: mutate,
|
||||
};
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { fetchSS } from "../utilsSS";
|
||||
|
||||
export async function fetchLLMProvidersSS() {
|
||||
const response = await fetchSS("/llm/provider");
|
||||
if (response.ok) {
|
||||
return (await response.json()) as LLMProviderDescriptor[];
|
||||
}
|
||||
return [];
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
/**
|
||||
* LLM action functions for mutations.
|
||||
*
|
||||
* These are async functions for one-off actions that don't need SWR caching.
|
||||
*
|
||||
* Endpoints:
|
||||
* - /api/admin/llm/test/default - Test the default LLM provider connection
|
||||
*/
|
||||
|
||||
/**
|
||||
* Test the default LLM provider.
|
||||
* Returns true if the default provider is configured and working, false otherwise.
|
||||
*/
|
||||
export async function testDefaultProvider(): Promise<boolean> {
|
||||
try {
|
||||
const response = await fetch("/api/admin/llm/test/default", {
|
||||
method: "POST",
|
||||
});
|
||||
return response?.ok || false;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
export const LLM_PROVIDERS_ADMIN_URL = "/api/admin/llm/provider";
|
||||
export const LLM_ADMIN_URL = "/api/admin/llm";
|
||||
export const LLM_PROVIDERS_ADMIN_URL = `${LLM_ADMIN_URL}/provider`;
|
||||
|
||||
export const LLM_CONTEXTUAL_COST_ADMIN_URL =
|
||||
"/api/admin/llm/provider-contextual-cost";
|
||||
68
web/src/lib/llmConfig/providers.ts
Normal file
68
web/src/lib/llmConfig/providers.ts
Normal file
@@ -0,0 +1,68 @@
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import {
|
||||
SvgCpu,
|
||||
SvgOpenai,
|
||||
SvgClaude,
|
||||
SvgOllama,
|
||||
SvgCloud,
|
||||
SvgAws,
|
||||
SvgOpenrouter,
|
||||
SvgServer,
|
||||
SvgAzure,
|
||||
SvgGemini,
|
||||
SvgLitellm,
|
||||
} from "@opal/icons";
|
||||
|
||||
const PROVIDER_ICONS: Record<string, IconFunctionComponent> = {
|
||||
openai: SvgOpenai,
|
||||
anthropic: SvgClaude,
|
||||
vertex_ai: SvgGemini,
|
||||
bedrock: SvgAws,
|
||||
azure: SvgAzure,
|
||||
litellm: SvgLitellm,
|
||||
ollama_chat: SvgOllama,
|
||||
openrouter: SvgOpenrouter,
|
||||
|
||||
// fallback
|
||||
custom: SvgServer,
|
||||
};
|
||||
|
||||
const PROVIDER_PRODUCT_NAMES: Record<string, string> = {
|
||||
openai: "GPT",
|
||||
anthropic: "Claude",
|
||||
vertex_ai: "Gemini",
|
||||
bedrock: "Amazon Bedrock",
|
||||
azure: "Azure OpenAI",
|
||||
litellm: "LiteLLM",
|
||||
ollama_chat: "Ollama",
|
||||
openrouter: "OpenRouter",
|
||||
|
||||
// fallback
|
||||
custom: "Custom Models",
|
||||
};
|
||||
|
||||
const PROVIDER_DISPLAY_NAMES: Record<string, string> = {
|
||||
openai: "OpenAI",
|
||||
anthropic: "Anthropic",
|
||||
vertex_ai: "Google Cloud Vertex AI",
|
||||
bedrock: "AWS",
|
||||
azure: "Microsoft Azure",
|
||||
litellm: "LiteLLM",
|
||||
ollama_chat: "Ollama",
|
||||
openrouter: "OpenRouter",
|
||||
|
||||
// fallback
|
||||
custom: "Other providers or self-hosted",
|
||||
};
|
||||
|
||||
export function getProviderProductName(providerName: string): string {
|
||||
return PROVIDER_PRODUCT_NAMES[providerName] ?? providerName;
|
||||
}
|
||||
|
||||
export function getProviderDisplayName(providerName: string): string {
|
||||
return PROVIDER_DISPLAY_NAMES[providerName] ?? providerName;
|
||||
}
|
||||
|
||||
export function getProviderIcon(providerName: string): IconFunctionComponent {
|
||||
return PROVIDER_ICONS[providerName] ?? SvgCpu;
|
||||
}
|
||||
71
web/src/lib/llmConfig/svc.ts
Normal file
71
web/src/lib/llmConfig/svc.ts
Normal file
@@ -0,0 +1,71 @@
|
||||
/**
|
||||
* LLM action functions for mutations.
|
||||
*
|
||||
* These are async functions for one-off actions that don't need SWR caching.
|
||||
*
|
||||
* Endpoints:
|
||||
* - /api/admin/llm/test/default - Test the default LLM provider connection
|
||||
* - /api/admin/llm/default - Set the default LLM model
|
||||
* - /api/admin/llm/provider/{id} - Delete an LLM provider
|
||||
*/
|
||||
|
||||
import {
|
||||
LLM_ADMIN_URL,
|
||||
LLM_PROVIDERS_ADMIN_URL,
|
||||
} from "@/lib/llmConfig/constants";
|
||||
|
||||
/**
|
||||
* Test the default LLM provider.
|
||||
* Returns true if the default provider is configured and working, false otherwise.
|
||||
*/
|
||||
export async function testDefaultProvider(): Promise<boolean> {
|
||||
try {
|
||||
const response = await fetch(`${LLM_ADMIN_URL}/test/default`, {
|
||||
method: "POST",
|
||||
});
|
||||
return response?.ok || false;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the default LLM model.
|
||||
* @param providerId - The provider ID
|
||||
* @param modelName - The model name within that provider
|
||||
* @throws Error with the detail message from the API on failure
|
||||
*/
|
||||
export async function setDefaultLlmModel(
|
||||
providerId: number,
|
||||
modelName: string
|
||||
): Promise<void> {
|
||||
const response = await fetch(`${LLM_ADMIN_URL}/default`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
provider_id: providerId,
|
||||
model_name: modelName,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
throw new Error(errorMsg);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete an LLM provider.
|
||||
* @param providerId - The provider ID to delete
|
||||
* @throws Error with the detail message from the API on failure
|
||||
*/
|
||||
export async function deleteLlmProvider(providerId: number): Promise<void> {
|
||||
const response = await fetch(`${LLM_PROVIDERS_ADMIN_URL}/${providerId}`, {
|
||||
method: "DELETE",
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
throw new Error(errorMsg);
|
||||
}
|
||||
}
|
||||
@@ -1,21 +1,28 @@
|
||||
import { MinimalPersonaSnapshot } from "@/app/admin/assistants/interfaces";
|
||||
import {
|
||||
DefaultModel,
|
||||
LLMProviderDescriptor,
|
||||
ModelConfiguration,
|
||||
} from "@/app/admin/configuration/llm/interfaces";
|
||||
} from "@/interfaces/llm";
|
||||
import { LlmDescriptor } from "@/lib/hooks";
|
||||
|
||||
export function getFinalLLM(
|
||||
llmProviders: LLMProviderDescriptor[],
|
||||
persona: MinimalPersonaSnapshot | null,
|
||||
currentLlm: LlmDescriptor | null
|
||||
currentLlm: LlmDescriptor | null,
|
||||
defaultText?: DefaultModel | null
|
||||
): [string, string] {
|
||||
const defaultProvider = llmProviders.find(
|
||||
(llmProvider) => llmProvider.is_default_provider
|
||||
);
|
||||
const defaultProvider = defaultText
|
||||
? llmProviders.find((p) => p.id === defaultText.provider_id)
|
||||
: llmProviders.find((p) =>
|
||||
p.model_configurations.some((m) => m.is_visible)
|
||||
);
|
||||
|
||||
let provider = defaultProvider?.provider || "";
|
||||
let model = defaultProvider?.default_model_name || "";
|
||||
let model =
|
||||
defaultText?.model_name ||
|
||||
defaultProvider?.model_configurations.find((m) => m.is_visible)?.name ||
|
||||
"";
|
||||
|
||||
if (persona) {
|
||||
// Map "provider override" to actual LLLMProvider
|
||||
@@ -1,4 +1,4 @@
|
||||
import { VisionProvider } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { VisionProvider } from "@/interfaces/llm";
|
||||
|
||||
export async function fetchVisionProviders(): Promise<VisionProvider[]> {
|
||||
const response = await fetch("/api/admin/llm/vision-providers", {
|
||||
@@ -4,7 +4,7 @@ import NameStep from "./steps/NameStep";
|
||||
import LLMStep from "./steps/LLMStep";
|
||||
import FinalStep from "./steps/FinalStep";
|
||||
import { OnboardingActions, OnboardingState, OnboardingStep } from "./types";
|
||||
import { WellKnownLLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { WellKnownLLMProviderDescriptor } from "@/interfaces/llm";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import { UserRole } from "@/lib/types";
|
||||
import NonAdminStep from "./components/NonAdminStep";
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { ModelConfiguration } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { ModelConfiguration } from "@/interfaces/llm";
|
||||
import { parseAzureTargetUri } from "@/lib/azureTargetUri";
|
||||
|
||||
export const buildInitialValues = () => ({
|
||||
@@ -93,7 +93,7 @@ export const testApiKeyHelper = async (
|
||||
...(formValues?.custom_config ?? {}),
|
||||
...(customConfigOverride ?? {}),
|
||||
},
|
||||
default_model_name: modelName ?? formValues?.default_model_name ?? "",
|
||||
model: modelName ?? formValues?.default_model_name ?? "",
|
||||
model_configurations: [
|
||||
...(formValues.model_configurations || []).map(
|
||||
(model: ModelConfiguration) => ({
|
||||
|
||||
@@ -8,7 +8,7 @@ import Separator from "@/refresh-components/Separator";
|
||||
import {
|
||||
ModelConfiguration,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "@/app/admin/configuration/llm/interfaces";
|
||||
} from "@/interfaces/llm";
|
||||
import {
|
||||
OnboardingFormWrapper,
|
||||
OnboardingFormChildProps,
|
||||
|
||||
@@ -6,7 +6,7 @@ import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import PasswordInputTypeIn from "@/refresh-components/inputs/PasswordInputTypeIn";
|
||||
import InputComboBox from "@/refresh-components/inputs/InputComboBox";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import { WellKnownLLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { WellKnownLLMProviderDescriptor } from "@/interfaces/llm";
|
||||
import {
|
||||
OnboardingFormWrapper,
|
||||
OnboardingFormChildProps,
|
||||
|
||||
@@ -11,7 +11,7 @@ import Text from "@/refresh-components/texts/Text";
|
||||
import { Button } from "@opal/components";
|
||||
import { cn, noProp } from "@/lib/utils";
|
||||
import { SvgAlertCircle, SvgRefreshCw } from "@opal/icons";
|
||||
import { WellKnownLLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { WellKnownLLMProviderDescriptor } from "@/interfaces/llm";
|
||||
import {
|
||||
OnboardingFormWrapper,
|
||||
OnboardingFormChildProps,
|
||||
|
||||
@@ -12,7 +12,7 @@ import { Button } from "@opal/components";
|
||||
import Tabs from "@/refresh-components/Tabs";
|
||||
import { cn, noProp } from "@/lib/utils";
|
||||
import { SvgRefreshCw } from "@opal/icons";
|
||||
import { WellKnownLLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { WellKnownLLMProviderDescriptor } from "@/interfaces/llm";
|
||||
import {
|
||||
OnboardingFormWrapper,
|
||||
OnboardingFormChildProps,
|
||||
|
||||
@@ -5,8 +5,8 @@ import ProviderModal from "@/components/modals/ProviderModal";
|
||||
import {
|
||||
ModelConfiguration,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "@/app/admin/configuration/llm/interfaces";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "@/app/admin/configuration/llm/constants";
|
||||
} from "@/interfaces/llm";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "@/lib/llmConfig/constants";
|
||||
import { OnboardingActions, OnboardingState } from "../types";
|
||||
import { APIFormFieldState } from "@/refresh-components/form/types";
|
||||
import {
|
||||
|
||||
@@ -8,7 +8,7 @@ import Separator from "@/refresh-components/Separator";
|
||||
import {
|
||||
ModelConfiguration,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "@/app/admin/configuration/llm/interfaces";
|
||||
} from "@/interfaces/llm";
|
||||
import {
|
||||
OnboardingFormWrapper,
|
||||
OnboardingFormChildProps,
|
||||
|
||||
@@ -8,7 +8,7 @@ import Separator from "@/refresh-components/Separator";
|
||||
import { Button } from "@opal/components";
|
||||
import { cn, noProp } from "@/lib/utils";
|
||||
import { SvgRefreshCw } from "@opal/icons";
|
||||
import { WellKnownLLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { WellKnownLLMProviderDescriptor } from "@/interfaces/llm";
|
||||
import {
|
||||
OnboardingFormWrapper,
|
||||
OnboardingFormChildProps,
|
||||
|
||||
@@ -10,7 +10,7 @@ import { SvgRefreshCw } from "@opal/icons";
|
||||
import {
|
||||
ModelConfiguration,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "@/app/admin/configuration/llm/interfaces";
|
||||
} from "@/interfaces/llm";
|
||||
import {
|
||||
OnboardingFormWrapper,
|
||||
OnboardingFormChildProps,
|
||||
|
||||
@@ -9,7 +9,7 @@ import {
|
||||
WellKnownLLMProviderDescriptor,
|
||||
LLMProviderName,
|
||||
ModelConfiguration,
|
||||
} from "@/app/admin/configuration/llm/interfaces";
|
||||
} from "@/interfaces/llm";
|
||||
import {
|
||||
OnboardingState,
|
||||
OnboardingActions,
|
||||
@@ -34,12 +34,14 @@ export function createMockLLMDescriptor(
|
||||
is_visible: true,
|
||||
max_input_tokens: 4096,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
{
|
||||
name: "test-model-2",
|
||||
is_visible: true,
|
||||
max_input_tokens: 8192,
|
||||
supports_image_input: true,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
],
|
||||
recommended_default_model: null,
|
||||
@@ -170,36 +172,42 @@ export const OPENAI_DEFAULT_VISIBLE_MODELS = [
|
||||
is_visible: true,
|
||||
max_input_tokens: 128000,
|
||||
supports_image_input: true,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
{
|
||||
name: "gpt-5-mini",
|
||||
is_visible: true,
|
||||
max_input_tokens: 128000,
|
||||
supports_image_input: true,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
{
|
||||
name: "o1",
|
||||
is_visible: true,
|
||||
max_input_tokens: 200000,
|
||||
supports_image_input: true,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
{
|
||||
name: "o3-mini",
|
||||
is_visible: true,
|
||||
max_input_tokens: 200000,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
{
|
||||
name: "gpt-4o",
|
||||
is_visible: true,
|
||||
max_input_tokens: 128000,
|
||||
supports_image_input: true,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
{
|
||||
name: "gpt-4o-mini",
|
||||
is_visible: true,
|
||||
max_input_tokens: 128000,
|
||||
supports_image_input: true,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
];
|
||||
|
||||
@@ -214,18 +222,21 @@ export const ANTHROPIC_DEFAULT_VISIBLE_MODELS = [
|
||||
is_visible: true,
|
||||
max_input_tokens: 200000,
|
||||
supports_image_input: true,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
{
|
||||
name: "claude-sonnet-4-5",
|
||||
is_visible: true,
|
||||
max_input_tokens: 200000,
|
||||
supports_image_input: true,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
{
|
||||
name: "claude-haiku-4-5",
|
||||
is_visible: true,
|
||||
max_input_tokens: 200000,
|
||||
supports_image_input: true,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
];
|
||||
|
||||
@@ -240,18 +251,21 @@ export const VERTEXAI_DEFAULT_VISIBLE_MODELS = [
|
||||
is_visible: true,
|
||||
max_input_tokens: 1048576,
|
||||
supports_image_input: true,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
{
|
||||
name: "gemini-2.5-flash-lite",
|
||||
is_visible: true,
|
||||
max_input_tokens: 1048576,
|
||||
supports_image_input: true,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
{
|
||||
name: "gemini-2.5-pro",
|
||||
is_visible: true,
|
||||
max_input_tokens: 1048576,
|
||||
supports_image_input: true,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
];
|
||||
|
||||
@@ -270,12 +284,14 @@ export const MOCK_PROVIDERS = {
|
||||
is_visible: true,
|
||||
max_input_tokens: 4096,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
{
|
||||
name: "mistral",
|
||||
is_visible: true,
|
||||
max_input_tokens: 8192,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
]),
|
||||
azure: createMockLLMDescriptor(LLMProviderName.AZURE, [
|
||||
@@ -284,6 +300,7 @@ export const MOCK_PROVIDERS = {
|
||||
is_visible: true,
|
||||
max_input_tokens: 8192,
|
||||
supports_image_input: true,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
]),
|
||||
bedrock: createMockLLMDescriptor(LLMProviderName.BEDROCK, [
|
||||
@@ -292,6 +309,7 @@ export const MOCK_PROVIDERS = {
|
||||
is_visible: true,
|
||||
max_input_tokens: 200000,
|
||||
supports_image_input: true,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
]),
|
||||
vertexAi: createMockLLMDescriptor(
|
||||
@@ -304,6 +322,7 @@ export const MOCK_PROVIDERS = {
|
||||
is_visible: true,
|
||||
max_input_tokens: 8192,
|
||||
supports_image_input: true,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
]),
|
||||
};
|
||||
|
||||
@@ -2,7 +2,7 @@ import React from "react";
|
||||
import {
|
||||
WellKnownLLMProviderDescriptor,
|
||||
LLMProviderName,
|
||||
} from "@/app/admin/configuration/llm/interfaces";
|
||||
} from "@/interfaces/llm";
|
||||
import { OnboardingActions, OnboardingState } from "../types";
|
||||
import { OpenAIOnboardingForm } from "./OpenAIOnboardingForm";
|
||||
import { AnthropicOnboardingForm } from "./AnthropicOnboardingForm";
|
||||
|
||||
@@ -4,7 +4,7 @@ import Button from "@/refresh-components/buttons/Button";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import LLMProviderCard from "../components/LLMProviderCard";
|
||||
import { OnboardingActions, OnboardingState, OnboardingStep } from "../types";
|
||||
import { WellKnownLLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { WellKnownLLMProviderDescriptor } from "@/interfaces/llm";
|
||||
import {
|
||||
getOnboardingForm,
|
||||
getProviderDisplayInfo,
|
||||
|
||||
@@ -7,11 +7,11 @@ import {
|
||||
OnboardingState,
|
||||
OnboardingStep,
|
||||
} from "./types";
|
||||
import { WellKnownLLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { WellKnownLLMProviderDescriptor } from "@/interfaces/llm";
|
||||
import { updateUserPersonalization } from "@/lib/userSettings";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import { MinimalPersonaSnapshot } from "@/app/admin/assistants/interfaces";
|
||||
import { useLLMProviders } from "@/lib/hooks/useLLMProviders";
|
||||
import { useLLMProviders } from "@/hooks/useLLMProviders";
|
||||
import { useProviderStatus } from "@/components/chat/ProviderContext";
|
||||
|
||||
export function useOnboardingState(liveAssistant?: MinimalPersonaSnapshot): {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { buildLlmOptions, groupLlmOptions } from "./LLMPopover";
|
||||
import { LLMOption } from "./interfaces";
|
||||
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { LLMProviderDescriptor } from "@/interfaces/llm";
|
||||
import { makeProvider } from "@tests/setup/llmProviderTestUtils";
|
||||
|
||||
describe("LLMPopover helpers", () => {
|
||||
@@ -15,6 +15,7 @@ describe("LLMPopover helpers", () => {
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
],
|
||||
}),
|
||||
@@ -27,6 +28,7 @@ describe("LLMPopover helpers", () => {
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
],
|
||||
}),
|
||||
@@ -39,6 +41,7 @@ describe("LLMPopover helpers", () => {
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
],
|
||||
}),
|
||||
@@ -67,6 +70,7 @@ describe("LLMPopover helpers", () => {
|
||||
is_visible: false,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
],
|
||||
}),
|
||||
|
||||
@@ -3,12 +3,12 @@
|
||||
import { useState, useEffect, useCallback, useMemo, useRef } from "react";
|
||||
import Popover, { PopoverMenu } from "@/refresh-components/Popover";
|
||||
import { LlmDescriptor, LlmManager } from "@/lib/hooks";
|
||||
import { structureValue } from "@/lib/llm/utils";
|
||||
import { structureValue } from "@/lib/llmConfig/utils";
|
||||
import {
|
||||
getProviderIcon,
|
||||
AGGREGATOR_PROVIDERS,
|
||||
} from "@/app/admin/configuration/llm/utils";
|
||||
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { LLMProviderDescriptor } from "@/interfaces/llm";
|
||||
import { Slider } from "@/components/ui/slider";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import LineItem from "@/refresh-components/buttons/LineItem";
|
||||
|
||||
@@ -17,8 +17,8 @@ import Separator from "@/refresh-components/Separator";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import { useFormikContext } from "formik";
|
||||
import LLMSelector from "@/components/llm/LLMSelector";
|
||||
import { parseLlmDescriptor, structureValue } from "@/lib/llm/utils";
|
||||
import { useLLMProviders } from "@/lib/hooks/useLLMProviders";
|
||||
import { parseLlmDescriptor, structureValue } from "@/lib/llmConfig/utils";
|
||||
import { useLLMProviders } from "@/hooks/useLLMProviders";
|
||||
import {
|
||||
STARTER_MESSAGES_EXAMPLES,
|
||||
MAX_CHARACTERS_STARTER_MESSAGE,
|
||||
|
||||
477
web/src/refresh-pages/admin/LLMConfigurationPage.tsx
Normal file
477
web/src/refresh-pages/admin/LLMConfigurationPage.tsx
Normal file
@@ -0,0 +1,477 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import {
|
||||
useAdminLLMProviders,
|
||||
useWellKnownLLMProviders,
|
||||
} from "@/hooks/useLLMProviders";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { Content, ContentAction } from "@opal/layouts";
|
||||
import { Button } from "@opal/components";
|
||||
import { SvgCpu, SvgArrowExchange, SvgSettings, SvgTrash } from "@opal/icons";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import * as GeneralLayouts from "@/layouts/general-layouts";
|
||||
import {
|
||||
getProviderDisplayName,
|
||||
getProviderIcon,
|
||||
getProviderProductName,
|
||||
} from "@/lib/llmConfig/providers";
|
||||
import { deleteLlmProvider, setDefaultLlmModel } from "@/lib/llmConfig/svc";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Horizontal as HorizontalInput } from "@/layouts/input-layouts";
|
||||
import Card from "@/refresh-components/cards/Card";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import Message from "@/refresh-components/messages/Message";
|
||||
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
|
||||
import { useCreateModal } from "@/refresh-components/contexts/ModalContext";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import {
|
||||
LLMProviderView,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "@/interfaces/llm";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "@/lib/llmConfig/constants";
|
||||
import { getModalForExistingProvider } from "@/sections/modals/llmConfig/getModal";
|
||||
import { OpenAIModal } from "@/sections/modals/llmConfig/OpenAIModal";
|
||||
import { AnthropicModal } from "@/sections/modals/llmConfig/AnthropicModal";
|
||||
import { OllamaModal } from "@/sections/modals/llmConfig/OllamaModal";
|
||||
import { AzureModal } from "@/sections/modals/llmConfig/AzureModal";
|
||||
import { BedrockModal } from "@/sections/modals/llmConfig/BedrockModal";
|
||||
import { VertexAIModal } from "@/sections/modals/llmConfig/VertexAIModal";
|
||||
import { OpenRouterModal } from "@/sections/modals/llmConfig/OpenRouterModal";
|
||||
import { CustomModal } from "@/sections/modals/llmConfig/CustomModal";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
|
||||
// ============================================================================
|
||||
// Provider form mapping (keyed by provider name from the API)
|
||||
// ============================================================================
|
||||
|
||||
const PROVIDER_MODAL_MAP: Record<
|
||||
string,
|
||||
(
|
||||
shouldMarkAsDefault: boolean,
|
||||
open: boolean,
|
||||
onOpenChange: (open: boolean) => void
|
||||
) => React.ReactNode
|
||||
> = {
|
||||
openai: (d, open, onOpenChange) => (
|
||||
<OpenAIModal
|
||||
shouldMarkAsDefault={d}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
/>
|
||||
),
|
||||
anthropic: (d, open, onOpenChange) => (
|
||||
<AnthropicModal
|
||||
shouldMarkAsDefault={d}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
/>
|
||||
),
|
||||
ollama_chat: (d, open, onOpenChange) => (
|
||||
<OllamaModal
|
||||
shouldMarkAsDefault={d}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
/>
|
||||
),
|
||||
azure: (d, open, onOpenChange) => (
|
||||
<AzureModal
|
||||
shouldMarkAsDefault={d}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
/>
|
||||
),
|
||||
bedrock: (d, open, onOpenChange) => (
|
||||
<BedrockModal
|
||||
shouldMarkAsDefault={d}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
/>
|
||||
),
|
||||
vertex_ai: (d, open, onOpenChange) => (
|
||||
<VertexAIModal
|
||||
shouldMarkAsDefault={d}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
/>
|
||||
),
|
||||
openrouter: (d, open, onOpenChange) => (
|
||||
<OpenRouterModal
|
||||
shouldMarkAsDefault={d}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
/>
|
||||
),
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// ExistingProviderCard — card for configured (existing) providers
|
||||
// ============================================================================
|
||||
|
||||
interface ExistingProviderCardProps {
|
||||
provider: LLMProviderView;
|
||||
isDefault: boolean;
|
||||
}
|
||||
|
||||
function ExistingProviderCard({
|
||||
provider,
|
||||
isDefault,
|
||||
}: ExistingProviderCardProps) {
|
||||
const { mutate } = useSWRConfig();
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const deleteModal = useCreateModal();
|
||||
|
||||
const handleDelete = async () => {
|
||||
try {
|
||||
await deleteLlmProvider(provider.id);
|
||||
mutate(LLM_PROVIDERS_ADMIN_URL);
|
||||
deleteModal.toggle(false);
|
||||
toast.success("Provider deleted successfully!");
|
||||
} catch (e) {
|
||||
const message = e instanceof Error ? e.message : "Unknown error";
|
||||
toast.error(`Failed to delete provider: ${message}`);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
{deleteModal.isOpen && (
|
||||
<ConfirmationModalLayout
|
||||
icon={SvgTrash}
|
||||
title="Delete LLM Provider"
|
||||
onClose={() => deleteModal.toggle(false)}
|
||||
submit={
|
||||
<Button variant="danger" onClick={handleDelete}>
|
||||
Delete
|
||||
</Button>
|
||||
}
|
||||
>
|
||||
<Section alignItems="start">
|
||||
<Text as="p" text03>
|
||||
This will permanently delete <b>{provider.name}</b> and all of its
|
||||
model configurations.
|
||||
</Text>
|
||||
{isDefault && (
|
||||
<Text as="p" text03>
|
||||
This is currently your <b>default provider</b>. Deleting it will
|
||||
require you to set a new default.
|
||||
</Text>
|
||||
)}
|
||||
</Section>
|
||||
</ConfirmationModalLayout>
|
||||
)}
|
||||
|
||||
{/*
|
||||
# TODO (@raunakab)
|
||||
Abstract into the hover stylings into a proper, core `Hoverable` component inside of Opal later.
|
||||
The API should look something like:
|
||||
|
||||
```tsx
|
||||
<Hoverable.Root group="MyGroup">
|
||||
<Card>
|
||||
<ContentAction
|
||||
// ...
|
||||
rightChildren={
|
||||
<Hoverable.Item group="MyGroup" variant="opacity-on-hover">
|
||||
<Button
|
||||
icon={SvgTrash}
|
||||
// ...
|
||||
/>
|
||||
</Hoverable.Item>
|
||||
}
|
||||
/>
|
||||
</Card>
|
||||
</Hoverable.Group>
|
||||
```
|
||||
*/}
|
||||
<div className="group/ExistingProviderCard">
|
||||
<Card padding={0.5}>
|
||||
<ContentAction
|
||||
icon={getProviderIcon(provider.provider)}
|
||||
title={provider.name}
|
||||
description={getProviderDisplayName(provider.provider)}
|
||||
sizePreset="main-content"
|
||||
variant="section"
|
||||
tag={isDefault ? { title: "Default", color: "blue" } : undefined}
|
||||
rightChildren={
|
||||
<Section flexDirection="row" gap={0} alignItems="start">
|
||||
<div className="opacity-0 group-hover/ExistingProviderCard:opacity-100 transition-all duration-200">
|
||||
<Button
|
||||
icon={SvgTrash}
|
||||
prominence="tertiary"
|
||||
onClick={() => deleteModal.toggle(true)}
|
||||
/>
|
||||
</div>
|
||||
<Button
|
||||
icon={SvgSettings}
|
||||
prominence="tertiary"
|
||||
onClick={() => setIsOpen(true)}
|
||||
/>
|
||||
</Section>
|
||||
}
|
||||
/>
|
||||
{getModalForExistingProvider(provider, isOpen, setIsOpen)}
|
||||
</Card>
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// NewProviderCard — card for the "Add Provider" list
|
||||
// ============================================================================
|
||||
|
||||
interface NewProviderCardProps {
|
||||
provider: WellKnownLLMProviderDescriptor;
|
||||
isFirstProvider: boolean;
|
||||
formFn: (
|
||||
shouldMarkAsDefault: boolean,
|
||||
open: boolean,
|
||||
onOpenChange: (open: boolean) => void
|
||||
) => React.ReactNode;
|
||||
}
|
||||
|
||||
function NewProviderCard({
|
||||
provider,
|
||||
isFirstProvider,
|
||||
formFn,
|
||||
}: NewProviderCardProps) {
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
|
||||
return (
|
||||
<Card variant="secondary" padding={0.5}>
|
||||
<ContentAction
|
||||
icon={getProviderIcon(provider.name)}
|
||||
title={getProviderProductName(provider.name)}
|
||||
description={getProviderDisplayName(provider.name)}
|
||||
sizePreset="main-content"
|
||||
variant="section"
|
||||
rightChildren={
|
||||
<Button
|
||||
rightIcon={SvgArrowExchange}
|
||||
prominence="tertiary"
|
||||
onClick={() => setIsOpen(true)}
|
||||
>
|
||||
Connect
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
{formFn(isFirstProvider, isOpen, setIsOpen)}
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// NewCustomProviderCard — card for adding a custom LLM provider
|
||||
// ============================================================================
|
||||
|
||||
interface NewCustomProviderCardProps {
|
||||
isFirstProvider: boolean;
|
||||
}
|
||||
|
||||
function NewCustomProviderCard({
|
||||
isFirstProvider,
|
||||
}: NewCustomProviderCardProps) {
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
|
||||
return (
|
||||
<Card variant="secondary" padding={0.5}>
|
||||
<ContentAction
|
||||
icon={getProviderIcon("custom")}
|
||||
title={getProviderProductName("custom")}
|
||||
description={getProviderDisplayName("custom")}
|
||||
sizePreset="main-content"
|
||||
variant="section"
|
||||
rightChildren={
|
||||
<Button
|
||||
rightIcon={SvgArrowExchange}
|
||||
prominence="tertiary"
|
||||
onClick={() => setIsOpen(true)}
|
||||
>
|
||||
Set Up
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
<CustomModal
|
||||
shouldMarkAsDefault={isFirstProvider}
|
||||
open={isOpen}
|
||||
onOpenChange={setIsOpen}
|
||||
/>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// LLMConfigurationPage — main page component
|
||||
// ============================================================================
|
||||
|
||||
export default function LLMConfigurationPage() {
|
||||
const { mutate } = useSWRConfig();
|
||||
const { llmProviders: existingLlmProviders, defaultText } =
|
||||
useAdminLLMProviders();
|
||||
const { wellKnownLLMProviders } = useWellKnownLLMProviders();
|
||||
|
||||
if (!existingLlmProviders) {
|
||||
return <ThreeDotsLoader />;
|
||||
}
|
||||
|
||||
const hasProviders = existingLlmProviders.length > 0;
|
||||
const isFirstProvider = !hasProviders;
|
||||
|
||||
// Pre-sort providers so the default appears first
|
||||
const sortedProviders = [...existingLlmProviders].sort((a, b) => {
|
||||
const aIsDefault = defaultText?.provider_id === a.id;
|
||||
const bIsDefault = defaultText?.provider_id === b.id;
|
||||
if (aIsDefault && !bIsDefault) return -1;
|
||||
if (!aIsDefault && bIsDefault) return 1;
|
||||
return 0;
|
||||
});
|
||||
|
||||
// Pre-filter to providers that have at least one visible model
|
||||
const providersWithVisibleModels = existingLlmProviders
|
||||
.map((provider) => ({
|
||||
provider,
|
||||
visibleModels: provider.model_configurations.filter((m) => m.is_visible),
|
||||
}))
|
||||
.filter(({ visibleModels }) => visibleModels.length > 0);
|
||||
|
||||
// Default model logic — use the global default from the API response
|
||||
const currentDefaultValue = defaultText
|
||||
? `${defaultText.provider_id}:${defaultText.model_name}`
|
||||
: undefined;
|
||||
|
||||
async function handleDefaultModelChange(compositeValue: string) {
|
||||
const separatorIndex = compositeValue.indexOf(":");
|
||||
const providerId = Number(compositeValue.slice(0, separatorIndex));
|
||||
const modelName = compositeValue.slice(separatorIndex + 1);
|
||||
|
||||
try {
|
||||
await setDefaultLlmModel(providerId, modelName);
|
||||
mutate(LLM_PROVIDERS_ADMIN_URL);
|
||||
toast.success("Default model updated successfully!");
|
||||
} catch (e) {
|
||||
const message = e instanceof Error ? e.message : "Unknown error";
|
||||
toast.error(`Failed to set default model: ${message}`);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header icon={SvgCpu} title="LLM Models" separator />
|
||||
|
||||
<SettingsLayouts.Body>
|
||||
{hasProviders ? (
|
||||
<Card>
|
||||
<HorizontalInput
|
||||
title="Default Model"
|
||||
description="This model will be used by Onyx by default in your chats."
|
||||
nonInteractive
|
||||
center
|
||||
>
|
||||
<InputSelect
|
||||
value={currentDefaultValue}
|
||||
onValueChange={handleDefaultModelChange}
|
||||
>
|
||||
<InputSelect.Trigger placeholder="Select a default model" />
|
||||
<InputSelect.Content>
|
||||
{providersWithVisibleModels.map(
|
||||
({ provider, visibleModels }) => (
|
||||
<InputSelect.Group key={provider.id}>
|
||||
<InputSelect.Label>{provider.name}</InputSelect.Label>
|
||||
{visibleModels.map((model) => (
|
||||
<InputSelect.Item
|
||||
key={`${provider.id}:${model.name}`}
|
||||
value={`${provider.id}:${model.name}`}
|
||||
>
|
||||
{model.display_name || model.name}
|
||||
</InputSelect.Item>
|
||||
))}
|
||||
</InputSelect.Group>
|
||||
)
|
||||
)}
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</HorizontalInput>
|
||||
</Card>
|
||||
) : (
|
||||
<Message
|
||||
info
|
||||
large
|
||||
icon
|
||||
close={false}
|
||||
text="Set up an LLM provider to start chatting."
|
||||
className="w-full"
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* ── Available Providers (only when providers exist) ── */}
|
||||
{hasProviders && (
|
||||
<>
|
||||
<GeneralLayouts.Section
|
||||
gap={0.75}
|
||||
height="fit"
|
||||
alignItems="stretch"
|
||||
justifyContent="start"
|
||||
>
|
||||
<Content
|
||||
title="Available Providers"
|
||||
sizePreset="main-content"
|
||||
variant="section"
|
||||
/>
|
||||
|
||||
<div className="flex flex-col gap-2">
|
||||
{sortedProviders.map((provider) => (
|
||||
<ExistingProviderCard
|
||||
key={provider.id}
|
||||
provider={provider}
|
||||
isDefault={defaultText?.provider_id === provider.id}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</GeneralLayouts.Section>
|
||||
|
||||
<Separator noPadding />
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* ── Add Provider (always visible) ── */}
|
||||
<GeneralLayouts.Section
|
||||
gap={0.75}
|
||||
height="fit"
|
||||
alignItems="stretch"
|
||||
justifyContent="start"
|
||||
>
|
||||
<Content
|
||||
title="Add Provider"
|
||||
description="Onyx supports both popular providers and self-hosted models."
|
||||
sizePreset="main-content"
|
||||
variant="section"
|
||||
/>
|
||||
|
||||
<div className="grid grid-cols-2 gap-2">
|
||||
{wellKnownLLMProviders?.map((provider) => {
|
||||
const formFn = PROVIDER_MODAL_MAP[provider.name];
|
||||
if (!formFn) {
|
||||
toast.error(
|
||||
`No modal mapping for provider "${provider.name}".`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
return (
|
||||
<NewProviderCard
|
||||
key={provider.name}
|
||||
provider={provider}
|
||||
isFirstProvider={isFirstProvider}
|
||||
formFn={formFn}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
<NewCustomProviderCard isFirstProvider={isFirstProvider} />
|
||||
</div>
|
||||
</GeneralLayouts.Section>
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
@@ -39,8 +39,8 @@ import DocumentSetCard from "@/sections/cards/DocumentSetCard";
|
||||
import {
|
||||
getLLMProviderOverrideForPersona,
|
||||
getDisplayName,
|
||||
} from "@/lib/llm/utils";
|
||||
import { useLLMProviders } from "@/lib/hooks/useLLMProviders";
|
||||
} from "@/lib/llmConfig/utils";
|
||||
import { useLLMProviders } from "@/hooks/useLLMProviders";
|
||||
import { Interactive } from "@opal/core";
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { Form, Formik } from "formik";
|
||||
import { LLMProviderFormProps } from "../interfaces";
|
||||
import { LLMProviderFormProps } from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import {
|
||||
ProviderFormEntrypointWrapper,
|
||||
@@ -21,15 +21,19 @@ import { DisplayModels } from "./components/DisplayModels";
|
||||
export const ANTHROPIC_PROVIDER_NAME = "anthropic";
|
||||
const DEFAULT_DEFAULT_MODEL_NAME = "claude-sonnet-4-5";
|
||||
|
||||
export function AnthropicForm({
|
||||
export function AnthropicModal({
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
}: LLMProviderFormProps) {
|
||||
return (
|
||||
<ProviderFormEntrypointWrapper
|
||||
providerName="Anthropic"
|
||||
providerEndpoint={ANTHROPIC_PROVIDER_NAME}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
>
|
||||
{({
|
||||
onClose,
|
||||
@@ -52,7 +56,6 @@ export function AnthropicForm({
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
api_base: existingLlmProvider?.api_base ?? undefined,
|
||||
default_model_name:
|
||||
existingLlmProvider?.default_model_name ??
|
||||
wellKnownLLMProvider?.recommended_default_model?.name ??
|
||||
DEFAULT_DEFAULT_MODEL_NAME,
|
||||
// Default to auto mode for new Anthropic providers
|
||||
@@ -1,6 +1,6 @@
|
||||
import { Form, Formik } from "formik";
|
||||
import { TextFormField } from "@/components/Field";
|
||||
import { LLMProviderFormProps, LLMProviderView } from "../interfaces";
|
||||
import { LLMProviderFormProps, LLMProviderView } from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import {
|
||||
ProviderFormEntrypointWrapper,
|
||||
@@ -28,7 +28,7 @@ import Separator from "@/refresh-components/Separator";
|
||||
export const AZURE_PROVIDER_NAME = "azure";
|
||||
const AZURE_DISPLAY_NAME = "Microsoft Azure Cloud";
|
||||
|
||||
interface AzureFormValues extends BaseLLMFormValues {
|
||||
interface AzureModalValues extends BaseLLMFormValues {
|
||||
api_key: string;
|
||||
target_uri: string;
|
||||
api_base?: string;
|
||||
@@ -47,15 +47,19 @@ const buildTargetUri = (existingLlmProvider?: LLMProviderView): string => {
|
||||
return `${existingLlmProvider.api_base}/openai/deployments/${deploymentName}/chat/completions?api-version=${existingLlmProvider.api_version}`;
|
||||
};
|
||||
|
||||
export function AzureForm({
|
||||
export function AzureModal({
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
}: LLMProviderFormProps) {
|
||||
return (
|
||||
<ProviderFormEntrypointWrapper
|
||||
providerName={AZURE_DISPLAY_NAME}
|
||||
providerEndpoint={AZURE_PROVIDER_NAME}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
>
|
||||
{({
|
||||
onClose,
|
||||
@@ -70,7 +74,7 @@ export function AzureForm({
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider
|
||||
);
|
||||
const initialValues: AzureFormValues = {
|
||||
const initialValues: AzureModalValues = {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations
|
||||
@@ -97,7 +101,7 @@ export function AzureForm({
|
||||
validateOnMount={true}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
// Parse target_uri to extract api_base, api_version, and deployment_name
|
||||
let processedValues: AzureFormValues = { ...values };
|
||||
let processedValues: AzureModalValues = { ...values };
|
||||
|
||||
if (values.target_uri) {
|
||||
try {
|
||||
@@ -8,7 +8,7 @@ import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
} from "../interfaces";
|
||||
} from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import {
|
||||
ProviderFormEntrypointWrapper,
|
||||
@@ -27,7 +27,7 @@ import {
|
||||
} from "./formUtils";
|
||||
import { AdvancedOptions } from "./components/AdvancedOptions";
|
||||
import { DisplayModels } from "./components/DisplayModels";
|
||||
import { fetchBedrockModels } from "../utils";
|
||||
import { fetchBedrockModels } from "@/app/admin/configuration/llm/utils";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Tabs from "@/refresh-components/Tabs";
|
||||
@@ -65,7 +65,7 @@ const FIELD_AWS_ACCESS_KEY_ID = "custom_config.AWS_ACCESS_KEY_ID";
|
||||
const FIELD_AWS_SECRET_ACCESS_KEY = "custom_config.AWS_SECRET_ACCESS_KEY";
|
||||
const FIELD_AWS_BEARER_TOKEN_BEDROCK = "custom_config.AWS_BEARER_TOKEN_BEDROCK";
|
||||
|
||||
interface BedrockFormValues extends BaseLLMFormValues {
|
||||
interface BedrockModalValues extends BaseLLMFormValues {
|
||||
custom_config: {
|
||||
AWS_REGION_NAME: string;
|
||||
BEDROCK_AUTH_METHOD?: string;
|
||||
@@ -75,8 +75,8 @@ interface BedrockFormValues extends BaseLLMFormValues {
|
||||
};
|
||||
}
|
||||
|
||||
interface BedrockFormInternalsProps {
|
||||
formikProps: FormikProps<BedrockFormValues>;
|
||||
interface BedrockModalInternalsProps {
|
||||
formikProps: FormikProps<BedrockModalValues>;
|
||||
existingLlmProvider: LLMProviderView | undefined;
|
||||
fetchedModels: ModelConfiguration[];
|
||||
setFetchedModels: (models: ModelConfiguration[]) => void;
|
||||
@@ -87,7 +87,7 @@ interface BedrockFormInternalsProps {
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
function BedrockFormInternals({
|
||||
function BedrockModalInternals({
|
||||
formikProps,
|
||||
existingLlmProvider,
|
||||
fetchedModels,
|
||||
@@ -97,7 +97,7 @@ function BedrockFormInternals({
|
||||
testError,
|
||||
mutate,
|
||||
onClose,
|
||||
}: BedrockFormInternalsProps) {
|
||||
}: BedrockModalInternalsProps) {
|
||||
const authMethod = formikProps.values.custom_config?.BEDROCK_AUTH_METHOD;
|
||||
|
||||
// Clean up unused auth fields when tab changes
|
||||
@@ -258,9 +258,11 @@ function BedrockFormInternals({
|
||||
);
|
||||
}
|
||||
|
||||
export function BedrockForm({
|
||||
export function BedrockModal({
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
|
||||
@@ -268,6 +270,8 @@ export function BedrockForm({
|
||||
<ProviderFormEntrypointWrapper
|
||||
providerName={BEDROCK_DISPLAY_NAME}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
>
|
||||
{({
|
||||
onClose,
|
||||
@@ -282,7 +286,7 @@ export function BedrockForm({
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider
|
||||
);
|
||||
const initialValues: BedrockFormValues = {
|
||||
const initialValues: BedrockModalValues = {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations
|
||||
@@ -352,7 +356,7 @@ export function BedrockForm({
|
||||
}}
|
||||
>
|
||||
{(formikProps) => (
|
||||
<BedrockFormInternals
|
||||
<BedrockModalInternals
|
||||
formikProps={formikProps}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
fetchedModels={fetchedModels}
|
||||
@@ -7,7 +7,7 @@
|
||||
*/
|
||||
import React from "react";
|
||||
import { render, screen, setupUser, waitFor } from "@tests/setup/test-utils";
|
||||
import { CustomForm } from "./CustomForm";
|
||||
import { CustomModal } from "./CustomModal";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
|
||||
// Mock SWR's mutate function and useSWR
|
||||
@@ -116,11 +116,10 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
name: "My Custom Provider",
|
||||
provider: "openai",
|
||||
api_key: "test-key",
|
||||
default_model_name: "gpt-4",
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
render(<CustomForm />);
|
||||
render(<CustomModal />);
|
||||
|
||||
await openModalAndFillBasicFields(user, {
|
||||
name: "My Custom Provider",
|
||||
@@ -177,7 +176,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
json: async () => ({ detail: "Invalid API key" }),
|
||||
} as Response);
|
||||
|
||||
render(<CustomForm />);
|
||||
render(<CustomModal />);
|
||||
|
||||
await openModalAndFillBasicFields(user, {
|
||||
name: "Bad Provider",
|
||||
@@ -224,13 +223,13 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
api_key: "old-key",
|
||||
api_base: "",
|
||||
api_version: "",
|
||||
default_model_name: "claude-3-opus",
|
||||
model_configurations: [
|
||||
{
|
||||
name: "claude-3-opus",
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: null,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
],
|
||||
custom_config: {},
|
||||
@@ -239,9 +238,6 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
groups: [],
|
||||
personas: [],
|
||||
deployment_name: null,
|
||||
is_default_provider: false,
|
||||
default_vision_model: null,
|
||||
is_default_vision_provider: null,
|
||||
};
|
||||
|
||||
// Mock POST /api/admin/llm/test
|
||||
@@ -256,7 +252,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
json: async () => ({ ...existingProvider, api_key: "new-key" }),
|
||||
} as Response);
|
||||
|
||||
render(<CustomForm existingLlmProvider={existingProvider} />);
|
||||
render(<CustomModal existingLlmProvider={existingProvider} />);
|
||||
|
||||
// For existing provider, click "Edit" button to open modal
|
||||
const editButton = screen.getByRole("button", { name: /edit/i });
|
||||
@@ -307,13 +303,13 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
api_key: "old-key",
|
||||
api_base: "https://example-openai-compatible.local/v1",
|
||||
api_version: "",
|
||||
default_model_name: "gpt-oss-20b-bw-failover",
|
||||
model_configurations: [
|
||||
{
|
||||
name: "gpt-oss-20b-bw-failover",
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: null,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
],
|
||||
custom_config: {},
|
||||
@@ -322,9 +318,6 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
groups: [],
|
||||
personas: [],
|
||||
deployment_name: null,
|
||||
is_default_provider: false,
|
||||
default_vision_model: null,
|
||||
is_default_vision_provider: null,
|
||||
};
|
||||
|
||||
// Mock POST /api/admin/llm/test
|
||||
@@ -343,19 +336,21 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
name: "gpt-oss-20b-bw-failover",
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: null,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
{
|
||||
name: "nemotron",
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: null,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
],
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
render(<CustomForm existingLlmProvider={existingProvider} />);
|
||||
render(<CustomModal existingLlmProvider={existingProvider} />);
|
||||
|
||||
const editButton = screen.getByRole("button", { name: /edit/i });
|
||||
await user.click(editButton);
|
||||
@@ -423,7 +418,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
json: async () => ({}),
|
||||
} as Response);
|
||||
|
||||
render(<CustomForm shouldMarkAsDefault={true} />);
|
||||
render(<CustomModal shouldMarkAsDefault={true} />);
|
||||
|
||||
await openModalAndFillBasicFields(user, {
|
||||
name: "New Default Provider",
|
||||
@@ -463,7 +458,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
json: async () => ({ detail: "Database error" }),
|
||||
} as Response);
|
||||
|
||||
render(<CustomForm />);
|
||||
render(<CustomModal />);
|
||||
|
||||
await openModalAndFillBasicFields(user, {
|
||||
name: "Test Provider",
|
||||
@@ -499,7 +494,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
json: async () => ({ id: 1, name: "Provider with Custom Config" }),
|
||||
} as Response);
|
||||
|
||||
render(<CustomForm />);
|
||||
render(<CustomModal />);
|
||||
|
||||
// Open modal
|
||||
const openButton = screen.getByRole("button", {
|
||||
@@ -7,7 +7,7 @@ import {
|
||||
Formik,
|
||||
ErrorMessage,
|
||||
} from "formik";
|
||||
import { LLMProviderFormProps, LLMProviderView } from "../interfaces";
|
||||
import { LLMProviderFormProps, LLMProviderView } from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import { ProviderFormEntrypointWrapper } from "./components/FormWrapper";
|
||||
import { DisplayNameField } from "./components/DisplayNameField";
|
||||
@@ -21,7 +21,7 @@ import {
|
||||
} from "./formUtils";
|
||||
import { AdvancedOptions } from "./components/AdvancedOptions";
|
||||
import { TextFormField } from "@/components/Field";
|
||||
import { ModelConfigurationField } from "../ModelConfigurationField";
|
||||
import { ModelConfigurationField } from "@/app/admin/configuration/llm/ModelConfigurationField";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import CreateButton from "@/refresh-components/buttons/CreateButton";
|
||||
import IconButton from "@/refresh-components/buttons/IconButton";
|
||||
@@ -38,9 +38,11 @@ function customConfigProcessing(customConfigsList: [string, string][]) {
|
||||
return customConfig;
|
||||
}
|
||||
|
||||
export function CustomForm({
|
||||
export function CustomModal({
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
}: LLMProviderFormProps) {
|
||||
return (
|
||||
<ProviderFormEntrypointWrapper
|
||||
@@ -48,6 +50,8 @@ export function CustomForm({
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
buttonMode={!existingLlmProvider}
|
||||
buttonText="Add Custom LLM Provider"
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
>
|
||||
{({
|
||||
onClose,
|
||||
@@ -68,7 +72,15 @@ export function CustomForm({
|
||||
...modelConfiguration,
|
||||
max_input_tokens: modelConfiguration.max_input_tokens ?? null,
|
||||
})
|
||||
) ?? [{ name: "", is_visible: true, max_input_tokens: null }],
|
||||
) ?? [
|
||||
{
|
||||
name: "",
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
},
|
||||
],
|
||||
custom_config_list: existingLlmProvider?.custom_config
|
||||
? Object.entries(existingLlmProvider.custom_config)
|
||||
: [],
|
||||
@@ -112,7 +124,8 @@ export function CustomForm({
|
||||
name: mc.name,
|
||||
is_visible: mc.is_visible,
|
||||
max_input_tokens: mc.max_input_tokens ?? null,
|
||||
supports_image_input: null,
|
||||
supports_image_input: mc.supports_image_input ?? false,
|
||||
supports_reasoning: mc.supports_reasoning ?? false,
|
||||
}))
|
||||
.filter(
|
||||
(mc) => mc.name === values.default_model_name || mc.is_visible
|
||||
@@ -5,7 +5,7 @@ import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
} from "../interfaces";
|
||||
} from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import {
|
||||
ProviderFormEntrypointWrapper,
|
||||
@@ -24,20 +24,20 @@ import {
|
||||
import { AdvancedOptions } from "./components/AdvancedOptions";
|
||||
import { DisplayModels } from "./components/DisplayModels";
|
||||
import { useEffect, useState } from "react";
|
||||
import { fetchOllamaModels } from "../utils";
|
||||
import { fetchOllamaModels } from "@/app/admin/configuration/llm/utils";
|
||||
|
||||
export const OLLAMA_PROVIDER_NAME = "ollama_chat";
|
||||
const DEFAULT_API_BASE = "http://127.0.0.1:11434";
|
||||
|
||||
interface OllamaFormValues extends BaseLLMFormValues {
|
||||
interface OllamaModalValues extends BaseLLMFormValues {
|
||||
api_base: string;
|
||||
custom_config: {
|
||||
OLLAMA_API_KEY?: string;
|
||||
};
|
||||
}
|
||||
|
||||
interface OllamaFormContentProps {
|
||||
formikProps: FormikProps<OllamaFormValues>;
|
||||
interface OllamaModalContentProps {
|
||||
formikProps: FormikProps<OllamaModalValues>;
|
||||
existingLlmProvider?: LLMProviderView;
|
||||
fetchedModels: ModelConfiguration[];
|
||||
setFetchedModels: (models: ModelConfiguration[]) => void;
|
||||
@@ -48,7 +48,7 @@ interface OllamaFormContentProps {
|
||||
isFormValid: boolean;
|
||||
}
|
||||
|
||||
function OllamaFormContent({
|
||||
function OllamaModalContent({
|
||||
formikProps,
|
||||
existingLlmProvider,
|
||||
fetchedModels,
|
||||
@@ -58,7 +58,7 @@ function OllamaFormContent({
|
||||
mutate,
|
||||
onClose,
|
||||
isFormValid,
|
||||
}: OllamaFormContentProps) {
|
||||
}: OllamaModalContentProps) {
|
||||
const [isLoadingModels, setIsLoadingModels] = useState(true);
|
||||
|
||||
useEffect(() => {
|
||||
@@ -131,9 +131,11 @@ function OllamaFormContent({
|
||||
);
|
||||
}
|
||||
|
||||
export function OllamaForm({
|
||||
export function OllamaModal({
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
|
||||
@@ -141,6 +143,8 @@ export function OllamaForm({
|
||||
<ProviderFormEntrypointWrapper
|
||||
providerName="Ollama"
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
>
|
||||
{({
|
||||
onClose,
|
||||
@@ -155,7 +159,7 @@ export function OllamaForm({
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider
|
||||
);
|
||||
const initialValues: OllamaFormValues = {
|
||||
const initialValues: OllamaModalValues = {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations
|
||||
@@ -212,7 +216,7 @@ export function OllamaForm({
|
||||
}}
|
||||
>
|
||||
{(formikProps) => (
|
||||
<OllamaFormContent
|
||||
<OllamaModalContent
|
||||
formikProps={formikProps}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
fetchedModels={fetchedModels}
|
||||
@@ -1,6 +1,6 @@
|
||||
import { Form, Formik } from "formik";
|
||||
|
||||
import { LLMProviderFormProps } from "../interfaces";
|
||||
import { LLMProviderFormProps } from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import { ProviderFormEntrypointWrapper } from "./components/FormWrapper";
|
||||
import { DisplayNameField } from "./components/DisplayNameField";
|
||||
@@ -19,15 +19,19 @@ import { DisplayModels } from "./components/DisplayModels";
|
||||
export const OPENAI_PROVIDER_NAME = "openai";
|
||||
const DEFAULT_DEFAULT_MODEL_NAME = "gpt-5.2";
|
||||
|
||||
export function OpenAIForm({
|
||||
export function OpenAIModal({
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
}: LLMProviderFormProps) {
|
||||
return (
|
||||
<ProviderFormEntrypointWrapper
|
||||
providerName="OpenAI"
|
||||
providerEndpoint={OPENAI_PROVIDER_NAME}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
>
|
||||
{({
|
||||
onClose,
|
||||
@@ -49,7 +53,6 @@ export function OpenAIForm({
|
||||
),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
default_model_name:
|
||||
existingLlmProvider?.default_model_name ??
|
||||
wellKnownLLMProvider?.recommended_default_model?.name ??
|
||||
DEFAULT_DEFAULT_MODEL_NAME,
|
||||
// Default to auto mode for new OpenAI providers
|
||||
@@ -5,7 +5,7 @@ import {
|
||||
LLMProviderFormProps,
|
||||
ModelConfiguration,
|
||||
OpenRouterModelResponse,
|
||||
} from "../interfaces";
|
||||
} from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import {
|
||||
ProviderFormEntrypointWrapper,
|
||||
@@ -32,7 +32,7 @@ const OPENROUTER_DISPLAY_NAME = "OpenRouter";
|
||||
const DEFAULT_API_BASE = "https://openrouter.ai/api/v1";
|
||||
const OPENROUTER_MODELS_API_URL = "/api/admin/llm/openrouter/available-models";
|
||||
|
||||
interface OpenRouterFormValues extends BaseLLMFormValues {
|
||||
interface OpenRouterModalValues extends BaseLLMFormValues {
|
||||
api_key: string;
|
||||
api_base: string;
|
||||
}
|
||||
@@ -80,6 +80,7 @@ async function fetchOpenRouterModels(params: {
|
||||
is_visible: true,
|
||||
max_input_tokens: modelData.max_input_tokens,
|
||||
supports_image_input: modelData.supports_image_input,
|
||||
supports_reasoning: false,
|
||||
}));
|
||||
|
||||
return { models };
|
||||
@@ -90,9 +91,11 @@ async function fetchOpenRouterModels(params: {
|
||||
}
|
||||
}
|
||||
|
||||
export function OpenRouterForm({
|
||||
export function OpenRouterModal({
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
|
||||
@@ -101,6 +104,8 @@ export function OpenRouterForm({
|
||||
providerName={OPENROUTER_DISPLAY_NAME}
|
||||
providerEndpoint={OPENROUTER_PROVIDER_NAME}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
>
|
||||
{({
|
||||
onClose,
|
||||
@@ -115,7 +120,7 @@ export function OpenRouterForm({
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider
|
||||
);
|
||||
const initialValues: OpenRouterFormValues = {
|
||||
const initialValues: OpenRouterModalValues = {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations
|
||||
@@ -1,6 +1,6 @@
|
||||
import { Form, Formik } from "formik";
|
||||
import { TextFormField, FileUploadFormField } from "@/components/Field";
|
||||
import { LLMProviderFormProps } from "../interfaces";
|
||||
import { LLMProviderFormProps } from "@/interfaces/llm";
|
||||
import * as Yup from "yup";
|
||||
import {
|
||||
ProviderFormEntrypointWrapper,
|
||||
@@ -25,22 +25,26 @@ const VERTEXAI_DISPLAY_NAME = "Google Cloud Vertex AI";
|
||||
const VERTEXAI_DEFAULT_MODEL = "gemini-2.5-pro";
|
||||
const VERTEXAI_DEFAULT_LOCATION = "global";
|
||||
|
||||
interface VertexAIFormValues extends BaseLLMFormValues {
|
||||
interface VertexAIModalValues extends BaseLLMFormValues {
|
||||
custom_config: {
|
||||
vertex_credentials: string;
|
||||
vertex_location: string;
|
||||
};
|
||||
}
|
||||
|
||||
export function VertexAIForm({
|
||||
export function VertexAIModal({
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
open,
|
||||
onOpenChange,
|
||||
}: LLMProviderFormProps) {
|
||||
return (
|
||||
<ProviderFormEntrypointWrapper
|
||||
providerName={VERTEXAI_DISPLAY_NAME}
|
||||
providerEndpoint={VERTEXAI_PROVIDER_NAME}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
open={open}
|
||||
onOpenChange={onOpenChange}
|
||||
>
|
||||
{({
|
||||
onClose,
|
||||
@@ -55,13 +59,12 @@ export function VertexAIForm({
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider
|
||||
);
|
||||
const initialValues: VertexAIFormValues = {
|
||||
const initialValues: VertexAIModalValues = {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations
|
||||
),
|
||||
default_model_name:
|
||||
existingLlmProvider?.default_model_name ??
|
||||
wellKnownLLMProvider?.recommended_default_model?.name ??
|
||||
VERTEXAI_DEFAULT_MODEL,
|
||||
// Default to auto mode for new Vertex AI providers
|
||||
@@ -1,4 +1,4 @@
|
||||
import { ModelConfiguration, SimpleKnownModel } from "../../interfaces";
|
||||
import { ModelConfiguration, SimpleKnownModel } from "@/interfaces/llm";
|
||||
import { FormikProps } from "formik";
|
||||
import { BaseLLMFormValues } from "../formUtils";
|
||||
|
||||
@@ -2,7 +2,7 @@ import { useState, useEffect } from "react";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import SimpleTooltip from "@/refresh-components/SimpleTooltip";
|
||||
import { ModelConfiguration } from "../../interfaces";
|
||||
import { ModelConfiguration } from "@/interfaces/llm";
|
||||
|
||||
interface FetchModelsButtonProps {
|
||||
onFetch: () => Promise<{ models: ModelConfiguration[]; error?: string }>;
|
||||
@@ -2,8 +2,9 @@ import { LoadingAnimation } from "@/components/Loading";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { SvgTrash } from "@opal/icons";
|
||||
import { LLMProviderView } from "../../interfaces";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "../../constants";
|
||||
import { LLMProviderView } from "@/interfaces/llm";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "@/lib/llmConfig/constants";
|
||||
import { deleteLlmProvider } from "@/lib/llmConfig/svc";
|
||||
|
||||
interface FormActionButtonsProps {
|
||||
isTesting: boolean;
|
||||
@@ -25,41 +26,14 @@ export function FormActionButtons({
|
||||
const handleDelete = async () => {
|
||||
if (!existingLlmProvider) return;
|
||||
|
||||
const response = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}`,
|
||||
{
|
||||
method: "DELETE",
|
||||
}
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
alert(`Failed to delete provider: ${errorMsg}`);
|
||||
return;
|
||||
try {
|
||||
await deleteLlmProvider(existingLlmProvider.id);
|
||||
mutate(LLM_PROVIDERS_ADMIN_URL);
|
||||
onClose();
|
||||
} catch (e) {
|
||||
const message = e instanceof Error ? e.message : "Unknown error";
|
||||
alert(`Failed to delete provider: ${message}`);
|
||||
}
|
||||
|
||||
// If the deleted provider was the default, set the first remaining provider as default
|
||||
if (existingLlmProvider.is_default_provider) {
|
||||
const remainingProvidersResponse = await fetch(LLM_PROVIDERS_ADMIN_URL);
|
||||
if (remainingProvidersResponse.ok) {
|
||||
const remainingProviders = await remainingProvidersResponse.json();
|
||||
|
||||
if (remainingProviders.length > 0) {
|
||||
const setDefaultResponse = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}/${remainingProviders[0].id}/default`,
|
||||
{
|
||||
method: "POST",
|
||||
}
|
||||
);
|
||||
if (!setDefaultResponse.ok) {
|
||||
console.error("Failed to set new default provider");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mutate(LLM_PROVIDERS_ADMIN_URL);
|
||||
onClose();
|
||||
};
|
||||
|
||||
return (
|
||||
@@ -6,7 +6,7 @@ import { toast } from "@/hooks/useToast";
|
||||
import {
|
||||
LLMProviderView,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "../../interfaces";
|
||||
} from "@/interfaces/llm";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
@@ -14,7 +14,7 @@ import Button from "@/refresh-components/buttons/Button";
|
||||
import { SvgSettings } from "@opal/icons";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "../../constants";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "@/lib/llmConfig/constants";
|
||||
|
||||
export interface ProviderFormContext {
|
||||
onClose: () => void;
|
||||
@@ -35,6 +35,10 @@ interface ProviderFormEntrypointWrapperProps {
|
||||
buttonMode?: boolean;
|
||||
/** Custom button text for buttonMode (defaults to "Add {providerName}") */
|
||||
buttonText?: string;
|
||||
/** Controlled open state — when defined, the wrapper renders only a modal (no card/button UI) */
|
||||
open?: boolean;
|
||||
/** Callback when controlled modal requests close */
|
||||
onOpenChange?: (open: boolean) => void;
|
||||
}
|
||||
|
||||
export function ProviderFormEntrypointWrapper({
|
||||
@@ -44,8 +48,11 @@ export function ProviderFormEntrypointWrapper({
|
||||
existingLlmProvider,
|
||||
buttonMode,
|
||||
buttonText,
|
||||
open,
|
||||
onOpenChange,
|
||||
}: ProviderFormEntrypointWrapperProps) {
|
||||
const [formIsVisible, setFormIsVisible] = useState(false);
|
||||
const isControlled = open !== undefined;
|
||||
|
||||
// Shared hooks
|
||||
const { mutate } = useSWRConfig();
|
||||
@@ -54,15 +61,25 @@ export function ProviderFormEntrypointWrapper({
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const [testError, setTestError] = useState<string>("");
|
||||
|
||||
// Suppress SWR when controlled + closed to avoid unnecessary API calls
|
||||
const swrKey =
|
||||
providerEndpoint && !(isControlled && !open)
|
||||
? `/api/admin/llm/built-in/options/${providerEndpoint}`
|
||||
: null;
|
||||
|
||||
// Fetch model configurations for this provider
|
||||
const { data: wellKnownLLMProvider } = useSWR<WellKnownLLMProviderDescriptor>(
|
||||
providerEndpoint
|
||||
? `/api/admin/llm/built-in/options/${providerEndpoint}`
|
||||
: null,
|
||||
swrKey,
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
||||
const onClose = () => setFormIsVisible(false);
|
||||
const onClose = () => {
|
||||
if (isControlled) {
|
||||
onOpenChange?.(false);
|
||||
} else {
|
||||
setFormIsVisible(false);
|
||||
}
|
||||
};
|
||||
|
||||
async function handleSetAsDefault(): Promise<void> {
|
||||
if (!existingLlmProvider) return;
|
||||
@@ -93,6 +110,28 @@ export function ProviderFormEntrypointWrapper({
|
||||
wellKnownLLMProvider,
|
||||
};
|
||||
|
||||
// Controlled mode: render nothing when closed, render only modal when open
|
||||
if (isControlled) {
|
||||
if (!open) return null;
|
||||
|
||||
return (
|
||||
<Modal open onOpenChange={onClose}>
|
||||
<Modal.Content>
|
||||
<Modal.Header
|
||||
icon={SvgSettings}
|
||||
title={`${existingLlmProvider ? "Configure" : "Setup"} ${
|
||||
existingLlmProvider?.name
|
||||
? `"${existingLlmProvider.name}"`
|
||||
: providerName
|
||||
}`}
|
||||
onClose={onClose}
|
||||
/>
|
||||
<Modal.Body>{children(context)}</Modal.Body>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
|
||||
// Button mode: simple button that opens a modal
|
||||
if (buttonMode && !existingLlmProvider) {
|
||||
return (
|
||||
@@ -135,24 +174,18 @@ export function ProviderFormEntrypointWrapper({
|
||||
<Text as="p" secondaryBody text03 className="italic">
|
||||
({providerName})
|
||||
</Text>
|
||||
{!existingLlmProvider.is_default_provider && (
|
||||
<Text
|
||||
as="p"
|
||||
className={cn("text-action-link-05", "cursor-pointer")}
|
||||
onClick={handleSetAsDefault}
|
||||
>
|
||||
Set as default
|
||||
</Text>
|
||||
)}
|
||||
<Text
|
||||
as="p"
|
||||
className={cn("text-action-link-05", "cursor-pointer")}
|
||||
onClick={handleSetAsDefault}
|
||||
>
|
||||
Set as default
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
{existingLlmProvider && (
|
||||
<div className="my-auto ml-3">
|
||||
{existingLlmProvider.is_default_provider ? (
|
||||
<Badge variant="agent">Default</Badge>
|
||||
) : (
|
||||
<Badge variant="success">Enabled</Badge>
|
||||
)}
|
||||
<Badge variant="success">Enabled</Badge>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -2,8 +2,11 @@ import {
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "../interfaces";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "../constants";
|
||||
} from "@/interfaces/llm";
|
||||
import {
|
||||
LLM_ADMIN_URL,
|
||||
LLM_PROVIDERS_ADMIN_URL,
|
||||
} from "@/lib/llmConfig/constants";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import * as Yup from "yup";
|
||||
import isEqual from "lodash/isEqual";
|
||||
@@ -15,10 +18,7 @@ export const buildDefaultInitialValues = (
|
||||
existingLlmProvider?: LLMProviderView,
|
||||
modelConfigurations?: ModelConfiguration[]
|
||||
) => {
|
||||
const defaultModelName =
|
||||
existingLlmProvider?.default_model_name ??
|
||||
modelConfigurations?.[0]?.name ??
|
||||
"";
|
||||
const defaultModelName = modelConfigurations?.[0]?.name ?? "";
|
||||
|
||||
// Auto mode must be explicitly enabled by the user
|
||||
// Default to false for new providers, preserve existing value when editing
|
||||
@@ -119,6 +119,7 @@ export const filterModelConfigurations = (
|
||||
is_visible: visibleModels.includes(modelConfiguration.name),
|
||||
max_input_tokens: modelConfiguration.max_input_tokens ?? null,
|
||||
supports_image_input: modelConfiguration.supports_image_input,
|
||||
supports_reasoning: modelConfiguration.supports_reasoning,
|
||||
display_name: modelConfiguration.display_name,
|
||||
})
|
||||
)
|
||||
@@ -141,6 +142,7 @@ export const getAutoModeModelConfigurations = (
|
||||
is_visible: modelConfiguration.is_visible,
|
||||
max_input_tokens: modelConfiguration.max_input_tokens ?? null,
|
||||
supports_image_input: modelConfiguration.supports_image_input,
|
||||
supports_reasoning: modelConfiguration.supports_reasoning,
|
||||
display_name: modelConfiguration.display_name,
|
||||
})
|
||||
);
|
||||
@@ -223,6 +225,8 @@ export const submitLLMProvider = async <T extends BaseLLMFormValues>({
|
||||
body: JSON.stringify({
|
||||
provider: providerName,
|
||||
...finalValues,
|
||||
model: finalDefaultModelName,
|
||||
id: existingLlmProvider?.id,
|
||||
}),
|
||||
});
|
||||
setIsTesting(false);
|
||||
@@ -247,6 +251,7 @@ export const submitLLMProvider = async <T extends BaseLLMFormValues>({
|
||||
body: JSON.stringify({
|
||||
provider: providerName,
|
||||
...finalValues,
|
||||
id: existingLlmProvider?.id,
|
||||
}),
|
||||
}
|
||||
);
|
||||
@@ -262,12 +267,16 @@ export const submitLLMProvider = async <T extends BaseLLMFormValues>({
|
||||
|
||||
if (shouldMarkAsDefault) {
|
||||
const newLlmProvider = (await response.json()) as LLMProviderView;
|
||||
const setDefaultResponse = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`,
|
||||
{
|
||||
method: "POST",
|
||||
}
|
||||
);
|
||||
const setDefaultResponse = await fetch(`${LLM_ADMIN_URL}/default`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
provider_id: newLlmProvider.id,
|
||||
model_name: finalDefaultModelName,
|
||||
}),
|
||||
});
|
||||
if (!setDefaultResponse.ok) {
|
||||
const errorMsg = (await setDefaultResponse.json()).detail;
|
||||
toast.error(`Failed to set provider as default: ${errorMsg}`);
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user