Compare commits

...

60 Commits

Author SHA1 Message Date
Dane Urban
908d360011 . 2026-02-06 17:52:23 -08:00
Dane Urban
30578bdf9a n 2026-02-06 17:38:36 -08:00
Dane Urban
aebde89432 nits 2026-02-06 16:25:00 -08:00
Dane Urban
4a4b4bb378 t 2026-02-06 13:39:05 -08:00
Dane Urban
a8d231976a nit 2026-02-06 09:56:16 -08:00
Dane Urban
9c8ae5bb4b nit 2026-02-05 17:07:24 -08:00
Dane Urban
0fc1fa3d36 nits 2026-02-05 10:28:59 -08:00
Dane Urban
94633698c3 nit 2026-02-03 00:42:20 -08:00
Dane Urban
6ae15589cd nits 2026-02-02 18:56:22 -08:00
Dane Urban
c24a8bb228 Add change 2026-02-02 18:55:38 -08:00
Dane Urban
01945abd86 fix test 2026-02-02 16:49:31 -08:00
Dane Urban
658632195f nit 2026-02-02 16:47:21 -08:00
Dane Urban
ec6fd01ba4 Merge branch 'llm_provider_refactor_1' into llm_provider_refactor_2 2026-02-02 15:02:12 -08:00
Dane Urban
148e6fb97d nit 2026-02-02 15:01:57 -08:00
Dane Urban
6598c1a48d nit 2026-02-02 14:59:42 -08:00
Dane Urban
497ce43bd8 Fix some tests 2026-02-02 13:36:42 -08:00
Dane Urban
8634cb0446 Merge branch 'llm_provider_refactor_1' into llm_provider_refactor_2 2026-02-02 13:28:29 -08:00
Dane Urban
8d56fd3dc6 . 2026-02-02 13:27:08 -08:00
Dane Urban
a7579a99d0 Resolve merge conflicts 2026-02-02 12:01:44 -08:00
Dane Urban
3533c10da4 n 2026-02-02 11:48:28 -08:00
Dane Urban
7b0414bf0d fix migration 2026-02-02 11:48:08 -08:00
Dane Urban
b500ea537a nits 2026-02-02 11:46:52 -08:00
Dane Urban
abd6d55add Merge branch 'flow_mapping_table' into llm_provider_refactor_1 2026-02-02 11:44:27 -08:00
Dane Urban
f15b6b8034 Merge branch 'main' into llm_provider_refactor_1 2026-02-02 11:44:17 -08:00
Dane Urban
fb40485f25 Update this 2026-02-02 11:43:58 -08:00
Dane Urban
22e85f1f28 Merge branch 'main' into flow_mapping_table 2026-02-02 11:43:24 -08:00
Dane Urban
2ef7c3e6f3 rename 2026-02-02 11:40:21 -08:00
Dane Urban
92a471ed2b . 2026-02-02 11:35:09 -08:00
Dane Urban
d1b7e529a4 nit 2026-02-02 11:32:33 -08:00
Dane Urban
95c3579264 nits 2026-02-02 11:19:51 -08:00
Dane Urban
8802e5cad3 nit 2026-02-02 11:02:58 -08:00
Dane Urban
a41b4bbc82 fix tests 2026-02-01 22:59:15 -08:00
Dane Urban
c026c077b5 nit 2026-02-01 22:53:38 -08:00
Dane Urban
3eee539a86 Merge branch 'llm_provider_refactor_1' into llm_provider_refactor_2 2026-02-01 22:13:54 -08:00
Dane Urban
143e7a0d72 nits 2026-02-01 22:13:21 -08:00
Dane Urban
4572358038 nits 2026-02-01 22:10:37 -08:00
Dane Urban
1753f94c11 start fixes 2026-02-01 21:51:02 -08:00
Dane Urban
120ddf2ef6 Merge branch 'llm_provider_refactor_1' into llm_provider_refactor_2 2026-02-01 21:42:40 -08:00
Dane Urban
2cce5bc58f Merge branch 'main' into flow_mapping_table 2026-02-01 21:38:54 -08:00
Dane Urban
383a6001d2 nit 2026-02-01 21:37:35 -08:00
Dane Urban
3a6f45bfca Merge branch 'main' into llm_provider_refactor_1 2026-02-01 19:36:43 -08:00
Dane Urban
e06b5ef202 Merge branch 'flow_mapping_table' into llm_provider_refactor_1 2026-02-01 15:23:59 -08:00
Dane Urban
c13ce816fa fix revision id 2026-02-01 13:55:01 -08:00
Dane Urban
39f3e872ec Merge branch 'main' into flow_mapping_table 2026-02-01 13:53:53 -08:00
Dane Urban
b033c00217 . 2026-02-01 13:52:58 -08:00
Dane Urban
6d47c5f21a nit 2026-02-01 13:51:54 -08:00
Dane Urban
0645540e24 . 2026-01-31 23:44:17 -08:00
Dane Urban
a2c0fc4df0 . 2026-01-31 19:23:46 -08:00
Dane Urban
7dccc88b35 . 2026-01-31 18:24:42 -08:00
Dane Urban
ac617a51ce nits 2026-01-31 17:30:49 -08:00
Dane Urban
339a111a8f . 2026-01-30 18:19:03 -08:00
Dane Urban
09b7e6fc9b fix revision id 2026-01-30 17:39:02 -08:00
Dane Urban
135238014f Merge branch 'main' into flow_mapping_table 2026-01-30 17:38:20 -08:00
Dane Urban
303e37bf53 migrate 2026-01-30 17:38:15 -08:00
Dane Urban
6a888e9900 nit 2026-01-30 17:01:22 -08:00
Dane Urban
e90a7767c6 nit 2026-01-30 15:35:31 -08:00
Dane Urban
1ded3af63c nit 2026-01-30 14:22:27 -08:00
Dane Urban
c53546c000 nit 2026-01-30 13:03:05 -08:00
Dane Urban
9afa12edda nit 2026-01-30 13:02:48 -08:00
Dane Urban
32046de962 nit 2026-01-30 13:01:36 -08:00
71 changed files with 2901 additions and 1609 deletions

View File

@@ -0,0 +1,58 @@
"""LLMProvider deprecated fields are nullable
Revision ID: 001984c88745
Revises: 01f8e6d95a33
Create Date: 2026-02-01 22:24:34.171100
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "001984c88745"
down_revision = "01f8e6d95a33"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Make default_model_name nullable (was NOT NULL)
op.alter_column(
"llm_provider",
"default_model_name",
existing_type=sa.String(),
nullable=True,
)
# Remove server_default from is_default_vision_provider (was server_default=false())
op.alter_column(
"llm_provider",
"is_default_vision_provider",
existing_type=sa.Boolean(),
server_default=None,
)
# is_default_provider and default_vision_model are already nullable with no server_default
def downgrade() -> None:
# Restore default_model_name to NOT NULL (set empty string for any NULLs first)
op.execute(
"UPDATE llm_provider SET default_model_name = '' WHERE default_model_name IS NULL"
)
op.alter_column(
"llm_provider",
"default_model_name",
existing_type=sa.String(),
nullable=False,
)
# Restore server_default for is_default_vision_provider
op.alter_column(
"llm_provider",
"is_default_vision_provider",
existing_type=sa.Boolean(),
server_default=sa.false(),
)

View File

@@ -0,0 +1,112 @@
"""Populate flow mapping data
Revision ID: 01f8e6d95a33
Revises: f220515df7b4
Create Date: 2026-01-31 17:37:10.485558
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "01f8e6d95a33"
down_revision = "f220515df7b4"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add each model config to the conversation flow, setting the global default if it exists
# Exclude models that are part of ImageGenerationConfig
op.execute(
"""
INSERT INTO llm_model_flow (llm_model_flow_type, is_default, model_configuration_id)
SELECT
'chat' AS llm_model_flow_type,
COALESCE(
(lp.is_default_provider IS TRUE AND lp.default_model_name = mc.name),
FALSE
) AS is_default,
mc.id AS model_configuration_id
FROM model_configuration mc
LEFT JOIN llm_provider lp
ON lp.id = mc.llm_provider_id
WHERE NOT EXISTS (
SELECT 1 FROM image_generation_config igc
WHERE igc.model_configuration_id = mc.id
);
"""
)
# Add models with supports_image_input to the vision flow
op.execute(
"""
INSERT INTO llm_model_flow (llm_model_flow_type, is_default, model_configuration_id)
SELECT
'vision' AS llm_model_flow_type,
COALESCE(
(lp.is_default_vision_provider IS TRUE AND lp.default_vision_model = mc.name),
FALSE
) AS is_default,
mc.id AS model_configuration_id
FROM model_configuration mc
LEFT JOIN llm_provider lp
ON lp.id = mc.llm_provider_id
WHERE mc.supports_image_input IS TRUE;
"""
)
def downgrade() -> None:
# Populate vision defaults from model_flow
op.execute(
"""
UPDATE llm_provider AS lp
SET
is_default_vision_provider = TRUE,
default_vision_model = mc.name
FROM llm_model_flow mf
JOIN model_configuration mc ON mc.id = mf.model_configuration_id
WHERE mf.llm_model_flow_type = 'vision'
AND mf.is_default = TRUE
AND mc.llm_provider_id = lp.id;
"""
)
# Populate conversation defaults from model_flow
op.execute(
"""
UPDATE llm_provider AS lp
SET
is_default_provider = TRUE,
default_model_name = mc.name
FROM llm_model_flow mf
JOIN model_configuration mc ON mc.id = mf.model_configuration_id
WHERE mf.llm_model_flow_type = 'chat'
AND mf.is_default = TRUE
AND mc.llm_provider_id = lp.id;
"""
)
# For providers that have conversation flow mappings but aren't the default,
# we still need a default_model_name (it was NOT NULL originally)
# Pick the first visible model or any model for that provider
op.execute(
"""
UPDATE llm_provider AS lp
SET default_model_name = (
SELECT mc.name
FROM model_configuration mc
JOIN llm_model_flow mf ON mf.model_configuration_id = mc.id
WHERE mc.llm_provider_id = lp.id
AND mf.llm_model_flow_type = 'chat'
ORDER BY mc.is_visible DESC, mc.id ASC
LIMIT 1
)
WHERE lp.default_model_name IS NULL;
"""
)
# Delete all model_flow entries (reverse the inserts from upgrade)
op.execute("DELETE FROM llm_model_flow;")

View File

@@ -0,0 +1,57 @@
"""Add flow mapping table
Revision ID: f220515df7b4
Revises: cbc03e08d0f3
Create Date: 2026-01-30 12:21:24.955922
"""
from onyx.db.enums import LLMModelFlowType
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "f220515df7b4"
down_revision = "9d1543a37106"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"llm_model_flow",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"llm_model_flow_type",
sa.Enum(LLMModelFlowType, name="llmmodelflowtype", native_enum=False),
nullable=False,
),
sa.Column(
"is_default", sa.Boolean(), nullable=False, server_default=sa.text("false")
),
sa.Column("model_configuration_id", sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
["model_configuration_id"], ["model_configuration.id"], ondelete="CASCADE"
),
sa.UniqueConstraint(
"llm_model_flow_type",
"model_configuration_id",
name="uq_model_config_per_llm_model_flow_type",
),
)
# Partial unique index so that there is at most one default for each flow type
op.create_index(
"ix_one_default_per_llm_model_flow",
"llm_model_flow",
["llm_model_flow_type"],
unique=True,
postgresql_where=sa.text("is_default IS TRUE"),
)
def downgrade() -> None:
# Drop the llm_model_flow table (index is dropped automatically with table)
op.drop_table("llm_model_flow")

View File

@@ -123,9 +123,14 @@ def _seed_llms(
upsert_llm_provider(llm_upsert_request, db_session)
for llm_upsert_request in llm_upsert_requests
]
update_default_provider(
provider_id=seeded_providers[0].id, db_session=db_session
)
if len(seeded_providers[0].model_configurations) > 0:
default_model = seeded_providers[0].model_configurations[0].name
update_default_provider(
provider_id=seeded_providers[0].id,
model_name=default_model,
db_session=db_session,
)
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:

View File

@@ -300,12 +300,12 @@ def configure_default_api_keys(db_session: Session) -> None:
has_set_default_provider = False
def _upsert(request: LLMProviderUpsertRequest) -> None:
def _upsert(request: LLMProviderUpsertRequest, default_model: str) -> None:
nonlocal has_set_default_provider
try:
provider = upsert_llm_provider(request, db_session)
if not has_set_default_provider:
update_default_provider(provider.id, db_session)
update_default_provider(provider.id, default_model, db_session)
has_set_default_provider = True
except Exception as e:
logger.error(f"Failed to configure {request.provider} provider: {e}")
@@ -323,14 +323,13 @@ def configure_default_api_keys(db_session: Session) -> None:
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name=default_model_name,
model_configurations=_build_model_configuration_upsert_requests(
OPENAI_PROVIDER_NAME, recommendations
),
api_key_changed=True,
is_auto_mode=True,
)
_upsert(openai_provider)
_upsert(openai_provider, default_model_name)
# Create default image generation config using the OpenAI API key
try:
@@ -359,14 +358,13 @@ def configure_default_api_keys(db_session: Session) -> None:
name="Anthropic",
provider=ANTHROPIC_PROVIDER_NAME,
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name=default_model_name,
model_configurations=_build_model_configuration_upsert_requests(
ANTHROPIC_PROVIDER_NAME, recommendations
),
api_key_changed=True,
is_auto_mode=True,
)
_upsert(anthropic_provider)
_upsert(anthropic_provider, default_model_name)
else:
logger.info(
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
@@ -391,14 +389,13 @@ def configure_default_api_keys(db_session: Session) -> None:
name="Google Vertex AI",
provider=VERTEXAI_PROVIDER_NAME,
custom_config=custom_config,
default_model_name=default_model_name,
model_configurations=_build_model_configuration_upsert_requests(
VERTEXAI_PROVIDER_NAME, recommendations
),
api_key_changed=True,
is_auto_mode=True,
)
_upsert(vertexai_provider)
_upsert(vertexai_provider, default_model_name)
else:
logger.info(
"VERTEXAI_DEFAULT_CREDENTIALS not set, skipping Vertex AI provider configuration"
@@ -430,12 +427,11 @@ def configure_default_api_keys(db_session: Session) -> None:
name="OpenRouter",
provider=OPENROUTER_PROVIDER_NAME,
api_key=OPENROUTER_DEFAULT_API_KEY,
default_model_name=default_model_name,
model_configurations=model_configurations,
api_key_changed=True,
is_auto_mode=True,
)
_upsert(openrouter_provider)
_upsert(openrouter_provider, default_model_name)
else:
logger.info(
"OPENROUTER_DEFAULT_API_KEY not set, skipping OpenRouter provider configuration"

View File

@@ -285,3 +285,9 @@ class HierarchyNodeType(str, PyEnum):
# Slack
CHANNEL = "channel"
class LLMModelFlowType(str, PyEnum):
CHAT = "chat"
VISION = "vision"
EMBEDDINGS = "embeddings"

View File

@@ -1,12 +1,15 @@
from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.db.enums import LLMModelFlowType
from onyx.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
from onyx.db.models import DocumentSet
from onyx.db.models import ImageGenerationConfig
from onyx.db.models import LLMModelFlow
from onyx.db.models import LLMProvider as LLMProviderModel
from onyx.db.models import LLMProvider__Persona
from onyx.db.models import LLMProvider__UserGroup
@@ -233,9 +236,6 @@ def upsert_llm_provider(
existing_llm_provider.api_base = llm_provider_upsert_request.api_base
existing_llm_provider.api_version = llm_provider_upsert_request.api_version
existing_llm_provider.custom_config = custom_config
existing_llm_provider.default_model_name = (
llm_provider_upsert_request.default_model_name
)
existing_llm_provider.is_public = llm_provider_upsert_request.is_public
existing_llm_provider.is_auto_mode = llm_provider_upsert_request.is_auto_mode
existing_llm_provider.deployment_name = llm_provider_upsert_request.deployment_name
@@ -244,9 +244,29 @@ def upsert_llm_provider(
# If its not already in the db, we need to generate an ID by flushing
db_session.flush()
# Delete existing model configurations
models_to_exist = {
mc.name for mc in llm_provider_upsert_request.model_configurations
}
removed_model_configuration_ids = {
mc.id
for mc in existing_llm_provider.model_configurations
if mc.name not in models_to_exist
}
updated_model_configuration_names = {
mc.name: mc.id
for mc in existing_llm_provider.model_configurations
if mc.name in models_to_exist
}
new_model_names = models_to_exist - {
mc.name for mc in existing_llm_provider.model_configurations
}
# Delete removed models
db_session.query(ModelConfiguration).filter(
ModelConfiguration.llm_provider_id == existing_llm_provider.id
ModelConfiguration.id.in_(removed_model_configuration_ids)
).delete(synchronize_session="fetch")
db_session.flush()
@@ -263,18 +283,31 @@ def upsert_llm_provider(
model_provider=llm_provider_upsert_request.provider,
)
db_session.execute(
insert(ModelConfiguration)
.values(
supported_flows = [LLMModelFlowType.CHAT]
if model_configuration.supports_image_input:
supported_flows.append(LLMModelFlowType.VISION)
if model_configuration.name in new_model_names:
insert_new_model_configuration__no_commit(
db_session=db_session,
llm_provider_id=existing_llm_provider.id,
name=model_configuration.name,
model_name=model_configuration.name,
supported_flows=supported_flows,
is_visible=model_configuration.is_visible,
max_input_tokens=max_input_tokens,
display_name=model_configuration.display_name,
)
elif model_configuration.name in updated_model_configuration_names:
update_model_configuration__no_commit(
db_session=db_session,
model_configuration_id=updated_model_configuration_names[
model_configuration.name
],
supported_flows=supported_flows,
is_visible=model_configuration.is_visible,
max_input_tokens=max_input_tokens,
supports_image_input=model_configuration.supports_image_input,
display_name=model_configuration.display_name,
)
.on_conflict_do_nothing()
)
# Make sure the relationship table stays up to date
update_group_llm_provider_relationships__no_commit(
@@ -331,17 +364,18 @@ def sync_model_configurations(
model_name = model["name"]
if model_name not in existing_names:
# Insert new model with is_visible=False (user must explicitly enable)
db_session.execute(
insert(ModelConfiguration)
.values(
llm_provider_id=provider.id,
name=model_name,
is_visible=False,
max_input_tokens=model.get("max_input_tokens"),
supports_image_input=model.get("supports_image_input", False),
display_name=model.get("display_name"),
)
.on_conflict_do_nothing()
supported_flows = [LLMModelFlowType.CHAT]
if model.get("supports_image_input", False):
supported_flows.append(LLMModelFlowType.VISION)
insert_new_model_configuration__no_commit(
db_session=db_session,
llm_provider_id=provider.id,
model_name=model_name,
supported_flows=supported_flows,
is_visible=False,
max_input_tokens=model.get("max_input_tokens"),
display_name=model.get("display_name"),
)
new_count += 1
@@ -371,8 +405,26 @@ def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolM
)
def fetch_existing_models(
db_session: Session,
flow_types: list[LLMModelFlowType],
) -> list[ModelConfiguration]:
models = (
select(ModelConfiguration)
.join(LLMModelFlow)
.where(LLMModelFlow.llm_model_flow_type.in_(flow_types))
.options(
selectinload(ModelConfiguration.llm_provider),
selectinload(ModelConfiguration.llm_model_flows),
)
)
return list(db_session.scalars(models).all())
def fetch_existing_llm_providers(
db_session: Session,
flow_types: list[LLMModelFlowType],
only_public: bool = False,
exclude_image_generation_providers: bool = True,
) -> list[LLMProviderModel]:
@@ -380,23 +432,37 @@ def fetch_existing_llm_providers(
Args:
db_session: Database session
flow_types: List of flow types to filter by
only_public: If True, only return public providers
exclude_image_generation_providers: If True, exclude providers that are
used for image generation configs
"""
stmt = select(LLMProviderModel).options(
providers_with_flows = (
select(ModelConfiguration.llm_provider_id)
.join(LLMModelFlow)
.where(LLMModelFlow.llm_model_flow_type.in_(flow_types))
.distinct()
)
if exclude_image_generation_providers:
stmt = select(LLMProviderModel).where(
LLMProviderModel.id.in_(providers_with_flows)
)
else:
image_gen_provider_ids = select(ModelConfiguration.llm_provider_id).join(
ImageGenerationConfig
)
stmt = select(LLMProviderModel).where(
LLMProviderModel.id.in_(providers_with_flows)
| LLMProviderModel.id.in_(image_gen_provider_ids)
)
stmt = stmt.options(
selectinload(LLMProviderModel.model_configurations),
selectinload(LLMProviderModel.groups),
selectinload(LLMProviderModel.personas),
)
if exclude_image_generation_providers:
# Get LLM provider IDs used by ImageGenerationConfig
image_gen_provider_ids = select(ModelConfiguration.llm_provider_id).join(
ImageGenerationConfig
)
stmt = stmt.where(LLMProviderModel.id.not_in(image_gen_provider_ids))
providers = list(db_session.scalars(stmt).all())
if only_public:
return [provider for provider in providers if provider.is_public]
@@ -429,26 +495,29 @@ def fetch_embedding_provider(
)
def fetch_default_provider(db_session: Session) -> LLMProviderView | None:
provider_model = db_session.scalar(
select(LLMProviderModel)
.where(LLMProviderModel.is_default_provider == True) # noqa: E712
.options(selectinload(LLMProviderModel.model_configurations))
)
if not provider_model:
return None
return LLMProviderView.from_model(provider_model)
def fetch_default_llm_model(db_session: Session) -> ModelConfiguration | None:
return fetch_default_model(db_session, LLMModelFlowType.CHAT)
def fetch_default_vision_provider(db_session: Session) -> LLMProviderView | None:
provider_model = db_session.scalar(
select(LLMProviderModel)
.where(LLMProviderModel.is_default_vision_provider == True) # noqa: E712
.options(selectinload(LLMProviderModel.model_configurations))
def fetch_default_vision_model(db_session: Session) -> ModelConfiguration | None:
return fetch_default_model(db_session, LLMModelFlowType.VISION)
def fetch_default_model(
db_session: Session,
flow_type: LLMModelFlowType,
) -> ModelConfiguration | None:
model_config = db_session.scalar(
select(ModelConfiguration)
.join(LLMModelFlow)
.where(
ModelConfiguration.is_visible == True, # noqa: E712
LLMModelFlow.llm_model_flow_type == flow_type,
LLMModelFlow.is_default == True, # noqa: E712
)
)
if not provider_model:
return None
return LLMProviderView.from_model(provider_model)
return model_config
def fetch_llm_provider_view(
@@ -526,61 +595,40 @@ def remove_llm_provider__no_commit(db_session: Session, provider_id: int) -> Non
db_session.flush()
def update_default_provider(provider_id: int, db_session: Session) -> None:
new_default = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.id == provider_id)
def update_default_provider(
provider_id: int, model_name: str, db_session: Session
) -> None:
_update_default_model(
db_session,
provider_id,
model_name,
LLMModelFlowType.CHAT,
)
if not new_default:
raise ValueError(f"LLM Provider with id {provider_id} does not exist")
existing_default = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.is_default_provider == True # noqa: E712
)
)
if existing_default:
existing_default.is_default_provider = None
# required to ensure that the below does not cause a unique constraint violation
db_session.flush()
new_default.is_default_provider = True
db_session.commit()
def update_default_vision_provider(
provider_id: int, vision_model: str | None, db_session: Session
provider_id: int, vision_model: str, db_session: Session
) -> None:
new_default = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.id == provider_id)
)
if not new_default:
raise ValueError(f"LLM Provider with id {provider_id} does not exist")
# Validate that the specified vision model supports image input
model_to_validate = vision_model or new_default.default_model_name
if model_to_validate:
if not model_supports_image_input(model_to_validate, new_default.provider):
raise ValueError(
f"Model '{model_to_validate}' for provider '{new_default.provider}' does not support image input"
)
else:
raise ValueError(
f"Model '{vision_model}' is not a valid model for provider '{new_default.provider}'"
)
existing_default = db_session.scalar(
provider = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.is_default_vision_provider == True # noqa: E712
LLMProviderModel.id == provider_id,
)
)
if existing_default:
existing_default.is_default_vision_provider = None
# required to ensure that the below does not cause a unique constraint violation
db_session.flush()
new_default.is_default_vision_provider = True
new_default.default_vision_model = vision_model
db_session.commit()
if provider is None:
raise ValueError(f"LLM Provider with id={provider_id} does not exist")
if not model_supports_image_input(vision_model, provider.provider):
raise ValueError(
f"Model '{vision_model}' for provider '{provider.provider} does not support image input"
)
_update_default_model(
db_session=db_session,
provider_id=provider_id,
model=vision_model,
flow_type=LLMModelFlowType.VISION,
)
def fetch_auto_mode_providers(db_session: Session) -> list[LLMProviderModel]:
@@ -670,11 +718,162 @@ def sync_auto_mode_models(
db_session.add(new_model)
changes += 1
# In Auto mode, default model is always set from GitHub config
default_model = llm_recommendations.get_default_model(provider.provider)
if default_model and provider.default_model_name != default_model.name:
provider.default_model_name = default_model.name
changes += 1
db_session.commit()
return changes
def create_new_flow_mapping__no_commit(
db_session: Session,
model_configuration_id: int,
flow_type: LLMModelFlowType,
) -> LLMModelFlow:
result = db_session.execute(
insert(LLMModelFlow)
.values(
model_configuration_id=model_configuration_id,
llm_model_flow_type=flow_type,
is_default=False,
)
.on_conflict_do_nothing()
.returning(LLMModelFlow)
)
flow = result.scalar()
if not flow:
raise ValueError(
f"Failed to create new flow mapping for model_configuration_id={model_configuration_id} and flow_type={flow_type}"
)
return flow
def insert_new_model_configuration__no_commit(
db_session: Session,
llm_provider_id: int,
model_name: str,
supported_flows: list[LLMModelFlowType],
is_visible: bool,
max_input_tokens: int | None,
display_name: str | None,
) -> int | None:
result = db_session.execute(
insert(ModelConfiguration)
.values(
llm_provider_id=llm_provider_id,
name=model_name,
is_visible=is_visible,
max_input_tokens=max_input_tokens,
display_name=display_name,
)
.on_conflict_do_nothing()
.returning(ModelConfiguration.id)
)
model_config_id = result.scalar()
if not model_config_id:
return None
for flow_type in supported_flows:
create_new_flow_mapping__no_commit(
db_session=db_session,
model_configuration_id=model_config_id,
flow_type=flow_type,
)
return model_config_id
def update_model_configuration__no_commit(
db_session: Session,
model_configuration_id: int,
supported_flows: list[LLMModelFlowType],
is_visible: bool,
max_input_tokens: int | None,
display_name: str | None,
) -> None:
result = db_session.execute(
update(ModelConfiguration)
.values(
is_visible=is_visible,
max_input_tokens=max_input_tokens,
display_name=display_name,
)
.where(ModelConfiguration.id == model_configuration_id)
.returning(ModelConfiguration)
)
model_configuration = result.scalar()
if not model_configuration:
raise ValueError(
f"Failed to update model configuration with id={model_configuration_id}"
)
new_flows = {
flow_type
for flow_type in supported_flows
if flow_type not in model_configuration.llm_model_flow_types
}
removed_flows = {
flow_type
for flow_type in model_configuration.llm_model_flow_types
if flow_type not in supported_flows
}
for flow_type in new_flows:
create_new_flow_mapping__no_commit(
db_session=db_session,
model_configuration_id=model_configuration_id,
flow_type=flow_type,
)
for flow_type in removed_flows:
db_session.execute(
delete(LLMModelFlow).where(
LLMModelFlow.model_configuration_id == model_configuration_id,
LLMModelFlow.llm_model_flow_type == flow_type,
)
)
db_session.flush()
def _update_default_model(
db_session: Session,
provider_id: int,
model: str,
flow_type: LLMModelFlowType,
) -> None:
result = db_session.execute(
select(ModelConfiguration, LLMModelFlow)
.join(
LLMModelFlow, LLMModelFlow.model_configuration_id == ModelConfiguration.id
)
.where(
ModelConfiguration.llm_provider_id == provider_id,
ModelConfiguration.name == model,
LLMModelFlow.llm_model_flow_type == flow_type,
)
).first()
if not result:
raise ValueError(
f"Model '{model}' is not a valid model for provider_id={provider_id}"
)
model_config, new_default = result
# Clear existing default and set in an atomic operation
db_session.execute(
update(LLMModelFlow)
.where(
LLMModelFlow.llm_model_flow_type == flow_type,
LLMModelFlow.is_default == True, # noqa: E712
)
.values(is_default=False)
)
new_default.is_default = True
model_config.is_visible = True
db_session.commit()

View File

@@ -7,6 +7,7 @@ from uuid import uuid4
from pydantic import BaseModel
from sqlalchemy.orm import validates
from typing_extensions import TypedDict # noreorder
from uuid import UUID
from pydantic import ValidationError
@@ -72,6 +73,7 @@ from onyx.db.enums import (
MCPAuthenticationPerformer,
MCPTransport,
MCPServerStatus,
LLMModelFlowType,
ThemePreference,
SwitchoverType,
)
@@ -2704,14 +2706,9 @@ class LLMProvider(Base):
custom_config: Mapped[dict[str, str] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
default_model_name: Mapped[str] = mapped_column(String)
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
# should only be set for a single provider
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
is_default_vision_provider: Mapped[bool | None] = mapped_column(Boolean)
default_vision_model: Mapped[str | None] = mapped_column(String, nullable=True)
# EE only
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
# Auto mode: models, visibility, and defaults are managed by GitHub config
@@ -2761,8 +2758,6 @@ class ModelConfiguration(Base):
# - The end-user is configuring a model and chooses not to set a max-input-tokens limit.
max_input_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
supports_image_input: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
# Human-readable display name for the model.
# For dynamic providers (OpenRouter, Bedrock, Ollama), this comes from the source API.
# For static providers (OpenAI, Anthropic), this may be null and will fall back to LiteLLM.
@@ -2773,6 +2768,51 @@ class ModelConfiguration(Base):
back_populates="model_configurations",
)
llm_model_flows: Mapped[list["LLMModelFlow"]] = relationship(
"LLMModelFlow",
back_populates="model_configuration",
cascade="all, delete-orphan",
passive_deletes=True,
)
@property
def llm_model_flow_types(self) -> list[LLMModelFlowType]:
return [flow.llm_model_flow_type for flow in self.llm_model_flows]
class LLMModelFlow(Base):
__tablename__ = "llm_model_flow"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
llm_model_flow_type: Mapped[LLMModelFlowType] = mapped_column(
Enum(LLMModelFlowType, native_enum=False), nullable=False
)
model_configuration_id: Mapped[int] = mapped_column(
ForeignKey("model_configuration.id", ondelete="CASCADE"),
nullable=False,
)
is_default: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
model_configuration: Mapped["ModelConfiguration"] = relationship(
"ModelConfiguration",
back_populates="llm_model_flows",
)
__table_args__ = (
UniqueConstraint(
"llm_model_flow_type",
"model_configuration_id",
name="uq_model_config_per_llm_model_flow_type",
),
Index(
"ix_one_default_per_llm_model_flow",
"llm_model_flow_type",
unique=True,
postgresql_where=(is_default == True), # noqa: E712
),
)
class ImageGenerationConfig(Base):
__tablename__ = "image_generation_config"

View File

@@ -1,22 +1,20 @@
from collections.abc import Callable
from sqlalchemy.orm import Session
from onyx.chat.models import PersonaOverrideConfig
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import LLMModelFlowType
from onyx.db.llm import can_user_access_llm_provider
from onyx.db.llm import fetch_default_provider
from onyx.db.llm import fetch_default_vision_provider
from onyx.db.llm import fetch_default_llm_model
from onyx.db.llm import fetch_default_vision_model
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_existing_models
from onyx.db.llm import fetch_llm_provider_view
from onyx.db.llm import fetch_user_group_ids
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.llm.constants import LlmProviderNames
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMConfig
from onyx.llm.multi_llm import LitellmLLM
from onyx.llm.override_models import LLMOverride
from onyx.llm.utils import get_max_input_tokens_from_llm_provider
@@ -52,54 +50,6 @@ def _build_provider_extra_headers(
return {}
def get_llm_config_for_persona(
persona: Persona,
db_session: Session,
llm_override: LLMOverride | None = None,
) -> LLMConfig:
"""Get LLM config from persona without access checks.
This function assumes access to the persona has already been verified.
Use this when you need the LLM config but don't need to create the full LLM object.
"""
provider_name_override = llm_override.model_provider if llm_override else None
model_version_override = llm_override.model_version if llm_override else None
temperature_override = llm_override.temperature if llm_override else None
provider_name = provider_name_override or persona.llm_model_provider_override
if not provider_name:
llm_provider = fetch_default_provider(db_session)
if not llm_provider:
raise ValueError("No default LLM provider found")
model_name: str | None = llm_provider.default_model_name
else:
llm_provider = fetch_llm_provider_view(db_session, provider_name)
if not llm_provider:
raise ValueError(f"No LLM provider found with name: {provider_name}")
model_name = model_version_override or persona.llm_model_version_override
if not model_name:
model_name = llm_provider.default_model_name
if not model_name:
raise ValueError("No model name found")
max_input_tokens = get_max_input_tokens_from_llm_provider(
llm_provider=llm_provider, model_name=model_name
)
return LLMConfig(
model_provider=llm_provider.provider,
model_name=model_name,
temperature=temperature_override or GEN_AI_TEMPERATURE,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
deployment_name=llm_provider.deployment_name,
custom_config=llm_provider.custom_config,
max_input_tokens=max_input_tokens,
)
def get_llm_for_persona(
persona: Persona | PersonaOverrideConfig | None,
user: User,
@@ -203,46 +153,41 @@ def get_default_llm_with_vision(
with get_session_with_current_tenant() as db_session:
# Try the default vision provider first
default_provider = fetch_default_vision_provider(db_session)
if default_provider and default_provider.default_vision_model:
default_model = fetch_default_vision_model(db_session)
if default_model:
if model_supports_image_input(
default_provider.default_vision_model, default_provider.provider
default_model.name, default_model.llm_provider.provider
):
return create_vision_llm(
default_provider, default_provider.default_vision_model
LLMProviderView.from_model(default_model.llm_provider),
default_model.name,
)
# Fall back to searching all providers
providers = fetch_existing_llm_providers(db_session)
models = fetch_existing_models(
db_session=db_session,
flow_types=[LLMModelFlowType.VISION, LLMModelFlowType.CHAT],
)
if not providers:
if not models:
return None
# Check all providers for viable vision models
for provider in providers:
provider_view = LLMProviderView.from_model(provider)
# Search for viable vision model followed by conversation models
# Sort models from VISION to CONVERSATION priority
sorted_models = sorted(
models,
key=lambda x: (
LLMModelFlowType.VISION in x.llm_model_flow_types,
LLMModelFlowType.CHAT in x.llm_model_flow_types,
),
reverse=True,
)
# First priority: Check if provider has a default_vision_model
if provider.default_vision_model and model_supports_image_input(
provider.default_vision_model, provider.provider
):
return create_vision_llm(provider_view, provider.default_vision_model)
# If no model-configurations are specified, try default model
if not provider.model_configurations:
# Try default_model_name
if provider.default_model_name and model_supports_image_input(
provider.default_model_name, provider.provider
):
return create_vision_llm(provider_view, provider.default_model_name)
# Otherwise, if model-configurations are specified, check each model
else:
for model_configuration in provider.model_configurations:
if model_supports_image_input(
model_configuration.name, provider.provider
):
return create_vision_llm(provider_view, model_configuration.name)
for model in sorted_models:
if model_supports_image_input(model.name, model.llm_provider.provider):
return create_vision_llm(
LLMProviderView.from_model(model.llm_provider),
model.name,
)
return None
@@ -288,22 +233,18 @@ def get_default_llm(
additional_headers: dict[str, str] | None = None,
) -> LLM:
with get_session_with_current_tenant() as db_session:
llm_provider = fetch_default_provider(db_session)
model = fetch_default_llm_model(db_session)
if not llm_provider:
raise ValueError("No default LLM provider found")
if not model:
raise ValueError("No default LLM model found")
model_name = llm_provider.default_model_name
if not model_name:
raise ValueError("No default model name found")
return llm_from_provider(
model_name=model_name,
llm_provider=llm_provider,
timeout=timeout,
temperature=temperature,
additional_headers=additional_headers,
)
return llm_from_provider(
model_name=model.name,
llm_provider=LLMProviderView.from_model(model.llm_provider),
timeout=timeout,
temperature=temperature,
additional_headers=additional_headers,
)
def get_llm(

View File

@@ -17,6 +17,7 @@ from onyx.configs.model_configs import GEN_AI_MAX_TOKENS
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from onyx.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import LLMModelFlowType
from onyx.db.models import LLMProvider
from onyx.db.models import ModelConfiguration
from onyx.llm.constants import LlmProviderNames
@@ -689,8 +690,11 @@ def model_supports_image_input(model_name: str, model_provider: str) -> bool:
LLMProvider.provider == model_provider,
)
)
if model_config and model_config.supports_image_input is not None:
return model_config.supports_image_input
if (
model_config
and LLMModelFlowType.VISION in model_config.llm_model_flow_types
):
return True
except Exception as e:
logger.warning(
f"Failed to query database for {model_provider} model {model_name} image support: {e}"

View File

@@ -28,7 +28,7 @@ from sqlalchemy.orm import Session as DBSession
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import MessageType
from onyx.db.enums import SandboxStatus
from onyx.db.llm import fetch_default_provider
from onyx.db.llm import fetch_default_llm_model
from onyx.db.models import BuildMessage
from onyx.db.models import BuildSession
from onyx.db.models import User
@@ -338,15 +338,15 @@ class SessionManager:
)
# Fallback to system default
default_provider = fetch_default_provider(self._db_session)
if not default_provider:
raise ValueError("No LLM provider configured")
default_model = fetch_default_llm_model(self._db_session)
if not default_model:
raise ValueError("No default LLM model found")
return LLMProviderConfig(
provider=default_provider.provider,
model_name=default_provider.default_model_name,
api_key=default_provider.api_key,
api_base=default_provider.api_base,
provider=default_model.llm_provider.provider,
model_name=default_model.name,
api_key=default_model.llm_provider.api_key,
api_base=default_model.llm_provider.api_base,
)
# =========================================================================

View File

@@ -93,7 +93,6 @@ def _build_llm_provider_request(
api_key=source_provider.api_key, # Only this from source
api_base=api_base, # From request
api_version=api_version, # From request
default_model_name=model_name,
deployment_name=deployment_name, # From request
is_public=True,
groups=[],
@@ -132,7 +131,6 @@ def _build_llm_provider_request(
api_key=api_key,
api_base=api_base,
api_version=api_version,
default_model_name=model_name,
deployment_name=deployment_name,
is_public=True,
groups=[],
@@ -164,7 +162,6 @@ def _create_image_gen_llm_provider__no_commit(
api_key=provider_request.api_key,
api_base=provider_request.api_base,
api_version=provider_request.api_version,
default_model_name=provider_request.default_model_name,
deployment_name=provider_request.deployment_name,
is_public=provider_request.is_public,
custom_config=provider_request.custom_config,

View File

@@ -1,4 +1,5 @@
import os
from collections import defaultdict
from datetime import datetime
from datetime import timezone
from typing import Any
@@ -19,9 +20,13 @@ from onyx.auth.schemas import UserRole
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accessible_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import LLMModelFlowType
from onyx.db.llm import can_user_access_llm_provider
from onyx.db.llm import fetch_default_llm_model
from onyx.db.llm import fetch_default_vision_model
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_existing_models
from onyx.db.llm import fetch_persona_with_groups
from onyx.db.llm import fetch_user_group_ids
from onyx.db.llm import remove_llm_provider
@@ -37,7 +42,6 @@ from onyx.llm.factory import get_llm
from onyx.llm.factory import get_max_input_tokens_from_llm_provider
from onyx.llm.utils import get_bedrock_token_limit
from onyx.llm.utils import get_llm_contextual_cost
from onyx.llm.utils import model_supports_image_input
from onyx.llm.utils import test_llm
from onyx.llm.well_known_providers.auto_update_service import (
fetch_llm_recommendations_from_github,
@@ -50,11 +54,12 @@ from onyx.llm.well_known_providers.llm_provider_options import (
)
from onyx.server.manage.llm.models import BedrockFinalModelResponse
from onyx.server.manage.llm.models import BedrockModelsRequest
from onyx.server.manage.llm.models import DefaultModel
from onyx.server.manage.llm.models import LLMCost
from onyx.server.manage.llm.models import LLMProviderDescriptor
from onyx.server.manage.llm.models import LLMProviderResponse
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.server.manage.llm.models import OllamaFinalModelResponse
from onyx.server.manage.llm.models import OllamaModelDetails
from onyx.server.manage.llm.models import OllamaModelsRequest
@@ -211,7 +216,7 @@ def test_llm_configuration(
llm = get_llm(
provider=test_llm_request.provider,
model=test_llm_request.default_model_name,
model=test_llm_request.model,
api_key=test_api_key,
api_base=test_llm_request.api_base,
api_version=test_llm_request.api_version,
@@ -246,13 +251,15 @@ def list_llm_providers(
include_image_gen: bool = Query(False),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderView]:
) -> LLMProviderResponse[LLMProviderView]:
start_time = datetime.now(timezone.utc)
logger.debug("Starting to fetch LLM providers")
llm_provider_list: list[LLMProviderView] = []
for llm_provider_model in fetch_existing_llm_providers(
db_session, exclude_image_generation_providers=not include_image_gen
db_session=db_session,
flow_types=[LLMModelFlowType.CHAT, LLMModelFlowType.VISION],
exclude_image_generation_providers=not include_image_gen,
):
from_model_start = datetime.now(timezone.utc)
full_llm_provider = LLMProviderView.from_model(llm_provider_model)
@@ -269,7 +276,25 @@ def list_llm_providers(
duration = (end_time - start_time).total_seconds()
logger.debug(f"Completed fetching LLM providers in {duration:.2f} seconds")
return llm_provider_list
default_model = None
if model_config := fetch_default_llm_model(db_session):
default_model = DefaultModel(
provider_id=model_config.llm_provider.id,
model_name=model_config.name,
)
default_vision_model = None
if model_config := fetch_default_vision_model(db_session):
default_vision_model = DefaultModel(
provider_id=model_config.llm_provider.id,
model_name=model_config.name,
)
return LLMProviderResponse[LLMProviderView].from_models(
providers=llm_provider_list,
default_text=default_model,
default_vision=default_vision_model,
)
@admin_router.put("/provider")
@@ -328,20 +353,6 @@ def put_llm_provider(
deduplicated_personas.append(persona_id)
llm_provider_upsert_request.personas = deduplicated_personas
default_model_found = False
for model_configuration in llm_provider_upsert_request.model_configurations:
if model_configuration.name == llm_provider_upsert_request.default_model_name:
model_configuration.is_visible = True
default_model_found = True
if not default_model_found:
llm_provider_upsert_request.model_configurations.append(
ModelConfigurationUpsertRequest(
name=llm_provider_upsert_request.default_model_name, is_visible=True
)
)
# the llm api key is sanitized when returned to clients, so the only time we
# should get a real key is when it is explicitly changed
if existing_provider and not llm_provider_upsert_request.api_key_changed:
@@ -398,26 +409,29 @@ def delete_llm_provider(
raise HTTPException(status_code=404, detail=str(e))
@admin_router.post("/provider/{provider_id}/default")
@admin_router.post("/default")
def set_provider_as_default(
provider_id: int,
default_model_request: DefaultModel,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_default_provider(provider_id=provider_id, db_session=db_session)
update_default_provider(
provider_id=default_model_request.provider_id,
model_name=default_model_request.model_name,
db_session=db_session,
)
@admin_router.post("/provider/{provider_id}/default-vision")
@admin_router.post("/default-vision")
def set_provider_as_default_vision(
provider_id: int,
vision_model: str | None = Query(
None, description="The default vision model to use"
),
default_model_request: DefaultModel,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_default_vision_provider(
provider_id=provider_id, vision_model=vision_model, db_session=db_session
provider_id=default_model_request.provider_id,
vision_model=default_model_request.model_name,
db_session=db_session,
)
@@ -443,43 +457,47 @@ def get_auto_config(
def get_vision_capable_providers(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[VisionProviderResponse]:
) -> LLMProviderResponse[VisionProviderResponse]:
"""Return a list of LLM providers and their models that support image input"""
vision_models = fetch_existing_models(
db_session=db_session, flow_types=[LLMModelFlowType.VISION]
)
providers = fetch_existing_llm_providers(db_session)
vision_providers = []
# Group vision models by provider ID (using ID as key since it's hashable)
provider_models: dict[int, list[str]] = defaultdict(list)
providers_by_id: dict[int, LLMProviderView] = {}
logger.info("Fetching vision-capable providers")
for provider in providers:
vision_models = []
# Check each model for vision capability
for model_configuration in provider.model_configurations:
if model_supports_image_input(model_configuration.name, provider.provider):
vision_models.append(model_configuration.name)
logger.debug(
f"Vision model found: {provider.provider}/{model_configuration.name}"
)
# Only include providers with at least one vision-capable model
if vision_models:
provider_view = LLMProviderView.from_model(provider)
for vision_model in vision_models:
provider_id = vision_model.llm_provider.id
provider_models[provider_id].append(vision_model.name)
# Only create the view once per provider
if provider_id not in providers_by_id:
provider_view = LLMProviderView.from_model(vision_model.llm_provider)
_mask_provider_credentials(provider_view)
providers_by_id[provider_id] = provider_view
vision_providers.append(
VisionProviderResponse(
**provider_view.model_dump(),
vision_models=vision_models,
)
)
# Build response list
vision_provider_response = [
VisionProviderResponse(
**providers_by_id[provider_id].model_dump(),
vision_models=model_names,
)
for provider_id, model_names in provider_models.items()
]
logger.info(
f"Vision provider: {provider.provider} with models: {vision_models}"
)
logger.debug(f"Found {len(vision_provider_response)} vision-capable providers")
logger.info(f"Found {len(vision_providers)} vision-capable providers")
return vision_providers
default_vision_model = None
if model_config := fetch_default_vision_model(db_session):
default_vision_model = DefaultModel(
provider_id=model_config.llm_provider.id,
model_name=model_config.name,
)
return LLMProviderResponse[VisionProviderResponse].from_models(
providers=vision_provider_response,
default_vision=default_vision_model,
)
"""Endpoints for all"""
@@ -489,7 +507,7 @@ def get_vision_capable_providers(
def list_llm_provider_basics(
user: User = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderDescriptor]:
) -> LLMProviderResponse[LLMProviderDescriptor]:
"""Get LLM providers accessible to the current user.
Returns:
@@ -502,7 +520,9 @@ def list_llm_provider_basics(
start_time = datetime.now(timezone.utc)
logger.debug("Starting to fetch user-accessible LLM providers")
all_providers = fetch_existing_llm_providers(db_session)
all_providers = fetch_existing_llm_providers(
db_session, [LLMModelFlowType.CHAT, LLMModelFlowType.VISION]
)
user_group_ids = fetch_user_group_ids(db_session, user)
is_admin = user.role == UserRole.ADMIN
@@ -526,7 +546,25 @@ def list_llm_provider_basics(
f"Completed fetching {len(accessible_providers)} user-accessible providers in {duration:.2f} seconds"
)
return accessible_providers
default_model = None
if model_config := fetch_default_llm_model(db_session):
default_model = DefaultModel(
provider_id=model_config.llm_provider.id,
model_name=model_config.name,
)
default_vision_model = None
if model_config := fetch_default_vision_model(db_session):
default_vision_model = DefaultModel(
provider_id=model_config.llm_provider.id,
model_name=model_config.name,
)
return LLMProviderResponse[LLMProviderDescriptor].from_models(
providers=accessible_providers,
default_text=default_model,
default_vision=default_vision_model,
)
def get_valid_model_names_for_persona(
@@ -545,7 +583,9 @@ def get_valid_model_names_for_persona(
return []
is_admin = user.role == UserRole.ADMIN
all_providers = fetch_existing_llm_providers(db_session)
all_providers = fetch_existing_llm_providers(
db_session, [LLMModelFlowType.CHAT, LLMModelFlowType.VISION]
)
user_group_ids = set() if is_admin else fetch_user_group_ids(db_session, user)
valid_models = []
@@ -567,7 +607,7 @@ def list_llm_providers_for_persona(
persona_id: int,
user: User = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderDescriptor]:
) -> LLMProviderResponse[LLMProviderDescriptor]:
"""Get LLM providers for a specific persona.
Returns providers that the user can access when using this persona:
@@ -592,7 +632,9 @@ def list_llm_providers_for_persona(
)
is_admin = user.role == UserRole.ADMIN
all_providers = fetch_existing_llm_providers(db_session)
all_providers = fetch_existing_llm_providers(
db_session, [LLMModelFlowType.CHAT, LLMModelFlowType.VISION]
)
user_group_ids = set() if is_admin else fetch_user_group_ids(db_session, user)
llm_provider_list: list[LLMProviderDescriptor] = []
@@ -612,7 +654,11 @@ def list_llm_providers_for_persona(
f"Completed fetching {len(llm_provider_list)} LLM providers for persona {persona_id} in {duration:.2f} seconds"
)
return llm_provider_list
return LLMProviderResponse[LLMProviderDescriptor].from_models(
providers=llm_provider_list,
default_text=None,
default_vision=None,
)
@admin_router.get("/provider-contextual-cost")
@@ -630,7 +676,7 @@ def get_provider_contextual_cost(
- the chunk_context
- The per-token cost of the LLM used to generate the doc_summary and chunk_context
"""
providers = fetch_existing_llm_providers(db_session)
providers = fetch_existing_llm_providers(db_session, [LLMModelFlowType.CHAT])
costs = []
for provider in providers:
for model_configuration in provider.model_configurations:

View File

@@ -1,10 +1,13 @@
from typing import Any
from typing import Generic
from typing import TYPE_CHECKING
from typing import TypeVar
from pydantic import BaseModel
from pydantic import Field
from pydantic import field_validator
from onyx.db.enums import LLMModelFlowType
from onyx.llm.utils import get_max_input_tokens
from onyx.llm.utils import litellm_thinks_model_supports_image_input
from onyx.llm.utils import model_is_reasoning_model
@@ -20,22 +23,22 @@ if TYPE_CHECKING:
ModelConfiguration as ModelConfigurationModel,
)
T = TypeVar("T", bound="LLMProviderDescriptor | LLMProviderView")
class TestLLMRequest(BaseModel):
# provider level
name: str | None = None
provider: str
model: str
api_key: str | None = None
api_base: str | None = None
api_version: str | None = None
custom_config: dict[str, str] | None = None
# model level
default_model_name: str
deployment_name: str | None = None
model_configurations: list["ModelConfigurationUpsertRequest"]
# if try and use the existing API/custom config key
api_key_changed: bool
custom_config_changed: bool
@@ -51,13 +54,10 @@ class LLMProviderDescriptor(BaseModel):
"""A descriptor for an LLM provider that can be safely viewed by
non-admin users. Used when giving a list of available LLMs."""
id: int
name: str
provider: str
provider_display_name: str # Human-friendly name like "Claude (Anthropic)"
default_model_name: str
is_default_provider: bool | None
is_default_vision_provider: bool | None
default_vision_model: str | None
model_configurations: list["ModelConfigurationView"]
@classmethod
@@ -72,13 +72,10 @@ class LLMProviderDescriptor(BaseModel):
provider = llm_provider_model.provider
return cls(
id=llm_provider_model.id,
name=llm_provider_model.name,
provider=provider,
provider_display_name=get_provider_display_name(provider),
default_model_name=llm_provider_model.default_model_name,
is_default_provider=llm_provider_model.is_default_provider,
is_default_vision_provider=llm_provider_model.is_default_vision_provider,
default_vision_model=llm_provider_model.default_vision_model,
model_configurations=filter_model_configurations(
llm_provider_model.model_configurations, provider
),
@@ -92,13 +89,11 @@ class LLMProvider(BaseModel):
api_base: str | None = None
api_version: str | None = None
custom_config: dict[str, str] | None = None
default_model_name: str
is_public: bool = True
is_auto_mode: bool = False
groups: list[int] = Field(default_factory=list)
personas: list[int] = Field(default_factory=list)
deployment_name: str | None = None
default_vision_model: str | None = None
class LLMProviderUpsertRequest(LLMProvider):
@@ -119,8 +114,6 @@ class LLMProviderView(LLMProvider):
"""Stripped down representation of LLMProvider for display / limited access info only"""
id: int
is_default_provider: bool | None = None
is_default_vision_provider: bool | None = None
model_configurations: list["ModelConfigurationView"]
@classmethod
@@ -150,10 +143,6 @@ class LLMProviderView(LLMProvider):
api_base=llm_provider_model.api_base,
api_version=llm_provider_model.api_version,
custom_config=llm_provider_model.custom_config,
default_model_name=llm_provider_model.default_model_name,
is_default_provider=llm_provider_model.is_default_provider,
is_default_vision_provider=llm_provider_model.is_default_vision_provider,
default_vision_model=llm_provider_model.default_vision_model,
is_public=llm_provider_model.is_public,
is_auto_mode=llm_provider_model.is_auto_mode,
groups=groups,
@@ -180,7 +169,8 @@ class ModelConfigurationUpsertRequest(BaseModel):
name=model_configuration_model.name,
is_visible=model_configuration_model.is_visible,
max_input_tokens=model_configuration_model.max_input_tokens,
supports_image_input=model_configuration_model.supports_image_input,
supports_image_input=LLMModelFlowType.VISION
in model_configuration_model.llm_model_flow_types,
display_name=model_configuration_model.display_name,
)
@@ -219,7 +209,8 @@ class ModelConfigurationView(BaseModel):
is_visible=model_configuration_model.is_visible,
max_input_tokens=model_configuration_model.max_input_tokens,
supports_image_input=(
model_configuration_model.supports_image_input or False
LLMModelFlowType.VISION
in model_configuration_model.llm_model_flow_types
),
# Infer reasoning support from model name/display name
supports_reasoning=is_reasoning_model(
@@ -261,8 +252,9 @@ class ModelConfigurationView(BaseModel):
)
),
supports_image_input=(
val
if (val := model_configuration_model.supports_image_input) is not None
True
if LLMModelFlowType.VISION
in model_configuration_model.llm_model_flow_types
else litellm_thinks_model_supports_image_input(
model_configuration_model.name, provider_name
)
@@ -371,3 +363,27 @@ class OpenRouterFinalModelResponse(BaseModel):
int | None
) # From OpenRouter API context_length (may be missing for some models)
supports_image_input: bool
class DefaultModel(BaseModel):
provider_id: int
model_name: str
class LLMProviderResponse(BaseModel, Generic[T]):
providers: list[T]
default_text: DefaultModel | None = None
default_vision: DefaultModel | None = None
@classmethod
def from_models(
cls,
providers: list[T],
default_text: DefaultModel | None = None,
default_vision: DefaultModel | None = None,
) -> "LLMProviderResponse[T]":
return cls(
providers=providers,
default_text=default_text,
default_vision=default_vision,
)

View File

@@ -23,7 +23,7 @@ from onyx.db.document import check_docs_exist
from onyx.db.enums import EmbeddingPrecision
from onyx.db.index_attempt import cancel_indexing_attempts_past_model
from onyx.db.index_attempt import expire_index_attempts
from onyx.db.llm import fetch_default_provider
from onyx.db.llm import fetch_default_llm_model
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.search_settings import get_active_search_settings
@@ -286,7 +286,11 @@ def setup_postgres(db_session: Session) -> None:
create_initial_default_connector(db_session)
associate_default_cc_pair(db_session)
if GEN_AI_API_KEY and fetch_default_provider(db_session) is None:
if (
GEN_AI_API_KEY
and fetch_default_llm_model(db_session) is None
and not INTEGRATION_TESTS_MODE
):
# Only for dev flows
logger.notice("Setting up default OpenAI LLM for dev.")
@@ -298,7 +302,6 @@ def setup_postgres(db_session: Session) -> None:
api_base=None,
api_version=None,
custom_config=None,
default_model_name=llm_model,
is_public=True,
groups=[],
model_configurations=[
@@ -310,7 +313,9 @@ def setup_postgres(db_session: Session) -> None:
new_llm_provider = upsert_llm_provider(
llm_provider_upsert_request=model_req, db_session=db_session
)
update_default_provider(provider_id=new_llm_provider.id, db_session=db_session)
update_default_provider(
provider_id=new_llm_provider.id, model_name=llm_model, db_session=db_session
)
def update_default_multipass_indexing(db_session: Session) -> None:

View File

@@ -17,7 +17,7 @@ def test_bedrock_llm_configuration(client: TestClient) -> None:
# Prepare the test request payload
test_request: dict[str, Any] = {
"provider": LlmProviderNames.BEDROCK,
"default_model_name": _DEFAULT_BEDROCK_MODEL,
"model": _DEFAULT_BEDROCK_MODEL,
"api_key": None,
"api_base": None,
"api_version": None,
@@ -26,7 +26,6 @@ def test_bedrock_llm_configuration(client: TestClient) -> None:
"AWS_ACCESS_KEY_ID": os.environ.get("AWS_ACCESS_KEY_ID"),
"AWS_SECRET_ACCESS_KEY": os.environ.get("AWS_SECRET_ACCESS_KEY"),
},
"model_configurations": [{"name": _DEFAULT_BEDROCK_MODEL, "is_visible": True}],
"api_key_changed": True,
}
@@ -43,7 +42,7 @@ def test_bedrock_llm_configuration_invalid_key(client: TestClient) -> None:
# Prepare the test request payload with invalid credentials
test_request: dict[str, Any] = {
"provider": LlmProviderNames.BEDROCK,
"default_model_name": _DEFAULT_BEDROCK_MODEL,
"model": _DEFAULT_BEDROCK_MODEL,
"api_key": None,
"api_base": None,
"api_version": None,
@@ -52,7 +51,6 @@ def test_bedrock_llm_configuration_invalid_key(client: TestClient) -> None:
"AWS_ACCESS_KEY_ID": "invalid_access_key_id",
"AWS_SECRET_ACCESS_KEY": "invalid_secret_access_key",
},
"model_configurations": [{"name": _DEFAULT_BEDROCK_MODEL, "is_visible": True}],
"api_key_changed": True,
}

View File

@@ -12,6 +12,7 @@ from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.llm.constants import LlmProviderNames
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
# Counter for generating unique file IDs in mock file store
@@ -27,14 +28,19 @@ def ensure_default_llm_provider(db_session: Session) -> None:
provider=LlmProviderNames.OPENAI,
api_key=os.environ.get("OPENAI_API_KEY", "test"),
is_public=True,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini",
is_visible=True,
)
],
groups=[],
)
provider = upsert_llm_provider(
llm_provider_upsert_request=llm_provider_request,
db_session=db_session,
)
update_default_provider(provider.id, db_session)
update_default_provider(provider.id, "gpt-4o-mini", db_session)
except Exception as exc: # pragma: no cover - only hits on duplicate setup issues
print(f"Note: Could not create LLM provider: {exc}")

View File

@@ -10,6 +10,7 @@ from onyx.chat.models import MessageResponseIDInfo
from onyx.chat.models import StreamingError
from onyx.chat.process_message import stream_chat_message_objects
from onyx.db.chat import create_chat_session
from onyx.db.enums import LLMModelFlowType
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import update_default_provider
@@ -36,7 +37,7 @@ def test_answer_with_only_anthropic_provider(
assert anthropic_api_key, "ANTHROPIC_API_KEY environment variable must be set"
# Drop any existing providers so that only Anthropic is available.
for provider in fetch_existing_llm_providers(db_session):
for provider in fetch_existing_llm_providers(db_session, [LLMModelFlowType.CHAT]):
remove_llm_provider(db_session, provider.id)
anthropic_model = "claude-haiku-4-5-20251001"
@@ -47,7 +48,6 @@ def test_answer_with_only_anthropic_provider(
name=provider_name,
provider=LlmProviderNames.ANTHROPIC,
api_key=anthropic_api_key,
default_model_name=anthropic_model,
is_public=True,
groups=[],
model_configurations=[
@@ -59,7 +59,7 @@ def test_answer_with_only_anthropic_provider(
)
try:
update_default_provider(anthropic_provider.id, db_session)
update_default_provider(anthropic_provider.id, anthropic_model, db_session)
test_user = create_test_user(db_session, email_prefix="anthropic_only")
chat_session = create_chat_session(

View File

@@ -14,6 +14,7 @@ import pytest
from fastapi import HTTPException
from sqlalchemy.orm import Session
from onyx.db.enums import LLMModelFlowType
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import update_default_provider
@@ -106,12 +107,7 @@ class TestLLMConfigurationEndpoint:
api_key="sk-new-test-key-0000000000000000000000000000",
api_key_changed=True,
custom_config_changed=False,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -156,12 +152,7 @@ class TestLLMConfigurationEndpoint:
api_key="sk-invalid-key-00000000000000000000000000",
api_key_changed=True,
custom_config_changed=False,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -206,12 +197,7 @@ class TestLLMConfigurationEndpoint:
api_key=None, # Not providing a new key
api_key_changed=False, # Using existing key
custom_config_changed=False,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -258,12 +244,7 @@ class TestLLMConfigurationEndpoint:
api_key=new_api_key, # Providing a new key
api_key_changed=True, # Key is being changed
custom_config_changed=False,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -304,12 +285,7 @@ class TestLLMConfigurationEndpoint:
api_key_changed=True,
custom_config=original_custom_config,
custom_config_changed=True,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
db_session=db_session,
)
@@ -326,12 +302,7 @@ class TestLLMConfigurationEndpoint:
api_key_changed=False,
custom_config=None, # Not providing new config
custom_config_changed=False, # Using existing config
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -372,12 +343,7 @@ class TestLLMConfigurationEndpoint:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
custom_config_changed=False,
default_model_name=model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=model_name, is_visible=True
)
],
model=model_name,
),
_=_create_mock_admin(),
db_session=db_session,
@@ -441,7 +407,6 @@ class TestDefaultProviderEndpoint:
provider=LlmProviderNames.OPENAI,
api_key=provider_1_api_key,
api_key_changed=True,
default_model_name=provider_1_initial_model,
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
@@ -451,7 +416,7 @@ class TestDefaultProviderEndpoint:
)
# Set provider 1 as the default provider explicitly
update_default_provider(provider_1.id, db_session)
update_default_provider(provider_1.id, provider_1_initial_model, db_session)
# Step 2: Call run_test_default_provider - should use provider 1's default model
with patch(
@@ -471,7 +436,6 @@ class TestDefaultProviderEndpoint:
provider=LlmProviderNames.OPENAI,
api_key=provider_2_api_key,
api_key_changed=True,
default_model_name=provider_2_default_model,
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
@@ -502,7 +466,6 @@ class TestDefaultProviderEndpoint:
provider=LlmProviderNames.OPENAI,
api_key=provider_1_api_key,
api_key_changed=True,
default_model_name=provider_1_updated_model, # Changed
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
@@ -511,6 +474,9 @@ class TestDefaultProviderEndpoint:
db_session=db_session,
)
# Set provider 1's default model to the updated model
update_default_provider(provider_1.id, provider_1_updated_model, db_session)
# Step 6: Call run_test_default_provider - should use new model on provider 1
with patch(
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
@@ -523,7 +489,7 @@ class TestDefaultProviderEndpoint:
captured_llms.clear()
# Step 7: Change the default provider to provider 2
update_default_provider(provider_2.id, db_session)
update_default_provider(provider_2.id, provider_2_default_model, db_session)
# Step 8: Call run_test_default_provider - should use provider 2
with patch(
@@ -551,7 +517,9 @@ class TestDefaultProviderEndpoint:
from onyx.db.llm import fetch_existing_llm_providers
try:
existing_providers = fetch_existing_llm_providers(db_session)
existing_providers = fetch_existing_llm_providers(
db_session, flow_types=[LLMModelFlowType.CHAT]
)
provider_names_to_restore: list[str] = []
for provider in existing_providers:
@@ -593,7 +561,6 @@ class TestDefaultProviderEndpoint:
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
@@ -602,7 +569,7 @@ class TestDefaultProviderEndpoint:
),
db_session=db_session,
)
update_default_provider(provider.id, db_session)
update_default_provider(provider.id, "gpt-4o-mini", db_session)
# Test should fail
with patch(

View File

@@ -500,7 +500,7 @@ def test_upload_with_custom_config_then_change(
LLMTestRequest(
name=name,
provider=provider_name,
default_model_name=default_model_name,
model=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
@@ -539,7 +539,7 @@ def test_upload_with_custom_config_then_change(
LLMTestRequest(
name=name,
provider=provider_name,
default_model_name=default_model_name,
model=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True

View File

@@ -14,7 +14,7 @@ from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from onyx.db.llm import fetch_default_provider
from onyx.db.llm import fetch_default_llm_model
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import update_default_provider
@@ -132,7 +132,6 @@ class TestAutoModeSyncFeature:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
default_model_name=expected_default_model,
model_configurations=[], # No model configs provided
),
is_creation=True,
@@ -160,25 +159,20 @@ class TestAutoModeSyncFeature:
if mc.name in all_expected_models:
assert mc.is_visible is True, f"Model '{mc.name}' should be visible"
# Verify the default model was set correctly
assert (
provider.default_model_name == expected_default_model
), f"Default model should be '{expected_default_model}'"
# Step 4: Set the provider as default
update_default_provider(provider.id, db_session)
update_default_provider(provider.id, expected_default_model, db_session)
# Step 5: Fetch the default provider and verify
default_provider = fetch_default_provider(db_session)
assert default_provider is not None, "Default provider should exist"
default_model = fetch_default_llm_model(db_session)
assert default_model is not None, "Default provider should exist"
assert (
default_provider.name == provider_name
default_model.llm_provider.name == provider_name
), "Default provider should be our test provider"
assert (
default_provider.default_model_name == expected_default_model
default_model.name == expected_default_model
), f"Default provider's default model should be '{expected_default_model}'"
assert (
default_provider.is_auto_mode is True
default_model.llm_provider.is_auto_mode is True
), "Default provider should be in auto mode"
finally:
@@ -235,7 +229,6 @@ class TestAutoModeSyncFeature:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
default_model_name="gpt-4o",
model_configurations=[],
),
is_creation=True,
@@ -314,7 +307,6 @@ class TestAutoModeSyncFeature:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=False, # Not in auto mode initially
default_model_name="gpt-4",
model_configurations=initial_models,
),
is_creation=True,
@@ -346,7 +338,6 @@ class TestAutoModeSyncFeature:
api_key=None, # Not changing API key
api_key_changed=False,
is_auto_mode=True, # Now enabling auto mode
default_model_name=auto_mode_default,
model_configurations=[], # Auto mode will sync from config
),
is_creation=False, # This is an update
@@ -385,9 +376,6 @@ class TestAutoModeSyncFeature:
model_visibility[model_name] is False
), f"Model '{model_name}' not in auto config should NOT be visible"
# Verify the default model was updated
assert provider.default_model_name == auto_mode_default
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)
@@ -429,7 +417,6 @@ class TestAutoModeSyncFeature:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
default_model_name="gpt-4o",
model_configurations=[],
),
is_creation=True,
@@ -532,7 +519,6 @@ class TestAutoModeSyncFeature:
api_key=provider_1_api_key,
api_key_changed=True,
is_auto_mode=True,
default_model_name=provider_1_default_model,
model_configurations=[],
),
is_creation=True,
@@ -546,7 +532,7 @@ class TestAutoModeSyncFeature:
name=provider_1_name, db_session=db_session
)
assert provider_1 is not None
update_default_provider(provider_1.id, db_session)
update_default_provider(provider_1.id, provider_1_default_model, db_session)
with patch(
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
@@ -560,7 +546,6 @@ class TestAutoModeSyncFeature:
api_key=provider_2_api_key,
api_key_changed=True,
is_auto_mode=True,
default_model_name=provider_2_default_model,
model_configurations=[],
),
is_creation=True,
@@ -570,26 +555,26 @@ class TestAutoModeSyncFeature:
# Step 3: Verify provider 1 is still the default
db_session.expire_all()
default_provider = fetch_default_provider(db_session)
assert default_provider is not None
assert default_provider.name == provider_1_name
assert default_provider.default_model_name == provider_1_default_model
assert default_provider.is_auto_mode is True
default_model = fetch_default_llm_model(db_session)
assert default_model is not None
assert default_model.llm_provider.name == provider_1_name
assert default_model.name == provider_1_default_model
assert default_model.llm_provider.is_auto_mode is True
# Step 4: Change the default to provider 2
provider_2 = fetch_existing_llm_provider(
name=provider_2_name, db_session=db_session
)
assert provider_2 is not None
update_default_provider(provider_2.id, db_session)
update_default_provider(provider_2.id, provider_2_default_model, db_session)
# Step 5: Verify provider 2 is now the default
db_session.expire_all()
default_provider = fetch_default_provider(db_session)
assert default_provider is not None
assert default_provider.name == provider_2_name
assert default_provider.default_model_name == provider_2_default_model
assert default_provider.is_auto_mode is True
default_model = fetch_default_llm_model(db_session)
assert default_model is not None
assert default_model.llm_provider.name == provider_2_name
assert default_model.name == provider_2_default_model
assert default_model.llm_provider.is_auto_mode is True
# Step 6: Run test_default_provider and verify it uses provider 2's model
with patch(

View File

@@ -5,6 +5,11 @@ from unittest.mock import Mock
from unittest.mock import patch
from uuid import uuid4
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
# Set environment variables to disable model server for testing
os.environ["DISABLE_MODEL_SERVER"] = "true"
os.environ["MODEL_SERVER_HOST"] = "disabled"
@@ -403,9 +408,13 @@ class TestSlackBotFederatedSearch:
def _setup_llm_provider(self, db_session: Session) -> None:
"""Create a default LLM provider in the database for testing with real API key"""
# Delete any existing default LLM provider to ensure clean state
# Use SQL-level delete to properly trigger ON DELETE CASCADE
# (ORM-level delete tries to set foreign keys to NULL instead)
from sqlalchemy import delete
existing_providers = db_session.query(LLMProvider).all()
for provider in existing_providers:
db_session.delete(provider)
db_session.execute(delete(LLMProvider).where(LLMProvider.id == provider.id))
db_session.commit()
api_key = os.getenv("OPENAI_API_KEY")
@@ -414,16 +423,25 @@ class TestSlackBotFederatedSearch:
"OPENAI_API_KEY environment variable not set - test requires real API key"
)
llm_provider = LLMProvider(
name=f"test-llm-provider-{uuid4().hex[:8]}",
provider=LlmProviderNames.OPENAI,
api_key=api_key,
default_model_name="gpt-4o",
is_default_provider=True,
is_public=True,
provider_view = upsert_llm_provider(
LLMProviderUpsertRequest(
name=f"test-llm-provider-{uuid4().hex[:8]}",
provider=LlmProviderNames.OPENAI,
api_key=api_key,
is_public=True,
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o",
is_visible=True,
max_input_tokens=None,
display_name="gpt-4o",
),
],
),
db_session=db_session,
)
db_session.add(llm_provider)
db_session.commit()
update_default_provider(provider_view.id, "gpt-4o", db_session)
def _teardown_common_mocks(self, patches: list) -> None:
"""Stop all patches"""

View File

@@ -4,8 +4,10 @@ from uuid import uuid4
import requests
from onyx.llm.constants import LlmProviderNames
from onyx.server.manage.llm.models import DefaultModel
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestLLMProvider
@@ -36,7 +38,6 @@ class LLMProviderManager:
llm_provider = LLMProviderUpsertRequest(
name=name or f"test-provider-{uuid4()}",
provider=provider or LlmProviderNames.OPENAI,
default_model_name=default_model_name or "gpt-4o-mini",
api_key=api_key or os.environ["OPENAI_API_KEY"],
api_base=api_base,
api_version=api_version,
@@ -44,7 +45,15 @@ class LLMProviderManager:
is_public=True if is_public is None else is_public,
groups=groups or [],
personas=personas or [],
model_configurations=[],
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name or "gpt-4o-mini",
is_visible=True,
max_input_tokens=None,
display_name=default_model_name or "gpt-4o-mini",
supports_image_input=True,
)
],
api_key_changed=True,
)
@@ -125,7 +134,12 @@ class LLMProviderManager:
user_performing_action: DATestUser | None = None,
) -> None:
all_llm_providers = LLMProviderManager.get_all(user_performing_action)
default_model = LLMProviderManager.get_default_model(user_performing_action)
for fetched_llm_provider in all_llm_providers:
model_names = [
model.name for model in fetched_llm_provider.model_configurations
]
if llm_provider.id == fetched_llm_provider.id:
if verify_deleted:
raise ValueError(
@@ -138,11 +152,25 @@ class LLMProviderManager:
if (
fetched_llm_groups == llm_provider_groups
and llm_provider.provider == fetched_llm_provider.provider
and llm_provider.default_model_name
== fetched_llm_provider.default_model_name
and default_model.model_name in model_names
and llm_provider.is_public == fetched_llm_provider.is_public
and set(fetched_llm_provider.personas) == set(llm_provider.personas)
):
return
if not verify_deleted:
raise ValueError(f"LLM Provider {llm_provider.id} not found")
@staticmethod
def get_default_model(
user_performing_action: DATestUser | None = None,
) -> DefaultModel:
response = requests.get(
f"{API_SERVER_URL}/admin/llm/default",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return DefaultModel(**response.json())

View File

@@ -79,6 +79,7 @@ export { default as SvgHourglass } from "@opal/icons/hourglass";
export { default as SvgImage } from "@opal/icons/image";
export { default as SvgImageSmall } from "@opal/icons/image-small";
export { default as SvgImport } from "@opal/icons/import";
export { default as SvgInfo } from "@opal/icons/info";
export { default as SvgInfoSmall } from "@opal/icons/info-small";
export { default as SvgKey } from "@opal/icons/key";
export { default as SvgKeystroke } from "@opal/icons/keystroke";

View File

@@ -0,0 +1,20 @@
import type { IconProps } from "@opal/types";
const SvgInfo = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
<path
d="M8.00001 10.6666V7.99998M8.00001 5.33331H8.00668M14.6667 7.99998C14.6667 11.6819 11.6819 14.6666 8.00001 14.6666C4.31811 14.6666 1.33334 11.6819 1.33334 7.99998C1.33334 4.31808 4.31811 1.33331 8.00001 1.33331C11.6819 1.33331 14.6667 4.31808 14.6667 7.99998Z"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgInfo;

View File

@@ -6,7 +6,7 @@ import { Callout } from "@/components/ui/callout";
import Text from "@/refresh-components/texts/Text";
import Title from "@/components/ui/title";
import { ThreeDotsLoader } from "@/components/Loading";
import { LLMProviderView } from "./interfaces";
import { LLMProviderResponse, LLMProviderView } from "./interfaces";
import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
import { OpenAIForm } from "./forms/OpenAIForm";
import { AnthropicForm } from "./forms/AnthropicForm";
@@ -17,44 +17,59 @@ import { VertexAIForm } from "./forms/VertexAIForm";
import { OpenRouterForm } from "./forms/OpenRouterForm";
import { getFormForExistingProvider } from "./forms/getForm";
import { CustomForm } from "./forms/CustomForm";
import { DefaultModelSelector } from "./forms/components/DefaultModel";
import * as GeneralLayouts from "@/layouts/general-layouts";
import Separator from "@/refresh-components/Separator";
import { useLlmManager } from "@/lib/hooks";
export function LLMConfiguration() {
const { data: existingLlmProviders } = useSWR<LLMProviderView[]>(
LLM_PROVIDERS_ADMIN_URL,
errorHandlingFetcher
);
const { data: existingLLMProvidersResponse } = useSWR<
LLMProviderResponse<LLMProviderView>
>(LLM_PROVIDERS_ADMIN_URL, errorHandlingFetcher);
if (!existingLlmProviders) {
if (!existingLLMProvidersResponse) {
return <ThreeDotsLoader />;
}
const existingLlmProviders = existingLLMProvidersResponse.providers;
const defaultProviderId =
existingLLMProvidersResponse.default_text?.provider_id;
const defaultLlmModel =
existingLLMProvidersResponse.default_text ?? undefined;
const isFirstProvider = existingLlmProviders.length === 0;
const { updateDefaultLlmModel } = useLlmManager();
return (
<>
<Title className="mb-2">Enabled LLM Providers</Title>
{existingLlmProviders.length > 0 ? (
<>
<Text as="p" className="mb-4">
If multiple LLM providers are enabled, the default provider will be
used for all &quot;Default&quot; Assistants. For user-created
Assistants, you can select the LLM provider/model that best fits the
use case!
</Text>
<div className="flex flex-col gap-y-4">
<DefaultModelSelector
existingLlmProviders={existingLlmProviders}
defaultLlmModel={defaultLlmModel ?? null}
onModelChange={(provider_id, model_name) =>
updateDefaultLlmModel({ provider_id, model_name })
}
/>
<GeneralLayouts.Section
flexDirection="column"
justifyContent="start"
alignItems="start"
>
<Text headingH3>Available Providers</Text>
{[...existingLlmProviders]
.sort((a, b) => {
if (a.is_default_provider && !b.is_default_provider) return -1;
if (!a.is_default_provider && b.is_default_provider) return 1;
if (a.id === defaultProviderId && b.id !== defaultProviderId)
return -1;
if (a.id !== defaultProviderId && b.id === defaultProviderId)
return 1;
return 0;
})
.map((llmProvider) => (
<div key={llmProvider.id}>
{getFormForExistingProvider(llmProvider)}
</div>
))}
</div>
.map((llmProvider) => getFormForExistingProvider(llmProvider))}
</GeneralLayouts.Section>
<Separator />
</>
) : (
<Callout type="warning" title="No LLM providers configured yet">
@@ -62,19 +77,25 @@ export function LLMConfiguration() {
</Callout>
)}
<Title className="mb-2 mt-6">Add LLM Provider</Title>
<Text as="p" className="mb-4">
Add a new LLM provider by either selecting from one of the default
providers or by specifying your own custom LLM provider.
</Text>
<GeneralLayouts.Section
flexDirection="column"
justifyContent="start"
alignItems="start"
gap={0}
>
<Text headingH3>Add Provider</Text>
<Text as="p" secondaryBody text03>
Onyx supports both popular providers and self-hosted models.
</Text>
</GeneralLayouts.Section>
<div className="flex flex-col gap-y-4">
<div className="grid grid-cols-2 gap-4">
<OpenAIForm shouldMarkAsDefault={isFirstProvider} />
<AnthropicForm shouldMarkAsDefault={isFirstProvider} />
<OllamaForm shouldMarkAsDefault={isFirstProvider} />
<VertexAIForm shouldMarkAsDefault={isFirstProvider} />
<AzureForm shouldMarkAsDefault={isFirstProvider} />
<BedrockForm shouldMarkAsDefault={isFirstProvider} />
<VertexAIForm shouldMarkAsDefault={isFirstProvider} />
<OpenRouterForm shouldMarkAsDefault={isFirstProvider} />
<CustomForm shouldMarkAsDefault={isFirstProvider} />

View File

@@ -12,6 +12,7 @@ export const ProviderIcon = ({
size = 16,
className = defaultTailwindCSS,
}: ProviderIconProps) => {
console.log(provider);
const Icon = getProviderIcon(provider, modelName);
return <Icon size={size} className={className} />;
};

View File

@@ -1,4 +1,5 @@
export const LLM_PROVIDERS_ADMIN_URL = "/api/admin/llm/provider";
export const LLM_ADMIN_URL = "/api/admin/llm";
export const LLM_PROVIDERS_ADMIN_URL = `${LLM_ADMIN_URL}/provider`;
export const LLM_CONTEXTUAL_COST_ADMIN_URL =
"/api/admin/llm/provider-contextual-cost";

View File

@@ -17,17 +17,22 @@ import {
} from "./formUtils";
import { AdvancedOptions } from "./components/AdvancedOptions";
import { DisplayModels } from "./components/DisplayModels";
import Separator from "@/refresh-components/Separator";
import InputWrapper from "./components/InputWrapper";
export const ANTHROPIC_PROVIDER_NAME = "anthropic";
const DEFAULT_DEFAULT_MODEL_NAME = "claude-sonnet-4-5";
export function AnthropicForm({
existingLlmProvider,
defaultLlmModel,
shouldMarkAsDefault,
}: LLMProviderFormProps) {
return (
<ProviderFormEntrypointWrapper
providerName="Anthropic"
providerDisplayName={existingLlmProvider?.name ?? "Claude"}
providerInternalName="anthropic"
providerEndpoint={ANTHROPIC_PROVIDER_NAME}
existingLlmProvider={existingLlmProvider}
>
@@ -46,19 +51,28 @@ export function AnthropicForm({
existingLlmProvider,
wellKnownLLMProvider
);
const isAutoMode = existingLlmProvider?.is_auto_mode ?? true;
const autoModelDefault =
wellKnownLLMProvider?.recommended_default_model?.name ??
DEFAULT_DEFAULT_MODEL_NAME;
const defaultModel = shouldMarkAsDefault
? isAutoMode
? autoModelDefault
: defaultLlmModel?.model_name ?? DEFAULT_DEFAULT_MODEL_NAME
: undefined;
const initialValues = {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations
modelConfigurations,
defaultModel
),
api_key: existingLlmProvider?.api_key ?? "",
api_base: existingLlmProvider?.api_base ?? "",
default_model_name:
existingLlmProvider?.default_model_name ??
wellKnownLLMProvider?.recommended_default_model?.name ??
DEFAULT_DEFAULT_MODEL_NAME,
// Default to auto mode for new Anthropic providers
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,
is_auto_mode: isAutoMode,
};
const validationSchema = buildDefaultValidationSchema().shape({
@@ -92,16 +106,26 @@ export function AnthropicForm({
{(formikProps) => {
return (
<Form className={LLM_FORM_CLASS_NAME}>
<InputWrapper
label="API Key"
description="Paste your {link} from Anthropic to access your models."
descriptionLink={{
text: "API key",
href: "https://console.anthropic.com/dashboard",
}}
>
<PasswordInputTypeInField name="api_key" />
</InputWrapper>
<Separator noPadding />
<DisplayNameField disabled={!!existingLlmProvider} />
<PasswordInputTypeInField name="api_key" label="API Key" />
<Separator noPadding />
<DisplayModels
modelConfigurations={modelConfigurations}
formikProps={formikProps}
recommendedDefaultModel={
wellKnownLLMProvider?.recommended_default_model ?? null
}
shouldShowAutoUpdateToggle={true}
/>

View File

@@ -18,15 +18,16 @@ import {
LLM_FORM_CLASS_NAME,
} from "./formUtils";
import { AdvancedOptions } from "./components/AdvancedOptions";
import { SingleDefaultModelField } from "./components/SingleDefaultModelField";
import {
isValidAzureTargetUri,
parseAzureTargetUri,
} from "@/lib/azureTargetUri";
import Separator from "@/refresh-components/Separator";
import { DisplayModels } from "./components/DisplayModels";
import InputWrapper from "./components/InputWrapper";
export const AZURE_PROVIDER_NAME = "azure";
const AZURE_DISPLAY_NAME = "Microsoft Azure Cloud";
const AZURE_DISPLAY_NAME = "Azure OpenAI";
interface AzureFormValues extends BaseLLMFormValues {
api_key: string;
@@ -53,7 +54,9 @@ export function AzureForm({
}: LLMProviderFormProps) {
return (
<ProviderFormEntrypointWrapper
providerName={AZURE_DISPLAY_NAME}
providerName={"Microsoft Azure"}
providerDisplayName={existingLlmProvider?.name ?? AZURE_DISPLAY_NAME}
providerInternalName={"azure"}
providerEndpoint={AZURE_PROVIDER_NAME}
existingLlmProvider={existingLlmProvider}
>
@@ -138,19 +141,41 @@ export function AzureForm({
{(formikProps) => {
return (
<Form className={LLM_FORM_CLASS_NAME}>
<DisplayNameField disabled={!!existingLlmProvider} />
<PasswordInputTypeInField name="api_key" label="API Key" />
<TextFormField
name="target_uri"
<InputWrapper
label="Target URI"
placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview"
subtext="The complete target URI for your deployment from the Azure AI portal."
/>
description="Paste your endpoint target URI from Azure OpenAI (including API endpoint base, deployment name, and API version)."
>
<TextFormField
name="target_uri"
label=""
placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview"
/>
</InputWrapper>
<InputWrapper
label="API Key"
description="Paste your API key from {link} to access your models."
descriptionLink={{
text: "Azure OpenAI",
href: "https://oai.azure.com",
}}
>
<PasswordInputTypeInField name="api_key" placeholder="" />
</InputWrapper>
<Separator />
<SingleDefaultModelField placeholder="E.g. gpt-4o" />
<DisplayNameField disabled={!!existingLlmProvider} />
<Separator noPadding />
<DisplayModels
modelConfigurations={modelConfigurations}
formikProps={formikProps}
shouldShowAutoUpdateToggle={false}
noModelConfigurationsMessage="No models available. Provide a valid base URL or key."
/>
<Separator />
<AdvancedOptions formikProps={formikProps} />

View File

@@ -2,8 +2,10 @@
import { useState, useEffect } from "react";
import { Form, Formik, FormikProps } from "formik";
import { SelectorFormField, TextFormField } from "@/components/Field";
import { TextFormField } from "@/components/Field";
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
import InputSelectField from "@/refresh-components/form/InputSelectField";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import {
LLMProviderFormProps,
LLMProviderView,
@@ -16,7 +18,6 @@ import {
} from "./components/FormWrapper";
import { DisplayNameField } from "./components/DisplayNameField";
import { FormActionButtons } from "./components/FormActionButtons";
import { FetchModelsButton } from "./components/FetchModelsButton";
import {
buildDefaultInitialValues,
buildDefaultValidationSchema,
@@ -27,14 +28,16 @@ import {
} from "./formUtils";
import { AdvancedOptions } from "./components/AdvancedOptions";
import { DisplayModels } from "./components/DisplayModels";
import { fetchBedrockModels } from "../utils";
import Separator from "@/refresh-components/Separator";
import Text from "@/refresh-components/texts/Text";
import Tabs from "@/refresh-components/Tabs";
import { cn } from "@/lib/utils";
import { Card } from "@/refresh-components/cards";
import { SvgInfo } from "@opal/icons";
import * as GeneralLayouts from "@/layouts/general-layouts";
import InputWrapper from "./components/InputWrapper";
export const BEDROCK_PROVIDER_NAME = "bedrock";
const BEDROCK_DISPLAY_NAME = "AWS Bedrock";
const BEDROCK_DISPLAY_NAME = "Amazon Bedrock";
// AWS Bedrock regions - kept in sync with backend
const AWS_REGION_OPTIONS = [
@@ -58,6 +61,13 @@ const AUTH_METHOD_IAM = "iam";
const AUTH_METHOD_ACCESS_KEY = "access_key";
const AUTH_METHOD_LONG_TERM_API_KEY = "long_term_api_key";
// Auth method options
const AUTH_METHOD_OPTIONS = [
{ name: "IAM Role", value: AUTH_METHOD_IAM },
{ name: "Access Key", value: AUTH_METHOD_ACCESS_KEY },
{ name: "Long-term API Key", value: AUTH_METHOD_LONG_TERM_API_KEY },
];
// Field name constants
const FIELD_AWS_REGION_NAME = "custom_config.AWS_REGION_NAME";
const FIELD_BEDROCK_AUTH_METHOD = "custom_config.BEDROCK_AUTH_METHOD";
@@ -135,99 +145,96 @@ function BedrockFormInternals({
const isFetchDisabled =
!formikProps.values.custom_config?.AWS_REGION_NAME || !isAuthComplete;
const authDisplay = () => {
if (authMethod === AUTH_METHOD_IAM) {
return (
<>
<Card variant="secondary">
<GeneralLayouts.Section
flexDirection="row"
alignItems="center"
justifyContent="start"
>
<SvgInfo className="w-4 h-4" />
<Text as="p" text03>
Onyx will use the IAM role attached to the environment it's
running in to authenticate.
</Text>
</GeneralLayouts.Section>
</Card>
</>
);
} else if (authMethod === AUTH_METHOD_ACCESS_KEY) {
return (
<>
<TextFormField
name={FIELD_AWS_ACCESS_KEY_ID}
label="Access Key ID"
placeholder=""
/>
<PasswordInputTypeInField
name={FIELD_AWS_SECRET_ACCESS_KEY}
label="Secret Access Key"
placeholder=""
/>
</>
);
} else if (authMethod === AUTH_METHOD_LONG_TERM_API_KEY) {
return (
<>
<PasswordInputTypeInField
name={FIELD_AWS_BEARER_TOKEN_BEDROCK}
label="Long-term API Key"
placeholder=""
/>
</>
);
}
};
return (
<Form className={cn(LLM_FORM_CLASS_NAME, "w-full")}>
<DisplayNameField disabled={!!existingLlmProvider} />
<SelectorFormField
name={FIELD_AWS_REGION_NAME}
label="AWS Region"
subtext="Region where your Amazon Bedrock models are hosted."
options={AWS_REGION_OPTIONS}
/>
<div>
<Text as="p" mainUiAction>
Authentication Method
</Text>
<InputSelectField name={FIELD_AWS_REGION_NAME}>
<Text as="p">AWS Region Name</Text>
<InputSelect.Trigger placeholder="Select AWS region" />
<InputSelect.Content>
{AWS_REGION_OPTIONS.map((option) => (
<InputSelect.Item key={option.value} value={option.value}>
{option.name}
</InputSelect.Item>
))}
</InputSelect.Content>
<Text as="p" secondaryBody text03>
Choose how Onyx should authenticate with Bedrock.
Region where your Amazon Bedrock models are hosted. See full list of
regions supported at AWS.
</Text>
<Tabs
value={authMethod || AUTH_METHOD_ACCESS_KEY}
onValueChange={(value) =>
formikProps.setFieldValue(FIELD_BEDROCK_AUTH_METHOD, value)
}
</InputSelectField>
<InputSelectField name={FIELD_BEDROCK_AUTH_METHOD}>
<InputWrapper
label="Authentication Method"
description="See {link} for more instructions."
descriptionLink={{
text: "documentation",
href: `https://docs.onyx.app/admin/ai_models/bedrock#authentication-methods`,
}}
>
<Tabs.List>
<Tabs.Trigger value={AUTH_METHOD_IAM}>IAM Role</Tabs.Trigger>
<Tabs.Trigger value={AUTH_METHOD_ACCESS_KEY}>
Access Key
</Tabs.Trigger>
<Tabs.Trigger value={AUTH_METHOD_LONG_TERM_API_KEY}>
Long-term API Key
</Tabs.Trigger>
</Tabs.List>
<InputSelect.Trigger placeholder="Select authentication method" />
<InputSelect.Content>
{AUTH_METHOD_OPTIONS.map((option) => (
<InputSelect.Item key={option.value} value={option.value}>
{option.name}
</InputSelect.Item>
))}
</InputSelect.Content>
</InputWrapper>
</InputSelectField>
<Tabs.Content value={AUTH_METHOD_IAM}>
<Text as="p" text03>
Uses the IAM role attached to your AWS environment. Recommended
for EC2, ECS, Lambda, or other AWS services.
</Text>
</Tabs.Content>
{authDisplay()}
<Tabs.Content value={AUTH_METHOD_ACCESS_KEY}>
<div className="flex flex-col gap-4 w-full">
<TextFormField
name={FIELD_AWS_ACCESS_KEY_ID}
label="AWS Access Key ID"
placeholder="AKIAIOSFODNN7EXAMPLE"
/>
<PasswordInputTypeInField
name={FIELD_AWS_SECRET_ACCESS_KEY}
label="AWS Secret Access Key"
placeholder="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
/>
</div>
</Tabs.Content>
<Separator noPadding />
<Tabs.Content value={AUTH_METHOD_LONG_TERM_API_KEY}>
<div className="flex flex-col gap-4 w-full">
<PasswordInputTypeInField
name={FIELD_AWS_BEARER_TOKEN_BEDROCK}
label="AWS Bedrock Long-term API Key"
placeholder="Your long-term API key"
/>
</div>
</Tabs.Content>
</Tabs>
</div>
<FetchModelsButton
onFetch={() =>
fetchBedrockModels({
aws_region_name:
formikProps.values.custom_config?.AWS_REGION_NAME ?? "",
aws_access_key_id:
formikProps.values.custom_config?.AWS_ACCESS_KEY_ID,
aws_secret_access_key:
formikProps.values.custom_config?.AWS_SECRET_ACCESS_KEY,
aws_bearer_token_bedrock:
formikProps.values.custom_config?.AWS_BEARER_TOKEN_BEDROCK,
provider_name: existingLlmProvider?.name,
})
}
isDisabled={isFetchDisabled}
disabledHint={
!formikProps.values.custom_config?.AWS_REGION_NAME
? "Select an AWS region."
: !isAuthComplete
? 'Complete the "Authentication Method" section.'
: undefined
}
onModelsFetched={setFetchedModels}
autoFetchOnInitialLoad={!!existingLlmProvider}
/>
<DisplayNameField />
<Separator />
@@ -235,10 +242,8 @@ function BedrockFormInternals({
modelConfigurations={currentModels}
formikProps={formikProps}
noModelConfigurationsMessage={
"Fetch available models first, then you'll be able to select " +
"the models you want to make available in Onyx."
"No models available. Provide a valid region and key."
}
recommendedDefaultModel={null}
shouldShowAutoUpdateToggle={false}
/>
@@ -266,7 +271,9 @@ export function BedrockForm({
return (
<ProviderFormEntrypointWrapper
providerName={BEDROCK_DISPLAY_NAME}
providerName={"AWS"}
providerDisplayName={existingLlmProvider?.name ?? BEDROCK_DISPLAY_NAME}
providerInternalName={"bedrock"}
existingLlmProvider={existingLlmProvider}
>
{({

View File

@@ -326,9 +326,13 @@ describe("Custom LLM Provider Configuration Workflow", () => {
// Verify set as default API was called
await waitFor(() => {
expect(fetchSpy).toHaveBeenCalledWith(
"/api/admin/llm/provider/5/default",
"/api/admin/llm/provider/default",
expect.objectContaining({
method: "POST",
body: JSON.stringify({
provider_id: 5,
model_name: "gpt-4",
}),
})
);
});

View File

@@ -43,10 +43,9 @@ export function CustomForm({
}: LLMProviderFormProps) {
return (
<ProviderFormEntrypointWrapper
providerName="Custom LLM"
providerName="LiteLLM"
providerDisplayName={existingLlmProvider?.name ?? "LiteLLM Models"}
existingLlmProvider={existingLlmProvider}
buttonMode={!existingLlmProvider}
buttonText="Add Custom LLM Provider"
>
{({
onClose,

View File

@@ -25,10 +25,17 @@ import { AdvancedOptions } from "./components/AdvancedOptions";
import { DisplayModels } from "./components/DisplayModels";
import { useEffect, useState } from "react";
import { fetchOllamaModels } from "../utils";
import Tabs from "@/refresh-components/Tabs";
import Separator from "@/refresh-components/Separator";
export const OLLAMA_PROVIDER_NAME = "ollama_chat";
const DEFAULT_API_BASE = "http://127.0.0.1:11434";
enum OllamaHostingTab {
SelfHosted = "self-hosted",
Cloud = "cloud",
}
interface OllamaFormValues extends BaseLLMFormValues {
api_base: string;
custom_config: {
@@ -60,6 +67,9 @@ function OllamaFormContent({
isFormValid,
}: OllamaFormContentProps) {
const [isLoadingModels, setIsLoadingModels] = useState(true);
const [activeTab, setActiveTab] = useState<OllamaHostingTab>(
OllamaHostingTab.SelfHosted
);
useEffect(() => {
if (formikProps.values.api_base) {
@@ -93,27 +103,48 @@ function OllamaFormContent({
return (
<Form className={LLM_FORM_CLASS_NAME}>
<DisplayNameField disabled={!!existingLlmProvider} />
<Tabs
value={activeTab}
onValueChange={(value) => setActiveTab(value as OllamaHostingTab)}
>
<Tabs.List>
<Tabs.Trigger value={OllamaHostingTab.SelfHosted}>
Self-hosted Ollama
</Tabs.Trigger>
<Tabs.Trigger value={OllamaHostingTab.Cloud}>
Ollama Cloud
</Tabs.Trigger>
</Tabs.List>
<TextFormField
name="api_base"
label="API Base URL"
subtext="The base URL for your Ollama instance (e.g., http://127.0.0.1:11434)"
placeholder={DEFAULT_API_BASE}
/>
<Tabs.Content value={OllamaHostingTab.SelfHosted}>
<TextFormField
name="api_base"
label="API Base URL"
subtext="Your self-hosted Ollama API base URL."
placeholder="Your Ollama API base URL."
/>
</Tabs.Content>
<PasswordInputTypeInField
name="custom_config.OLLAMA_API_KEY"
label="API Key (Optional)"
subtext="Optional API key for Ollama Cloud (https://ollama.com). Leave blank for local instances."
/>
<Tabs.Content value={OllamaHostingTab.Cloud}>
<TextFormField
name="api_key"
label="API Key"
subtext="Paste your API key from Ollama Cloud to access your models."
/>
</Tabs.Content>
</Tabs>
<Separator noPadding />
<DisplayNameField />
<Separator noPadding />
<DisplayModels
modelConfigurations={currentModels}
formikProps={formikProps}
noModelConfigurationsMessage="No models found. Please provide a valid API base URL."
noModelConfigurationsMessage="No models found. Please provide a valid base URL or key."
isLoading={isLoadingModels}
recommendedDefaultModel={null}
shouldShowAutoUpdateToggle={false}
/>
@@ -140,6 +171,7 @@ export function OllamaForm({
return (
<ProviderFormEntrypointWrapper
providerName="Ollama"
providerDisplayName={existingLlmProvider?.name ?? "Ollama"}
existingLlmProvider={existingLlmProvider}
>
{({

View File

@@ -1,31 +1,34 @@
import { Form, Formik } from "formik";
import * as Yup from "yup";
import { LLMProviderFormProps } from "../interfaces";
import * as Yup from "yup";
import { ProviderFormEntrypointWrapper } from "./components/FormWrapper";
import { DisplayNameField } from "./components/DisplayNameField";
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
import { FormActionButtons } from "./components/FormActionButtons";
import {
buildDefaultInitialValues,
buildDefaultValidationSchema,
buildAvailableModelConfigurations,
submitLLMProvider,
LLM_FORM_CLASS_NAME,
} from "./formUtils";
import { AdvancedOptions } from "./components/AdvancedOptions";
import { DisplayModels } from "./components/DisplayModels";
import LLMFormLayout from "./components/FormLayout";
import Separator from "@/refresh-components/Separator";
import InputWrapper from "./components/InputWrapper";
export const OPENAI_PROVIDER_NAME = "openai";
const DEFAULT_DEFAULT_MODEL_NAME = "gpt-5.2";
export function OpenAIForm({
existingLlmProvider,
defaultLlmModel,
shouldMarkAsDefault,
}: LLMProviderFormProps) {
return (
<ProviderFormEntrypointWrapper
providerName="OpenAI"
providerDisplayName={existingLlmProvider?.name ?? "OpenAI"}
providerEndpoint={OPENAI_PROVIDER_NAME}
existingLlmProvider={existingLlmProvider}
>
@@ -44,18 +47,27 @@ export function OpenAIForm({
existingLlmProvider,
wellKnownLLMProvider
);
const isAutoMode = existingLlmProvider?.is_auto_mode ?? true;
const autoModelDefault =
wellKnownLLMProvider?.recommended_default_model?.name ??
DEFAULT_DEFAULT_MODEL_NAME;
// We use a default model if we're editting and this provider is the global default
// Or we are creating the first provider (and shouldMarkAsDefault is true)
const defaultModel = shouldMarkAsDefault
? isAutoMode
? autoModelDefault
: defaultLlmModel?.model_name ?? DEFAULT_DEFAULT_MODEL_NAME
: undefined;
const initialValues = {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations
modelConfigurations,
defaultModel
),
api_key: existingLlmProvider?.api_key ?? "",
default_model_name:
existingLlmProvider?.default_model_name ??
wellKnownLLMProvider?.recommended_default_model?.name ??
DEFAULT_DEFAULT_MODEL_NAME,
// Default to auto mode for new OpenAI providers
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,
is_auto_mode: isAutoMode,
};
const validationSchema = buildDefaultValidationSchema().shape({
@@ -86,35 +98,47 @@ export function OpenAIForm({
});
}}
>
{(formikProps) => {
return (
<Form className={LLM_FORM_CLASS_NAME}>
{(formikProps) => (
<Form>
<LLMFormLayout.Body>
<InputWrapper
label="API Key"
description="Paste your {link} from OpenAI to access your models."
descriptionLink={{
text: "API key",
href: "https://platform.openai.com/api-keys",
}}
>
<PasswordInputTypeInField
name="api_key"
subtext="Paste your API key from OpenAI to access your models."
/>
</InputWrapper>
<Separator />
<DisplayNameField disabled={!!existingLlmProvider} />
<PasswordInputTypeInField name="api_key" label="API Key" />
<Separator />
<DisplayModels
modelConfigurations={modelConfigurations}
formikProps={formikProps}
recommendedDefaultModel={
wellKnownLLMProvider?.recommended_default_model ?? null
}
shouldShowAutoUpdateToggle={true}
/>
<AdvancedOptions formikProps={formikProps} />
</LLMFormLayout.Body>
<FormActionButtons
isTesting={isTesting}
testError={testError}
existingLlmProvider={existingLlmProvider}
mutate={mutate}
onClose={onClose}
isFormValid={formikProps.isValid}
/>
</Form>
);
}}
<LLMFormLayout.Footer
onCancel={onClose}
submitLabel={existingLlmProvider ? "Update" : "Enable"}
isSubmitting={isTesting}
isSubmitDisabled={!formikProps.isValid}
error={testError}
/>
</Form>
)}
</Formik>
</>
);

View File

@@ -26,6 +26,7 @@ import { AdvancedOptions } from "./components/AdvancedOptions";
import { DisplayModels } from "./components/DisplayModels";
import { FetchModelsButton } from "./components/FetchModelsButton";
import { useState } from "react";
import InputWrapper from "./components/InputWrapper";
export const OPENROUTER_PROVIDER_NAME = "openrouter";
const OPENROUTER_DISPLAY_NAME = "OpenRouter";
@@ -40,7 +41,6 @@ interface OpenRouterFormValues extends BaseLLMFormValues {
async function fetchOpenRouterModels(params: {
apiBase: string;
apiKey: string;
providerName?: string;
}): Promise<{ models: ModelConfiguration[]; error?: string }> {
if (!params.apiBase || !params.apiKey) {
return {
@@ -58,7 +58,6 @@ async function fetchOpenRouterModels(params: {
body: JSON.stringify({
api_base: params.apiBase,
api_key: params.apiKey,
provider_name: params.providerName,
}),
});
@@ -92,6 +91,7 @@ async function fetchOpenRouterModels(params: {
export function OpenRouterForm({
existingLlmProvider,
defaultLlmModel,
shouldMarkAsDefault,
}: LLMProviderFormProps) {
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
@@ -99,6 +99,7 @@ export function OpenRouterForm({
return (
<ProviderFormEntrypointWrapper
providerName={OPENROUTER_DISPLAY_NAME}
providerDisplayName={existingLlmProvider?.name ?? OPENROUTER_DISPLAY_NAME}
providerEndpoint={OPENROUTER_PROVIDER_NAME}
existingLlmProvider={existingLlmProvider}
>
@@ -117,10 +118,22 @@ export function OpenRouterForm({
existingLlmProvider,
wellKnownLLMProvider
);
const isAutoMode = existingLlmProvider?.is_auto_mode ?? true;
const autoModelDefault =
wellKnownLLMProvider?.recommended_default_model?.name;
const defaultModel = shouldMarkAsDefault
? isAutoMode
? autoModelDefault
: defaultLlmModel?.model_name
: undefined;
const initialValues: OpenRouterFormValues = {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations
modelConfigurations,
defaultModel
),
api_key: existingLlmProvider?.api_key ?? "",
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
@@ -170,36 +183,31 @@ export function OpenRouterForm({
return (
<Form className={LLM_FORM_CLASS_NAME}>
<DisplayNameField disabled={!!existingLlmProvider} />
<PasswordInputTypeInField name="api_key" label="API Key" />
<TextFormField
name="api_base"
<InputWrapper
label="API Base URL"
subtext="The base URL for OpenRouter API."
placeholder={DEFAULT_API_BASE}
/>
description="Paste your OpenRouter compatible endpoint URL or use OpenRouter API directly."
>
<TextFormField
name="api_base_url"
placeholder="https://openrouter.ai/api/v1"
label=""
/>
</InputWrapper>
<FetchModelsButton
onFetch={() =>
fetchOpenRouterModels({
apiBase: formikProps.values.api_base,
apiKey: formikProps.values.api_key,
providerName: existingLlmProvider?.name,
})
}
isDisabled={isFetchDisabled}
disabledHint={
!formikProps.values.api_key
? "Enter your API key first."
: !formikProps.values.api_base
? "Enter the API base URL."
: undefined
}
onModelsFetched={setFetchedModels}
autoFetchOnInitialLoad={!!existingLlmProvider}
/>
<InputWrapper
label="API Key"
description="Paste your API key from {link} to access your models."
descriptionLink={{
text: "OpenRouter",
href: "https://openrouter.ai/settings/keys",
}}
>
<PasswordInputTypeInField name="api_key" label="" />
</InputWrapper>
<Separator />
<DisplayNameField />
<Separator />
@@ -207,10 +215,8 @@ export function OpenRouterForm({
modelConfigurations={currentModels}
formikProps={formikProps}
noModelConfigurationsMessage={
"Fetch available models first, then you'll be able to select " +
"the models you want to make available in Onyx."
"No models available. Provide a valid base URL and key."
}
recommendedDefaultModel={null}
shouldShowAutoUpdateToggle={false}
/>

View File

@@ -1,5 +1,4 @@
import { Form, Formik } from "formik";
import { TextFormField, FileUploadFormField } from "@/components/Field";
import { LLMProviderFormProps } from "../interfaces";
import * as Yup from "yup";
import {
@@ -19,12 +18,53 @@ import {
import { AdvancedOptions } from "./components/AdvancedOptions";
import { DisplayModels } from "./components/DisplayModels";
import Separator from "@/refresh-components/Separator";
import LLMFormLayout from "./components/FormLayout";
import { FormField } from "@/refresh-components/form/FormField";
import InputFile from "@/refresh-components/inputs/InputFile";
import InputSelectField from "@/refresh-components/form/InputSelectField";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import Text from "@/refresh-components/texts/Text";
import { ModelAccessOptions } from "./components/ModelAccessOptions";
export const VERTEXAI_PROVIDER_NAME = "vertex_ai";
const VERTEXAI_DISPLAY_NAME = "Google Cloud Vertex AI";
const VERTEXAI_DEFAULT_MODEL = "gemini-2.5-pro";
const VERTEXAI_DEFAULT_LOCATION = "global";
const VERTEXAI_REGION_OPTIONS = [
{ name: "global", value: "global" },
{ name: "us-central1", value: "us-central1" },
{ name: "us-east1", value: "us-east1" },
{ name: "us-east4", value: "us-east4" },
{ name: "us-east5", value: "us-east5" },
{ name: "us-south1", value: "us-south1" },
{ name: "us-west1", value: "us-west1" },
{ name: "northamerica-northeast1", value: "northamerica-northeast1" },
{ name: "southamerica-east1", value: "southamerica-east1" },
{ name: "europe-west4", value: "europe-west4" },
{ name: "europe-west9", value: "europe-west9" },
{ name: "europe-west2", value: "europe-west2" },
{ name: "europe-west3", value: "europe-west3" },
{ name: "europe-west1", value: "europe-west1" },
{ name: "europe-west6", value: "europe-west6" },
{ name: "europe-southwest1", value: "europe-southwest1" },
{ name: "europe-west8", value: "europe-west8" },
{ name: "europe-north1", value: "europe-north1" },
{ name: "europe-central2", value: "europe-central2" },
{ name: "asia-northeast1", value: "asia-northeast1" },
{ name: "australia-southeast1", value: "australia-southeast1" },
{ name: "asia-southeast1", value: "asia-southeast1" },
{ name: "asia-northeast3", value: "asia-northeast3" },
{ name: "asia-east1", value: "asia-east1" },
{ name: "asia-east2", value: "asia-east2" },
{ name: "asia-south1", value: "asia-south1" },
{ name: "me-central2", value: "me-central2" },
{ name: "me-central1", value: "me-central1" },
{ name: "me-west1", value: "me-west1" },
];
const VERTEXAI_REGION_NAME = "custom_config.vertex_location";
interface VertexAIFormValues extends BaseLLMFormValues {
custom_config: {
vertex_credentials: string;
@@ -34,11 +74,14 @@ interface VertexAIFormValues extends BaseLLMFormValues {
export function VertexAIForm({
existingLlmProvider,
defaultLlmModel,
shouldMarkAsDefault,
}: LLMProviderFormProps) {
return (
<ProviderFormEntrypointWrapper
providerName={VERTEXAI_DISPLAY_NAME}
providerDisplayName={existingLlmProvider?.name ?? "Gemini"}
providerInternalName={"vertex_ai"}
providerEndpoint={VERTEXAI_PROVIDER_NAME}
existingLlmProvider={existingLlmProvider}
>
@@ -57,17 +100,26 @@ export function VertexAIForm({
existingLlmProvider,
wellKnownLLMProvider
);
const isAutoMode = existingLlmProvider?.is_auto_mode ?? true;
const autoModelDefault =
wellKnownLLMProvider?.recommended_default_model?.name ??
VERTEXAI_DEFAULT_MODEL;
const defaultModel = shouldMarkAsDefault
? isAutoMode
? autoModelDefault
: defaultLlmModel?.model_name ?? VERTEXAI_DEFAULT_MODEL
: undefined;
const initialValues: VertexAIFormValues = {
...buildDefaultInitialValues(
existingLlmProvider,
modelConfigurations
modelConfigurations,
defaultModel
),
default_model_name:
existingLlmProvider?.default_model_name ??
wellKnownLLMProvider?.recommended_default_model?.name ??
VERTEXAI_DEFAULT_MODEL,
// Default to auto mode for new Vertex AI providers
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,
is_auto_mode: isAutoMode,
custom_config: {
vertex_credentials:
(existingLlmProvider?.custom_config
@@ -129,34 +181,64 @@ export function VertexAIForm({
{(formikProps) => {
return (
<Form className={LLM_FORM_CLASS_NAME}>
<LLMFormLayout.Body>
<InputSelectField name={VERTEXAI_REGION_NAME}>
<Text as="p">Google Cloud Region Name</Text>
<InputSelect.Trigger placeholder="Select region" />
<InputSelect.Content>
{VERTEXAI_REGION_OPTIONS.map((option) => (
<InputSelect.Item
key={option.value}
value={option.value}
>
{option.name}
</InputSelect.Item>
))}
</InputSelect.Content>
</InputSelectField>
<FormField
name="custom_config.vertex_credentials"
state={
formikProps.errors.custom_config?.vertex_credentials
? "error"
: "idle"
}
>
<FormField.Label>API Key</FormField.Label>
<FormField.Control>
<InputFile
setValue={(value) =>
formikProps.setFieldValue(
"custom_config.vertex_credentials",
value
)
}
error={
!!formikProps.errors.custom_config
?.vertex_credentials
}
onBlur={formikProps.handleBlur}
showClearButton={true}
disabled={formikProps.isSubmitting}
accept="application/json"
placeholder="Vertex AI API KEY (JSON)"
/>
</FormField.Control>
</FormField>
</LLMFormLayout.Body>
<DisplayNameField disabled={!!existingLlmProvider} />
<FileUploadFormField
name="custom_config.vertex_credentials"
label="Credentials File"
subtext="Upload your Google Cloud service account JSON credentials file."
/>
<TextFormField
name="custom_config.vertex_location"
label="Location"
placeholder={VERTEXAI_DEFAULT_LOCATION}
subtext="The Google Cloud region for your Vertex AI models (e.g., global, us-east1, us-central1, europe-west1). See [Google's documentation](https://docs.cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#google_model_endpoint_locations) to find the appropriate region for your model."
optional
/>
<Separator />
<Separator noPadding />
<DisplayModels
modelConfigurations={modelConfigurations}
formikProps={formikProps}
recommendedDefaultModel={
wellKnownLLMProvider?.recommended_default_model ?? null
}
shouldShowAutoUpdateToggle={true}
/>
<AdvancedOptions formikProps={formikProps} />
<ModelAccessOptions />
<FormActionButtons
isTesting={isTesting}

View File

@@ -0,0 +1,86 @@
"use client";
import { Card } from "@/refresh-components/cards";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import { DefaultModelSelectorProps } from "../../interfaces";
import { Section } from "@/layouts/general-layouts";
import { useMemo } from "react";
import * as InputLayouts from "@/layouts/input-layouts";
import { setDefaultLlmModel } from "@/lib/admin/llm/svc";
export function DefaultModelSelector({
existingLlmProviders,
defaultLlmModel,
onModelChange,
}: DefaultModelSelectorProps) {
// Flatten all models from all providers into a single list
const models = useMemo(() => {
return existingLlmProviders.flatMap((provider) =>
provider.model_configurations
.filter((model) => model.is_visible)
.map((model) => ({
model_display_name: model.display_name || model.name,
model_name: model.name,
provider_display_name: provider.name || provider.provider,
provider_id: provider.id,
}))
);
}, [existingLlmProviders]);
const modelChangeHandler = (provider_id: number, model_name: string) => {
setDefaultLlmModel(provider_id, model_name);
onModelChange(provider_id, model_name);
};
// Create a composite value string for the select (provider_id:model_name)
// Fall back to the first model if no default is set
const currentValue = useMemo(() => {
if (defaultLlmModel) {
return `${defaultLlmModel.provider_id}:${defaultLlmModel.model_name}`;
}
const firstModel = models[0];
if (firstModel) {
return `${firstModel.provider_id}:${firstModel.model_name}`;
}
return undefined;
}, [defaultLlmModel, models]);
const handleValueChange = (value: string) => {
const separatorIndex = value.indexOf(":");
const providerId = parseInt(value.slice(0, separatorIndex), 10);
const modelName = value.slice(separatorIndex + 1);
modelChangeHandler(providerId, modelName);
};
return (
<Card>
<Section
flexDirection="row"
justifyContent="between"
alignItems="center"
height="fit"
>
<InputLayouts.Horizontal
title="Default Model"
description="This model will be used by Onyx by default in your chats."
>
<InputSelect value={currentValue} onValueChange={handleValueChange}>
<InputSelect.Trigger placeholder="Select a model..." />
<InputSelect.Content>
{models.map((model) => (
<InputSelect.Item
key={`${model.provider_id}:${model.model_name}`}
value={`${model.provider_id}:${model.model_name}`}
description={model.provider_display_name}
>
{model.model_display_name}
</InputSelect.Item>
))}
</InputSelect.Content>
</InputSelect>
</InputLayouts.Horizontal>
</Section>
</Card>
);
}

View File

@@ -1,11 +1,25 @@
import { ModelConfiguration, SimpleKnownModel } from "../../interfaces";
import { FormikProps } from "formik";
import { BaseLLMFormValues } from "../formUtils";
import { useState } from "react";
import Checkbox from "@/refresh-components/inputs/Checkbox";
import Switch from "@/refresh-components/inputs/Switch";
import Text from "@/refresh-components/texts/Text";
import Button from "@/refresh-components/buttons/Button";
import IconButton from "@/refresh-components/buttons/IconButton";
import { Card } from "@/refresh-components/cards";
import Separator from "@/refresh-components/Separator";
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
import { SvgChevronDown, SvgRefreshCw } from "@opal/icons";
import { cn } from "@/lib/utils";
import { FieldLabel } from "@/components/Field";
import {
Collapsible,
CollapsibleContent,
CollapsibleTrigger,
} from "@/refresh-components/Collapsible";
import * as GeneralLayouts from "@/layouts/general-layouts";
import { SvgEmpty } from "@opal/icons";
interface AutoModeToggleProps {
isAutoMode: boolean;
@@ -14,51 +28,173 @@ interface AutoModeToggleProps {
function AutoModeToggle({ isAutoMode, onToggle }: AutoModeToggleProps) {
return (
<div className="flex items-center justify-between">
<div>
<Text as="p" mainUiAction className="block">
Auto Update
<GeneralLayouts.Section
flexDirection="row"
justifyContent="between"
alignItems="center"
>
<GeneralLayouts.Section gap={0.125} alignItems="start" width="fit">
<GeneralLayouts.Section
flexDirection="row"
gap={0.375}
alignItems="center"
width="fit"
>
<Text as="span" mainUiAction>
Auto Update
</Text>
<Text as="span" secondaryBody text03>
(Recommended)
</Text>
</GeneralLayouts.Section>
<Text as="p" secondaryBody text03>
Update the available models when new models are released.
</Text>
<Text as="p" secondaryBody text03 className="block">
Automatically update the available models when new models are
released. Recommended for most teams.
</Text>
</div>
<button
type="button"
role="switch"
aria-checked={isAutoMode}
className={cn(
"relative inline-flex h-6 w-11 shrink-0 cursor-pointer rounded-full",
"border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none",
isAutoMode ? "bg-action-link-05" : "bg-background-neutral-03"
)}
onClick={onToggle}
>
<span
className={cn(
"pointer-events-none inline-block h-5 w-5 transform rounded-full",
"bg-white shadow ring-0 transition duration-200 ease-in-out",
isAutoMode ? "translate-x-5" : "translate-x-0"
)}
/>
</button>
</div>
</GeneralLayouts.Section>
<Switch checked={isAutoMode} onCheckedChange={onToggle} />
</GeneralLayouts.Section>
);
}
function DisplayModelHeader({ alternativeText }: { alternativeText?: string }) {
interface DisplayModelHeaderProps {
alternativeText?: string;
onSelectAll?: () => void;
onRefresh?: () => void;
isAutoMode: boolean;
showSelectAll?: boolean;
}
function DisplayModelHeader({
alternativeText,
onSelectAll,
onRefresh,
isAutoMode,
showSelectAll = true,
}: DisplayModelHeaderProps) {
return (
<div>
<FieldLabel
label="Available Models"
subtext={
alternativeText ??
"Select which models to make available for this provider."
}
name="_available-models"
/>
</div>
<GeneralLayouts.Section
flexDirection="row"
justifyContent="between"
alignItems="center"
>
<GeneralLayouts.Section gap={0} alignItems="start" width="fit">
<Text as="p" mainContentBody>
Models
</Text>
<Text as="p" secondaryBody text03>
{alternativeText ??
"Select models to make available for this provider."}
</Text>
</GeneralLayouts.Section>
<GeneralLayouts.Section flexDirection="row" gap={0.25} width="fit">
{showSelectAll && (
<Button main tertiary onClick={onSelectAll} disabled={isAutoMode}>
Select All
</Button>
)}
<IconButton
icon={SvgRefreshCw}
main
internal
onClick={onRefresh}
tooltip="Refresh models"
/>
</GeneralLayouts.Section>
</GeneralLayouts.Section>
);
}
interface ModelRowProps {
modelName: string;
modelDisplayName?: string;
isSelected: boolean;
isDefault: boolean;
onCheckChange: (checked: boolean) => void;
onSetDefault?: () => void;
}
function ModelRow({
modelName,
modelDisplayName,
isSelected,
isDefault,
onCheckChange,
onSetDefault,
}: ModelRowProps) {
return (
<Card
variant="borderless"
flexDirection="row"
justifyContent="between"
alignItems="center"
className={cn("cursor-pointer group hover:bg-background-tint-01")}
padding={0.25}
onClick={() => onCheckChange(!isSelected)}
>
<GeneralLayouts.Section
flexDirection="row"
gap={0.75}
alignItems="center"
width="fit"
>
<Checkbox
checked={isSelected}
onCheckedChange={(checked) => onCheckChange(checked)}
onClick={(e) => e.stopPropagation()}
/>
<Text
as="span"
className={cn(
"select-none",
isSelected ? "text-action-link-04" : "text-text-03"
)}
>
{modelDisplayName ?? modelName}
</Text>
</GeneralLayouts.Section>
{isDefault ? (
<Text as="span" secondaryBody className="text-action-link-05">
Default Model
</Text>
) : (
onSetDefault && (
<Button
main
tertiary
className="opacity-0 group-hover:opacity-100 transition-opacity"
onClick={(e) => {
e.stopPropagation();
onSetDefault();
}}
>
<Text text03>Set as Default</Text>
</Button>
)
)}
</Card>
);
}
interface MoreModelsButtonProps {
isOpen: boolean;
}
function MoreModelsButton({ isOpen }: MoreModelsButtonProps) {
return (
<Button
main
internal
transient
leftIcon={SvgChevronDown}
className={cn(
"[&_svg]:transition-transform [&_svg]:duration-200",
isOpen && "[&_svg]:rotate-180"
)}
>
<Text as="span" text02>
More Models
</Text>
</Button>
);
}
@@ -67,54 +203,35 @@ export function DisplayModels<T extends BaseLLMFormValues>({
modelConfigurations,
noModelConfigurationsMessage,
isLoading,
recommendedDefaultModel,
shouldShowAutoUpdateToggle,
}: {
formikProps: FormikProps<T>;
modelConfigurations: ModelConfiguration[];
noModelConfigurationsMessage?: string;
isLoading?: boolean;
recommendedDefaultModel: SimpleKnownModel | null;
shouldShowAutoUpdateToggle: boolean;
}) {
const [moreModelsOpen, setMoreModelsOpen] = useState(false);
const isAutoMode = formikProps.values.is_auto_mode;
if (isLoading) {
return (
<div>
<DisplayModelHeader />
<div className="mt-2 flex items-center p-3 border border-border-01 rounded-lg bg-background-neutral-00">
<div className="h-5 w-5 animate-spin rounded-full border-2 border-border-03 border-t-action-link-05" />
</div>
</div>
);
}
const defaultModelName = formikProps.values.default_model_name;
const handleCheckboxChange = (modelName: string, checked: boolean) => {
// Read current values inside the handler to avoid stale closure issues
if (!checked && modelName === defaultModelName) {
return;
}
const currentSelected = formikProps.values.selected_model_names ?? [];
const currentDefault = formikProps.values.default_model_name;
if (checked) {
const newSelected = [...currentSelected, modelName];
formikProps.setFieldValue("selected_model_names", newSelected);
// If this is the first model, set it as default
if (currentSelected.length === 0) {
formikProps.setFieldValue("default_model_name", modelName);
}
} else {
const newSelected = currentSelected.filter((name) => name !== modelName);
formikProps.setFieldValue("selected_model_names", newSelected);
// If removing the default, set the first remaining model as default
if (currentDefault === modelName && newSelected.length > 0) {
formikProps.setFieldValue("default_model_name", newSelected[0]);
} else if (newSelected.length === 0) {
formikProps.setFieldValue("default_model_name", null);
}
}
};
const handleSetDefault = (modelName: string) => {
const onSetDefault = (modelName: string) => {
formikProps.setFieldValue("default_model_name", modelName);
};
@@ -124,19 +241,31 @@ export function DisplayModels<T extends BaseLLMFormValues>({
"selected_model_names",
modelConfigurations.filter((m) => m.is_visible).map((m) => m.name)
);
formikProps.setFieldValue(
"default_model_name",
recommendedDefaultModel?.name ?? ""
);
};
const handleSelectAll = () => {
const allModelNames = modelConfigurations.map((m) => m.name);
formikProps.setFieldValue("selected_model_names", allModelNames);
};
const handleRefresh = () => {
// Trigger a refresh of models - this would need to be wired up to your fetch logic
// For now, this is a placeholder
};
const selectedModels = formikProps.values.selected_model_names ?? [];
const defaultModel = formikProps.values.default_model_name;
// Sort models: default first, then selected, then unselected
const primaryModels = modelConfigurations.filter((m) =>
selectedModels.includes(m.name)
);
const moreModels = modelConfigurations.filter(
(m) => !selectedModels.includes(m.name)
);
const sortedModelConfigurations = [...modelConfigurations].sort((a, b) => {
const aIsDefault = a.name === defaultModel;
const bIsDefault = b.name === defaultModel;
const aIsDefault = a.name === defaultModelName;
const bIsDefault = b.name === defaultModelName;
const aIsSelected = selectedModels.includes(a.name);
const bIsSelected = selectedModels.includes(b.name);
@@ -147,166 +276,136 @@ export function DisplayModels<T extends BaseLLMFormValues>({
return 0;
});
if (modelConfigurations.length === 0) {
return (
<div>
<DisplayModelHeader
alternativeText={noModelConfigurationsMessage ?? "No models found"}
/>
</div>
);
}
// Sort auto mode models: default model first
// For auto mode display
const visibleModels = modelConfigurations.filter((m) => m.is_visible);
const sortedAutoModels = [...visibleModels].sort((a, b) => {
const aIsDefault = a.name === defaultModel;
const bIsDefault = b.name === defaultModel;
if (aIsDefault && !bIsDefault) return -1;
if (!aIsDefault && bIsDefault) return 1;
return 0;
});
const primaryAutoModels = visibleModels.filter((m) =>
selectedModels.includes(m.name)
);
const moreAutoModels = visibleModels.filter(
(m) => !selectedModels.includes(m.name)
);
return (
<div className="flex flex-col gap-3">
<DisplayModelHeader />
<div className="border border-border-01 rounded-lg p-3">
{shouldShowAutoUpdateToggle && (
<AutoModeToggle
isAutoMode={isAutoMode}
onToggle={handleToggleAutoMode}
/>
)}
{/* Model list section */}
<div
className={cn(
"flex flex-col gap-1",
shouldShowAutoUpdateToggle && "mt-3 pt-3 border-t border-border-01"
)}
>
<Card variant="borderless">
<DisplayModelHeader
onSelectAll={handleSelectAll}
onRefresh={handleRefresh}
showSelectAll={modelConfigurations.length > 0}
isAutoMode={isAutoMode}
/>
{modelConfigurations.length > 0 ? (
<Card variant="borderless" padding={0} gap={0}>
{isAutoMode && shouldShowAutoUpdateToggle ? (
// Auto mode: read-only display
<div className="flex flex-col gap-2">
{sortedAutoModels.map((model) => {
const isDefault = model.name === defaultModel;
<>
{primaryAutoModels.map((model) => {
const isDefault = model.name === defaultModelName;
return (
<div
<ModelRow
key={model.name}
className={cn(
"flex items-center justify-between gap-3 rounded-lg border p-1",
"bg-background-neutral-00",
isDefault ? "border-action-link-05" : "border-border-01"
)}
>
<div className="flex flex-1 items-center gap-2 px-2 py-1">
<div
className={cn(
"size-2 shrink-0 rounded-full",
isDefault
? "bg-action-link-05"
: "bg-background-neutral-03"
)}
/>
<div className="flex flex-col gap-0.5">
<Text mainUiAction text05>
{model.display_name || model.name}
</Text>
{model.display_name && (
<Text secondaryBody text03>
{model.name}
</Text>
)}
</div>
</div>
{isDefault && (
<div className="flex items-center justify-end pr-2">
<Text
secondaryBody
className="text-action-text-link-05"
>
Default
</Text>
</div>
)}
</div>
modelName={model.name}
modelDisplayName={model.display_name}
isSelected={true}
isDefault={isDefault}
onCheckChange={() => {}}
/>
);
})}
</div>
</>
) : (
// Manual mode: checkbox selection
<div
className={cn(
"flex flex-col gap-1",
"max-h-48 4xl:max-h-64",
"overflow-y-auto"
)}
>
{sortedModelConfigurations.map((modelConfiguration) => {
<>
{primaryModels.map((modelConfiguration) => {
const isSelected = selectedModels.includes(
modelConfiguration.name
);
const isDefault = defaultModel === modelConfiguration.name;
const isDefault = defaultModelName === modelConfiguration.name;
return (
<div
<ModelRow
key={modelConfiguration.name}
className="flex items-center justify-between py-1.5 px-2 rounded hover:bg-background-neutral-subtle"
>
<div
className="flex items-center gap-2 cursor-pointer"
onClick={() =>
handleCheckboxChange(
modelConfiguration.name,
!isSelected
)
}
>
<div
className="flex items-center"
onClick={(e) => e.stopPropagation()}
>
<Checkbox
checked={isSelected}
onCheckedChange={(checked) =>
handleCheckboxChange(
modelConfiguration.name,
checked
)
}
/>
</div>
<Text
as="p"
secondaryBody
className="select-none leading-none"
>
{modelConfiguration.name}
</Text>
</div>
<button
type="button"
disabled={!isSelected}
onClick={() => handleSetDefault(modelConfiguration.name)}
className={`text-xs px-2 py-0.5 rounded transition-all duration-200 ease-in-out ${
isSelected
? "opacity-100 translate-x-0"
: "opacity-0 translate-x-2 pointer-events-none"
} ${
isDefault
? "bg-action-link-05 text-text-inverse font-medium scale-100"
: "bg-background-neutral-02 text-text-03 hover:bg-background-neutral-03 scale-95 hover:scale-100"
}`}
>
{isDefault ? "Default" : "Set as default"}
</button>
</div>
modelName={modelConfiguration.name}
modelDisplayName={modelConfiguration.display_name}
isSelected={isSelected}
isDefault={isDefault}
onCheckChange={(checked) =>
handleCheckboxChange(modelConfiguration.name, checked)
}
onSetDefault={
onSetDefault
? () => onSetDefault(modelConfiguration.name)
: undefined
}
/>
);
})}
</div>
</>
)}
</div>
</div>
</div>
{moreModels.length > 0 && (
<Collapsible open={moreModelsOpen} onOpenChange={setMoreModelsOpen}>
<CollapsibleTrigger asChild>
<Card variant="borderless" padding={0}>
<MoreModelsButton isOpen={moreModelsOpen} />
</Card>
</CollapsibleTrigger>
<CollapsibleContent>
<GeneralLayouts.Section gap={0.25} alignItems="start">
{moreModels.map((modelConfiguration) => {
const isSelected = selectedModels.includes(
modelConfiguration.name
);
const isDefault =
defaultModelName === modelConfiguration.name;
return (
<ModelRow
key={modelConfiguration.name}
modelName={modelConfiguration.name}
modelDisplayName={modelConfiguration.display_name}
isSelected={isSelected}
isDefault={isDefault}
onCheckChange={(checked) =>
handleCheckboxChange(modelConfiguration.name, checked)
}
onSetDefault={() =>
onSetDefault(modelConfiguration.name)
}
/>
);
})}
</GeneralLayouts.Section>
</CollapsibleContent>
</Collapsible>
)}
{/* Auto update toggle */}
{shouldShowAutoUpdateToggle && (
<Card variant="borderless" padding={0.75} gap={0.75}>
<Separator noPadding />
<AutoModeToggle
isAutoMode={isAutoMode}
onToggle={handleToggleAutoMode}
/>
</Card>
)}
</Card>
) : (
<Card variant="tertiary">
<GeneralLayouts.Section
gap={0.5}
flexDirection="row"
alignItems="center"
justifyContent="start"
>
<SvgEmpty className="w-4 h-4 line-item-icon-muted" />
<Text text03 secondaryBody>
{noModelConfigurationsMessage ?? "No models found"}
</Text>
</GeneralLayouts.Section>
</Card>
)}
</Card>
);
}

View File

@@ -1,4 +1,7 @@
import { TextFormField } from "@/components/Field";
import Text from "@/refresh-components/texts/Text";
import * as GeneralLayouts from "@/layouts/general-layouts";
import InputWrapper from "./InputWrapper";
interface DisplayNameFieldProps {
disabled?: boolean;
@@ -6,12 +9,17 @@ interface DisplayNameFieldProps {
export function DisplayNameField({ disabled = false }: DisplayNameFieldProps) {
return (
<TextFormField
name="name"
<InputWrapper
label="Display Name"
subtext="A name which you can use to identify this provider when selecting it in the UI."
placeholder="Display Name"
disabled={disabled}
/>
optional
description="Use to identify this provider in the app"
>
<TextFormField
name="name"
label=""
placeholder="Display Name"
disabled={disabled}
/>
</InputWrapper>
);
}

View File

@@ -2,13 +2,14 @@ import { LoadingAnimation } from "@/components/Loading";
import Text from "@/refresh-components/texts/Text";
import Button from "@/refresh-components/buttons/Button";
import { SvgTrash } from "@opal/icons";
import { LLMProviderView } from "../../interfaces";
import { DefaultModel, LLMProviderView } from "../../interfaces";
import { LLM_PROVIDERS_ADMIN_URL } from "../../constants";
interface FormActionButtonsProps {
isTesting: boolean;
testError: string;
existingLlmProvider?: LLMProviderView;
defaultModel?: DefaultModel;
mutate: (key: string) => void;
onClose: () => void;
isFormValid: boolean;
@@ -18,6 +19,7 @@ export function FormActionButtons({
isTesting,
testError,
existingLlmProvider,
defaultModel,
mutate,
onClose,
isFormValid,
@@ -39,16 +41,22 @@ export function FormActionButtons({
}
// If the deleted provider was the default, set the first remaining provider as default
if (existingLlmProvider.is_default_provider) {
if (defaultModel?.provider_id === existingLlmProvider.id) {
const remainingProvidersResponse = await fetch(LLM_PROVIDERS_ADMIN_URL);
if (remainingProvidersResponse.ok) {
const remainingProviders = await remainingProvidersResponse.json();
const remainingProvidersJson = await remainingProvidersResponse.json();
const remainingProviders = remainingProvidersJson.providers;
if (remainingProviders.length > 0) {
const setDefaultResponse = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${remainingProviders[0].id}/default`,
`${LLM_PROVIDERS_ADMIN_URL}/default`,
{
method: "POST",
body: JSON.stringify({
provider_id: remainingProviders[0].id,
model_name:
remainingProviders[0].model_configurations[0]?.name ?? "",
}),
}
);
if (!setDefaultResponse.ok) {

View File

@@ -0,0 +1,184 @@
"use client";
import React from "react";
import { Section } from "@/layouts/general-layouts";
import Text from "@/refresh-components/texts/Text";
import Button from "@/refresh-components/buttons/Button";
import Modal from "@/refresh-components/Modal";
import { LoadingAnimation } from "@/components/Loading";
import type { IconProps } from "@opal/types";
/**
* LLMFormLayout - A compound component for LLM provider configuration forms
*
* Provides a complete modal form structure with:
* - Modal: The modal wrapper with header (icon, title, close button)
* - Body: Form content area
* - Footer: Cancel and Connect/Update buttons
*
* @example
* ```tsx
* <LLMFormLayout.Modal
* icon={OpenAIIcon}
* displayName="OpenAI"
* providerName="openai"
* onClose={handleClose}
* isEditing={false}
* >
* <Formik ...>
* {(formikProps) => (
* <Form>
* <LLMFormLayout.Body>
* <InputField name="api_key" label="API Key" />
* </LLMFormLayout.Body>
* <LLMFormLayout.Footer
* onCancel={handleClose}
* isSubmitting={isSubmitting}
* />
* </Form>
* )}
* </Formik>
* </LLMFormLayout.Modal>
* ```
*/
// ============================================================================
// Modal (Root)
// ============================================================================
interface LLMFormModalProps {
children: React.ReactNode;
/** Icon component for the provider */
icon: React.FunctionComponent<IconProps>;
/** Display name shown in the modal title */
displayName: string;
/** Whether editing an existing provider (changes title) */
isEditing?: boolean;
/** Custom name of the existing provider (shown in title when editing) */
existingName?: string;
/** Handler for closing the modal */
onClose: () => void;
}
function LLMFormModal({
children,
icon: Icon,
displayName,
isEditing = false,
existingName,
onClose,
}: LLMFormModalProps) {
const title = isEditing
? `Configure ${existingName ? `"${existingName}"` : displayName}`
: `Setup up ${displayName}`;
const description = isEditing
? ""
: `Connect to ${displayName} to set up your ${displayName} models.`;
return (
<Modal open onOpenChange={onClose}>
<Modal.Content width="md-sm">
<Modal.Header
icon={Icon}
title={title}
description={description}
onClose={onClose}
/>
<Modal.Body alignItems="stretch">{children}</Modal.Body>
</Modal.Content>
</Modal>
);
}
// ============================================================================
// Body
// ============================================================================
interface LLMFormBodyProps {
children: React.ReactNode;
}
function LLMFormBody({ children }: LLMFormBodyProps) {
return <>{children}</>;
}
// ============================================================================
// Footer
// ============================================================================
interface LLMFormFooterProps {
/** Handler for cancel button */
onCancel: () => void;
/** Label for the cancel button. Default: "Cancel" */
cancelLabel?: string;
/** Label for the submit button. Default: "Connect" */
submitLabel?: string;
/** Text to show while submitting. Default: "Testing" */
submittingLabel?: string;
/** Whether the form is currently submitting */
isSubmitting?: boolean;
/** Whether the submit button should be disabled */
isSubmitDisabled?: boolean;
/** Optional left-side content (e.g., delete button) */
leftChildren?: React.ReactNode;
/** Error message to display */
error?: string;
}
function LLMFormFooter({
onCancel,
cancelLabel = "Cancel",
submitLabel = "Connect",
submittingLabel = "Testing",
isSubmitting = false,
isSubmitDisabled = false,
leftChildren,
error,
}: LLMFormFooterProps) {
return (
<Section alignItems="stretch" gap={0.5}>
{error && (
<Text as="p" className="text-error">
{error}
</Text>
)}
<Section flexDirection="row" justifyContent="between" gap={0.5}>
<Section
width="fit"
flexDirection="row"
justifyContent="start"
gap={0.5}
>
{leftChildren}
</Section>
<Section width="fit" flexDirection="row" justifyContent="end" gap={0.5}>
<Button secondary onClick={onCancel} disabled={isSubmitting}>
{cancelLabel}
</Button>
<Button type="submit" disabled={isSubmitting || isSubmitDisabled}>
{isSubmitting ? (
<Text as="span" inverted>
<LoadingAnimation text={submittingLabel} />
</Text>
) : (
submitLabel
)}
</Button>
</Section>
</Section>
</Section>
);
}
// ============================================================================
// Export
// ============================================================================
const LLMFormLayout = {
Modal: LLMFormModal,
Body: LLMFormBody,
Footer: LLMFormFooter,
};
export default LLMFormLayout;

View File

@@ -8,13 +8,13 @@ import {
WellKnownLLMProviderDescriptor,
} from "../../interfaces";
import { errorHandlingFetcher } from "@/lib/fetcher";
import Modal from "@/refresh-components/Modal";
import Text from "@/refresh-components/texts/Text";
import Button from "@/refresh-components/buttons/Button";
import { SvgSettings } from "@opal/icons";
import { Badge } from "@/components/ui/badge";
import { cn } from "@/lib/utils";
import { LLM_PROVIDERS_ADMIN_URL } from "../../constants";
import { SvgArrowExchange, SvgSettings } from "@opal/icons";
import { Card } from "@/refresh-components/cards";
import * as GeneralLayouts from "@/layouts/general-layouts";
import { ProviderIcon } from "../../ProviderIcon";
import IconButton from "@/refresh-components/buttons/IconButton";
import LineItem from "@/refresh-components/buttons/LineItem";
import LLMFormLayout from "./FormLayout";
export interface ProviderFormContext {
onClose: () => void;
@@ -31,21 +31,19 @@ export interface ProviderFormContext {
interface ProviderFormEntrypointWrapperProps {
children: (context: ProviderFormContext) => ReactNode;
providerName: string;
providerDisplayName?: string;
providerInternalName?: string;
providerEndpoint?: string;
existingLlmProvider?: LLMProviderView;
/** When true, renders a simple button instead of a card-based UI */
buttonMode?: boolean;
/** Custom button text for buttonMode (defaults to "Add {providerName}") */
buttonText?: string;
}
export function ProviderFormEntrypointWrapper({
children,
providerName,
providerDisplayName,
providerInternalName,
providerEndpoint,
existingLlmProvider,
buttonMode,
buttonText,
}: ProviderFormEntrypointWrapperProps) {
const [formIsVisible, setFormIsVisible] = useState(false);
@@ -67,31 +65,6 @@ export function ProviderFormEntrypointWrapper({
const onClose = () => setFormIsVisible(false);
async function handleSetAsDefault(): Promise<void> {
if (!existingLlmProvider) return;
const response = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}/default`,
{
method: "POST",
}
);
if (!response.ok) {
const errorMsg = (await response.json()).detail;
setPopup({
type: "error",
message: `Failed to set provider as default: ${errorMsg}`,
});
return;
}
await mutate(LLM_PROVIDERS_ADMIN_URL);
setPopup({
type: "success",
message: "Provider set as default successfully!",
});
}
const context: ProviderFormContext = {
onClose,
mutate,
@@ -104,113 +77,51 @@ export function ProviderFormEntrypointWrapper({
wellKnownLLMProvider,
};
// Button mode: simple button that opens a modal
if (buttonMode && !existingLlmProvider) {
return (
<>
{popup}
<Button action onClick={() => setFormIsVisible(true)}>
{buttonText ?? `Add ${providerName}`}
</Button>
const displayName = providerDisplayName ?? providerName;
const internalName = providerInternalName ?? providerName;
{formIsVisible && (
<Modal open onOpenChange={onClose}>
<Modal.Content>
<Modal.Header
icon={SvgSettings}
title={`Setup ${providerName}`}
onClose={onClose}
/>
<Modal.Body>{children(context)}</Modal.Body>
</Modal.Content>
</Modal>
)}
</>
);
}
// Card mode: card-based UI with modal
return (
<div>
<>
{popup}
<div className="border p-3 bg-background-neutral-01 rounded-16 w-96 flex shadow-md">
<Card padding={0}>
{existingLlmProvider ? (
<>
<div className="my-auto">
<Text
as="p"
headingH3
text04
className="text-ellipsis overflow-hidden max-w-32"
>
{existingLlmProvider.name}
</Text>
<Text as="p" secondaryBody text03 className="italic">
({providerName})
</Text>
{!existingLlmProvider.is_default_provider && (
<Text
as="p"
className={cn("text-action-link-05", "cursor-pointer")}
onClick={handleSetAsDefault}
>
Set as default
</Text>
)}
</div>
{existingLlmProvider && (
<div className="my-auto ml-3">
{existingLlmProvider.is_default_provider ? (
<Badge variant="agent">Default</Badge>
) : (
<Badge variant="success">Enabled</Badge>
)}
</div>
)}
<div className="ml-auto my-auto">
<Button
action={!existingLlmProvider}
secondary={!!existingLlmProvider}
<GeneralLayouts.CardItemLayout
icon={() => <ProviderIcon provider={internalName} size={24} />}
title={displayName}
description={providerName}
rightChildren={
<IconButton
icon={SvgSettings}
internal
onClick={() => setFormIsVisible(true)}
>
Edit
</Button>
</div>
</>
/>
}
/>
) : (
<>
<div className="my-auto">
<Text as="p" headingH3>
{providerName}
</Text>
</div>
<div className="ml-auto my-auto">
<Button action onClick={() => setFormIsVisible(true)}>
Set up
</Button>
</div>
</>
<GeneralLayouts.CardItemLayout
icon={() => <ProviderIcon provider={internalName} size={24} />}
title={displayName}
description={providerName}
rightChildren={
<LineItem
children="Connect"
icon={SvgArrowExchange}
onClick={() => setFormIsVisible(true)}
/>
}
/>
)}
</div>
</Card>
{formIsVisible && (
<Modal open onOpenChange={onClose}>
<Modal.Content>
<Modal.Header
icon={SvgSettings}
title={`${existingLlmProvider ? "Configure" : "Setup"} ${
existingLlmProvider?.name
? `"${existingLlmProvider.name}"`
: providerName
}`}
onClose={onClose}
/>
<Modal.Body>{children(context)}</Modal.Body>
</Modal.Content>
</Modal>
<LLMFormLayout.Modal
icon={() => <ProviderIcon provider={internalName} size={24} />}
displayName={displayName}
onClose={onClose}
>
{children(context)}
</LLMFormLayout.Modal>
)}
</div>
</>
);
}

View File

@@ -0,0 +1,75 @@
import * as GeneralLayouts from "@/layouts/general-layouts";
import Text from "@/refresh-components/texts/Text";
interface DescriptionLink {
text: string;
href: string;
}
interface InputWrapperProps {
label: string;
description?: string;
descriptionLink?: DescriptionLink;
children: React.ReactNode;
optional?: boolean;
}
export default function InputWrapper({
label,
optional,
description,
descriptionLink,
children,
}: InputWrapperProps) {
const renderDescription = () => {
if (!description) return null;
if (descriptionLink && description.includes("{link}")) {
const [before, after] = description.split("{link}");
return (
<Text as="p" secondaryBody text03>
{before}
<a
href={descriptionLink.href}
target="_blank"
rel="noopener noreferrer"
className="underline"
>
{descriptionLink.text}
</a>
{after}
</Text>
);
}
return (
<Text as="p" secondaryBody text03>
{description}
</Text>
);
};
return (
<GeneralLayouts.Section
flexDirection="column"
alignItems="start"
gap={0.25}
>
<GeneralLayouts.Section
flexDirection="row"
gap={0.25}
alignItems="center"
justifyContent="start"
>
<Text as="p">{label}</Text>
{optional && (
<Text as="p" text03>
(Optional)
</Text>
)}
</GeneralLayouts.Section>
{children}
{renderDescription()}
</GeneralLayouts.Section>
);
}

View File

@@ -0,0 +1,64 @@
"use client";
import * as InputLayouts from "@/layouts/input-layouts";
import InputSelectField from "@/refresh-components/form/InputSelectField";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import { useState } from "react";
import { SvgOrganization } from "@opal/icons";
import * as GeneralLayouts from "@/layouts/general-layouts";
import Text from "@/refresh-components/texts/Text";
import { Card } from "@/refresh-components/cards";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
enum ModelAccessOption {
ALL = "all", // All users and agents
NAMED = "named", // Only named users and agents
}
export function ModelAccessOptions() {
const [modelAccessOption, setModelAccessOption] = useState<ModelAccessOption>(
ModelAccessOption.ALL
);
return (
<>
<InputLayouts.Horizontal
title="Model Access"
description="Who can access this provider."
>
<InputSelectField
name="model_access_options"
defaultValue={modelAccessOption}
onValueChange={(value) =>
setModelAccessOption(value as ModelAccessOption)
}
>
<InputSelect.Trigger />
<InputSelect.Content>
<InputSelect.Item value={ModelAccessOption.ALL}>
All users and agents
</InputSelect.Item>
<InputSelect.Item value={ModelAccessOption.NAMED}>
Named users and agents
</InputSelect.Item>
</InputSelect.Content>
</InputSelectField>
</InputLayouts.Horizontal>
{modelAccessOption === ModelAccessOption.NAMED && (
<NamedModelAccessOptions />
)}
</>
);
}
function NamedModelAccessOptions() {
return (
<Card>
<InputTypeIn
variant="primary"
placeholder="Add users, groups, accounts, and agents"
/>
</Card>
);
}

View File

@@ -3,30 +3,26 @@ import {
ModelConfiguration,
WellKnownLLMProviderDescriptor,
} from "../interfaces";
import { LLM_PROVIDERS_ADMIN_URL } from "../constants";
import { LLM_PROVIDERS_ADMIN_URL, LLM_ADMIN_URL } from "../constants";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import * as Yup from "yup";
import isEqual from "lodash/isEqual";
import { setDefaultLlmModel } from "@/lib/admin/llm/svc";
// Common class names for the Form component across all LLM provider forms
export const LLM_FORM_CLASS_NAME = "flex flex-col gap-y-4 items-stretch mt-6";
export const buildDefaultInitialValues = (
existingLlmProvider?: LLMProviderView,
modelConfigurations?: ModelConfiguration[]
modelConfigurations?: ModelConfiguration[],
defaultModelName?: string
) => {
const defaultModelName =
existingLlmProvider?.default_model_name ??
modelConfigurations?.[0]?.name ??
"";
// Auto mode must be explicitly enabled by the user
// Default to false for new providers, preserve existing value when editing
const isAutoMode = existingLlmProvider?.is_auto_mode ?? false;
return {
name: existingLlmProvider?.name || "",
default_model_name: defaultModelName,
is_public: existingLlmProvider?.is_public ?? true,
is_auto_mode: isAutoMode,
groups: existingLlmProvider?.groups ?? [],
@@ -38,18 +34,19 @@ export const buildDefaultInitialValues = (
: modelConfigurations
?.filter((modelConfiguration) => modelConfiguration.is_visible)
.map((modelConfiguration) => modelConfiguration.name) ?? [],
default_model_name: defaultModelName,
};
};
export const buildDefaultValidationSchema = () => {
return Yup.object({
name: Yup.string().required("Display Name is required"),
default_model_name: Yup.string().required("Model name is required"),
is_public: Yup.boolean().required(),
is_auto_mode: Yup.boolean().required(),
groups: Yup.array().of(Yup.number()),
personas: Yup.array().of(Yup.number()),
selected_model_names: Yup.array().of(Yup.string()),
default_model_name: Yup.string().optional(),
});
};
@@ -81,13 +78,13 @@ export interface BaseLLMFormValues {
name: string;
api_key?: string;
api_base?: string;
default_model_name?: string;
is_public: boolean;
is_auto_mode: boolean;
groups: number[];
personas: number[];
selected_model_names: string[];
custom_config?: Record<string, string>;
default_model_name?: string;
}
export interface SubmitLLMProviderParams<
@@ -110,8 +107,7 @@ export interface SubmitLLMProviderParams<
export const filterModelConfigurations = (
currentModelConfigurations: ModelConfiguration[],
visibleModels: string[],
defaultModelName?: string
visibleModels: string[]
): ModelConfiguration[] => {
return currentModelConfigurations
.map(
@@ -123,11 +119,15 @@ export const filterModelConfigurations = (
display_name: modelConfiguration.display_name,
})
)
.filter(
(modelConfiguration) =>
modelConfiguration.name === defaultModelName ||
modelConfiguration.is_visible
);
.filter((modelConfiguration) => modelConfiguration.is_visible);
};
const getFirstVisibleModelConfiguration = (
modelConfigurations: ModelConfiguration[]
): ModelConfiguration | undefined => {
return modelConfigurations.find(
(modelConfiguration) => modelConfiguration.is_visible
);
};
// Helper to get model configurations for auto mode
@@ -169,27 +169,14 @@ export const submitLLMProvider = async <T extends BaseLLMFormValues>({
// In auto mode, use recommended models from descriptor
// In manual mode, use user's selection
let filteredModelConfigurations: ModelConfiguration[];
let finalDefaultModelName = rest.default_model_name;
if (values.is_auto_mode) {
filteredModelConfigurations =
getAutoModeModelConfigurations(modelConfigurations);
// In auto mode, use the first recommended model as default if current default isn't in the list
const visibleModelNames = new Set(
filteredModelConfigurations.map((m) => m.name)
);
if (
finalDefaultModelName &&
!visibleModelNames.has(finalDefaultModelName)
) {
finalDefaultModelName = filteredModelConfigurations[0]?.name ?? "";
}
} else {
filteredModelConfigurations = filterModelConfigurations(
modelConfigurations,
visibleModels,
rest.default_model_name as string | undefined
visibleModels
);
}
@@ -200,7 +187,6 @@ export const submitLLMProvider = async <T extends BaseLLMFormValues>({
const finalValues = {
...rest,
default_model_name: finalDefaultModelName,
api_key,
api_key_changed: api_key !== (initialValues.api_key as string | undefined),
custom_config_changed: customConfigChanged,
@@ -211,6 +197,13 @@ export const submitLLMProvider = async <T extends BaseLLMFormValues>({
if (!isEqual(finalValues, initialValues)) {
setIsTesting(true);
const testModel =
finalValues.default_model_name ??
(filteredModelConfigurations.length > 0
? getFirstVisibleModelConfiguration(filteredModelConfigurations)?.name
: modelConfigurations[0]?.name) ??
"";
const response = await fetch("/api/admin/llm/test", {
method: "POST",
headers: {
@@ -218,6 +211,7 @@ export const submitLLMProvider = async <T extends BaseLLMFormValues>({
},
body: JSON.stringify({
provider: providerName,
model: testModel,
...finalValues,
}),
});
@@ -263,13 +257,11 @@ export const submitLLMProvider = async <T extends BaseLLMFormValues>({
return;
}
if (shouldMarkAsDefault) {
if (shouldMarkAsDefault && finalValues.default_model_name) {
const newLlmProvider = (await response.json()) as LLMProviderView;
const setDefaultResponse = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`,
{
method: "POST",
}
const setDefaultResponse = await setDefaultLlmModel(
newLlmProvider.id,
finalValues.default_model_name
);
if (!setDefaultResponse.ok) {
const errorMsg = (await setDefaultResponse.json()).detail;

View File

@@ -47,20 +47,16 @@ export interface LLMProvider {
api_base: string | null;
api_version: string | null;
custom_config: { [key: string]: string } | null;
default_model_name: string;
is_public: boolean;
is_auto_mode: boolean;
groups: number[];
personas: number[];
deployment_name: string | null;
default_vision_model: string | null;
is_default_vision_provider: boolean | null;
model_configurations: ModelConfiguration[];
}
export interface LLMProviderView extends LLMProvider {
id: number;
is_default_provider: boolean | null;
}
export interface VisionProvider extends LLMProviderView {
@@ -68,19 +64,26 @@ export interface VisionProvider extends LLMProviderView {
}
export interface LLMProviderDescriptor {
id: number;
name: string;
provider: string;
provider_display_name?: string;
default_model_name: string;
is_default_provider: boolean | null;
is_default_vision_provider?: boolean | null;
default_vision_model?: string | null;
is_public?: boolean;
groups?: number[];
personas?: number[];
model_configurations: ModelConfiguration[];
}
export interface DefaultModel {
provider_id: number;
model_name: string;
}
export interface LLMProviderResponse<
T extends LLMProviderView | LLMProviderDescriptor,
> {
providers: T[];
default_text: DefaultModel | null;
default_vision: DefaultModel | null;
}
export interface OllamaModelResponse {
name: string;
display_name: string;
@@ -104,9 +107,16 @@ export interface BedrockModelResponse {
export interface LLMProviderFormProps {
existingLlmProvider?: LLMProviderView;
defaultLlmModel?: DefaultModel;
shouldMarkAsDefault?: boolean;
}
export interface DefaultModelSelectorProps {
existingLlmProviders: LLMProviderView[];
defaultLlmModel: DefaultModel | null;
onModelChange: (provider_id: number, model_name: string) => void;
}
// Param types for model fetching functions - use snake_case to match API structure
export interface BedrockFetchParams {
aws_region_name: string;

View File

@@ -3,12 +3,15 @@
import { AdminPageTitle } from "@/components/admin/Title";
import { LLMConfiguration } from "./LLMConfiguration";
import { SvgCpu } from "@opal/icons";
import * as SettingsLayouts from "@/layouts/settings-layouts";
export default function Page() {
return (
<>
<AdminPageTitle title="LLM Setup" icon={SvgCpu} />
<LLMConfiguration />
</>
<SettingsLayouts.Root>
<SettingsLayouts.Header icon={SvgCpu} title="LLM Models" separator />
<SettingsLayouts.Body>
<LLMConfiguration />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}

View File

@@ -25,7 +25,7 @@ import {
OllamaFetchParams,
OpenRouterFetchParams,
} from "./interfaces";
import { SvgAws, SvgOpenrouter } from "@opal/icons";
import { SvgAws, SvgOpenrouter, SvgServer } from "@opal/icons";
// Aggregator providers that host models from multiple vendors
export const AGGREGATOR_PROVIDERS = new Set([
@@ -69,6 +69,9 @@ export const getProviderIcon = (
bedrock_converse: SvgAws,
openrouter: SvgOpenrouter,
vertex_ai: GeminiIcon,
// Custom providers
litellm: SvgServer,
};
const lowerProviderName = providerName.toLowerCase();

View File

@@ -16,6 +16,7 @@ import LLMSelector from "@/components/llm/LLMSelector";
import { useVisionProviders } from "./hooks/useVisionProviders";
import InputTextArea from "@/refresh-components/inputs/InputTextArea";
import { SvgAlertTriangle } from "@opal/icons";
import { useLlmManager } from "@/lib/hooks";
export function Checkbox({
label,

View File

@@ -24,17 +24,15 @@ export function useVisionProviders(setPopup: SetPopup) {
setError(null);
try {
const data = await fetchVisionProviders();
setVisionProviders(data);
setVisionProviders(data.providers);
// Find the default vision provider and set it
const defaultProvider = data.find(
(provider) => provider.is_default_vision_provider
const defaultProvider = data.providers.find(
(provider) => provider.id === data.default_vision?.provider_id
);
if (defaultProvider) {
const modelToUse =
defaultProvider.default_vision_model ||
defaultProvider.default_model_name;
const modelToUse = data.default_vision?.model_name;
if (modelToUse && defaultProvider.vision_models.includes(modelToUse)) {
setVisionLLM(

View File

@@ -1,5 +1,8 @@
import { useMemo, useState, useCallback } from "react";
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
import {
DefaultModel,
LLMProviderDescriptor,
} from "@/app/admin/configuration/llm/interfaces";
import {
BuildLlmSelection,
getBuildLlmSelection,
@@ -16,7 +19,8 @@ import {
* 2. Smart default - via getDefaultLlmSelection()
*/
export function useBuildLlmSelection(
llmProviders: LLMProviderDescriptor[] | undefined
llmProviders: LLMProviderDescriptor[] | undefined,
defaultLlmModel?: DefaultModel
) {
const [selection, setSelectionState] = useState<BuildLlmSelection | null>(
() => getBuildLlmSelection()
@@ -42,7 +46,19 @@ export function useBuildLlmSelection(
}
// Fall back to smart default
return getDefaultLlmSelection(llmProviders);
return getDefaultLlmSelection(
llmProviders?.map((p) => ({
name: p.name,
provider: p.provider,
default_model_name: (() => {
if (p.id === defaultLlmModel?.provider_id) {
return defaultLlmModel?.model_name ?? "";
}
return p.model_configurations[0]?.name ?? "";
})(),
is_default_provider: p.id === defaultLlmModel?.provider_id,
}))
);
}, [selection, llmProviders, isSelectionValid]);
// Update selection and persist to cookie

View File

@@ -59,6 +59,7 @@ export function BuildOnboardingProvider({
<BuildOnboardingModal
mode={controller.mode}
llmProviders={controller.llmProviders}
defaultLlm={controller.defaultText}
initialValues={controller.initialValues}
isAdmin={controller.isAdmin}
hasUserInfo={controller.hasUserInfo}

View File

@@ -18,7 +18,10 @@ import {
getBuildLlmSelection,
getDefaultLlmSelection,
} from "@/app/craft/onboarding/constants";
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
import {
DefaultModel,
LLMProviderDescriptor,
} from "@/app/admin/configuration/llm/interfaces";
import { LLM_PROVIDERS_ADMIN_URL } from "@/app/admin/configuration/llm/constants";
import {
buildInitialValues,
@@ -36,12 +39,25 @@ import OnboardingLlmSetup, {
* Used when user completes onboarding without going through LLM setup step.
*/
function autoSelectBestLlm(
llmProviders: LLMProviderDescriptor[] | undefined
llmProviders: LLMProviderDescriptor[] | undefined,
defaultLlmModel?: DefaultModel
): void {
// Don't override if user already has a selection
if (getBuildLlmSelection()) return;
const selection = getDefaultLlmSelection(llmProviders);
const selection = getDefaultLlmSelection(
llmProviders?.map((p) => ({
name: p.name,
provider: p.provider,
default_model_name: (() => {
if (p.id === defaultLlmModel?.provider_id) {
return defaultLlmModel?.model_name ?? "";
}
return p.model_configurations[0]?.name ?? "";
})(),
is_default_provider: p.id === defaultLlmModel?.provider_id,
}))
);
if (selection) {
setBuildLlmSelection(selection);
}
@@ -57,6 +73,7 @@ interface InitialValues {
interface BuildOnboardingModalProps {
mode: OnboardingModalMode;
llmProviders?: LLMProviderDescriptor[];
defaultLlm?: DefaultModel;
initialValues: InitialValues;
isAdmin: boolean;
hasUserInfo: boolean;
@@ -103,6 +120,7 @@ function getStepsForMode(
export default function BuildOnboardingModal({
mode,
llmProviders,
defaultLlm,
initialValues,
isAdmin,
hasUserInfo,
@@ -271,8 +289,12 @@ export default function BuildOnboardingModal({
if (!llmProviders || llmProviders.length === 0) {
const newProvider = await response.json();
if (newProvider?.id) {
await fetch(`${LLM_PROVIDERS_ADMIN_URL}/${newProvider.id}/default`, {
await fetch(`${LLM_PROVIDERS_ADMIN_URL}/default`, {
method: "POST",
body: JSON.stringify({
provider_id: newProvider.id,
model_name: selectedModel,
}),
});
}
}
@@ -322,7 +344,7 @@ export default function BuildOnboardingModal({
// Auto-select best available LLM if user didn't go through LLM setup
// (e.g., non-admin users or when all providers already configured)
autoSelectBestLlm(llmProviders);
autoSelectBestLlm(llmProviders, defaultLlm);
// Validate workArea is provided before submission
if (!workArea) {

View File

@@ -46,6 +46,7 @@ export function useOnboardingModal(): OnboardingModalController {
const { user, isAdmin, refreshUser } = useUser();
const {
llmProviders,
defaultText,
isLoading: isLoadingLlm,
refetch: refetchLlmProviders,
} = useLLMProviders();
@@ -168,6 +169,7 @@ export function useOnboardingModal(): OnboardingModalController {
openLlmSetup,
close,
llmProviders,
defaultText: defaultText ?? undefined,
initialValues,
completeUserInfo,
completeLlmSetup,

View File

@@ -36,6 +36,9 @@ export interface OnboardingModalController {
llmProviders:
| import("@/app/admin/configuration/llm/interfaces").LLMProviderDescriptor[]
| undefined;
defaultText:
| import("@/app/admin/configuration/llm/interfaces").DefaultModel
| undefined;
initialValues: {
firstName: string;
lastName: string;
@@ -54,7 +57,9 @@ export interface OnboardingModalController {
completeUserInfo: (info: BuildUserInfo) => Promise<void>;
completeLlmSetup: () => Promise<void>;
refetchLlmProviders: () => Promise<
| import("@/app/admin/configuration/llm/interfaces").LLMProviderDescriptor[]
| import("@/app/admin/configuration/llm/interfaces").LLMProviderResponse<
import("@/app/admin/configuration/llm/interfaces").LLMProviderDescriptor
>
| undefined
>;
}

View File

@@ -78,7 +78,7 @@ interface SelectedConnectorState {
*/
export default function BuildConfigPage() {
const { isAdmin, isCurator } = useUser();
const { llmProviders } = useLLMProviders();
const { llmProviders, defaultText } = useLLMProviders();
const { openPersonaEditor, openLlmSetup } = useOnboarding();
const [selectedConnector, setSelectedConnector] =
useState<SelectedConnectorState | null>(null);
@@ -109,7 +109,7 @@ export default function BuildConfigPage() {
// Build mode LLM selection (cookie-based)
const { selection: llmSelection, updateSelection: updateLlmSelection } =
useBuildLlmSelection(llmProviders);
useBuildLlmSelection(llmProviders, defaultText ?? undefined);
// Read demo data from cookie (single source of truth)
const [demoDataEnabled, setDemoDataEnabledLocal] = useState(() =>

View File

@@ -1,4 +1,5 @@
import {
LLMProviderResponse,
LLMProviderView,
WellKnownLLMProviderDescriptor,
} from "@/app/admin/configuration/llm/interfaces";
@@ -36,9 +37,14 @@ export async function checkLlmProvider(user: User | null) {
const [providerResponse, optionsResponse, defaultCheckResponse] =
await Promise.all(tasks);
let providers: LLMProviderView[] = [];
let providers: LLMProviderResponse<LLMProviderView> = {
providers: [],
default_text: null,
default_vision: null,
};
if (providerResponse?.ok) {
providers = await providerResponse.json();
providers =
(await providerResponse.json()) as LLMProviderResponse<LLMProviderView>;
}
let options: WellKnownLLMProviderDescriptor[] = [];
@@ -52,5 +58,5 @@ export async function checkLlmProvider(user: User | null) {
setDefaultLLMProviderTestComplete();
}
return { providers, options, defaultCheckSuccessful };
return { providers: providers.providers, options, defaultCheckSuccessful };
}

View File

@@ -2,7 +2,10 @@
import { useMemo } from "react";
import { parseLlmDescriptor, structureValue } from "@/lib/llm/utils";
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
import {
DefaultModel,
LLMProviderDescriptor,
} from "@/app/admin/configuration/llm/interfaces";
import { getProviderIcon } from "@/app/admin/configuration/llm/utils";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import { createIcon } from "@/components/icons/icons";
@@ -139,17 +142,6 @@ export default function LLMSelector({
});
}, [llmOptions]);
const defaultProvider = llmProviders.find(
(llmProvider) => llmProvider.is_default_provider
);
const defaultModelName = defaultProvider?.default_model_name;
const defaultModelConfig = defaultProvider?.model_configurations.find(
(m) => m.name === defaultModelName
);
const defaultModelDisplayName = defaultModelConfig
? defaultModelConfig.display_name || defaultModelConfig.name
: defaultModelName || null;
const defaultLabel = userSettings ? "System Default" : "User Default";
// Determine if we should show grouped view (only if we have multiple vendors)
@@ -164,16 +156,7 @@ export default function LLMSelector({
<InputSelect.Content>
{!excludePublicProviders && (
<InputSelect.Item
value="default"
description={
userSettings && defaultModelDisplayName
? `(${defaultModelDisplayName})`
: undefined
}
>
{defaultLabel}
</InputSelect.Item>
<InputSelect.Item value="default">{defaultLabel}</InputSelect.Item>
)}
{showGrouped
? groupedOptions.map((group) => (

View File

@@ -974,7 +974,8 @@ export default function useChatController({
const [_, llmModel] = getFinalLLM(
llmManager.llmProviders || [],
liveAssistant || null,
llmManager.currentLlm
llmManager.currentLlm,
llmManager.defaultLlmModel
);
const llmAcceptsImages = modelSupportsImageInput(
llmManager.llmProviders || [],

View File

@@ -0,0 +1,16 @@
import { LLM_ADMIN_URL } from "@/app/admin/configuration/llm/constants";
export function setDefaultLlmModel(providerId: number, modelName: string) {
const response = fetch(`${LLM_ADMIN_URL}/default`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
provider_id: providerId,
model_name: modelName,
}),
});
return response;
}

View File

@@ -32,7 +32,10 @@ import {
MinimalPersonaSnapshot,
PersonaLabel,
} from "@/app/admin/assistants/interfaces";
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
import {
DefaultModel,
LLMProviderDescriptor,
} from "@/app/admin/configuration/llm/interfaces";
import { isAnthropic } from "@/app/admin/configuration/llm/utils";
import { getSourceMetadataForSources } from "./sources";
import { AuthType, NEXT_PUBLIC_CLOUD_ENABLED } from "./constants";
@@ -471,6 +474,8 @@ export interface LlmDescriptor {
export interface LlmManager {
currentLlm: LlmDescriptor;
updateCurrentLlm: (newOverride: LlmDescriptor) => void;
defaultLlmModel?: DefaultModel;
updateDefaultLlmModel: (newDefaultLlmModel: DefaultModel) => void;
temperature: number;
updateTemperature: (temperature: number) => void;
updateModelOverrideBasedOnChatSession: (chatSession?: ChatSession) => void;
@@ -525,16 +530,20 @@ providing appropriate defaults for new conversations based on the available tool
*/
function getDefaultLlmDescriptor(
llmProviders: LLMProviderDescriptor[]
llmProviders: LLMProviderDescriptor[],
defaultLlmModel?: DefaultModel
): LlmDescriptor | null {
const defaultProvider = llmProviders.find(
(provider) => provider.is_default_provider
(provider) => provider.id === defaultLlmModel?.provider_id
);
if (defaultProvider) {
return {
name: defaultProvider.name,
provider: defaultProvider.provider,
modelName: defaultProvider.default_model_name,
modelName:
defaultLlmModel?.model_name ??
defaultProvider.model_configurations[0]?.name ??
"",
};
}
const firstLlmProvider = llmProviders.find(
@@ -544,7 +553,7 @@ function getDefaultLlmDescriptor(
return {
name: firstLlmProvider.name,
provider: firstLlmProvider.provider,
modelName: firstLlmProvider.default_model_name,
modelName: firstLlmProvider.model_configurations[0]?.name ?? "",
};
}
return null;
@@ -558,8 +567,13 @@ export function useLlmManager(
// Get all user-accessible providers via SWR (general providers - no persona filter)
// This includes public + all restricted providers user can access via groups
const { llmProviders: allUserProviders, isLoading: isLoadingAllProviders } =
useLLMProviders();
const {
llmProviders: allUserProviders,
defaultText,
isLoading: isLoadingAllProviders,
} = useLLMProviders();
const defaultTextLlmModel = defaultText ?? undefined;
// Fetch persona-specific providers to enforce RBAC restrictions per assistant
// Only fetch if we have an assistant selected
const personaId =
@@ -581,6 +595,13 @@ export function useLlmManager(
modelName: "",
});
const [defaultLlmModel, setDefaultLlmModel] = useState<
DefaultModel | undefined
>(defaultTextLlmModel);
const updateDefaultLlmModel = (newDefaultLlmModel: DefaultModel) => {
setDefaultLlmModel(newDefaultLlmModel);
};
// Track the previous assistant ID to detect when it changes
const prevAssistantIdRef = useRef<number | undefined>(undefined);
@@ -629,7 +650,10 @@ export function useLlmManager(
} else if (user?.preferences?.default_model) {
setCurrentLlm(getValidLlmDescriptor(user.preferences.default_model));
} else {
const defaultLlm = getDefaultLlmDescriptor(llmProviders);
const defaultLlm = getDefaultLlmDescriptor(
llmProviders,
defaultTextLlmModel
);
if (defaultLlm) {
setCurrentLlm(defaultLlm);
}
@@ -679,7 +703,7 @@ export function useLlmManager(
// Model not found in available providers - fall back to default model
return (
getDefaultLlmDescriptor(llmProviders) ?? {
getDefaultLlmDescriptor(llmProviders, defaultTextLlmModel) ?? {
name: "",
provider: "",
modelName: "",
@@ -787,6 +811,8 @@ export function useLlmManager(
updateModelOverrideBasedOnChatSession,
currentLlm,
updateCurrentLlm,
defaultLlmModel,
updateDefaultLlmModel,
temperature,
updateTemperature,
imageFilesPresent,

View File

@@ -1,5 +1,8 @@
import useSWR from "swr";
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
import {
LLMProviderDescriptor,
LLMProviderResponse,
} from "@/app/admin/configuration/llm/interfaces";
import { errorHandlingFetcher } from "@/lib/fetcher";
export function useLLMProviders(personaId?: number) {
@@ -12,17 +15,17 @@ export function useLLMProviders(personaId?: number) {
? `/api/llm/persona/${personaId}/providers`
: "/api/llm/provider";
const { data, error, mutate } = useSWR<LLMProviderDescriptor[] | undefined>(
url,
errorHandlingFetcher,
{
revalidateOnFocus: false, // Cache aggressively for performance
dedupingInterval: 60000, // Dedupe requests within 1 minute
}
);
const { data, error, mutate } = useSWR<
LLMProviderResponse<LLMProviderDescriptor> | undefined
>(url, errorHandlingFetcher, {
revalidateOnFocus: false, // Cache aggressively for performance
dedupingInterval: 60000, // Dedupe requests within 1 minute
});
return {
llmProviders: data,
llmProviders: data?.providers,
defaultText: data?.default_text,
defaultVision: data?.default_vision,
isLoading: !error && !data,
error,
refetch: mutate,

View File

@@ -1,14 +0,0 @@
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
import { fetchSS } from "../utilsSS";
export async function fetchLLMProvidersSS() {
// Test helper: allow Playwright runs to force an empty provider list so onboarding appears.
if (process.env.PLAYWRIGHT_FORCE_EMPTY_LLM_PROVIDERS === "true") {
return [];
}
const response = await fetchSS("/llm/provider");
if (response.ok) {
return (await response.json()) as LLMProviderDescriptor[];
}
return [];
}

View File

@@ -1,5 +1,6 @@
import { MinimalPersonaSnapshot } from "@/app/admin/assistants/interfaces";
import {
DefaultModel,
LLMProviderDescriptor,
ModelConfiguration,
} from "@/app/admin/configuration/llm/interfaces";
@@ -8,14 +9,15 @@ import { LlmDescriptor } from "@/lib/hooks";
export function getFinalLLM(
llmProviders: LLMProviderDescriptor[],
persona: MinimalPersonaSnapshot | null,
currentLlm: LlmDescriptor | null
currentLlm: LlmDescriptor | null,
defaultLlmModel?: DefaultModel
): [string, string] {
const defaultProvider = llmProviders.find(
(llmProvider) => llmProvider.is_default_provider
(llmProvider) => llmProvider.id === defaultLlmModel?.provider_id
);
let provider = defaultProvider?.provider || "";
let model = defaultProvider?.default_model_name || "";
let model = defaultLlmModel?.model_name || "";
if (persona) {
// Map "provider override" to actual LLLMProvider

View File

@@ -1,6 +1,11 @@
import { VisionProvider } from "@/app/admin/configuration/llm/interfaces";
import {
LLMProviderResponse,
VisionProvider,
} from "@/app/admin/configuration/llm/interfaces";
export async function fetchVisionProviders(): Promise<VisionProvider[]> {
export async function fetchVisionProviders(): Promise<
LLMProviderResponse<VisionProvider>
> {
const response = await fetch("/api/admin/llm/vision-providers", {
headers: {
"Content-Type": "application/json",

View File

@@ -93,15 +93,7 @@ export const testApiKeyHelper = async (
...(formValues?.custom_config ?? {}),
...(customConfigOverride ?? {}),
},
default_model_name: modelName ?? formValues?.default_model_name ?? "",
model_configurations: [
...(formValues.model_configurations || []).map(
(model: ModelConfiguration) => ({
name: model.name,
is_visible: true,
})
),
],
model: modelName ?? formValues?.default_model_name ?? "",
};
return await submitLlmTestRequest(

View File

@@ -6,7 +6,10 @@ import {
ModelConfiguration,
WellKnownLLMProviderDescriptor,
} from "@/app/admin/configuration/llm/interfaces";
import { LLM_PROVIDERS_ADMIN_URL } from "@/app/admin/configuration/llm/constants";
import {
LLM_ADMIN_URL,
LLM_PROVIDERS_ADMIN_URL,
} from "@/app/admin/configuration/llm/constants";
import { OnboardingActions, OnboardingState } from "../types";
import { APIFormFieldState } from "@/refresh-components/form/types";
import {
@@ -225,10 +228,19 @@ export function OnboardingFormWrapper<T extends Record<string, any>>({
try {
const newLlmProvider = await response.json();
if (newLlmProvider?.id != null) {
const setDefaultResponse = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`,
{ method: "POST" }
);
const defaultLLMModel =
payload.default_model_name ??
newLlmProvider.model_configurations[0].name;
const setDefaultResponse = await fetch(`${LLM_ADMIN_URL}/default`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
provider_id: newLlmProvider.id,
model_name: defaultLLMModel,
}),
});
if (!setDefaultResponse.ok) {
const err = await setDefaultResponse.json().catch(() => ({}));
console.error("Failed to set provider as default", err?.detail);

View File

@@ -401,8 +401,14 @@ test("test, create, and set as default", async () => {
expect.objectContaining({ method: "PUT" })
);
expect(fetchSpy).toHaveBeenCalledWith(
"/api/llm/provider/5/default",
expect.objectContaining({ method: "POST" })
"/api/llm/provider/default",
expect.objectContaining({
method: "POST",
body: JSON.stringify({
provider_id: 5,
model_name: "gpt-4"
})
})
);
});
});

View File

@@ -67,7 +67,7 @@ test.describe("First user onboarding flow", () => {
}
);
await page.route("**/api/admin/llm/provider/*/default", async (route) => {
await page.route("**/api/admin/llm/default", async (route) => {
if (route.request().method() === "POST") {
await route.fulfill({
status: 200,