mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-16 23:35:46 +00:00
Compare commits
60 Commits
v2.12.0-cl
...
craft_chan
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
908d360011 | ||
|
|
30578bdf9a | ||
|
|
aebde89432 | ||
|
|
4a4b4bb378 | ||
|
|
a8d231976a | ||
|
|
9c8ae5bb4b | ||
|
|
0fc1fa3d36 | ||
|
|
94633698c3 | ||
|
|
6ae15589cd | ||
|
|
c24a8bb228 | ||
|
|
01945abd86 | ||
|
|
658632195f | ||
|
|
ec6fd01ba4 | ||
|
|
148e6fb97d | ||
|
|
6598c1a48d | ||
|
|
497ce43bd8 | ||
|
|
8634cb0446 | ||
|
|
8d56fd3dc6 | ||
|
|
a7579a99d0 | ||
|
|
3533c10da4 | ||
|
|
7b0414bf0d | ||
|
|
b500ea537a | ||
|
|
abd6d55add | ||
|
|
f15b6b8034 | ||
|
|
fb40485f25 | ||
|
|
22e85f1f28 | ||
|
|
2ef7c3e6f3 | ||
|
|
92a471ed2b | ||
|
|
d1b7e529a4 | ||
|
|
95c3579264 | ||
|
|
8802e5cad3 | ||
|
|
a41b4bbc82 | ||
|
|
c026c077b5 | ||
|
|
3eee539a86 | ||
|
|
143e7a0d72 | ||
|
|
4572358038 | ||
|
|
1753f94c11 | ||
|
|
120ddf2ef6 | ||
|
|
2cce5bc58f | ||
|
|
383a6001d2 | ||
|
|
3a6f45bfca | ||
|
|
e06b5ef202 | ||
|
|
c13ce816fa | ||
|
|
39f3e872ec | ||
|
|
b033c00217 | ||
|
|
6d47c5f21a | ||
|
|
0645540e24 | ||
|
|
a2c0fc4df0 | ||
|
|
7dccc88b35 | ||
|
|
ac617a51ce | ||
|
|
339a111a8f | ||
|
|
09b7e6fc9b | ||
|
|
135238014f | ||
|
|
303e37bf53 | ||
|
|
6a888e9900 | ||
|
|
e90a7767c6 | ||
|
|
1ded3af63c | ||
|
|
c53546c000 | ||
|
|
9afa12edda | ||
|
|
32046de962 |
@@ -0,0 +1,58 @@
|
||||
"""LLMProvider deprecated fields are nullable
|
||||
|
||||
Revision ID: 001984c88745
|
||||
Revises: 01f8e6d95a33
|
||||
Create Date: 2026-02-01 22:24:34.171100
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "001984c88745"
|
||||
down_revision = "01f8e6d95a33"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Make default_model_name nullable (was NOT NULL)
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"default_model_name",
|
||||
existing_type=sa.String(),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Remove server_default from is_default_vision_provider (was server_default=false())
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"is_default_vision_provider",
|
||||
existing_type=sa.Boolean(),
|
||||
server_default=None,
|
||||
)
|
||||
|
||||
# is_default_provider and default_vision_model are already nullable with no server_default
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Restore default_model_name to NOT NULL (set empty string for any NULLs first)
|
||||
op.execute(
|
||||
"UPDATE llm_provider SET default_model_name = '' WHERE default_model_name IS NULL"
|
||||
)
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"default_model_name",
|
||||
existing_type=sa.String(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Restore server_default for is_default_vision_provider
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"is_default_vision_provider",
|
||||
existing_type=sa.Boolean(),
|
||||
server_default=sa.false(),
|
||||
)
|
||||
@@ -0,0 +1,112 @@
|
||||
"""Populate flow mapping data
|
||||
|
||||
Revision ID: 01f8e6d95a33
|
||||
Revises: f220515df7b4
|
||||
Create Date: 2026-01-31 17:37:10.485558
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "01f8e6d95a33"
|
||||
down_revision = "f220515df7b4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add each model config to the conversation flow, setting the global default if it exists
|
||||
# Exclude models that are part of ImageGenerationConfig
|
||||
op.execute(
|
||||
"""
|
||||
INSERT INTO llm_model_flow (llm_model_flow_type, is_default, model_configuration_id)
|
||||
SELECT
|
||||
'chat' AS llm_model_flow_type,
|
||||
COALESCE(
|
||||
(lp.is_default_provider IS TRUE AND lp.default_model_name = mc.name),
|
||||
FALSE
|
||||
) AS is_default,
|
||||
mc.id AS model_configuration_id
|
||||
FROM model_configuration mc
|
||||
LEFT JOIN llm_provider lp
|
||||
ON lp.id = mc.llm_provider_id
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM image_generation_config igc
|
||||
WHERE igc.model_configuration_id = mc.id
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add models with supports_image_input to the vision flow
|
||||
op.execute(
|
||||
"""
|
||||
INSERT INTO llm_model_flow (llm_model_flow_type, is_default, model_configuration_id)
|
||||
SELECT
|
||||
'vision' AS llm_model_flow_type,
|
||||
COALESCE(
|
||||
(lp.is_default_vision_provider IS TRUE AND lp.default_vision_model = mc.name),
|
||||
FALSE
|
||||
) AS is_default,
|
||||
mc.id AS model_configuration_id
|
||||
FROM model_configuration mc
|
||||
LEFT JOIN llm_provider lp
|
||||
ON lp.id = mc.llm_provider_id
|
||||
WHERE mc.supports_image_input IS TRUE;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Populate vision defaults from model_flow
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE llm_provider AS lp
|
||||
SET
|
||||
is_default_vision_provider = TRUE,
|
||||
default_vision_model = mc.name
|
||||
FROM llm_model_flow mf
|
||||
JOIN model_configuration mc ON mc.id = mf.model_configuration_id
|
||||
WHERE mf.llm_model_flow_type = 'vision'
|
||||
AND mf.is_default = TRUE
|
||||
AND mc.llm_provider_id = lp.id;
|
||||
"""
|
||||
)
|
||||
|
||||
# Populate conversation defaults from model_flow
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE llm_provider AS lp
|
||||
SET
|
||||
is_default_provider = TRUE,
|
||||
default_model_name = mc.name
|
||||
FROM llm_model_flow mf
|
||||
JOIN model_configuration mc ON mc.id = mf.model_configuration_id
|
||||
WHERE mf.llm_model_flow_type = 'chat'
|
||||
AND mf.is_default = TRUE
|
||||
AND mc.llm_provider_id = lp.id;
|
||||
"""
|
||||
)
|
||||
|
||||
# For providers that have conversation flow mappings but aren't the default,
|
||||
# we still need a default_model_name (it was NOT NULL originally)
|
||||
# Pick the first visible model or any model for that provider
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE llm_provider AS lp
|
||||
SET default_model_name = (
|
||||
SELECT mc.name
|
||||
FROM model_configuration mc
|
||||
JOIN llm_model_flow mf ON mf.model_configuration_id = mc.id
|
||||
WHERE mc.llm_provider_id = lp.id
|
||||
AND mf.llm_model_flow_type = 'chat'
|
||||
ORDER BY mc.is_visible DESC, mc.id ASC
|
||||
LIMIT 1
|
||||
)
|
||||
WHERE lp.default_model_name IS NULL;
|
||||
"""
|
||||
)
|
||||
|
||||
# Delete all model_flow entries (reverse the inserts from upgrade)
|
||||
op.execute("DELETE FROM llm_model_flow;")
|
||||
@@ -0,0 +1,57 @@
|
||||
"""Add flow mapping table
|
||||
|
||||
Revision ID: f220515df7b4
|
||||
Revises: cbc03e08d0f3
|
||||
Create Date: 2026-01-30 12:21:24.955922
|
||||
|
||||
"""
|
||||
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f220515df7b4"
|
||||
down_revision = "9d1543a37106"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"llm_model_flow",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"llm_model_flow_type",
|
||||
sa.Enum(LLMModelFlowType, name="llmmodelflowtype", native_enum=False),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"is_default", sa.Boolean(), nullable=False, server_default=sa.text("false")
|
||||
),
|
||||
sa.Column("model_configuration_id", sa.Integer(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["model_configuration_id"], ["model_configuration.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
"llm_model_flow_type",
|
||||
"model_configuration_id",
|
||||
name="uq_model_config_per_llm_model_flow_type",
|
||||
),
|
||||
)
|
||||
|
||||
# Partial unique index so that there is at most one default for each flow type
|
||||
op.create_index(
|
||||
"ix_one_default_per_llm_model_flow",
|
||||
"llm_model_flow",
|
||||
["llm_model_flow_type"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("is_default IS TRUE"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the llm_model_flow table (index is dropped automatically with table)
|
||||
op.drop_table("llm_model_flow")
|
||||
@@ -123,9 +123,14 @@ def _seed_llms(
|
||||
upsert_llm_provider(llm_upsert_request, db_session)
|
||||
for llm_upsert_request in llm_upsert_requests
|
||||
]
|
||||
update_default_provider(
|
||||
provider_id=seeded_providers[0].id, db_session=db_session
|
||||
)
|
||||
|
||||
if len(seeded_providers[0].model_configurations) > 0:
|
||||
default_model = seeded_providers[0].model_configurations[0].name
|
||||
update_default_provider(
|
||||
provider_id=seeded_providers[0].id,
|
||||
model_name=default_model,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:
|
||||
|
||||
@@ -300,12 +300,12 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
|
||||
has_set_default_provider = False
|
||||
|
||||
def _upsert(request: LLMProviderUpsertRequest) -> None:
|
||||
def _upsert(request: LLMProviderUpsertRequest, default_model: str) -> None:
|
||||
nonlocal has_set_default_provider
|
||||
try:
|
||||
provider = upsert_llm_provider(request, db_session)
|
||||
if not has_set_default_provider:
|
||||
update_default_provider(provider.id, db_session)
|
||||
update_default_provider(provider.id, default_model, db_session)
|
||||
has_set_default_provider = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure {request.provider} provider: {e}")
|
||||
@@ -323,14 +323,13 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
OPENAI_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openai_provider)
|
||||
_upsert(openai_provider, default_model_name)
|
||||
|
||||
# Create default image generation config using the OpenAI API key
|
||||
try:
|
||||
@@ -359,14 +358,13 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="Anthropic",
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
ANTHROPIC_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(anthropic_provider)
|
||||
_upsert(anthropic_provider, default_model_name)
|
||||
else:
|
||||
logger.info(
|
||||
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
|
||||
@@ -391,14 +389,13 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="Google Vertex AI",
|
||||
provider=VERTEXAI_PROVIDER_NAME,
|
||||
custom_config=custom_config,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
VERTEXAI_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(vertexai_provider)
|
||||
_upsert(vertexai_provider, default_model_name)
|
||||
else:
|
||||
logger.info(
|
||||
"VERTEXAI_DEFAULT_CREDENTIALS not set, skipping Vertex AI provider configuration"
|
||||
@@ -430,12 +427,11 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="OpenRouter",
|
||||
provider=OPENROUTER_PROVIDER_NAME,
|
||||
api_key=OPENROUTER_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=model_configurations,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openrouter_provider)
|
||||
_upsert(openrouter_provider, default_model_name)
|
||||
else:
|
||||
logger.info(
|
||||
"OPENROUTER_DEFAULT_API_KEY not set, skipping OpenRouter provider configuration"
|
||||
|
||||
@@ -285,3 +285,9 @@ class HierarchyNodeType(str, PyEnum):
|
||||
|
||||
# Slack
|
||||
CHANNEL = "channel"
|
||||
|
||||
|
||||
class LLMModelFlowType(str, PyEnum):
|
||||
CHAT = "chat"
|
||||
VISION = "vision"
|
||||
EMBEDDINGS = "embeddings"
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import ImageGenerationConfig
|
||||
from onyx.db.models import LLMModelFlow
|
||||
from onyx.db.models import LLMProvider as LLMProviderModel
|
||||
from onyx.db.models import LLMProvider__Persona
|
||||
from onyx.db.models import LLMProvider__UserGroup
|
||||
@@ -233,9 +236,6 @@ def upsert_llm_provider(
|
||||
existing_llm_provider.api_base = llm_provider_upsert_request.api_base
|
||||
existing_llm_provider.api_version = llm_provider_upsert_request.api_version
|
||||
existing_llm_provider.custom_config = custom_config
|
||||
existing_llm_provider.default_model_name = (
|
||||
llm_provider_upsert_request.default_model_name
|
||||
)
|
||||
existing_llm_provider.is_public = llm_provider_upsert_request.is_public
|
||||
existing_llm_provider.is_auto_mode = llm_provider_upsert_request.is_auto_mode
|
||||
existing_llm_provider.deployment_name = llm_provider_upsert_request.deployment_name
|
||||
@@ -244,9 +244,29 @@ def upsert_llm_provider(
|
||||
# If its not already in the db, we need to generate an ID by flushing
|
||||
db_session.flush()
|
||||
|
||||
# Delete existing model configurations
|
||||
models_to_exist = {
|
||||
mc.name for mc in llm_provider_upsert_request.model_configurations
|
||||
}
|
||||
|
||||
removed_model_configuration_ids = {
|
||||
mc.id
|
||||
for mc in existing_llm_provider.model_configurations
|
||||
if mc.name not in models_to_exist
|
||||
}
|
||||
|
||||
updated_model_configuration_names = {
|
||||
mc.name: mc.id
|
||||
for mc in existing_llm_provider.model_configurations
|
||||
if mc.name in models_to_exist
|
||||
}
|
||||
|
||||
new_model_names = models_to_exist - {
|
||||
mc.name for mc in existing_llm_provider.model_configurations
|
||||
}
|
||||
|
||||
# Delete removed models
|
||||
db_session.query(ModelConfiguration).filter(
|
||||
ModelConfiguration.llm_provider_id == existing_llm_provider.id
|
||||
ModelConfiguration.id.in_(removed_model_configuration_ids)
|
||||
).delete(synchronize_session="fetch")
|
||||
|
||||
db_session.flush()
|
||||
@@ -263,18 +283,31 @@ def upsert_llm_provider(
|
||||
model_provider=llm_provider_upsert_request.provider,
|
||||
)
|
||||
|
||||
db_session.execute(
|
||||
insert(ModelConfiguration)
|
||||
.values(
|
||||
supported_flows = [LLMModelFlowType.CHAT]
|
||||
if model_configuration.supports_image_input:
|
||||
supported_flows.append(LLMModelFlowType.VISION)
|
||||
|
||||
if model_configuration.name in new_model_names:
|
||||
insert_new_model_configuration__no_commit(
|
||||
db_session=db_session,
|
||||
llm_provider_id=existing_llm_provider.id,
|
||||
name=model_configuration.name,
|
||||
model_name=model_configuration.name,
|
||||
supported_flows=supported_flows,
|
||||
is_visible=model_configuration.is_visible,
|
||||
max_input_tokens=max_input_tokens,
|
||||
display_name=model_configuration.display_name,
|
||||
)
|
||||
elif model_configuration.name in updated_model_configuration_names:
|
||||
update_model_configuration__no_commit(
|
||||
db_session=db_session,
|
||||
model_configuration_id=updated_model_configuration_names[
|
||||
model_configuration.name
|
||||
],
|
||||
supported_flows=supported_flows,
|
||||
is_visible=model_configuration.is_visible,
|
||||
max_input_tokens=max_input_tokens,
|
||||
supports_image_input=model_configuration.supports_image_input,
|
||||
display_name=model_configuration.display_name,
|
||||
)
|
||||
.on_conflict_do_nothing()
|
||||
)
|
||||
|
||||
# Make sure the relationship table stays up to date
|
||||
update_group_llm_provider_relationships__no_commit(
|
||||
@@ -331,17 +364,18 @@ def sync_model_configurations(
|
||||
model_name = model["name"]
|
||||
if model_name not in existing_names:
|
||||
# Insert new model with is_visible=False (user must explicitly enable)
|
||||
db_session.execute(
|
||||
insert(ModelConfiguration)
|
||||
.values(
|
||||
llm_provider_id=provider.id,
|
||||
name=model_name,
|
||||
is_visible=False,
|
||||
max_input_tokens=model.get("max_input_tokens"),
|
||||
supports_image_input=model.get("supports_image_input", False),
|
||||
display_name=model.get("display_name"),
|
||||
)
|
||||
.on_conflict_do_nothing()
|
||||
supported_flows = [LLMModelFlowType.CHAT]
|
||||
if model.get("supports_image_input", False):
|
||||
supported_flows.append(LLMModelFlowType.VISION)
|
||||
|
||||
insert_new_model_configuration__no_commit(
|
||||
db_session=db_session,
|
||||
llm_provider_id=provider.id,
|
||||
model_name=model_name,
|
||||
supported_flows=supported_flows,
|
||||
is_visible=False,
|
||||
max_input_tokens=model.get("max_input_tokens"),
|
||||
display_name=model.get("display_name"),
|
||||
)
|
||||
new_count += 1
|
||||
|
||||
@@ -371,8 +405,26 @@ def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolM
|
||||
)
|
||||
|
||||
|
||||
def fetch_existing_models(
|
||||
db_session: Session,
|
||||
flow_types: list[LLMModelFlowType],
|
||||
) -> list[ModelConfiguration]:
|
||||
models = (
|
||||
select(ModelConfiguration)
|
||||
.join(LLMModelFlow)
|
||||
.where(LLMModelFlow.llm_model_flow_type.in_(flow_types))
|
||||
.options(
|
||||
selectinload(ModelConfiguration.llm_provider),
|
||||
selectinload(ModelConfiguration.llm_model_flows),
|
||||
)
|
||||
)
|
||||
|
||||
return list(db_session.scalars(models).all())
|
||||
|
||||
|
||||
def fetch_existing_llm_providers(
|
||||
db_session: Session,
|
||||
flow_types: list[LLMModelFlowType],
|
||||
only_public: bool = False,
|
||||
exclude_image_generation_providers: bool = True,
|
||||
) -> list[LLMProviderModel]:
|
||||
@@ -380,23 +432,37 @@ def fetch_existing_llm_providers(
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
flow_types: List of flow types to filter by
|
||||
only_public: If True, only return public providers
|
||||
exclude_image_generation_providers: If True, exclude providers that are
|
||||
used for image generation configs
|
||||
"""
|
||||
stmt = select(LLMProviderModel).options(
|
||||
providers_with_flows = (
|
||||
select(ModelConfiguration.llm_provider_id)
|
||||
.join(LLMModelFlow)
|
||||
.where(LLMModelFlow.llm_model_flow_type.in_(flow_types))
|
||||
.distinct()
|
||||
)
|
||||
|
||||
if exclude_image_generation_providers:
|
||||
stmt = select(LLMProviderModel).where(
|
||||
LLMProviderModel.id.in_(providers_with_flows)
|
||||
)
|
||||
else:
|
||||
image_gen_provider_ids = select(ModelConfiguration.llm_provider_id).join(
|
||||
ImageGenerationConfig
|
||||
)
|
||||
stmt = select(LLMProviderModel).where(
|
||||
LLMProviderModel.id.in_(providers_with_flows)
|
||||
| LLMProviderModel.id.in_(image_gen_provider_ids)
|
||||
)
|
||||
|
||||
stmt = stmt.options(
|
||||
selectinload(LLMProviderModel.model_configurations),
|
||||
selectinload(LLMProviderModel.groups),
|
||||
selectinload(LLMProviderModel.personas),
|
||||
)
|
||||
|
||||
if exclude_image_generation_providers:
|
||||
# Get LLM provider IDs used by ImageGenerationConfig
|
||||
image_gen_provider_ids = select(ModelConfiguration.llm_provider_id).join(
|
||||
ImageGenerationConfig
|
||||
)
|
||||
stmt = stmt.where(LLMProviderModel.id.not_in(image_gen_provider_ids))
|
||||
|
||||
providers = list(db_session.scalars(stmt).all())
|
||||
if only_public:
|
||||
return [provider for provider in providers if provider.is_public]
|
||||
@@ -429,26 +495,29 @@ def fetch_embedding_provider(
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_provider(db_session: Session) -> LLMProviderView | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel)
|
||||
.where(LLMProviderModel.is_default_provider == True) # noqa: E712
|
||||
.options(selectinload(LLMProviderModel.model_configurations))
|
||||
)
|
||||
if not provider_model:
|
||||
return None
|
||||
return LLMProviderView.from_model(provider_model)
|
||||
def fetch_default_llm_model(db_session: Session) -> ModelConfiguration | None:
|
||||
return fetch_default_model(db_session, LLMModelFlowType.CHAT)
|
||||
|
||||
|
||||
def fetch_default_vision_provider(db_session: Session) -> LLMProviderView | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel)
|
||||
.where(LLMProviderModel.is_default_vision_provider == True) # noqa: E712
|
||||
.options(selectinload(LLMProviderModel.model_configurations))
|
||||
def fetch_default_vision_model(db_session: Session) -> ModelConfiguration | None:
|
||||
return fetch_default_model(db_session, LLMModelFlowType.VISION)
|
||||
|
||||
|
||||
def fetch_default_model(
|
||||
db_session: Session,
|
||||
flow_type: LLMModelFlowType,
|
||||
) -> ModelConfiguration | None:
|
||||
model_config = db_session.scalar(
|
||||
select(ModelConfiguration)
|
||||
.join(LLMModelFlow)
|
||||
.where(
|
||||
ModelConfiguration.is_visible == True, # noqa: E712
|
||||
LLMModelFlow.llm_model_flow_type == flow_type,
|
||||
LLMModelFlow.is_default == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
if not provider_model:
|
||||
return None
|
||||
return LLMProviderView.from_model(provider_model)
|
||||
|
||||
return model_config
|
||||
|
||||
|
||||
def fetch_llm_provider_view(
|
||||
@@ -526,61 +595,40 @@ def remove_llm_provider__no_commit(db_session: Session, provider_id: int) -> Non
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def update_default_provider(provider_id: int, db_session: Session) -> None:
|
||||
new_default = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
||||
def update_default_provider(
|
||||
provider_id: int, model_name: str, db_session: Session
|
||||
) -> None:
|
||||
_update_default_model(
|
||||
db_session,
|
||||
provider_id,
|
||||
model_name,
|
||||
LLMModelFlowType.CHAT,
|
||||
)
|
||||
if not new_default:
|
||||
raise ValueError(f"LLM Provider with id {provider_id} does not exist")
|
||||
|
||||
existing_default = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.is_default_provider == True # noqa: E712
|
||||
)
|
||||
)
|
||||
if existing_default:
|
||||
existing_default.is_default_provider = None
|
||||
# required to ensure that the below does not cause a unique constraint violation
|
||||
db_session.flush()
|
||||
|
||||
new_default.is_default_provider = True
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_default_vision_provider(
|
||||
provider_id: int, vision_model: str | None, db_session: Session
|
||||
provider_id: int, vision_model: str, db_session: Session
|
||||
) -> None:
|
||||
new_default = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
||||
)
|
||||
if not new_default:
|
||||
raise ValueError(f"LLM Provider with id {provider_id} does not exist")
|
||||
|
||||
# Validate that the specified vision model supports image input
|
||||
model_to_validate = vision_model or new_default.default_model_name
|
||||
if model_to_validate:
|
||||
if not model_supports_image_input(model_to_validate, new_default.provider):
|
||||
raise ValueError(
|
||||
f"Model '{model_to_validate}' for provider '{new_default.provider}' does not support image input"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Model '{vision_model}' is not a valid model for provider '{new_default.provider}'"
|
||||
)
|
||||
|
||||
existing_default = db_session.scalar(
|
||||
provider = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.is_default_vision_provider == True # noqa: E712
|
||||
LLMProviderModel.id == provider_id,
|
||||
)
|
||||
)
|
||||
if existing_default:
|
||||
existing_default.is_default_vision_provider = None
|
||||
# required to ensure that the below does not cause a unique constraint violation
|
||||
db_session.flush()
|
||||
|
||||
new_default.is_default_vision_provider = True
|
||||
new_default.default_vision_model = vision_model
|
||||
db_session.commit()
|
||||
if provider is None:
|
||||
raise ValueError(f"LLM Provider with id={provider_id} does not exist")
|
||||
|
||||
if not model_supports_image_input(vision_model, provider.provider):
|
||||
raise ValueError(
|
||||
f"Model '{vision_model}' for provider '{provider.provider} does not support image input"
|
||||
)
|
||||
|
||||
_update_default_model(
|
||||
db_session=db_session,
|
||||
provider_id=provider_id,
|
||||
model=vision_model,
|
||||
flow_type=LLMModelFlowType.VISION,
|
||||
)
|
||||
|
||||
|
||||
def fetch_auto_mode_providers(db_session: Session) -> list[LLMProviderModel]:
|
||||
@@ -670,11 +718,162 @@ def sync_auto_mode_models(
|
||||
db_session.add(new_model)
|
||||
changes += 1
|
||||
|
||||
# In Auto mode, default model is always set from GitHub config
|
||||
default_model = llm_recommendations.get_default_model(provider.provider)
|
||||
if default_model and provider.default_model_name != default_model.name:
|
||||
provider.default_model_name = default_model.name
|
||||
changes += 1
|
||||
|
||||
db_session.commit()
|
||||
return changes
|
||||
|
||||
|
||||
def create_new_flow_mapping__no_commit(
|
||||
db_session: Session,
|
||||
model_configuration_id: int,
|
||||
flow_type: LLMModelFlowType,
|
||||
) -> LLMModelFlow:
|
||||
result = db_session.execute(
|
||||
insert(LLMModelFlow)
|
||||
.values(
|
||||
model_configuration_id=model_configuration_id,
|
||||
llm_model_flow_type=flow_type,
|
||||
is_default=False,
|
||||
)
|
||||
.on_conflict_do_nothing()
|
||||
.returning(LLMModelFlow)
|
||||
)
|
||||
|
||||
flow = result.scalar()
|
||||
if not flow:
|
||||
raise ValueError(
|
||||
f"Failed to create new flow mapping for model_configuration_id={model_configuration_id} and flow_type={flow_type}"
|
||||
)
|
||||
|
||||
return flow
|
||||
|
||||
|
||||
def insert_new_model_configuration__no_commit(
|
||||
db_session: Session,
|
||||
llm_provider_id: int,
|
||||
model_name: str,
|
||||
supported_flows: list[LLMModelFlowType],
|
||||
is_visible: bool,
|
||||
max_input_tokens: int | None,
|
||||
display_name: str | None,
|
||||
) -> int | None:
|
||||
result = db_session.execute(
|
||||
insert(ModelConfiguration)
|
||||
.values(
|
||||
llm_provider_id=llm_provider_id,
|
||||
name=model_name,
|
||||
is_visible=is_visible,
|
||||
max_input_tokens=max_input_tokens,
|
||||
display_name=display_name,
|
||||
)
|
||||
.on_conflict_do_nothing()
|
||||
.returning(ModelConfiguration.id)
|
||||
)
|
||||
|
||||
model_config_id = result.scalar()
|
||||
|
||||
if not model_config_id:
|
||||
return None
|
||||
|
||||
for flow_type in supported_flows:
|
||||
create_new_flow_mapping__no_commit(
|
||||
db_session=db_session,
|
||||
model_configuration_id=model_config_id,
|
||||
flow_type=flow_type,
|
||||
)
|
||||
|
||||
return model_config_id
|
||||
|
||||
|
||||
def update_model_configuration__no_commit(
|
||||
db_session: Session,
|
||||
model_configuration_id: int,
|
||||
supported_flows: list[LLMModelFlowType],
|
||||
is_visible: bool,
|
||||
max_input_tokens: int | None,
|
||||
display_name: str | None,
|
||||
) -> None:
|
||||
result = db_session.execute(
|
||||
update(ModelConfiguration)
|
||||
.values(
|
||||
is_visible=is_visible,
|
||||
max_input_tokens=max_input_tokens,
|
||||
display_name=display_name,
|
||||
)
|
||||
.where(ModelConfiguration.id == model_configuration_id)
|
||||
.returning(ModelConfiguration)
|
||||
)
|
||||
|
||||
model_configuration = result.scalar()
|
||||
if not model_configuration:
|
||||
raise ValueError(
|
||||
f"Failed to update model configuration with id={model_configuration_id}"
|
||||
)
|
||||
|
||||
new_flows = {
|
||||
flow_type
|
||||
for flow_type in supported_flows
|
||||
if flow_type not in model_configuration.llm_model_flow_types
|
||||
}
|
||||
removed_flows = {
|
||||
flow_type
|
||||
for flow_type in model_configuration.llm_model_flow_types
|
||||
if flow_type not in supported_flows
|
||||
}
|
||||
|
||||
for flow_type in new_flows:
|
||||
create_new_flow_mapping__no_commit(
|
||||
db_session=db_session,
|
||||
model_configuration_id=model_configuration_id,
|
||||
flow_type=flow_type,
|
||||
)
|
||||
|
||||
for flow_type in removed_flows:
|
||||
db_session.execute(
|
||||
delete(LLMModelFlow).where(
|
||||
LLMModelFlow.model_configuration_id == model_configuration_id,
|
||||
LLMModelFlow.llm_model_flow_type == flow_type,
|
||||
)
|
||||
)
|
||||
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def _update_default_model(
|
||||
db_session: Session,
|
||||
provider_id: int,
|
||||
model: str,
|
||||
flow_type: LLMModelFlowType,
|
||||
) -> None:
|
||||
result = db_session.execute(
|
||||
select(ModelConfiguration, LLMModelFlow)
|
||||
.join(
|
||||
LLMModelFlow, LLMModelFlow.model_configuration_id == ModelConfiguration.id
|
||||
)
|
||||
.where(
|
||||
ModelConfiguration.llm_provider_id == provider_id,
|
||||
ModelConfiguration.name == model,
|
||||
LLMModelFlow.llm_model_flow_type == flow_type,
|
||||
)
|
||||
).first()
|
||||
|
||||
if not result:
|
||||
raise ValueError(
|
||||
f"Model '{model}' is not a valid model for provider_id={provider_id}"
|
||||
)
|
||||
|
||||
model_config, new_default = result
|
||||
|
||||
# Clear existing default and set in an atomic operation
|
||||
db_session.execute(
|
||||
update(LLMModelFlow)
|
||||
.where(
|
||||
LLMModelFlow.llm_model_flow_type == flow_type,
|
||||
LLMModelFlow.is_default == True, # noqa: E712
|
||||
)
|
||||
.values(is_default=False)
|
||||
)
|
||||
|
||||
new_default.is_default = True
|
||||
model_config.is_visible = True
|
||||
|
||||
db_session.commit()
|
||||
|
||||
@@ -7,6 +7,7 @@ from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import validates
|
||||
|
||||
from typing_extensions import TypedDict # noreorder
|
||||
from uuid import UUID
|
||||
from pydantic import ValidationError
|
||||
@@ -72,6 +73,7 @@ from onyx.db.enums import (
|
||||
MCPAuthenticationPerformer,
|
||||
MCPTransport,
|
||||
MCPServerStatus,
|
||||
LLMModelFlowType,
|
||||
ThemePreference,
|
||||
SwitchoverType,
|
||||
)
|
||||
@@ -2704,14 +2706,9 @@ class LLMProvider(Base):
|
||||
custom_config: Mapped[dict[str, str] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
default_model_name: Mapped[str] = mapped_column(String)
|
||||
|
||||
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# should only be set for a single provider
|
||||
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
|
||||
is_default_vision_provider: Mapped[bool | None] = mapped_column(Boolean)
|
||||
default_vision_model: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
# EE only
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
# Auto mode: models, visibility, and defaults are managed by GitHub config
|
||||
@@ -2761,8 +2758,6 @@ class ModelConfiguration(Base):
|
||||
# - The end-user is configuring a model and chooses not to set a max-input-tokens limit.
|
||||
max_input_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
supports_image_input: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
|
||||
# Human-readable display name for the model.
|
||||
# For dynamic providers (OpenRouter, Bedrock, Ollama), this comes from the source API.
|
||||
# For static providers (OpenAI, Anthropic), this may be null and will fall back to LiteLLM.
|
||||
@@ -2773,6 +2768,51 @@ class ModelConfiguration(Base):
|
||||
back_populates="model_configurations",
|
||||
)
|
||||
|
||||
llm_model_flows: Mapped[list["LLMModelFlow"]] = relationship(
|
||||
"LLMModelFlow",
|
||||
back_populates="model_configuration",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def llm_model_flow_types(self) -> list[LLMModelFlowType]:
|
||||
return [flow.llm_model_flow_type for flow in self.llm_model_flows]
|
||||
|
||||
|
||||
class LLMModelFlow(Base):
|
||||
__tablename__ = "llm_model_flow"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
|
||||
llm_model_flow_type: Mapped[LLMModelFlowType] = mapped_column(
|
||||
Enum(LLMModelFlowType, native_enum=False), nullable=False
|
||||
)
|
||||
model_configuration_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("model_configuration.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
is_default: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
model_configuration: Mapped["ModelConfiguration"] = relationship(
|
||||
"ModelConfiguration",
|
||||
back_populates="llm_model_flows",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"llm_model_flow_type",
|
||||
"model_configuration_id",
|
||||
name="uq_model_config_per_llm_model_flow_type",
|
||||
),
|
||||
Index(
|
||||
"ix_one_default_per_llm_model_flow",
|
||||
"llm_model_flow_type",
|
||||
unique=True,
|
||||
postgresql_where=(is_default == True), # noqa: E712
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ImageGenerationConfig(Base):
|
||||
__tablename__ = "image_generation_config"
|
||||
|
||||
@@ -1,22 +1,20 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import can_user_access_llm_provider
|
||||
from onyx.db.llm import fetch_default_provider
|
||||
from onyx.db.llm import fetch_default_vision_provider
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.llm import fetch_default_vision_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_existing_models
|
||||
from onyx.db.llm import fetch_llm_provider_view
|
||||
from onyx.db.llm import fetch_user_group_ids
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.multi_llm import LitellmLLM
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.utils import get_max_input_tokens_from_llm_provider
|
||||
@@ -52,54 +50,6 @@ def _build_provider_extra_headers(
|
||||
return {}
|
||||
|
||||
|
||||
def get_llm_config_for_persona(
|
||||
persona: Persona,
|
||||
db_session: Session,
|
||||
llm_override: LLMOverride | None = None,
|
||||
) -> LLMConfig:
|
||||
"""Get LLM config from persona without access checks.
|
||||
|
||||
This function assumes access to the persona has already been verified.
|
||||
Use this when you need the LLM config but don't need to create the full LLM object.
|
||||
"""
|
||||
provider_name_override = llm_override.model_provider if llm_override else None
|
||||
model_version_override = llm_override.model_version if llm_override else None
|
||||
temperature_override = llm_override.temperature if llm_override else None
|
||||
|
||||
provider_name = provider_name_override or persona.llm_model_provider_override
|
||||
if not provider_name:
|
||||
llm_provider = fetch_default_provider(db_session)
|
||||
if not llm_provider:
|
||||
raise ValueError("No default LLM provider found")
|
||||
model_name: str | None = llm_provider.default_model_name
|
||||
else:
|
||||
llm_provider = fetch_llm_provider_view(db_session, provider_name)
|
||||
if not llm_provider:
|
||||
raise ValueError(f"No LLM provider found with name: {provider_name}")
|
||||
model_name = model_version_override or persona.llm_model_version_override
|
||||
if not model_name:
|
||||
model_name = llm_provider.default_model_name
|
||||
|
||||
if not model_name:
|
||||
raise ValueError("No model name found")
|
||||
|
||||
max_input_tokens = get_max_input_tokens_from_llm_provider(
|
||||
llm_provider=llm_provider, model_name=model_name
|
||||
)
|
||||
|
||||
return LLMConfig(
|
||||
model_provider=llm_provider.provider,
|
||||
model_name=model_name,
|
||||
temperature=temperature_override or GEN_AI_TEMPERATURE,
|
||||
api_key=llm_provider.api_key,
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
deployment_name=llm_provider.deployment_name,
|
||||
custom_config=llm_provider.custom_config,
|
||||
max_input_tokens=max_input_tokens,
|
||||
)
|
||||
|
||||
|
||||
def get_llm_for_persona(
|
||||
persona: Persona | PersonaOverrideConfig | None,
|
||||
user: User,
|
||||
@@ -203,46 +153,41 @@ def get_default_llm_with_vision(
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Try the default vision provider first
|
||||
default_provider = fetch_default_vision_provider(db_session)
|
||||
if default_provider and default_provider.default_vision_model:
|
||||
default_model = fetch_default_vision_model(db_session)
|
||||
if default_model:
|
||||
if model_supports_image_input(
|
||||
default_provider.default_vision_model, default_provider.provider
|
||||
default_model.name, default_model.llm_provider.provider
|
||||
):
|
||||
return create_vision_llm(
|
||||
default_provider, default_provider.default_vision_model
|
||||
LLMProviderView.from_model(default_model.llm_provider),
|
||||
default_model.name,
|
||||
)
|
||||
|
||||
# Fall back to searching all providers
|
||||
providers = fetch_existing_llm_providers(db_session)
|
||||
models = fetch_existing_models(
|
||||
db_session=db_session,
|
||||
flow_types=[LLMModelFlowType.VISION, LLMModelFlowType.CHAT],
|
||||
)
|
||||
|
||||
if not providers:
|
||||
if not models:
|
||||
return None
|
||||
|
||||
# Check all providers for viable vision models
|
||||
for provider in providers:
|
||||
provider_view = LLMProviderView.from_model(provider)
|
||||
# Search for viable vision model followed by conversation models
|
||||
# Sort models from VISION to CONVERSATION priority
|
||||
sorted_models = sorted(
|
||||
models,
|
||||
key=lambda x: (
|
||||
LLMModelFlowType.VISION in x.llm_model_flow_types,
|
||||
LLMModelFlowType.CHAT in x.llm_model_flow_types,
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
# First priority: Check if provider has a default_vision_model
|
||||
if provider.default_vision_model and model_supports_image_input(
|
||||
provider.default_vision_model, provider.provider
|
||||
):
|
||||
return create_vision_llm(provider_view, provider.default_vision_model)
|
||||
|
||||
# If no model-configurations are specified, try default model
|
||||
if not provider.model_configurations:
|
||||
# Try default_model_name
|
||||
if provider.default_model_name and model_supports_image_input(
|
||||
provider.default_model_name, provider.provider
|
||||
):
|
||||
return create_vision_llm(provider_view, provider.default_model_name)
|
||||
|
||||
# Otherwise, if model-configurations are specified, check each model
|
||||
else:
|
||||
for model_configuration in provider.model_configurations:
|
||||
if model_supports_image_input(
|
||||
model_configuration.name, provider.provider
|
||||
):
|
||||
return create_vision_llm(provider_view, model_configuration.name)
|
||||
for model in sorted_models:
|
||||
if model_supports_image_input(model.name, model.llm_provider.provider):
|
||||
return create_vision_llm(
|
||||
LLMProviderView.from_model(model.llm_provider),
|
||||
model.name,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@@ -288,22 +233,18 @@ def get_default_llm(
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
) -> LLM:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
llm_provider = fetch_default_provider(db_session)
|
||||
model = fetch_default_llm_model(db_session)
|
||||
|
||||
if not llm_provider:
|
||||
raise ValueError("No default LLM provider found")
|
||||
if not model:
|
||||
raise ValueError("No default LLM model found")
|
||||
|
||||
model_name = llm_provider.default_model_name
|
||||
if not model_name:
|
||||
raise ValueError("No default model name found")
|
||||
|
||||
return llm_from_provider(
|
||||
model_name=model_name,
|
||||
llm_provider=llm_provider,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
additional_headers=additional_headers,
|
||||
)
|
||||
return llm_from_provider(
|
||||
model_name=model.name,
|
||||
llm_provider=LLMProviderView.from_model(model.llm_provider),
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
additional_headers=additional_headers,
|
||||
)
|
||||
|
||||
|
||||
def get_llm(
|
||||
|
||||
@@ -17,6 +17,7 @@ from onyx.configs.model_configs import GEN_AI_MAX_TOKENS
|
||||
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from onyx.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.models import LLMProvider
|
||||
from onyx.db.models import ModelConfiguration
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
@@ -689,8 +690,11 @@ def model_supports_image_input(model_name: str, model_provider: str) -> bool:
|
||||
LLMProvider.provider == model_provider,
|
||||
)
|
||||
)
|
||||
if model_config and model_config.supports_image_input is not None:
|
||||
return model_config.supports_image_input
|
||||
if (
|
||||
model_config
|
||||
and LLMModelFlowType.VISION in model_config.llm_model_flow_types
|
||||
):
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to query database for {model_provider} model {model_name} image support: {e}"
|
||||
|
||||
@@ -28,7 +28,7 @@ from sqlalchemy.orm import Session as DBSession
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.enums import SandboxStatus
|
||||
from onyx.db.llm import fetch_default_provider
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.models import BuildMessage
|
||||
from onyx.db.models import BuildSession
|
||||
from onyx.db.models import User
|
||||
@@ -338,15 +338,15 @@ class SessionManager:
|
||||
)
|
||||
|
||||
# Fallback to system default
|
||||
default_provider = fetch_default_provider(self._db_session)
|
||||
if not default_provider:
|
||||
raise ValueError("No LLM provider configured")
|
||||
default_model = fetch_default_llm_model(self._db_session)
|
||||
if not default_model:
|
||||
raise ValueError("No default LLM model found")
|
||||
|
||||
return LLMProviderConfig(
|
||||
provider=default_provider.provider,
|
||||
model_name=default_provider.default_model_name,
|
||||
api_key=default_provider.api_key,
|
||||
api_base=default_provider.api_base,
|
||||
provider=default_model.llm_provider.provider,
|
||||
model_name=default_model.name,
|
||||
api_key=default_model.llm_provider.api_key,
|
||||
api_base=default_model.llm_provider.api_base,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
|
||||
@@ -93,7 +93,6 @@ def _build_llm_provider_request(
|
||||
api_key=source_provider.api_key, # Only this from source
|
||||
api_base=api_base, # From request
|
||||
api_version=api_version, # From request
|
||||
default_model_name=model_name,
|
||||
deployment_name=deployment_name, # From request
|
||||
is_public=True,
|
||||
groups=[],
|
||||
@@ -132,7 +131,6 @@ def _build_llm_provider_request(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
default_model_name=model_name,
|
||||
deployment_name=deployment_name,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
@@ -164,7 +162,6 @@ def _create_image_gen_llm_provider__no_commit(
|
||||
api_key=provider_request.api_key,
|
||||
api_base=provider_request.api_base,
|
||||
api_version=provider_request.api_version,
|
||||
default_model_name=provider_request.default_model_name,
|
||||
deployment_name=provider_request.deployment_name,
|
||||
is_public=provider_request.is_public,
|
||||
custom_config=provider_request.custom_config,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
@@ -19,9 +20,13 @@ from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import can_user_access_llm_provider
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.llm import fetch_default_vision_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_existing_models
|
||||
from onyx.db.llm import fetch_persona_with_groups
|
||||
from onyx.db.llm import fetch_user_group_ids
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
@@ -37,7 +42,6 @@ from onyx.llm.factory import get_llm
|
||||
from onyx.llm.factory import get_max_input_tokens_from_llm_provider
|
||||
from onyx.llm.utils import get_bedrock_token_limit
|
||||
from onyx.llm.utils import get_llm_contextual_cost
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.llm.utils import test_llm
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
fetch_llm_recommendations_from_github,
|
||||
@@ -50,11 +54,12 @@ from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
)
|
||||
from onyx.server.manage.llm.models import BedrockFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BedrockModelsRequest
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LLMCost
|
||||
from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderResponse
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OllamaModelDetails
|
||||
from onyx.server.manage.llm.models import OllamaModelsRequest
|
||||
@@ -211,7 +216,7 @@ def test_llm_configuration(
|
||||
|
||||
llm = get_llm(
|
||||
provider=test_llm_request.provider,
|
||||
model=test_llm_request.default_model_name,
|
||||
model=test_llm_request.model,
|
||||
api_key=test_api_key,
|
||||
api_base=test_llm_request.api_base,
|
||||
api_version=test_llm_request.api_version,
|
||||
@@ -246,13 +251,15 @@ def list_llm_providers(
|
||||
include_image_gen: bool = Query(False),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderView]:
|
||||
) -> LLMProviderResponse[LLMProviderView]:
|
||||
start_time = datetime.now(timezone.utc)
|
||||
logger.debug("Starting to fetch LLM providers")
|
||||
|
||||
llm_provider_list: list[LLMProviderView] = []
|
||||
for llm_provider_model in fetch_existing_llm_providers(
|
||||
db_session, exclude_image_generation_providers=not include_image_gen
|
||||
db_session=db_session,
|
||||
flow_types=[LLMModelFlowType.CHAT, LLMModelFlowType.VISION],
|
||||
exclude_image_generation_providers=not include_image_gen,
|
||||
):
|
||||
from_model_start = datetime.now(timezone.utc)
|
||||
full_llm_provider = LLMProviderView.from_model(llm_provider_model)
|
||||
@@ -269,7 +276,25 @@ def list_llm_providers(
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
logger.debug(f"Completed fetching LLM providers in {duration:.2f} seconds")
|
||||
|
||||
return llm_provider_list
|
||||
default_model = None
|
||||
if model_config := fetch_default_llm_model(db_session):
|
||||
default_model = DefaultModel(
|
||||
provider_id=model_config.llm_provider.id,
|
||||
model_name=model_config.name,
|
||||
)
|
||||
|
||||
default_vision_model = None
|
||||
if model_config := fetch_default_vision_model(db_session):
|
||||
default_vision_model = DefaultModel(
|
||||
provider_id=model_config.llm_provider.id,
|
||||
model_name=model_config.name,
|
||||
)
|
||||
|
||||
return LLMProviderResponse[LLMProviderView].from_models(
|
||||
providers=llm_provider_list,
|
||||
default_text=default_model,
|
||||
default_vision=default_vision_model,
|
||||
)
|
||||
|
||||
|
||||
@admin_router.put("/provider")
|
||||
@@ -328,20 +353,6 @@ def put_llm_provider(
|
||||
deduplicated_personas.append(persona_id)
|
||||
llm_provider_upsert_request.personas = deduplicated_personas
|
||||
|
||||
default_model_found = False
|
||||
|
||||
for model_configuration in llm_provider_upsert_request.model_configurations:
|
||||
if model_configuration.name == llm_provider_upsert_request.default_model_name:
|
||||
model_configuration.is_visible = True
|
||||
default_model_found = True
|
||||
|
||||
if not default_model_found:
|
||||
llm_provider_upsert_request.model_configurations.append(
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=llm_provider_upsert_request.default_model_name, is_visible=True
|
||||
)
|
||||
)
|
||||
|
||||
# the llm api key is sanitized when returned to clients, so the only time we
|
||||
# should get a real key is when it is explicitly changed
|
||||
if existing_provider and not llm_provider_upsert_request.api_key_changed:
|
||||
@@ -398,26 +409,29 @@ def delete_llm_provider(
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.post("/provider/{provider_id}/default")
|
||||
@admin_router.post("/default")
|
||||
def set_provider_as_default(
|
||||
provider_id: int,
|
||||
default_model_request: DefaultModel,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_default_provider(provider_id=provider_id, db_session=db_session)
|
||||
update_default_provider(
|
||||
provider_id=default_model_request.provider_id,
|
||||
model_name=default_model_request.model_name,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@admin_router.post("/provider/{provider_id}/default-vision")
|
||||
@admin_router.post("/default-vision")
|
||||
def set_provider_as_default_vision(
|
||||
provider_id: int,
|
||||
vision_model: str | None = Query(
|
||||
None, description="The default vision model to use"
|
||||
),
|
||||
default_model_request: DefaultModel,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_default_vision_provider(
|
||||
provider_id=provider_id, vision_model=vision_model, db_session=db_session
|
||||
provider_id=default_model_request.provider_id,
|
||||
vision_model=default_model_request.model_name,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@@ -443,43 +457,47 @@ def get_auto_config(
|
||||
def get_vision_capable_providers(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[VisionProviderResponse]:
|
||||
) -> LLMProviderResponse[VisionProviderResponse]:
|
||||
"""Return a list of LLM providers and their models that support image input"""
|
||||
vision_models = fetch_existing_models(
|
||||
db_session=db_session, flow_types=[LLMModelFlowType.VISION]
|
||||
)
|
||||
|
||||
providers = fetch_existing_llm_providers(db_session)
|
||||
vision_providers = []
|
||||
# Group vision models by provider ID (using ID as key since it's hashable)
|
||||
provider_models: dict[int, list[str]] = defaultdict(list)
|
||||
providers_by_id: dict[int, LLMProviderView] = {}
|
||||
|
||||
logger.info("Fetching vision-capable providers")
|
||||
|
||||
for provider in providers:
|
||||
vision_models = []
|
||||
|
||||
# Check each model for vision capability
|
||||
for model_configuration in provider.model_configurations:
|
||||
if model_supports_image_input(model_configuration.name, provider.provider):
|
||||
vision_models.append(model_configuration.name)
|
||||
logger.debug(
|
||||
f"Vision model found: {provider.provider}/{model_configuration.name}"
|
||||
)
|
||||
|
||||
# Only include providers with at least one vision-capable model
|
||||
if vision_models:
|
||||
provider_view = LLMProviderView.from_model(provider)
|
||||
for vision_model in vision_models:
|
||||
provider_id = vision_model.llm_provider.id
|
||||
provider_models[provider_id].append(vision_model.name)
|
||||
# Only create the view once per provider
|
||||
if provider_id not in providers_by_id:
|
||||
provider_view = LLMProviderView.from_model(vision_model.llm_provider)
|
||||
_mask_provider_credentials(provider_view)
|
||||
providers_by_id[provider_id] = provider_view
|
||||
|
||||
vision_providers.append(
|
||||
VisionProviderResponse(
|
||||
**provider_view.model_dump(),
|
||||
vision_models=vision_models,
|
||||
)
|
||||
)
|
||||
# Build response list
|
||||
vision_provider_response = [
|
||||
VisionProviderResponse(
|
||||
**providers_by_id[provider_id].model_dump(),
|
||||
vision_models=model_names,
|
||||
)
|
||||
for provider_id, model_names in provider_models.items()
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Vision provider: {provider.provider} with models: {vision_models}"
|
||||
)
|
||||
logger.debug(f"Found {len(vision_provider_response)} vision-capable providers")
|
||||
|
||||
logger.info(f"Found {len(vision_providers)} vision-capable providers")
|
||||
return vision_providers
|
||||
default_vision_model = None
|
||||
if model_config := fetch_default_vision_model(db_session):
|
||||
default_vision_model = DefaultModel(
|
||||
provider_id=model_config.llm_provider.id,
|
||||
model_name=model_config.name,
|
||||
)
|
||||
|
||||
return LLMProviderResponse[VisionProviderResponse].from_models(
|
||||
providers=vision_provider_response,
|
||||
default_vision=default_vision_model,
|
||||
)
|
||||
|
||||
|
||||
"""Endpoints for all"""
|
||||
@@ -489,7 +507,7 @@ def get_vision_capable_providers(
|
||||
def list_llm_provider_basics(
|
||||
user: User = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
) -> LLMProviderResponse[LLMProviderDescriptor]:
|
||||
"""Get LLM providers accessible to the current user.
|
||||
|
||||
Returns:
|
||||
@@ -502,7 +520,9 @@ def list_llm_provider_basics(
|
||||
start_time = datetime.now(timezone.utc)
|
||||
logger.debug("Starting to fetch user-accessible LLM providers")
|
||||
|
||||
all_providers = fetch_existing_llm_providers(db_session)
|
||||
all_providers = fetch_existing_llm_providers(
|
||||
db_session, [LLMModelFlowType.CHAT, LLMModelFlowType.VISION]
|
||||
)
|
||||
user_group_ids = fetch_user_group_ids(db_session, user)
|
||||
is_admin = user.role == UserRole.ADMIN
|
||||
|
||||
@@ -526,7 +546,25 @@ def list_llm_provider_basics(
|
||||
f"Completed fetching {len(accessible_providers)} user-accessible providers in {duration:.2f} seconds"
|
||||
)
|
||||
|
||||
return accessible_providers
|
||||
default_model = None
|
||||
if model_config := fetch_default_llm_model(db_session):
|
||||
default_model = DefaultModel(
|
||||
provider_id=model_config.llm_provider.id,
|
||||
model_name=model_config.name,
|
||||
)
|
||||
|
||||
default_vision_model = None
|
||||
if model_config := fetch_default_vision_model(db_session):
|
||||
default_vision_model = DefaultModel(
|
||||
provider_id=model_config.llm_provider.id,
|
||||
model_name=model_config.name,
|
||||
)
|
||||
|
||||
return LLMProviderResponse[LLMProviderDescriptor].from_models(
|
||||
providers=accessible_providers,
|
||||
default_text=default_model,
|
||||
default_vision=default_vision_model,
|
||||
)
|
||||
|
||||
|
||||
def get_valid_model_names_for_persona(
|
||||
@@ -545,7 +583,9 @@ def get_valid_model_names_for_persona(
|
||||
return []
|
||||
|
||||
is_admin = user.role == UserRole.ADMIN
|
||||
all_providers = fetch_existing_llm_providers(db_session)
|
||||
all_providers = fetch_existing_llm_providers(
|
||||
db_session, [LLMModelFlowType.CHAT, LLMModelFlowType.VISION]
|
||||
)
|
||||
user_group_ids = set() if is_admin else fetch_user_group_ids(db_session, user)
|
||||
|
||||
valid_models = []
|
||||
@@ -567,7 +607,7 @@ def list_llm_providers_for_persona(
|
||||
persona_id: int,
|
||||
user: User = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
) -> LLMProviderResponse[LLMProviderDescriptor]:
|
||||
"""Get LLM providers for a specific persona.
|
||||
|
||||
Returns providers that the user can access when using this persona:
|
||||
@@ -592,7 +632,9 @@ def list_llm_providers_for_persona(
|
||||
)
|
||||
|
||||
is_admin = user.role == UserRole.ADMIN
|
||||
all_providers = fetch_existing_llm_providers(db_session)
|
||||
all_providers = fetch_existing_llm_providers(
|
||||
db_session, [LLMModelFlowType.CHAT, LLMModelFlowType.VISION]
|
||||
)
|
||||
user_group_ids = set() if is_admin else fetch_user_group_ids(db_session, user)
|
||||
|
||||
llm_provider_list: list[LLMProviderDescriptor] = []
|
||||
@@ -612,7 +654,11 @@ def list_llm_providers_for_persona(
|
||||
f"Completed fetching {len(llm_provider_list)} LLM providers for persona {persona_id} in {duration:.2f} seconds"
|
||||
)
|
||||
|
||||
return llm_provider_list
|
||||
return LLMProviderResponse[LLMProviderDescriptor].from_models(
|
||||
providers=llm_provider_list,
|
||||
default_text=None,
|
||||
default_vision=None,
|
||||
)
|
||||
|
||||
|
||||
@admin_router.get("/provider-contextual-cost")
|
||||
@@ -630,7 +676,7 @@ def get_provider_contextual_cost(
|
||||
- the chunk_context
|
||||
- The per-token cost of the LLM used to generate the doc_summary and chunk_context
|
||||
"""
|
||||
providers = fetch_existing_llm_providers(db_session)
|
||||
providers = fetch_existing_llm_providers(db_session, [LLMModelFlowType.CHAT])
|
||||
costs = []
|
||||
for provider in providers:
|
||||
for model_configuration in provider.model_configurations:
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.llm.utils import get_max_input_tokens
|
||||
from onyx.llm.utils import litellm_thinks_model_supports_image_input
|
||||
from onyx.llm.utils import model_is_reasoning_model
|
||||
@@ -20,22 +23,22 @@ if TYPE_CHECKING:
|
||||
ModelConfiguration as ModelConfigurationModel,
|
||||
)
|
||||
|
||||
T = TypeVar("T", bound="LLMProviderDescriptor | LLMProviderView")
|
||||
|
||||
|
||||
class TestLLMRequest(BaseModel):
|
||||
# provider level
|
||||
name: str | None = None
|
||||
provider: str
|
||||
model: str
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
custom_config: dict[str, str] | None = None
|
||||
|
||||
# model level
|
||||
default_model_name: str
|
||||
deployment_name: str | None = None
|
||||
|
||||
model_configurations: list["ModelConfigurationUpsertRequest"]
|
||||
|
||||
# if try and use the existing API/custom config key
|
||||
api_key_changed: bool
|
||||
custom_config_changed: bool
|
||||
@@ -51,13 +54,10 @@ class LLMProviderDescriptor(BaseModel):
|
||||
"""A descriptor for an LLM provider that can be safely viewed by
|
||||
non-admin users. Used when giving a list of available LLMs."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
provider: str
|
||||
provider_display_name: str # Human-friendly name like "Claude (Anthropic)"
|
||||
default_model_name: str
|
||||
is_default_provider: bool | None
|
||||
is_default_vision_provider: bool | None
|
||||
default_vision_model: str | None
|
||||
model_configurations: list["ModelConfigurationView"]
|
||||
|
||||
@classmethod
|
||||
@@ -72,13 +72,10 @@ class LLMProviderDescriptor(BaseModel):
|
||||
provider = llm_provider_model.provider
|
||||
|
||||
return cls(
|
||||
id=llm_provider_model.id,
|
||||
name=llm_provider_model.name,
|
||||
provider=provider,
|
||||
provider_display_name=get_provider_display_name(provider),
|
||||
default_model_name=llm_provider_model.default_model_name,
|
||||
is_default_provider=llm_provider_model.is_default_provider,
|
||||
is_default_vision_provider=llm_provider_model.is_default_vision_provider,
|
||||
default_vision_model=llm_provider_model.default_vision_model,
|
||||
model_configurations=filter_model_configurations(
|
||||
llm_provider_model.model_configurations, provider
|
||||
),
|
||||
@@ -92,13 +89,11 @@ class LLMProvider(BaseModel):
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
custom_config: dict[str, str] | None = None
|
||||
default_model_name: str
|
||||
is_public: bool = True
|
||||
is_auto_mode: bool = False
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
personas: list[int] = Field(default_factory=list)
|
||||
deployment_name: str | None = None
|
||||
default_vision_model: str | None = None
|
||||
|
||||
|
||||
class LLMProviderUpsertRequest(LLMProvider):
|
||||
@@ -119,8 +114,6 @@ class LLMProviderView(LLMProvider):
|
||||
"""Stripped down representation of LLMProvider for display / limited access info only"""
|
||||
|
||||
id: int
|
||||
is_default_provider: bool | None = None
|
||||
is_default_vision_provider: bool | None = None
|
||||
model_configurations: list["ModelConfigurationView"]
|
||||
|
||||
@classmethod
|
||||
@@ -150,10 +143,6 @@ class LLMProviderView(LLMProvider):
|
||||
api_base=llm_provider_model.api_base,
|
||||
api_version=llm_provider_model.api_version,
|
||||
custom_config=llm_provider_model.custom_config,
|
||||
default_model_name=llm_provider_model.default_model_name,
|
||||
is_default_provider=llm_provider_model.is_default_provider,
|
||||
is_default_vision_provider=llm_provider_model.is_default_vision_provider,
|
||||
default_vision_model=llm_provider_model.default_vision_model,
|
||||
is_public=llm_provider_model.is_public,
|
||||
is_auto_mode=llm_provider_model.is_auto_mode,
|
||||
groups=groups,
|
||||
@@ -180,7 +169,8 @@ class ModelConfigurationUpsertRequest(BaseModel):
|
||||
name=model_configuration_model.name,
|
||||
is_visible=model_configuration_model.is_visible,
|
||||
max_input_tokens=model_configuration_model.max_input_tokens,
|
||||
supports_image_input=model_configuration_model.supports_image_input,
|
||||
supports_image_input=LLMModelFlowType.VISION
|
||||
in model_configuration_model.llm_model_flow_types,
|
||||
display_name=model_configuration_model.display_name,
|
||||
)
|
||||
|
||||
@@ -219,7 +209,8 @@ class ModelConfigurationView(BaseModel):
|
||||
is_visible=model_configuration_model.is_visible,
|
||||
max_input_tokens=model_configuration_model.max_input_tokens,
|
||||
supports_image_input=(
|
||||
model_configuration_model.supports_image_input or False
|
||||
LLMModelFlowType.VISION
|
||||
in model_configuration_model.llm_model_flow_types
|
||||
),
|
||||
# Infer reasoning support from model name/display name
|
||||
supports_reasoning=is_reasoning_model(
|
||||
@@ -261,8 +252,9 @@ class ModelConfigurationView(BaseModel):
|
||||
)
|
||||
),
|
||||
supports_image_input=(
|
||||
val
|
||||
if (val := model_configuration_model.supports_image_input) is not None
|
||||
True
|
||||
if LLMModelFlowType.VISION
|
||||
in model_configuration_model.llm_model_flow_types
|
||||
else litellm_thinks_model_supports_image_input(
|
||||
model_configuration_model.name, provider_name
|
||||
)
|
||||
@@ -371,3 +363,27 @@ class OpenRouterFinalModelResponse(BaseModel):
|
||||
int | None
|
||||
) # From OpenRouter API context_length (may be missing for some models)
|
||||
supports_image_input: bool
|
||||
|
||||
|
||||
class DefaultModel(BaseModel):
|
||||
provider_id: int
|
||||
model_name: str
|
||||
|
||||
|
||||
class LLMProviderResponse(BaseModel, Generic[T]):
|
||||
providers: list[T]
|
||||
default_text: DefaultModel | None = None
|
||||
default_vision: DefaultModel | None = None
|
||||
|
||||
@classmethod
|
||||
def from_models(
|
||||
cls,
|
||||
providers: list[T],
|
||||
default_text: DefaultModel | None = None,
|
||||
default_vision: DefaultModel | None = None,
|
||||
) -> "LLMProviderResponse[T]":
|
||||
return cls(
|
||||
providers=providers,
|
||||
default_text=default_text,
|
||||
default_vision=default_vision,
|
||||
)
|
||||
|
||||
@@ -23,7 +23,7 @@ from onyx.db.document import check_docs_exist
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from onyx.db.index_attempt import expire_index_attempts
|
||||
from onyx.db.llm import fetch_default_provider
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
@@ -286,7 +286,11 @@ def setup_postgres(db_session: Session) -> None:
|
||||
create_initial_default_connector(db_session)
|
||||
associate_default_cc_pair(db_session)
|
||||
|
||||
if GEN_AI_API_KEY and fetch_default_provider(db_session) is None:
|
||||
if (
|
||||
GEN_AI_API_KEY
|
||||
and fetch_default_llm_model(db_session) is None
|
||||
and not INTEGRATION_TESTS_MODE
|
||||
):
|
||||
# Only for dev flows
|
||||
logger.notice("Setting up default OpenAI LLM for dev.")
|
||||
|
||||
@@ -298,7 +302,6 @@ def setup_postgres(db_session: Session) -> None:
|
||||
api_base=None,
|
||||
api_version=None,
|
||||
custom_config=None,
|
||||
default_model_name=llm_model,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
model_configurations=[
|
||||
@@ -310,7 +313,9 @@ def setup_postgres(db_session: Session) -> None:
|
||||
new_llm_provider = upsert_llm_provider(
|
||||
llm_provider_upsert_request=model_req, db_session=db_session
|
||||
)
|
||||
update_default_provider(provider_id=new_llm_provider.id, db_session=db_session)
|
||||
update_default_provider(
|
||||
provider_id=new_llm_provider.id, model_name=llm_model, db_session=db_session
|
||||
)
|
||||
|
||||
|
||||
def update_default_multipass_indexing(db_session: Session) -> None:
|
||||
|
||||
@@ -17,7 +17,7 @@ def test_bedrock_llm_configuration(client: TestClient) -> None:
|
||||
# Prepare the test request payload
|
||||
test_request: dict[str, Any] = {
|
||||
"provider": LlmProviderNames.BEDROCK,
|
||||
"default_model_name": _DEFAULT_BEDROCK_MODEL,
|
||||
"model": _DEFAULT_BEDROCK_MODEL,
|
||||
"api_key": None,
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
@@ -26,7 +26,6 @@ def test_bedrock_llm_configuration(client: TestClient) -> None:
|
||||
"AWS_ACCESS_KEY_ID": os.environ.get("AWS_ACCESS_KEY_ID"),
|
||||
"AWS_SECRET_ACCESS_KEY": os.environ.get("AWS_SECRET_ACCESS_KEY"),
|
||||
},
|
||||
"model_configurations": [{"name": _DEFAULT_BEDROCK_MODEL, "is_visible": True}],
|
||||
"api_key_changed": True,
|
||||
}
|
||||
|
||||
@@ -43,7 +42,7 @@ def test_bedrock_llm_configuration_invalid_key(client: TestClient) -> None:
|
||||
# Prepare the test request payload with invalid credentials
|
||||
test_request: dict[str, Any] = {
|
||||
"provider": LlmProviderNames.BEDROCK,
|
||||
"default_model_name": _DEFAULT_BEDROCK_MODEL,
|
||||
"model": _DEFAULT_BEDROCK_MODEL,
|
||||
"api_key": None,
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
@@ -52,7 +51,6 @@ def test_bedrock_llm_configuration_invalid_key(client: TestClient) -> None:
|
||||
"AWS_ACCESS_KEY_ID": "invalid_access_key_id",
|
||||
"AWS_SECRET_ACCESS_KEY": "invalid_secret_access_key",
|
||||
},
|
||||
"model_configurations": [{"name": _DEFAULT_BEDROCK_MODEL, "is_visible": True}],
|
||||
"api_key_changed": True,
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
|
||||
|
||||
# Counter for generating unique file IDs in mock file store
|
||||
@@ -27,14 +28,19 @@ def ensure_default_llm_provider(db_session: Session) -> None:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=os.environ.get("OPENAI_API_KEY", "test"),
|
||||
is_public=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini",
|
||||
is_visible=True,
|
||||
)
|
||||
],
|
||||
groups=[],
|
||||
)
|
||||
provider = upsert_llm_provider(
|
||||
llm_provider_upsert_request=llm_provider_request,
|
||||
db_session=db_session,
|
||||
)
|
||||
update_default_provider(provider.id, db_session)
|
||||
update_default_provider(provider.id, "gpt-4o-mini", db_session)
|
||||
except Exception as exc: # pragma: no cover - only hits on duplicate setup issues
|
||||
print(f"Note: Could not create LLM provider: {exc}")
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
@@ -36,7 +37,7 @@ def test_answer_with_only_anthropic_provider(
|
||||
assert anthropic_api_key, "ANTHROPIC_API_KEY environment variable must be set"
|
||||
|
||||
# Drop any existing providers so that only Anthropic is available.
|
||||
for provider in fetch_existing_llm_providers(db_session):
|
||||
for provider in fetch_existing_llm_providers(db_session, [LLMModelFlowType.CHAT]):
|
||||
remove_llm_provider(db_session, provider.id)
|
||||
|
||||
anthropic_model = "claude-haiku-4-5-20251001"
|
||||
@@ -47,7 +48,6 @@ def test_answer_with_only_anthropic_provider(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.ANTHROPIC,
|
||||
api_key=anthropic_api_key,
|
||||
default_model_name=anthropic_model,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
model_configurations=[
|
||||
@@ -59,7 +59,7 @@ def test_answer_with_only_anthropic_provider(
|
||||
)
|
||||
|
||||
try:
|
||||
update_default_provider(anthropic_provider.id, db_session)
|
||||
update_default_provider(anthropic_provider.id, anthropic_model, db_session)
|
||||
|
||||
test_user = create_test_user(db_session, email_prefix="anthropic_only")
|
||||
chat_session = create_chat_session(
|
||||
|
||||
@@ -14,6 +14,7 @@ import pytest
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
@@ -106,12 +107,7 @@ class TestLLMConfigurationEndpoint:
|
||||
api_key="sk-new-test-key-0000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -156,12 +152,7 @@ class TestLLMConfigurationEndpoint:
|
||||
api_key="sk-invalid-key-00000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -206,12 +197,7 @@ class TestLLMConfigurationEndpoint:
|
||||
api_key=None, # Not providing a new key
|
||||
api_key_changed=False, # Using existing key
|
||||
custom_config_changed=False,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -258,12 +244,7 @@ class TestLLMConfigurationEndpoint:
|
||||
api_key=new_api_key, # Providing a new key
|
||||
api_key_changed=True, # Key is being changed
|
||||
custom_config_changed=False,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -304,12 +285,7 @@ class TestLLMConfigurationEndpoint:
|
||||
api_key_changed=True,
|
||||
custom_config=original_custom_config,
|
||||
custom_config_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
@@ -326,12 +302,7 @@ class TestLLMConfigurationEndpoint:
|
||||
api_key_changed=False,
|
||||
custom_config=None, # Not providing new config
|
||||
custom_config_changed=False, # Using existing config
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
)
|
||||
],
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -372,12 +343,7 @@ class TestLLMConfigurationEndpoint:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
custom_config_changed=False,
|
||||
default_model_name=model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=model_name, is_visible=True
|
||||
)
|
||||
],
|
||||
model=model_name,
|
||||
),
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
@@ -441,7 +407,6 @@ class TestDefaultProviderEndpoint:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=provider_1_api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name=provider_1_initial_model,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
|
||||
@@ -451,7 +416,7 @@ class TestDefaultProviderEndpoint:
|
||||
)
|
||||
|
||||
# Set provider 1 as the default provider explicitly
|
||||
update_default_provider(provider_1.id, db_session)
|
||||
update_default_provider(provider_1.id, provider_1_initial_model, db_session)
|
||||
|
||||
# Step 2: Call run_test_default_provider - should use provider 1's default model
|
||||
with patch(
|
||||
@@ -471,7 +436,6 @@ class TestDefaultProviderEndpoint:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=provider_2_api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name=provider_2_default_model,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
@@ -502,7 +466,6 @@ class TestDefaultProviderEndpoint:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=provider_1_api_key,
|
||||
api_key_changed=True,
|
||||
default_model_name=provider_1_updated_model, # Changed
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
|
||||
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
|
||||
@@ -511,6 +474,9 @@ class TestDefaultProviderEndpoint:
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Set provider 1's default model to the updated model
|
||||
update_default_provider(provider_1.id, provider_1_updated_model, db_session)
|
||||
|
||||
# Step 6: Call run_test_default_provider - should use new model on provider 1
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
|
||||
@@ -523,7 +489,7 @@ class TestDefaultProviderEndpoint:
|
||||
captured_llms.clear()
|
||||
|
||||
# Step 7: Change the default provider to provider 2
|
||||
update_default_provider(provider_2.id, db_session)
|
||||
update_default_provider(provider_2.id, provider_2_default_model, db_session)
|
||||
|
||||
# Step 8: Call run_test_default_provider - should use provider 2
|
||||
with patch(
|
||||
@@ -551,7 +517,9 @@ class TestDefaultProviderEndpoint:
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
|
||||
try:
|
||||
existing_providers = fetch_existing_llm_providers(db_session)
|
||||
existing_providers = fetch_existing_llm_providers(
|
||||
db_session, flow_types=[LLMModelFlowType.CHAT]
|
||||
)
|
||||
provider_names_to_restore: list[str] = []
|
||||
|
||||
for provider in existing_providers:
|
||||
@@ -593,7 +561,6 @@ class TestDefaultProviderEndpoint:
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o-mini", is_visible=True
|
||||
@@ -602,7 +569,7 @@ class TestDefaultProviderEndpoint:
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
update_default_provider(provider.id, db_session)
|
||||
update_default_provider(provider.id, "gpt-4o-mini", db_session)
|
||||
|
||||
# Test should fail
|
||||
with patch(
|
||||
|
||||
@@ -500,7 +500,7 @@ def test_upload_with_custom_config_then_change(
|
||||
LLMTestRequest(
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
default_model_name=default_model_name,
|
||||
model=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
@@ -539,7 +539,7 @@ def test_upload_with_custom_config_then_change(
|
||||
LLMTestRequest(
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
default_model_name=default_model_name,
|
||||
model=default_model_name,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name, is_visible=True
|
||||
|
||||
@@ -14,7 +14,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.llm import fetch_default_provider
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
@@ -132,7 +132,6 @@ class TestAutoModeSyncFeature:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name=expected_default_model,
|
||||
model_configurations=[], # No model configs provided
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -160,25 +159,20 @@ class TestAutoModeSyncFeature:
|
||||
if mc.name in all_expected_models:
|
||||
assert mc.is_visible is True, f"Model '{mc.name}' should be visible"
|
||||
|
||||
# Verify the default model was set correctly
|
||||
assert (
|
||||
provider.default_model_name == expected_default_model
|
||||
), f"Default model should be '{expected_default_model}'"
|
||||
|
||||
# Step 4: Set the provider as default
|
||||
update_default_provider(provider.id, db_session)
|
||||
update_default_provider(provider.id, expected_default_model, db_session)
|
||||
|
||||
# Step 5: Fetch the default provider and verify
|
||||
default_provider = fetch_default_provider(db_session)
|
||||
assert default_provider is not None, "Default provider should exist"
|
||||
default_model = fetch_default_llm_model(db_session)
|
||||
assert default_model is not None, "Default provider should exist"
|
||||
assert (
|
||||
default_provider.name == provider_name
|
||||
default_model.llm_provider.name == provider_name
|
||||
), "Default provider should be our test provider"
|
||||
assert (
|
||||
default_provider.default_model_name == expected_default_model
|
||||
default_model.name == expected_default_model
|
||||
), f"Default provider's default model should be '{expected_default_model}'"
|
||||
assert (
|
||||
default_provider.is_auto_mode is True
|
||||
default_model.llm_provider.is_auto_mode is True
|
||||
), "Default provider should be in auto mode"
|
||||
|
||||
finally:
|
||||
@@ -235,7 +229,6 @@ class TestAutoModeSyncFeature:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name="gpt-4o",
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -314,7 +307,6 @@ class TestAutoModeSyncFeature:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=False, # Not in auto mode initially
|
||||
default_model_name="gpt-4",
|
||||
model_configurations=initial_models,
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -346,7 +338,6 @@ class TestAutoModeSyncFeature:
|
||||
api_key=None, # Not changing API key
|
||||
api_key_changed=False,
|
||||
is_auto_mode=True, # Now enabling auto mode
|
||||
default_model_name=auto_mode_default,
|
||||
model_configurations=[], # Auto mode will sync from config
|
||||
),
|
||||
is_creation=False, # This is an update
|
||||
@@ -385,9 +376,6 @@ class TestAutoModeSyncFeature:
|
||||
model_visibility[model_name] is False
|
||||
), f"Model '{model_name}' not in auto config should NOT be visible"
|
||||
|
||||
# Verify the default model was updated
|
||||
assert provider.default_model_name == auto_mode_default
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
@@ -429,7 +417,6 @@ class TestAutoModeSyncFeature:
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name="gpt-4o",
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -532,7 +519,6 @@ class TestAutoModeSyncFeature:
|
||||
api_key=provider_1_api_key,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name=provider_1_default_model,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -546,7 +532,7 @@ class TestAutoModeSyncFeature:
|
||||
name=provider_1_name, db_session=db_session
|
||||
)
|
||||
assert provider_1 is not None
|
||||
update_default_provider(provider_1.id, db_session)
|
||||
update_default_provider(provider_1.id, provider_1_default_model, db_session)
|
||||
|
||||
with patch(
|
||||
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
|
||||
@@ -560,7 +546,6 @@ class TestAutoModeSyncFeature:
|
||||
api_key=provider_2_api_key,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name=provider_2_default_model,
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
@@ -570,26 +555,26 @@ class TestAutoModeSyncFeature:
|
||||
|
||||
# Step 3: Verify provider 1 is still the default
|
||||
db_session.expire_all()
|
||||
default_provider = fetch_default_provider(db_session)
|
||||
assert default_provider is not None
|
||||
assert default_provider.name == provider_1_name
|
||||
assert default_provider.default_model_name == provider_1_default_model
|
||||
assert default_provider.is_auto_mode is True
|
||||
default_model = fetch_default_llm_model(db_session)
|
||||
assert default_model is not None
|
||||
assert default_model.llm_provider.name == provider_1_name
|
||||
assert default_model.name == provider_1_default_model
|
||||
assert default_model.llm_provider.is_auto_mode is True
|
||||
|
||||
# Step 4: Change the default to provider 2
|
||||
provider_2 = fetch_existing_llm_provider(
|
||||
name=provider_2_name, db_session=db_session
|
||||
)
|
||||
assert provider_2 is not None
|
||||
update_default_provider(provider_2.id, db_session)
|
||||
update_default_provider(provider_2.id, provider_2_default_model, db_session)
|
||||
|
||||
# Step 5: Verify provider 2 is now the default
|
||||
db_session.expire_all()
|
||||
default_provider = fetch_default_provider(db_session)
|
||||
assert default_provider is not None
|
||||
assert default_provider.name == provider_2_name
|
||||
assert default_provider.default_model_name == provider_2_default_model
|
||||
assert default_provider.is_auto_mode is True
|
||||
default_model = fetch_default_llm_model(db_session)
|
||||
assert default_model is not None
|
||||
assert default_model.llm_provider.name == provider_2_name
|
||||
assert default_model.name == provider_2_default_model
|
||||
assert default_model.llm_provider.is_auto_mode is True
|
||||
|
||||
# Step 6: Run test_default_provider and verify it uses provider 2's model
|
||||
with patch(
|
||||
|
||||
@@ -5,6 +5,11 @@ from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
|
||||
# Set environment variables to disable model server for testing
|
||||
os.environ["DISABLE_MODEL_SERVER"] = "true"
|
||||
os.environ["MODEL_SERVER_HOST"] = "disabled"
|
||||
@@ -403,9 +408,13 @@ class TestSlackBotFederatedSearch:
|
||||
def _setup_llm_provider(self, db_session: Session) -> None:
|
||||
"""Create a default LLM provider in the database for testing with real API key"""
|
||||
# Delete any existing default LLM provider to ensure clean state
|
||||
# Use SQL-level delete to properly trigger ON DELETE CASCADE
|
||||
# (ORM-level delete tries to set foreign keys to NULL instead)
|
||||
from sqlalchemy import delete
|
||||
|
||||
existing_providers = db_session.query(LLMProvider).all()
|
||||
for provider in existing_providers:
|
||||
db_session.delete(provider)
|
||||
db_session.execute(delete(LLMProvider).where(LLMProvider.id == provider.id))
|
||||
db_session.commit()
|
||||
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
@@ -414,16 +423,25 @@ class TestSlackBotFederatedSearch:
|
||||
"OPENAI_API_KEY environment variable not set - test requires real API key"
|
||||
)
|
||||
|
||||
llm_provider = LLMProvider(
|
||||
name=f"test-llm-provider-{uuid4().hex[:8]}",
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=api_key,
|
||||
default_model_name="gpt-4o",
|
||||
is_default_provider=True,
|
||||
is_public=True,
|
||||
provider_view = upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
name=f"test-llm-provider-{uuid4().hex[:8]}",
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key=api_key,
|
||||
is_public=True,
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name="gpt-4o",
|
||||
is_visible=True,
|
||||
max_input_tokens=None,
|
||||
display_name="gpt-4o",
|
||||
),
|
||||
],
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.add(llm_provider)
|
||||
db_session.commit()
|
||||
|
||||
update_default_provider(provider_view.id, "gpt-4o", db_session)
|
||||
|
||||
def _teardown_common_mocks(self, patches: list) -> None:
|
||||
"""Stop all patches"""
|
||||
|
||||
@@ -4,8 +4,10 @@ from uuid import uuid4
|
||||
import requests
|
||||
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
@@ -36,7 +38,6 @@ class LLMProviderManager:
|
||||
llm_provider = LLMProviderUpsertRequest(
|
||||
name=name or f"test-provider-{uuid4()}",
|
||||
provider=provider or LlmProviderNames.OPENAI,
|
||||
default_model_name=default_model_name or "gpt-4o-mini",
|
||||
api_key=api_key or os.environ["OPENAI_API_KEY"],
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
@@ -44,7 +45,15 @@ class LLMProviderManager:
|
||||
is_public=True if is_public is None else is_public,
|
||||
groups=groups or [],
|
||||
personas=personas or [],
|
||||
model_configurations=[],
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=default_model_name or "gpt-4o-mini",
|
||||
is_visible=True,
|
||||
max_input_tokens=None,
|
||||
display_name=default_model_name or "gpt-4o-mini",
|
||||
supports_image_input=True,
|
||||
)
|
||||
],
|
||||
api_key_changed=True,
|
||||
)
|
||||
|
||||
@@ -125,7 +134,12 @@ class LLMProviderManager:
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
all_llm_providers = LLMProviderManager.get_all(user_performing_action)
|
||||
default_model = LLMProviderManager.get_default_model(user_performing_action)
|
||||
for fetched_llm_provider in all_llm_providers:
|
||||
model_names = [
|
||||
model.name for model in fetched_llm_provider.model_configurations
|
||||
]
|
||||
|
||||
if llm_provider.id == fetched_llm_provider.id:
|
||||
if verify_deleted:
|
||||
raise ValueError(
|
||||
@@ -138,11 +152,25 @@ class LLMProviderManager:
|
||||
if (
|
||||
fetched_llm_groups == llm_provider_groups
|
||||
and llm_provider.provider == fetched_llm_provider.provider
|
||||
and llm_provider.default_model_name
|
||||
== fetched_llm_provider.default_model_name
|
||||
and default_model.model_name in model_names
|
||||
and llm_provider.is_public == fetched_llm_provider.is_public
|
||||
and set(fetched_llm_provider.personas) == set(llm_provider.personas)
|
||||
):
|
||||
return
|
||||
if not verify_deleted:
|
||||
raise ValueError(f"LLM Provider {llm_provider.id} not found")
|
||||
|
||||
@staticmethod
|
||||
def get_default_model(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> DefaultModel:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/llm/default",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
return DefaultModel(**response.json())
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -79,6 +79,7 @@ export { default as SvgHourglass } from "@opal/icons/hourglass";
|
||||
export { default as SvgImage } from "@opal/icons/image";
|
||||
export { default as SvgImageSmall } from "@opal/icons/image-small";
|
||||
export { default as SvgImport } from "@opal/icons/import";
|
||||
export { default as SvgInfo } from "@opal/icons/info";
|
||||
export { default as SvgInfoSmall } from "@opal/icons/info-small";
|
||||
export { default as SvgKey } from "@opal/icons/key";
|
||||
export { default as SvgKeystroke } from "@opal/icons/keystroke";
|
||||
|
||||
20
web/lib/opal/src/icons/info.tsx
Normal file
20
web/lib/opal/src/icons/info.tsx
Normal file
@@ -0,0 +1,20 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgInfo = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M8.00001 10.6666V7.99998M8.00001 5.33331H8.00668M14.6667 7.99998C14.6667 11.6819 11.6819 14.6666 8.00001 14.6666C4.31811 14.6666 1.33334 11.6819 1.33334 7.99998C1.33334 4.31808 4.31811 1.33331 8.00001 1.33331C11.6819 1.33331 14.6667 4.31808 14.6667 7.99998Z"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgInfo;
|
||||
@@ -6,7 +6,7 @@ import { Callout } from "@/components/ui/callout";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Title from "@/components/ui/title";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { LLMProviderView } from "./interfaces";
|
||||
import { LLMProviderResponse, LLMProviderView } from "./interfaces";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
|
||||
import { OpenAIForm } from "./forms/OpenAIForm";
|
||||
import { AnthropicForm } from "./forms/AnthropicForm";
|
||||
@@ -17,44 +17,59 @@ import { VertexAIForm } from "./forms/VertexAIForm";
|
||||
import { OpenRouterForm } from "./forms/OpenRouterForm";
|
||||
import { getFormForExistingProvider } from "./forms/getForm";
|
||||
import { CustomForm } from "./forms/CustomForm";
|
||||
import { DefaultModelSelector } from "./forms/components/DefaultModel";
|
||||
import * as GeneralLayouts from "@/layouts/general-layouts";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import { useLlmManager } from "@/lib/hooks";
|
||||
|
||||
export function LLMConfiguration() {
|
||||
const { data: existingLlmProviders } = useSWR<LLMProviderView[]>(
|
||||
LLM_PROVIDERS_ADMIN_URL,
|
||||
errorHandlingFetcher
|
||||
);
|
||||
const { data: existingLLMProvidersResponse } = useSWR<
|
||||
LLMProviderResponse<LLMProviderView>
|
||||
>(LLM_PROVIDERS_ADMIN_URL, errorHandlingFetcher);
|
||||
|
||||
if (!existingLlmProviders) {
|
||||
if (!existingLLMProvidersResponse) {
|
||||
return <ThreeDotsLoader />;
|
||||
}
|
||||
|
||||
const existingLlmProviders = existingLLMProvidersResponse.providers;
|
||||
const defaultProviderId =
|
||||
existingLLMProvidersResponse.default_text?.provider_id;
|
||||
const defaultLlmModel =
|
||||
existingLLMProvidersResponse.default_text ?? undefined;
|
||||
const isFirstProvider = existingLlmProviders.length === 0;
|
||||
|
||||
const { updateDefaultLlmModel } = useLlmManager();
|
||||
|
||||
return (
|
||||
<>
|
||||
<Title className="mb-2">Enabled LLM Providers</Title>
|
||||
|
||||
{existingLlmProviders.length > 0 ? (
|
||||
<>
|
||||
<Text as="p" className="mb-4">
|
||||
If multiple LLM providers are enabled, the default provider will be
|
||||
used for all "Default" Assistants. For user-created
|
||||
Assistants, you can select the LLM provider/model that best fits the
|
||||
use case!
|
||||
</Text>
|
||||
<div className="flex flex-col gap-y-4">
|
||||
<DefaultModelSelector
|
||||
existingLlmProviders={existingLlmProviders}
|
||||
defaultLlmModel={defaultLlmModel ?? null}
|
||||
onModelChange={(provider_id, model_name) =>
|
||||
updateDefaultLlmModel({ provider_id, model_name })
|
||||
}
|
||||
/>
|
||||
|
||||
<GeneralLayouts.Section
|
||||
flexDirection="column"
|
||||
justifyContent="start"
|
||||
alignItems="start"
|
||||
>
|
||||
<Text headingH3>Available Providers</Text>
|
||||
{[...existingLlmProviders]
|
||||
.sort((a, b) => {
|
||||
if (a.is_default_provider && !b.is_default_provider) return -1;
|
||||
if (!a.is_default_provider && b.is_default_provider) return 1;
|
||||
if (a.id === defaultProviderId && b.id !== defaultProviderId)
|
||||
return -1;
|
||||
if (a.id !== defaultProviderId && b.id === defaultProviderId)
|
||||
return 1;
|
||||
return 0;
|
||||
})
|
||||
.map((llmProvider) => (
|
||||
<div key={llmProvider.id}>
|
||||
{getFormForExistingProvider(llmProvider)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
.map((llmProvider) => getFormForExistingProvider(llmProvider))}
|
||||
</GeneralLayouts.Section>
|
||||
|
||||
<Separator />
|
||||
</>
|
||||
) : (
|
||||
<Callout type="warning" title="No LLM providers configured yet">
|
||||
@@ -62,19 +77,25 @@ export function LLMConfiguration() {
|
||||
</Callout>
|
||||
)}
|
||||
|
||||
<Title className="mb-2 mt-6">Add LLM Provider</Title>
|
||||
<Text as="p" className="mb-4">
|
||||
Add a new LLM provider by either selecting from one of the default
|
||||
providers or by specifying your own custom LLM provider.
|
||||
</Text>
|
||||
<GeneralLayouts.Section
|
||||
flexDirection="column"
|
||||
justifyContent="start"
|
||||
alignItems="start"
|
||||
gap={0}
|
||||
>
|
||||
<Text headingH3>Add Provider</Text>
|
||||
<Text as="p" secondaryBody text03>
|
||||
Onyx supports both popular providers and self-hosted models.
|
||||
</Text>
|
||||
</GeneralLayouts.Section>
|
||||
|
||||
<div className="flex flex-col gap-y-4">
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<OpenAIForm shouldMarkAsDefault={isFirstProvider} />
|
||||
<AnthropicForm shouldMarkAsDefault={isFirstProvider} />
|
||||
<OllamaForm shouldMarkAsDefault={isFirstProvider} />
|
||||
<VertexAIForm shouldMarkAsDefault={isFirstProvider} />
|
||||
<AzureForm shouldMarkAsDefault={isFirstProvider} />
|
||||
<BedrockForm shouldMarkAsDefault={isFirstProvider} />
|
||||
<VertexAIForm shouldMarkAsDefault={isFirstProvider} />
|
||||
<OpenRouterForm shouldMarkAsDefault={isFirstProvider} />
|
||||
|
||||
<CustomForm shouldMarkAsDefault={isFirstProvider} />
|
||||
|
||||
@@ -12,6 +12,7 @@ export const ProviderIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: ProviderIconProps) => {
|
||||
console.log(provider);
|
||||
const Icon = getProviderIcon(provider, modelName);
|
||||
return <Icon size={size} className={className} />;
|
||||
};
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
export const LLM_PROVIDERS_ADMIN_URL = "/api/admin/llm/provider";
|
||||
export const LLM_ADMIN_URL = "/api/admin/llm";
|
||||
export const LLM_PROVIDERS_ADMIN_URL = `${LLM_ADMIN_URL}/provider`;
|
||||
|
||||
export const LLM_CONTEXTUAL_COST_ADMIN_URL =
|
||||
"/api/admin/llm/provider-contextual-cost";
|
||||
|
||||
@@ -17,17 +17,22 @@ import {
|
||||
} from "./formUtils";
|
||||
import { AdvancedOptions } from "./components/AdvancedOptions";
|
||||
import { DisplayModels } from "./components/DisplayModels";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import InputWrapper from "./components/InputWrapper";
|
||||
|
||||
export const ANTHROPIC_PROVIDER_NAME = "anthropic";
|
||||
const DEFAULT_DEFAULT_MODEL_NAME = "claude-sonnet-4-5";
|
||||
|
||||
export function AnthropicForm({
|
||||
existingLlmProvider,
|
||||
defaultLlmModel,
|
||||
shouldMarkAsDefault,
|
||||
}: LLMProviderFormProps) {
|
||||
return (
|
||||
<ProviderFormEntrypointWrapper
|
||||
providerName="Anthropic"
|
||||
providerDisplayName={existingLlmProvider?.name ?? "Claude"}
|
||||
providerInternalName="anthropic"
|
||||
providerEndpoint={ANTHROPIC_PROVIDER_NAME}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
>
|
||||
@@ -46,19 +51,28 @@ export function AnthropicForm({
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider
|
||||
);
|
||||
|
||||
const isAutoMode = existingLlmProvider?.is_auto_mode ?? true;
|
||||
const autoModelDefault =
|
||||
wellKnownLLMProvider?.recommended_default_model?.name ??
|
||||
DEFAULT_DEFAULT_MODEL_NAME;
|
||||
|
||||
const defaultModel = shouldMarkAsDefault
|
||||
? isAutoMode
|
||||
? autoModelDefault
|
||||
: defaultLlmModel?.model_name ?? DEFAULT_DEFAULT_MODEL_NAME
|
||||
: undefined;
|
||||
|
||||
const initialValues = {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations
|
||||
modelConfigurations,
|
||||
defaultModel
|
||||
),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
api_base: existingLlmProvider?.api_base ?? "",
|
||||
default_model_name:
|
||||
existingLlmProvider?.default_model_name ??
|
||||
wellKnownLLMProvider?.recommended_default_model?.name ??
|
||||
DEFAULT_DEFAULT_MODEL_NAME,
|
||||
// Default to auto mode for new Anthropic providers
|
||||
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,
|
||||
is_auto_mode: isAutoMode,
|
||||
};
|
||||
|
||||
const validationSchema = buildDefaultValidationSchema().shape({
|
||||
@@ -92,16 +106,26 @@ export function AnthropicForm({
|
||||
{(formikProps) => {
|
||||
return (
|
||||
<Form className={LLM_FORM_CLASS_NAME}>
|
||||
<InputWrapper
|
||||
label="API Key"
|
||||
description="Paste your {link} from Anthropic to access your models."
|
||||
descriptionLink={{
|
||||
text: "API key",
|
||||
href: "https://console.anthropic.com/dashboard",
|
||||
}}
|
||||
>
|
||||
<PasswordInputTypeInField name="api_key" />
|
||||
</InputWrapper>
|
||||
|
||||
<Separator noPadding />
|
||||
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
|
||||
<PasswordInputTypeInField name="api_key" label="API Key" />
|
||||
<Separator noPadding />
|
||||
|
||||
<DisplayModels
|
||||
modelConfigurations={modelConfigurations}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={
|
||||
wellKnownLLMProvider?.recommended_default_model ?? null
|
||||
}
|
||||
shouldShowAutoUpdateToggle={true}
|
||||
/>
|
||||
|
||||
|
||||
@@ -18,15 +18,16 @@ import {
|
||||
LLM_FORM_CLASS_NAME,
|
||||
} from "./formUtils";
|
||||
import { AdvancedOptions } from "./components/AdvancedOptions";
|
||||
import { SingleDefaultModelField } from "./components/SingleDefaultModelField";
|
||||
import {
|
||||
isValidAzureTargetUri,
|
||||
parseAzureTargetUri,
|
||||
} from "@/lib/azureTargetUri";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import { DisplayModels } from "./components/DisplayModels";
|
||||
import InputWrapper from "./components/InputWrapper";
|
||||
|
||||
export const AZURE_PROVIDER_NAME = "azure";
|
||||
const AZURE_DISPLAY_NAME = "Microsoft Azure Cloud";
|
||||
const AZURE_DISPLAY_NAME = "Azure OpenAI";
|
||||
|
||||
interface AzureFormValues extends BaseLLMFormValues {
|
||||
api_key: string;
|
||||
@@ -53,7 +54,9 @@ export function AzureForm({
|
||||
}: LLMProviderFormProps) {
|
||||
return (
|
||||
<ProviderFormEntrypointWrapper
|
||||
providerName={AZURE_DISPLAY_NAME}
|
||||
providerName={"Microsoft Azure"}
|
||||
providerDisplayName={existingLlmProvider?.name ?? AZURE_DISPLAY_NAME}
|
||||
providerInternalName={"azure"}
|
||||
providerEndpoint={AZURE_PROVIDER_NAME}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
>
|
||||
@@ -138,19 +141,41 @@ export function AzureForm({
|
||||
{(formikProps) => {
|
||||
return (
|
||||
<Form className={LLM_FORM_CLASS_NAME}>
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
|
||||
<PasswordInputTypeInField name="api_key" label="API Key" />
|
||||
|
||||
<TextFormField
|
||||
name="target_uri"
|
||||
<InputWrapper
|
||||
label="Target URI"
|
||||
placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview"
|
||||
subtext="The complete target URI for your deployment from the Azure AI portal."
|
||||
/>
|
||||
description="Paste your endpoint target URI from Azure OpenAI (including API endpoint base, deployment name, and API version)."
|
||||
>
|
||||
<TextFormField
|
||||
name="target_uri"
|
||||
label=""
|
||||
placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview"
|
||||
/>
|
||||
</InputWrapper>
|
||||
|
||||
<InputWrapper
|
||||
label="API Key"
|
||||
description="Paste your API key from {link} to access your models."
|
||||
descriptionLink={{
|
||||
text: "Azure OpenAI",
|
||||
href: "https://oai.azure.com",
|
||||
}}
|
||||
>
|
||||
<PasswordInputTypeInField name="api_key" placeholder="" />
|
||||
</InputWrapper>
|
||||
|
||||
<Separator />
|
||||
<SingleDefaultModelField placeholder="E.g. gpt-4o" />
|
||||
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
|
||||
<Separator noPadding />
|
||||
|
||||
<DisplayModels
|
||||
modelConfigurations={modelConfigurations}
|
||||
formikProps={formikProps}
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
noModelConfigurationsMessage="No models available. Provide a valid base URL or key."
|
||||
/>
|
||||
|
||||
<Separator />
|
||||
|
||||
<AdvancedOptions formikProps={formikProps} />
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import { Form, Formik, FormikProps } from "formik";
|
||||
import { SelectorFormField, TextFormField } from "@/components/Field";
|
||||
import { TextFormField } from "@/components/Field";
|
||||
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
|
||||
import InputSelectField from "@/refresh-components/form/InputSelectField";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import {
|
||||
LLMProviderFormProps,
|
||||
LLMProviderView,
|
||||
@@ -16,7 +18,6 @@ import {
|
||||
} from "./components/FormWrapper";
|
||||
import { DisplayNameField } from "./components/DisplayNameField";
|
||||
import { FormActionButtons } from "./components/FormActionButtons";
|
||||
import { FetchModelsButton } from "./components/FetchModelsButton";
|
||||
import {
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
@@ -27,14 +28,16 @@ import {
|
||||
} from "./formUtils";
|
||||
import { AdvancedOptions } from "./components/AdvancedOptions";
|
||||
import { DisplayModels } from "./components/DisplayModels";
|
||||
import { fetchBedrockModels } from "../utils";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Tabs from "@/refresh-components/Tabs";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Card } from "@/refresh-components/cards";
|
||||
import { SvgInfo } from "@opal/icons";
|
||||
import * as GeneralLayouts from "@/layouts/general-layouts";
|
||||
import InputWrapper from "./components/InputWrapper";
|
||||
|
||||
export const BEDROCK_PROVIDER_NAME = "bedrock";
|
||||
const BEDROCK_DISPLAY_NAME = "AWS Bedrock";
|
||||
const BEDROCK_DISPLAY_NAME = "Amazon Bedrock";
|
||||
|
||||
// AWS Bedrock regions - kept in sync with backend
|
||||
const AWS_REGION_OPTIONS = [
|
||||
@@ -58,6 +61,13 @@ const AUTH_METHOD_IAM = "iam";
|
||||
const AUTH_METHOD_ACCESS_KEY = "access_key";
|
||||
const AUTH_METHOD_LONG_TERM_API_KEY = "long_term_api_key";
|
||||
|
||||
// Auth method options
|
||||
const AUTH_METHOD_OPTIONS = [
|
||||
{ name: "IAM Role", value: AUTH_METHOD_IAM },
|
||||
{ name: "Access Key", value: AUTH_METHOD_ACCESS_KEY },
|
||||
{ name: "Long-term API Key", value: AUTH_METHOD_LONG_TERM_API_KEY },
|
||||
];
|
||||
|
||||
// Field name constants
|
||||
const FIELD_AWS_REGION_NAME = "custom_config.AWS_REGION_NAME";
|
||||
const FIELD_BEDROCK_AUTH_METHOD = "custom_config.BEDROCK_AUTH_METHOD";
|
||||
@@ -135,99 +145,96 @@ function BedrockFormInternals({
|
||||
const isFetchDisabled =
|
||||
!formikProps.values.custom_config?.AWS_REGION_NAME || !isAuthComplete;
|
||||
|
||||
const authDisplay = () => {
|
||||
if (authMethod === AUTH_METHOD_IAM) {
|
||||
return (
|
||||
<>
|
||||
<Card variant="secondary">
|
||||
<GeneralLayouts.Section
|
||||
flexDirection="row"
|
||||
alignItems="center"
|
||||
justifyContent="start"
|
||||
>
|
||||
<SvgInfo className="w-4 h-4" />
|
||||
<Text as="p" text03>
|
||||
Onyx will use the IAM role attached to the environment it's
|
||||
running in to authenticate.
|
||||
</Text>
|
||||
</GeneralLayouts.Section>
|
||||
</Card>
|
||||
</>
|
||||
);
|
||||
} else if (authMethod === AUTH_METHOD_ACCESS_KEY) {
|
||||
return (
|
||||
<>
|
||||
<TextFormField
|
||||
name={FIELD_AWS_ACCESS_KEY_ID}
|
||||
label="Access Key ID"
|
||||
placeholder=""
|
||||
/>
|
||||
<PasswordInputTypeInField
|
||||
name={FIELD_AWS_SECRET_ACCESS_KEY}
|
||||
label="Secret Access Key"
|
||||
placeholder=""
|
||||
/>
|
||||
</>
|
||||
);
|
||||
} else if (authMethod === AUTH_METHOD_LONG_TERM_API_KEY) {
|
||||
return (
|
||||
<>
|
||||
<PasswordInputTypeInField
|
||||
name={FIELD_AWS_BEARER_TOKEN_BEDROCK}
|
||||
label="Long-term API Key"
|
||||
placeholder=""
|
||||
/>
|
||||
</>
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Form className={cn(LLM_FORM_CLASS_NAME, "w-full")}>
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
|
||||
<SelectorFormField
|
||||
name={FIELD_AWS_REGION_NAME}
|
||||
label="AWS Region"
|
||||
subtext="Region where your Amazon Bedrock models are hosted."
|
||||
options={AWS_REGION_OPTIONS}
|
||||
/>
|
||||
|
||||
<div>
|
||||
<Text as="p" mainUiAction>
|
||||
Authentication Method
|
||||
</Text>
|
||||
<InputSelectField name={FIELD_AWS_REGION_NAME}>
|
||||
<Text as="p">AWS Region Name</Text>
|
||||
<InputSelect.Trigger placeholder="Select AWS region" />
|
||||
<InputSelect.Content>
|
||||
{AWS_REGION_OPTIONS.map((option) => (
|
||||
<InputSelect.Item key={option.value} value={option.value}>
|
||||
{option.name}
|
||||
</InputSelect.Item>
|
||||
))}
|
||||
</InputSelect.Content>
|
||||
<Text as="p" secondaryBody text03>
|
||||
Choose how Onyx should authenticate with Bedrock.
|
||||
Region where your Amazon Bedrock models are hosted. See full list of
|
||||
regions supported at AWS.
|
||||
</Text>
|
||||
<Tabs
|
||||
value={authMethod || AUTH_METHOD_ACCESS_KEY}
|
||||
onValueChange={(value) =>
|
||||
formikProps.setFieldValue(FIELD_BEDROCK_AUTH_METHOD, value)
|
||||
}
|
||||
</InputSelectField>
|
||||
|
||||
<InputSelectField name={FIELD_BEDROCK_AUTH_METHOD}>
|
||||
<InputWrapper
|
||||
label="Authentication Method"
|
||||
description="See {link} for more instructions."
|
||||
descriptionLink={{
|
||||
text: "documentation",
|
||||
href: `https://docs.onyx.app/admin/ai_models/bedrock#authentication-methods`,
|
||||
}}
|
||||
>
|
||||
<Tabs.List>
|
||||
<Tabs.Trigger value={AUTH_METHOD_IAM}>IAM Role</Tabs.Trigger>
|
||||
<Tabs.Trigger value={AUTH_METHOD_ACCESS_KEY}>
|
||||
Access Key
|
||||
</Tabs.Trigger>
|
||||
<Tabs.Trigger value={AUTH_METHOD_LONG_TERM_API_KEY}>
|
||||
Long-term API Key
|
||||
</Tabs.Trigger>
|
||||
</Tabs.List>
|
||||
<InputSelect.Trigger placeholder="Select authentication method" />
|
||||
<InputSelect.Content>
|
||||
{AUTH_METHOD_OPTIONS.map((option) => (
|
||||
<InputSelect.Item key={option.value} value={option.value}>
|
||||
{option.name}
|
||||
</InputSelect.Item>
|
||||
))}
|
||||
</InputSelect.Content>
|
||||
</InputWrapper>
|
||||
</InputSelectField>
|
||||
|
||||
<Tabs.Content value={AUTH_METHOD_IAM}>
|
||||
<Text as="p" text03>
|
||||
Uses the IAM role attached to your AWS environment. Recommended
|
||||
for EC2, ECS, Lambda, or other AWS services.
|
||||
</Text>
|
||||
</Tabs.Content>
|
||||
{authDisplay()}
|
||||
|
||||
<Tabs.Content value={AUTH_METHOD_ACCESS_KEY}>
|
||||
<div className="flex flex-col gap-4 w-full">
|
||||
<TextFormField
|
||||
name={FIELD_AWS_ACCESS_KEY_ID}
|
||||
label="AWS Access Key ID"
|
||||
placeholder="AKIAIOSFODNN7EXAMPLE"
|
||||
/>
|
||||
<PasswordInputTypeInField
|
||||
name={FIELD_AWS_SECRET_ACCESS_KEY}
|
||||
label="AWS Secret Access Key"
|
||||
placeholder="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
|
||||
/>
|
||||
</div>
|
||||
</Tabs.Content>
|
||||
<Separator noPadding />
|
||||
|
||||
<Tabs.Content value={AUTH_METHOD_LONG_TERM_API_KEY}>
|
||||
<div className="flex flex-col gap-4 w-full">
|
||||
<PasswordInputTypeInField
|
||||
name={FIELD_AWS_BEARER_TOKEN_BEDROCK}
|
||||
label="AWS Bedrock Long-term API Key"
|
||||
placeholder="Your long-term API key"
|
||||
/>
|
||||
</div>
|
||||
</Tabs.Content>
|
||||
</Tabs>
|
||||
</div>
|
||||
|
||||
<FetchModelsButton
|
||||
onFetch={() =>
|
||||
fetchBedrockModels({
|
||||
aws_region_name:
|
||||
formikProps.values.custom_config?.AWS_REGION_NAME ?? "",
|
||||
aws_access_key_id:
|
||||
formikProps.values.custom_config?.AWS_ACCESS_KEY_ID,
|
||||
aws_secret_access_key:
|
||||
formikProps.values.custom_config?.AWS_SECRET_ACCESS_KEY,
|
||||
aws_bearer_token_bedrock:
|
||||
formikProps.values.custom_config?.AWS_BEARER_TOKEN_BEDROCK,
|
||||
provider_name: existingLlmProvider?.name,
|
||||
})
|
||||
}
|
||||
isDisabled={isFetchDisabled}
|
||||
disabledHint={
|
||||
!formikProps.values.custom_config?.AWS_REGION_NAME
|
||||
? "Select an AWS region."
|
||||
: !isAuthComplete
|
||||
? 'Complete the "Authentication Method" section.'
|
||||
: undefined
|
||||
}
|
||||
onModelsFetched={setFetchedModels}
|
||||
autoFetchOnInitialLoad={!!existingLlmProvider}
|
||||
/>
|
||||
<DisplayNameField />
|
||||
|
||||
<Separator />
|
||||
|
||||
@@ -235,10 +242,8 @@ function BedrockFormInternals({
|
||||
modelConfigurations={currentModels}
|
||||
formikProps={formikProps}
|
||||
noModelConfigurationsMessage={
|
||||
"Fetch available models first, then you'll be able to select " +
|
||||
"the models you want to make available in Onyx."
|
||||
"No models available. Provide a valid region and key."
|
||||
}
|
||||
recommendedDefaultModel={null}
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
/>
|
||||
|
||||
@@ -266,7 +271,9 @@ export function BedrockForm({
|
||||
|
||||
return (
|
||||
<ProviderFormEntrypointWrapper
|
||||
providerName={BEDROCK_DISPLAY_NAME}
|
||||
providerName={"AWS"}
|
||||
providerDisplayName={existingLlmProvider?.name ?? BEDROCK_DISPLAY_NAME}
|
||||
providerInternalName={"bedrock"}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
>
|
||||
{({
|
||||
|
||||
@@ -326,9 +326,13 @@ describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
// Verify set as default API was called
|
||||
await waitFor(() => {
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
"/api/admin/llm/provider/5/default",
|
||||
"/api/admin/llm/provider/default",
|
||||
expect.objectContaining({
|
||||
method: "POST",
|
||||
body: JSON.stringify({
|
||||
provider_id: 5,
|
||||
model_name: "gpt-4",
|
||||
}),
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
@@ -43,10 +43,9 @@ export function CustomForm({
|
||||
}: LLMProviderFormProps) {
|
||||
return (
|
||||
<ProviderFormEntrypointWrapper
|
||||
providerName="Custom LLM"
|
||||
providerName="LiteLLM"
|
||||
providerDisplayName={existingLlmProvider?.name ?? "LiteLLM Models"}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
buttonMode={!existingLlmProvider}
|
||||
buttonText="Add Custom LLM Provider"
|
||||
>
|
||||
{({
|
||||
onClose,
|
||||
|
||||
@@ -25,10 +25,17 @@ import { AdvancedOptions } from "./components/AdvancedOptions";
|
||||
import { DisplayModels } from "./components/DisplayModels";
|
||||
import { useEffect, useState } from "react";
|
||||
import { fetchOllamaModels } from "../utils";
|
||||
import Tabs from "@/refresh-components/Tabs";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
|
||||
export const OLLAMA_PROVIDER_NAME = "ollama_chat";
|
||||
const DEFAULT_API_BASE = "http://127.0.0.1:11434";
|
||||
|
||||
enum OllamaHostingTab {
|
||||
SelfHosted = "self-hosted",
|
||||
Cloud = "cloud",
|
||||
}
|
||||
|
||||
interface OllamaFormValues extends BaseLLMFormValues {
|
||||
api_base: string;
|
||||
custom_config: {
|
||||
@@ -60,6 +67,9 @@ function OllamaFormContent({
|
||||
isFormValid,
|
||||
}: OllamaFormContentProps) {
|
||||
const [isLoadingModels, setIsLoadingModels] = useState(true);
|
||||
const [activeTab, setActiveTab] = useState<OllamaHostingTab>(
|
||||
OllamaHostingTab.SelfHosted
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (formikProps.values.api_base) {
|
||||
@@ -93,27 +103,48 @@ function OllamaFormContent({
|
||||
|
||||
return (
|
||||
<Form className={LLM_FORM_CLASS_NAME}>
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
<Tabs
|
||||
value={activeTab}
|
||||
onValueChange={(value) => setActiveTab(value as OllamaHostingTab)}
|
||||
>
|
||||
<Tabs.List>
|
||||
<Tabs.Trigger value={OllamaHostingTab.SelfHosted}>
|
||||
Self-hosted Ollama
|
||||
</Tabs.Trigger>
|
||||
<Tabs.Trigger value={OllamaHostingTab.Cloud}>
|
||||
Ollama Cloud
|
||||
</Tabs.Trigger>
|
||||
</Tabs.List>
|
||||
|
||||
<TextFormField
|
||||
name="api_base"
|
||||
label="API Base URL"
|
||||
subtext="The base URL for your Ollama instance (e.g., http://127.0.0.1:11434)"
|
||||
placeholder={DEFAULT_API_BASE}
|
||||
/>
|
||||
<Tabs.Content value={OllamaHostingTab.SelfHosted}>
|
||||
<TextFormField
|
||||
name="api_base"
|
||||
label="API Base URL"
|
||||
subtext="Your self-hosted Ollama API base URL."
|
||||
placeholder="Your Ollama API base URL."
|
||||
/>
|
||||
</Tabs.Content>
|
||||
|
||||
<PasswordInputTypeInField
|
||||
name="custom_config.OLLAMA_API_KEY"
|
||||
label="API Key (Optional)"
|
||||
subtext="Optional API key for Ollama Cloud (https://ollama.com). Leave blank for local instances."
|
||||
/>
|
||||
<Tabs.Content value={OllamaHostingTab.Cloud}>
|
||||
<TextFormField
|
||||
name="api_key"
|
||||
label="API Key"
|
||||
subtext="Paste your API key from Ollama Cloud to access your models."
|
||||
/>
|
||||
</Tabs.Content>
|
||||
</Tabs>
|
||||
|
||||
<Separator noPadding />
|
||||
|
||||
<DisplayNameField />
|
||||
|
||||
<Separator noPadding />
|
||||
|
||||
<DisplayModels
|
||||
modelConfigurations={currentModels}
|
||||
formikProps={formikProps}
|
||||
noModelConfigurationsMessage="No models found. Please provide a valid API base URL."
|
||||
noModelConfigurationsMessage="No models found. Please provide a valid base URL or key."
|
||||
isLoading={isLoadingModels}
|
||||
recommendedDefaultModel={null}
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
/>
|
||||
|
||||
@@ -140,6 +171,7 @@ export function OllamaForm({
|
||||
return (
|
||||
<ProviderFormEntrypointWrapper
|
||||
providerName="Ollama"
|
||||
providerDisplayName={existingLlmProvider?.name ?? "Ollama"}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
>
|
||||
{({
|
||||
|
||||
@@ -1,31 +1,34 @@
|
||||
import { Form, Formik } from "formik";
|
||||
import * as Yup from "yup";
|
||||
|
||||
import { LLMProviderFormProps } from "../interfaces";
|
||||
import * as Yup from "yup";
|
||||
import { ProviderFormEntrypointWrapper } from "./components/FormWrapper";
|
||||
import { DisplayNameField } from "./components/DisplayNameField";
|
||||
import PasswordInputTypeInField from "@/refresh-components/form/PasswordInputTypeInField";
|
||||
import { FormActionButtons } from "./components/FormActionButtons";
|
||||
import {
|
||||
buildDefaultInitialValues,
|
||||
buildDefaultValidationSchema,
|
||||
buildAvailableModelConfigurations,
|
||||
submitLLMProvider,
|
||||
LLM_FORM_CLASS_NAME,
|
||||
} from "./formUtils";
|
||||
import { AdvancedOptions } from "./components/AdvancedOptions";
|
||||
import { DisplayModels } from "./components/DisplayModels";
|
||||
import LLMFormLayout from "./components/FormLayout";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import InputWrapper from "./components/InputWrapper";
|
||||
|
||||
export const OPENAI_PROVIDER_NAME = "openai";
|
||||
const DEFAULT_DEFAULT_MODEL_NAME = "gpt-5.2";
|
||||
|
||||
export function OpenAIForm({
|
||||
existingLlmProvider,
|
||||
defaultLlmModel,
|
||||
shouldMarkAsDefault,
|
||||
}: LLMProviderFormProps) {
|
||||
return (
|
||||
<ProviderFormEntrypointWrapper
|
||||
providerName="OpenAI"
|
||||
providerDisplayName={existingLlmProvider?.name ?? "OpenAI"}
|
||||
providerEndpoint={OPENAI_PROVIDER_NAME}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
>
|
||||
@@ -44,18 +47,27 @@ export function OpenAIForm({
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider
|
||||
);
|
||||
const isAutoMode = existingLlmProvider?.is_auto_mode ?? true;
|
||||
const autoModelDefault =
|
||||
wellKnownLLMProvider?.recommended_default_model?.name ??
|
||||
DEFAULT_DEFAULT_MODEL_NAME;
|
||||
|
||||
// We use a default model if we're editting and this provider is the global default
|
||||
// Or we are creating the first provider (and shouldMarkAsDefault is true)
|
||||
const defaultModel = shouldMarkAsDefault
|
||||
? isAutoMode
|
||||
? autoModelDefault
|
||||
: defaultLlmModel?.model_name ?? DEFAULT_DEFAULT_MODEL_NAME
|
||||
: undefined;
|
||||
|
||||
const initialValues = {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations
|
||||
modelConfigurations,
|
||||
defaultModel
|
||||
),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
default_model_name:
|
||||
existingLlmProvider?.default_model_name ??
|
||||
wellKnownLLMProvider?.recommended_default_model?.name ??
|
||||
DEFAULT_DEFAULT_MODEL_NAME,
|
||||
// Default to auto mode for new OpenAI providers
|
||||
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,
|
||||
is_auto_mode: isAutoMode,
|
||||
};
|
||||
|
||||
const validationSchema = buildDefaultValidationSchema().shape({
|
||||
@@ -86,35 +98,47 @@ export function OpenAIForm({
|
||||
});
|
||||
}}
|
||||
>
|
||||
{(formikProps) => {
|
||||
return (
|
||||
<Form className={LLM_FORM_CLASS_NAME}>
|
||||
{(formikProps) => (
|
||||
<Form>
|
||||
<LLMFormLayout.Body>
|
||||
<InputWrapper
|
||||
label="API Key"
|
||||
description="Paste your {link} from OpenAI to access your models."
|
||||
descriptionLink={{
|
||||
text: "API key",
|
||||
href: "https://platform.openai.com/api-keys",
|
||||
}}
|
||||
>
|
||||
<PasswordInputTypeInField
|
||||
name="api_key"
|
||||
subtext="Paste your API key from OpenAI to access your models."
|
||||
/>
|
||||
</InputWrapper>
|
||||
|
||||
<Separator />
|
||||
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
|
||||
<PasswordInputTypeInField name="api_key" label="API Key" />
|
||||
<Separator />
|
||||
|
||||
<DisplayModels
|
||||
modelConfigurations={modelConfigurations}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={
|
||||
wellKnownLLMProvider?.recommended_default_model ?? null
|
||||
}
|
||||
shouldShowAutoUpdateToggle={true}
|
||||
/>
|
||||
|
||||
<AdvancedOptions formikProps={formikProps} />
|
||||
</LLMFormLayout.Body>
|
||||
|
||||
<FormActionButtons
|
||||
isTesting={isTesting}
|
||||
testError={testError}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
mutate={mutate}
|
||||
onClose={onClose}
|
||||
isFormValid={formikProps.isValid}
|
||||
/>
|
||||
</Form>
|
||||
);
|
||||
}}
|
||||
<LLMFormLayout.Footer
|
||||
onCancel={onClose}
|
||||
submitLabel={existingLlmProvider ? "Update" : "Enable"}
|
||||
isSubmitting={isTesting}
|
||||
isSubmitDisabled={!formikProps.isValid}
|
||||
error={testError}
|
||||
/>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -26,6 +26,7 @@ import { AdvancedOptions } from "./components/AdvancedOptions";
|
||||
import { DisplayModels } from "./components/DisplayModels";
|
||||
import { FetchModelsButton } from "./components/FetchModelsButton";
|
||||
import { useState } from "react";
|
||||
import InputWrapper from "./components/InputWrapper";
|
||||
|
||||
export const OPENROUTER_PROVIDER_NAME = "openrouter";
|
||||
const OPENROUTER_DISPLAY_NAME = "OpenRouter";
|
||||
@@ -40,7 +41,6 @@ interface OpenRouterFormValues extends BaseLLMFormValues {
|
||||
async function fetchOpenRouterModels(params: {
|
||||
apiBase: string;
|
||||
apiKey: string;
|
||||
providerName?: string;
|
||||
}): Promise<{ models: ModelConfiguration[]; error?: string }> {
|
||||
if (!params.apiBase || !params.apiKey) {
|
||||
return {
|
||||
@@ -58,7 +58,6 @@ async function fetchOpenRouterModels(params: {
|
||||
body: JSON.stringify({
|
||||
api_base: params.apiBase,
|
||||
api_key: params.apiKey,
|
||||
provider_name: params.providerName,
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -92,6 +91,7 @@ async function fetchOpenRouterModels(params: {
|
||||
|
||||
export function OpenRouterForm({
|
||||
existingLlmProvider,
|
||||
defaultLlmModel,
|
||||
shouldMarkAsDefault,
|
||||
}: LLMProviderFormProps) {
|
||||
const [fetchedModels, setFetchedModels] = useState<ModelConfiguration[]>([]);
|
||||
@@ -99,6 +99,7 @@ export function OpenRouterForm({
|
||||
return (
|
||||
<ProviderFormEntrypointWrapper
|
||||
providerName={OPENROUTER_DISPLAY_NAME}
|
||||
providerDisplayName={existingLlmProvider?.name ?? OPENROUTER_DISPLAY_NAME}
|
||||
providerEndpoint={OPENROUTER_PROVIDER_NAME}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
>
|
||||
@@ -117,10 +118,22 @@ export function OpenRouterForm({
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider
|
||||
);
|
||||
|
||||
const isAutoMode = existingLlmProvider?.is_auto_mode ?? true;
|
||||
const autoModelDefault =
|
||||
wellKnownLLMProvider?.recommended_default_model?.name;
|
||||
|
||||
const defaultModel = shouldMarkAsDefault
|
||||
? isAutoMode
|
||||
? autoModelDefault
|
||||
: defaultLlmModel?.model_name
|
||||
: undefined;
|
||||
|
||||
const initialValues: OpenRouterFormValues = {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations
|
||||
modelConfigurations,
|
||||
defaultModel
|
||||
),
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
api_base: existingLlmProvider?.api_base ?? DEFAULT_API_BASE,
|
||||
@@ -170,36 +183,31 @@ export function OpenRouterForm({
|
||||
|
||||
return (
|
||||
<Form className={LLM_FORM_CLASS_NAME}>
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
|
||||
<PasswordInputTypeInField name="api_key" label="API Key" />
|
||||
|
||||
<TextFormField
|
||||
name="api_base"
|
||||
<InputWrapper
|
||||
label="API Base URL"
|
||||
subtext="The base URL for OpenRouter API."
|
||||
placeholder={DEFAULT_API_BASE}
|
||||
/>
|
||||
description="Paste your OpenRouter compatible endpoint URL or use OpenRouter API directly."
|
||||
>
|
||||
<TextFormField
|
||||
name="api_base_url"
|
||||
placeholder="https://openrouter.ai/api/v1"
|
||||
label=""
|
||||
/>
|
||||
</InputWrapper>
|
||||
|
||||
<FetchModelsButton
|
||||
onFetch={() =>
|
||||
fetchOpenRouterModels({
|
||||
apiBase: formikProps.values.api_base,
|
||||
apiKey: formikProps.values.api_key,
|
||||
providerName: existingLlmProvider?.name,
|
||||
})
|
||||
}
|
||||
isDisabled={isFetchDisabled}
|
||||
disabledHint={
|
||||
!formikProps.values.api_key
|
||||
? "Enter your API key first."
|
||||
: !formikProps.values.api_base
|
||||
? "Enter the API base URL."
|
||||
: undefined
|
||||
}
|
||||
onModelsFetched={setFetchedModels}
|
||||
autoFetchOnInitialLoad={!!existingLlmProvider}
|
||||
/>
|
||||
<InputWrapper
|
||||
label="API Key"
|
||||
description="Paste your API key from {link} to access your models."
|
||||
descriptionLink={{
|
||||
text: "OpenRouter",
|
||||
href: "https://openrouter.ai/settings/keys",
|
||||
}}
|
||||
>
|
||||
<PasswordInputTypeInField name="api_key" label="" />
|
||||
</InputWrapper>
|
||||
|
||||
<Separator />
|
||||
|
||||
<DisplayNameField />
|
||||
|
||||
<Separator />
|
||||
|
||||
@@ -207,10 +215,8 @@ export function OpenRouterForm({
|
||||
modelConfigurations={currentModels}
|
||||
formikProps={formikProps}
|
||||
noModelConfigurationsMessage={
|
||||
"Fetch available models first, then you'll be able to select " +
|
||||
"the models you want to make available in Onyx."
|
||||
"No models available. Provide a valid base URL and key."
|
||||
}
|
||||
recommendedDefaultModel={null}
|
||||
shouldShowAutoUpdateToggle={false}
|
||||
/>
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { Form, Formik } from "formik";
|
||||
import { TextFormField, FileUploadFormField } from "@/components/Field";
|
||||
import { LLMProviderFormProps } from "../interfaces";
|
||||
import * as Yup from "yup";
|
||||
import {
|
||||
@@ -19,12 +18,53 @@ import {
|
||||
import { AdvancedOptions } from "./components/AdvancedOptions";
|
||||
import { DisplayModels } from "./components/DisplayModels";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import LLMFormLayout from "./components/FormLayout";
|
||||
import { FormField } from "@/refresh-components/form/FormField";
|
||||
import InputFile from "@/refresh-components/inputs/InputFile";
|
||||
import InputSelectField from "@/refresh-components/form/InputSelectField";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { ModelAccessOptions } from "./components/ModelAccessOptions";
|
||||
|
||||
export const VERTEXAI_PROVIDER_NAME = "vertex_ai";
|
||||
const VERTEXAI_DISPLAY_NAME = "Google Cloud Vertex AI";
|
||||
const VERTEXAI_DEFAULT_MODEL = "gemini-2.5-pro";
|
||||
const VERTEXAI_DEFAULT_LOCATION = "global";
|
||||
|
||||
const VERTEXAI_REGION_OPTIONS = [
|
||||
{ name: "global", value: "global" },
|
||||
{ name: "us-central1", value: "us-central1" },
|
||||
{ name: "us-east1", value: "us-east1" },
|
||||
{ name: "us-east4", value: "us-east4" },
|
||||
{ name: "us-east5", value: "us-east5" },
|
||||
{ name: "us-south1", value: "us-south1" },
|
||||
{ name: "us-west1", value: "us-west1" },
|
||||
{ name: "northamerica-northeast1", value: "northamerica-northeast1" },
|
||||
{ name: "southamerica-east1", value: "southamerica-east1" },
|
||||
{ name: "europe-west4", value: "europe-west4" },
|
||||
{ name: "europe-west9", value: "europe-west9" },
|
||||
{ name: "europe-west2", value: "europe-west2" },
|
||||
{ name: "europe-west3", value: "europe-west3" },
|
||||
{ name: "europe-west1", value: "europe-west1" },
|
||||
{ name: "europe-west6", value: "europe-west6" },
|
||||
{ name: "europe-southwest1", value: "europe-southwest1" },
|
||||
{ name: "europe-west8", value: "europe-west8" },
|
||||
{ name: "europe-north1", value: "europe-north1" },
|
||||
{ name: "europe-central2", value: "europe-central2" },
|
||||
{ name: "asia-northeast1", value: "asia-northeast1" },
|
||||
{ name: "australia-southeast1", value: "australia-southeast1" },
|
||||
{ name: "asia-southeast1", value: "asia-southeast1" },
|
||||
{ name: "asia-northeast3", value: "asia-northeast3" },
|
||||
{ name: "asia-east1", value: "asia-east1" },
|
||||
{ name: "asia-east2", value: "asia-east2" },
|
||||
{ name: "asia-south1", value: "asia-south1" },
|
||||
{ name: "me-central2", value: "me-central2" },
|
||||
{ name: "me-central1", value: "me-central1" },
|
||||
{ name: "me-west1", value: "me-west1" },
|
||||
];
|
||||
|
||||
const VERTEXAI_REGION_NAME = "custom_config.vertex_location";
|
||||
|
||||
interface VertexAIFormValues extends BaseLLMFormValues {
|
||||
custom_config: {
|
||||
vertex_credentials: string;
|
||||
@@ -34,11 +74,14 @@ interface VertexAIFormValues extends BaseLLMFormValues {
|
||||
|
||||
export function VertexAIForm({
|
||||
existingLlmProvider,
|
||||
defaultLlmModel,
|
||||
shouldMarkAsDefault,
|
||||
}: LLMProviderFormProps) {
|
||||
return (
|
||||
<ProviderFormEntrypointWrapper
|
||||
providerName={VERTEXAI_DISPLAY_NAME}
|
||||
providerDisplayName={existingLlmProvider?.name ?? "Gemini"}
|
||||
providerInternalName={"vertex_ai"}
|
||||
providerEndpoint={VERTEXAI_PROVIDER_NAME}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
>
|
||||
@@ -57,17 +100,26 @@ export function VertexAIForm({
|
||||
existingLlmProvider,
|
||||
wellKnownLLMProvider
|
||||
);
|
||||
|
||||
const isAutoMode = existingLlmProvider?.is_auto_mode ?? true;
|
||||
const autoModelDefault =
|
||||
wellKnownLLMProvider?.recommended_default_model?.name ??
|
||||
VERTEXAI_DEFAULT_MODEL;
|
||||
|
||||
const defaultModel = shouldMarkAsDefault
|
||||
? isAutoMode
|
||||
? autoModelDefault
|
||||
: defaultLlmModel?.model_name ?? VERTEXAI_DEFAULT_MODEL
|
||||
: undefined;
|
||||
|
||||
const initialValues: VertexAIFormValues = {
|
||||
...buildDefaultInitialValues(
|
||||
existingLlmProvider,
|
||||
modelConfigurations
|
||||
modelConfigurations,
|
||||
defaultModel
|
||||
),
|
||||
default_model_name:
|
||||
existingLlmProvider?.default_model_name ??
|
||||
wellKnownLLMProvider?.recommended_default_model?.name ??
|
||||
VERTEXAI_DEFAULT_MODEL,
|
||||
// Default to auto mode for new Vertex AI providers
|
||||
is_auto_mode: existingLlmProvider?.is_auto_mode ?? true,
|
||||
is_auto_mode: isAutoMode,
|
||||
custom_config: {
|
||||
vertex_credentials:
|
||||
(existingLlmProvider?.custom_config
|
||||
@@ -129,34 +181,64 @@ export function VertexAIForm({
|
||||
{(formikProps) => {
|
||||
return (
|
||||
<Form className={LLM_FORM_CLASS_NAME}>
|
||||
<LLMFormLayout.Body>
|
||||
<InputSelectField name={VERTEXAI_REGION_NAME}>
|
||||
<Text as="p">Google Cloud Region Name</Text>
|
||||
<InputSelect.Trigger placeholder="Select region" />
|
||||
<InputSelect.Content>
|
||||
{VERTEXAI_REGION_OPTIONS.map((option) => (
|
||||
<InputSelect.Item
|
||||
key={option.value}
|
||||
value={option.value}
|
||||
>
|
||||
{option.name}
|
||||
</InputSelect.Item>
|
||||
))}
|
||||
</InputSelect.Content>
|
||||
</InputSelectField>
|
||||
|
||||
<FormField
|
||||
name="custom_config.vertex_credentials"
|
||||
state={
|
||||
formikProps.errors.custom_config?.vertex_credentials
|
||||
? "error"
|
||||
: "idle"
|
||||
}
|
||||
>
|
||||
<FormField.Label>API Key</FormField.Label>
|
||||
<FormField.Control>
|
||||
<InputFile
|
||||
setValue={(value) =>
|
||||
formikProps.setFieldValue(
|
||||
"custom_config.vertex_credentials",
|
||||
value
|
||||
)
|
||||
}
|
||||
error={
|
||||
!!formikProps.errors.custom_config
|
||||
?.vertex_credentials
|
||||
}
|
||||
onBlur={formikProps.handleBlur}
|
||||
showClearButton={true}
|
||||
disabled={formikProps.isSubmitting}
|
||||
accept="application/json"
|
||||
placeholder="Vertex AI API KEY (JSON)"
|
||||
/>
|
||||
</FormField.Control>
|
||||
</FormField>
|
||||
</LLMFormLayout.Body>
|
||||
|
||||
<DisplayNameField disabled={!!existingLlmProvider} />
|
||||
|
||||
<FileUploadFormField
|
||||
name="custom_config.vertex_credentials"
|
||||
label="Credentials File"
|
||||
subtext="Upload your Google Cloud service account JSON credentials file."
|
||||
/>
|
||||
|
||||
<TextFormField
|
||||
name="custom_config.vertex_location"
|
||||
label="Location"
|
||||
placeholder={VERTEXAI_DEFAULT_LOCATION}
|
||||
subtext="The Google Cloud region for your Vertex AI models (e.g., global, us-east1, us-central1, europe-west1). See [Google's documentation](https://docs.cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#google_model_endpoint_locations) to find the appropriate region for your model."
|
||||
optional
|
||||
/>
|
||||
|
||||
<Separator />
|
||||
<Separator noPadding />
|
||||
|
||||
<DisplayModels
|
||||
modelConfigurations={modelConfigurations}
|
||||
formikProps={formikProps}
|
||||
recommendedDefaultModel={
|
||||
wellKnownLLMProvider?.recommended_default_model ?? null
|
||||
}
|
||||
shouldShowAutoUpdateToggle={true}
|
||||
/>
|
||||
|
||||
<AdvancedOptions formikProps={formikProps} />
|
||||
<ModelAccessOptions />
|
||||
|
||||
<FormActionButtons
|
||||
isTesting={isTesting}
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
"use client";
|
||||
|
||||
import { Card } from "@/refresh-components/cards";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import { DefaultModelSelectorProps } from "../../interfaces";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { useMemo } from "react";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import { setDefaultLlmModel } from "@/lib/admin/llm/svc";
|
||||
|
||||
export function DefaultModelSelector({
|
||||
existingLlmProviders,
|
||||
defaultLlmModel,
|
||||
onModelChange,
|
||||
}: DefaultModelSelectorProps) {
|
||||
// Flatten all models from all providers into a single list
|
||||
const models = useMemo(() => {
|
||||
return existingLlmProviders.flatMap((provider) =>
|
||||
provider.model_configurations
|
||||
.filter((model) => model.is_visible)
|
||||
.map((model) => ({
|
||||
model_display_name: model.display_name || model.name,
|
||||
model_name: model.name,
|
||||
provider_display_name: provider.name || provider.provider,
|
||||
provider_id: provider.id,
|
||||
}))
|
||||
);
|
||||
}, [existingLlmProviders]);
|
||||
|
||||
const modelChangeHandler = (provider_id: number, model_name: string) => {
|
||||
setDefaultLlmModel(provider_id, model_name);
|
||||
onModelChange(provider_id, model_name);
|
||||
};
|
||||
|
||||
// Create a composite value string for the select (provider_id:model_name)
|
||||
// Fall back to the first model if no default is set
|
||||
const currentValue = useMemo(() => {
|
||||
if (defaultLlmModel) {
|
||||
return `${defaultLlmModel.provider_id}:${defaultLlmModel.model_name}`;
|
||||
}
|
||||
const firstModel = models[0];
|
||||
if (firstModel) {
|
||||
return `${firstModel.provider_id}:${firstModel.model_name}`;
|
||||
}
|
||||
return undefined;
|
||||
}, [defaultLlmModel, models]);
|
||||
|
||||
const handleValueChange = (value: string) => {
|
||||
const separatorIndex = value.indexOf(":");
|
||||
const providerId = parseInt(value.slice(0, separatorIndex), 10);
|
||||
const modelName = value.slice(separatorIndex + 1);
|
||||
|
||||
modelChangeHandler(providerId, modelName);
|
||||
};
|
||||
|
||||
return (
|
||||
<Card>
|
||||
<Section
|
||||
flexDirection="row"
|
||||
justifyContent="between"
|
||||
alignItems="center"
|
||||
height="fit"
|
||||
>
|
||||
<InputLayouts.Horizontal
|
||||
title="Default Model"
|
||||
description="This model will be used by Onyx by default in your chats."
|
||||
>
|
||||
<InputSelect value={currentValue} onValueChange={handleValueChange}>
|
||||
<InputSelect.Trigger placeholder="Select a model..." />
|
||||
<InputSelect.Content>
|
||||
{models.map((model) => (
|
||||
<InputSelect.Item
|
||||
key={`${model.provider_id}:${model.model_name}`}
|
||||
value={`${model.provider_id}:${model.model_name}`}
|
||||
description={model.provider_display_name}
|
||||
>
|
||||
{model.model_display_name}
|
||||
</InputSelect.Item>
|
||||
))}
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</InputLayouts.Horizontal>
|
||||
</Section>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
@@ -1,11 +1,25 @@
|
||||
import { ModelConfiguration, SimpleKnownModel } from "../../interfaces";
|
||||
import { FormikProps } from "formik";
|
||||
import { BaseLLMFormValues } from "../formUtils";
|
||||
import { useState } from "react";
|
||||
|
||||
import Checkbox from "@/refresh-components/inputs/Checkbox";
|
||||
import Switch from "@/refresh-components/inputs/Switch";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import IconButton from "@/refresh-components/buttons/IconButton";
|
||||
import { Card } from "@/refresh-components/cards";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import { SvgChevronDown, SvgRefreshCw } from "@opal/icons";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { FieldLabel } from "@/components/Field";
|
||||
import {
|
||||
Collapsible,
|
||||
CollapsibleContent,
|
||||
CollapsibleTrigger,
|
||||
} from "@/refresh-components/Collapsible";
|
||||
import * as GeneralLayouts from "@/layouts/general-layouts";
|
||||
import { SvgEmpty } from "@opal/icons";
|
||||
|
||||
interface AutoModeToggleProps {
|
||||
isAutoMode: boolean;
|
||||
@@ -14,51 +28,173 @@ interface AutoModeToggleProps {
|
||||
|
||||
function AutoModeToggle({ isAutoMode, onToggle }: AutoModeToggleProps) {
|
||||
return (
|
||||
<div className="flex items-center justify-between">
|
||||
<div>
|
||||
<Text as="p" mainUiAction className="block">
|
||||
Auto Update
|
||||
<GeneralLayouts.Section
|
||||
flexDirection="row"
|
||||
justifyContent="between"
|
||||
alignItems="center"
|
||||
>
|
||||
<GeneralLayouts.Section gap={0.125} alignItems="start" width="fit">
|
||||
<GeneralLayouts.Section
|
||||
flexDirection="row"
|
||||
gap={0.375}
|
||||
alignItems="center"
|
||||
width="fit"
|
||||
>
|
||||
<Text as="span" mainUiAction>
|
||||
Auto Update
|
||||
</Text>
|
||||
<Text as="span" secondaryBody text03>
|
||||
(Recommended)
|
||||
</Text>
|
||||
</GeneralLayouts.Section>
|
||||
<Text as="p" secondaryBody text03>
|
||||
Update the available models when new models are released.
|
||||
</Text>
|
||||
<Text as="p" secondaryBody text03 className="block">
|
||||
Automatically update the available models when new models are
|
||||
released. Recommended for most teams.
|
||||
</Text>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
role="switch"
|
||||
aria-checked={isAutoMode}
|
||||
className={cn(
|
||||
"relative inline-flex h-6 w-11 shrink-0 cursor-pointer rounded-full",
|
||||
"border-2 border-transparent transition-colors duration-200 ease-in-out focus:outline-none",
|
||||
isAutoMode ? "bg-action-link-05" : "bg-background-neutral-03"
|
||||
)}
|
||||
onClick={onToggle}
|
||||
>
|
||||
<span
|
||||
className={cn(
|
||||
"pointer-events-none inline-block h-5 w-5 transform rounded-full",
|
||||
"bg-white shadow ring-0 transition duration-200 ease-in-out",
|
||||
isAutoMode ? "translate-x-5" : "translate-x-0"
|
||||
)}
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
</GeneralLayouts.Section>
|
||||
<Switch checked={isAutoMode} onCheckedChange={onToggle} />
|
||||
</GeneralLayouts.Section>
|
||||
);
|
||||
}
|
||||
|
||||
function DisplayModelHeader({ alternativeText }: { alternativeText?: string }) {
|
||||
interface DisplayModelHeaderProps {
|
||||
alternativeText?: string;
|
||||
onSelectAll?: () => void;
|
||||
onRefresh?: () => void;
|
||||
isAutoMode: boolean;
|
||||
showSelectAll?: boolean;
|
||||
}
|
||||
|
||||
function DisplayModelHeader({
|
||||
alternativeText,
|
||||
onSelectAll,
|
||||
onRefresh,
|
||||
isAutoMode,
|
||||
showSelectAll = true,
|
||||
}: DisplayModelHeaderProps) {
|
||||
return (
|
||||
<div>
|
||||
<FieldLabel
|
||||
label="Available Models"
|
||||
subtext={
|
||||
alternativeText ??
|
||||
"Select which models to make available for this provider."
|
||||
}
|
||||
name="_available-models"
|
||||
/>
|
||||
</div>
|
||||
<GeneralLayouts.Section
|
||||
flexDirection="row"
|
||||
justifyContent="between"
|
||||
alignItems="center"
|
||||
>
|
||||
<GeneralLayouts.Section gap={0} alignItems="start" width="fit">
|
||||
<Text as="p" mainContentBody>
|
||||
Models
|
||||
</Text>
|
||||
<Text as="p" secondaryBody text03>
|
||||
{alternativeText ??
|
||||
"Select models to make available for this provider."}
|
||||
</Text>
|
||||
</GeneralLayouts.Section>
|
||||
<GeneralLayouts.Section flexDirection="row" gap={0.25} width="fit">
|
||||
{showSelectAll && (
|
||||
<Button main tertiary onClick={onSelectAll} disabled={isAutoMode}>
|
||||
Select All
|
||||
</Button>
|
||||
)}
|
||||
<IconButton
|
||||
icon={SvgRefreshCw}
|
||||
main
|
||||
internal
|
||||
onClick={onRefresh}
|
||||
tooltip="Refresh models"
|
||||
/>
|
||||
</GeneralLayouts.Section>
|
||||
</GeneralLayouts.Section>
|
||||
);
|
||||
}
|
||||
|
||||
interface ModelRowProps {
|
||||
modelName: string;
|
||||
modelDisplayName?: string;
|
||||
isSelected: boolean;
|
||||
isDefault: boolean;
|
||||
onCheckChange: (checked: boolean) => void;
|
||||
onSetDefault?: () => void;
|
||||
}
|
||||
|
||||
function ModelRow({
|
||||
modelName,
|
||||
modelDisplayName,
|
||||
isSelected,
|
||||
isDefault,
|
||||
onCheckChange,
|
||||
onSetDefault,
|
||||
}: ModelRowProps) {
|
||||
return (
|
||||
<Card
|
||||
variant="borderless"
|
||||
flexDirection="row"
|
||||
justifyContent="between"
|
||||
alignItems="center"
|
||||
className={cn("cursor-pointer group hover:bg-background-tint-01")}
|
||||
padding={0.25}
|
||||
onClick={() => onCheckChange(!isSelected)}
|
||||
>
|
||||
<GeneralLayouts.Section
|
||||
flexDirection="row"
|
||||
gap={0.75}
|
||||
alignItems="center"
|
||||
width="fit"
|
||||
>
|
||||
<Checkbox
|
||||
checked={isSelected}
|
||||
onCheckedChange={(checked) => onCheckChange(checked)}
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
/>
|
||||
<Text
|
||||
as="span"
|
||||
className={cn(
|
||||
"select-none",
|
||||
isSelected ? "text-action-link-04" : "text-text-03"
|
||||
)}
|
||||
>
|
||||
{modelDisplayName ?? modelName}
|
||||
</Text>
|
||||
</GeneralLayouts.Section>
|
||||
{isDefault ? (
|
||||
<Text as="span" secondaryBody className="text-action-link-05">
|
||||
Default Model
|
||||
</Text>
|
||||
) : (
|
||||
onSetDefault && (
|
||||
<Button
|
||||
main
|
||||
tertiary
|
||||
className="opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
onSetDefault();
|
||||
}}
|
||||
>
|
||||
<Text text03>Set as Default</Text>
|
||||
</Button>
|
||||
)
|
||||
)}
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
interface MoreModelsButtonProps {
|
||||
isOpen: boolean;
|
||||
}
|
||||
|
||||
function MoreModelsButton({ isOpen }: MoreModelsButtonProps) {
|
||||
return (
|
||||
<Button
|
||||
main
|
||||
internal
|
||||
transient
|
||||
leftIcon={SvgChevronDown}
|
||||
className={cn(
|
||||
"[&_svg]:transition-transform [&_svg]:duration-200",
|
||||
isOpen && "[&_svg]:rotate-180"
|
||||
)}
|
||||
>
|
||||
<Text as="span" text02>
|
||||
More Models
|
||||
</Text>
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -67,54 +203,35 @@ export function DisplayModels<T extends BaseLLMFormValues>({
|
||||
modelConfigurations,
|
||||
noModelConfigurationsMessage,
|
||||
isLoading,
|
||||
recommendedDefaultModel,
|
||||
shouldShowAutoUpdateToggle,
|
||||
}: {
|
||||
formikProps: FormikProps<T>;
|
||||
modelConfigurations: ModelConfiguration[];
|
||||
noModelConfigurationsMessage?: string;
|
||||
isLoading?: boolean;
|
||||
recommendedDefaultModel: SimpleKnownModel | null;
|
||||
shouldShowAutoUpdateToggle: boolean;
|
||||
}) {
|
||||
const [moreModelsOpen, setMoreModelsOpen] = useState(false);
|
||||
const isAutoMode = formikProps.values.is_auto_mode;
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div>
|
||||
<DisplayModelHeader />
|
||||
<div className="mt-2 flex items-center p-3 border border-border-01 rounded-lg bg-background-neutral-00">
|
||||
<div className="h-5 w-5 animate-spin rounded-full border-2 border-border-03 border-t-action-link-05" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
const defaultModelName = formikProps.values.default_model_name;
|
||||
|
||||
const handleCheckboxChange = (modelName: string, checked: boolean) => {
|
||||
// Read current values inside the handler to avoid stale closure issues
|
||||
if (!checked && modelName === defaultModelName) {
|
||||
return;
|
||||
}
|
||||
|
||||
const currentSelected = formikProps.values.selected_model_names ?? [];
|
||||
const currentDefault = formikProps.values.default_model_name;
|
||||
|
||||
if (checked) {
|
||||
const newSelected = [...currentSelected, modelName];
|
||||
formikProps.setFieldValue("selected_model_names", newSelected);
|
||||
// If this is the first model, set it as default
|
||||
if (currentSelected.length === 0) {
|
||||
formikProps.setFieldValue("default_model_name", modelName);
|
||||
}
|
||||
} else {
|
||||
const newSelected = currentSelected.filter((name) => name !== modelName);
|
||||
formikProps.setFieldValue("selected_model_names", newSelected);
|
||||
// If removing the default, set the first remaining model as default
|
||||
if (currentDefault === modelName && newSelected.length > 0) {
|
||||
formikProps.setFieldValue("default_model_name", newSelected[0]);
|
||||
} else if (newSelected.length === 0) {
|
||||
formikProps.setFieldValue("default_model_name", null);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const handleSetDefault = (modelName: string) => {
|
||||
const onSetDefault = (modelName: string) => {
|
||||
formikProps.setFieldValue("default_model_name", modelName);
|
||||
};
|
||||
|
||||
@@ -124,19 +241,31 @@ export function DisplayModels<T extends BaseLLMFormValues>({
|
||||
"selected_model_names",
|
||||
modelConfigurations.filter((m) => m.is_visible).map((m) => m.name)
|
||||
);
|
||||
formikProps.setFieldValue(
|
||||
"default_model_name",
|
||||
recommendedDefaultModel?.name ?? ""
|
||||
);
|
||||
};
|
||||
|
||||
const handleSelectAll = () => {
|
||||
const allModelNames = modelConfigurations.map((m) => m.name);
|
||||
formikProps.setFieldValue("selected_model_names", allModelNames);
|
||||
};
|
||||
|
||||
const handleRefresh = () => {
|
||||
// Trigger a refresh of models - this would need to be wired up to your fetch logic
|
||||
// For now, this is a placeholder
|
||||
};
|
||||
|
||||
const selectedModels = formikProps.values.selected_model_names ?? [];
|
||||
const defaultModel = formikProps.values.default_model_name;
|
||||
|
||||
// Sort models: default first, then selected, then unselected
|
||||
const primaryModels = modelConfigurations.filter((m) =>
|
||||
selectedModels.includes(m.name)
|
||||
);
|
||||
const moreModels = modelConfigurations.filter(
|
||||
(m) => !selectedModels.includes(m.name)
|
||||
);
|
||||
|
||||
const sortedModelConfigurations = [...modelConfigurations].sort((a, b) => {
|
||||
const aIsDefault = a.name === defaultModel;
|
||||
const bIsDefault = b.name === defaultModel;
|
||||
const aIsDefault = a.name === defaultModelName;
|
||||
const bIsDefault = b.name === defaultModelName;
|
||||
const aIsSelected = selectedModels.includes(a.name);
|
||||
const bIsSelected = selectedModels.includes(b.name);
|
||||
|
||||
@@ -147,166 +276,136 @@ export function DisplayModels<T extends BaseLLMFormValues>({
|
||||
return 0;
|
||||
});
|
||||
|
||||
if (modelConfigurations.length === 0) {
|
||||
return (
|
||||
<div>
|
||||
<DisplayModelHeader
|
||||
alternativeText={noModelConfigurationsMessage ?? "No models found"}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Sort auto mode models: default model first
|
||||
// For auto mode display
|
||||
const visibleModels = modelConfigurations.filter((m) => m.is_visible);
|
||||
const sortedAutoModels = [...visibleModels].sort((a, b) => {
|
||||
const aIsDefault = a.name === defaultModel;
|
||||
const bIsDefault = b.name === defaultModel;
|
||||
if (aIsDefault && !bIsDefault) return -1;
|
||||
if (!aIsDefault && bIsDefault) return 1;
|
||||
return 0;
|
||||
});
|
||||
|
||||
const primaryAutoModels = visibleModels.filter((m) =>
|
||||
selectedModels.includes(m.name)
|
||||
);
|
||||
const moreAutoModels = visibleModels.filter(
|
||||
(m) => !selectedModels.includes(m.name)
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-3">
|
||||
<DisplayModelHeader />
|
||||
<div className="border border-border-01 rounded-lg p-3">
|
||||
{shouldShowAutoUpdateToggle && (
|
||||
<AutoModeToggle
|
||||
isAutoMode={isAutoMode}
|
||||
onToggle={handleToggleAutoMode}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Model list section */}
|
||||
<div
|
||||
className={cn(
|
||||
"flex flex-col gap-1",
|
||||
shouldShowAutoUpdateToggle && "mt-3 pt-3 border-t border-border-01"
|
||||
)}
|
||||
>
|
||||
<Card variant="borderless">
|
||||
<DisplayModelHeader
|
||||
onSelectAll={handleSelectAll}
|
||||
onRefresh={handleRefresh}
|
||||
showSelectAll={modelConfigurations.length > 0}
|
||||
isAutoMode={isAutoMode}
|
||||
/>
|
||||
{modelConfigurations.length > 0 ? (
|
||||
<Card variant="borderless" padding={0} gap={0}>
|
||||
{isAutoMode && shouldShowAutoUpdateToggle ? (
|
||||
// Auto mode: read-only display
|
||||
<div className="flex flex-col gap-2">
|
||||
{sortedAutoModels.map((model) => {
|
||||
const isDefault = model.name === defaultModel;
|
||||
<>
|
||||
{primaryAutoModels.map((model) => {
|
||||
const isDefault = model.name === defaultModelName;
|
||||
return (
|
||||
<div
|
||||
<ModelRow
|
||||
key={model.name}
|
||||
className={cn(
|
||||
"flex items-center justify-between gap-3 rounded-lg border p-1",
|
||||
"bg-background-neutral-00",
|
||||
isDefault ? "border-action-link-05" : "border-border-01"
|
||||
)}
|
||||
>
|
||||
<div className="flex flex-1 items-center gap-2 px-2 py-1">
|
||||
<div
|
||||
className={cn(
|
||||
"size-2 shrink-0 rounded-full",
|
||||
isDefault
|
||||
? "bg-action-link-05"
|
||||
: "bg-background-neutral-03"
|
||||
)}
|
||||
/>
|
||||
<div className="flex flex-col gap-0.5">
|
||||
<Text mainUiAction text05>
|
||||
{model.display_name || model.name}
|
||||
</Text>
|
||||
{model.display_name && (
|
||||
<Text secondaryBody text03>
|
||||
{model.name}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
{isDefault && (
|
||||
<div className="flex items-center justify-end pr-2">
|
||||
<Text
|
||||
secondaryBody
|
||||
className="text-action-text-link-05"
|
||||
>
|
||||
Default
|
||||
</Text>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
modelName={model.name}
|
||||
modelDisplayName={model.display_name}
|
||||
isSelected={true}
|
||||
isDefault={isDefault}
|
||||
onCheckChange={() => {}}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
// Manual mode: checkbox selection
|
||||
<div
|
||||
className={cn(
|
||||
"flex flex-col gap-1",
|
||||
"max-h-48 4xl:max-h-64",
|
||||
"overflow-y-auto"
|
||||
)}
|
||||
>
|
||||
{sortedModelConfigurations.map((modelConfiguration) => {
|
||||
<>
|
||||
{primaryModels.map((modelConfiguration) => {
|
||||
const isSelected = selectedModels.includes(
|
||||
modelConfiguration.name
|
||||
);
|
||||
const isDefault = defaultModel === modelConfiguration.name;
|
||||
const isDefault = defaultModelName === modelConfiguration.name;
|
||||
|
||||
return (
|
||||
<div
|
||||
<ModelRow
|
||||
key={modelConfiguration.name}
|
||||
className="flex items-center justify-between py-1.5 px-2 rounded hover:bg-background-neutral-subtle"
|
||||
>
|
||||
<div
|
||||
className="flex items-center gap-2 cursor-pointer"
|
||||
onClick={() =>
|
||||
handleCheckboxChange(
|
||||
modelConfiguration.name,
|
||||
!isSelected
|
||||
)
|
||||
}
|
||||
>
|
||||
<div
|
||||
className="flex items-center"
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
>
|
||||
<Checkbox
|
||||
checked={isSelected}
|
||||
onCheckedChange={(checked) =>
|
||||
handleCheckboxChange(
|
||||
modelConfiguration.name,
|
||||
checked
|
||||
)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
<Text
|
||||
as="p"
|
||||
secondaryBody
|
||||
className="select-none leading-none"
|
||||
>
|
||||
{modelConfiguration.name}
|
||||
</Text>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
disabled={!isSelected}
|
||||
onClick={() => handleSetDefault(modelConfiguration.name)}
|
||||
className={`text-xs px-2 py-0.5 rounded transition-all duration-200 ease-in-out ${
|
||||
isSelected
|
||||
? "opacity-100 translate-x-0"
|
||||
: "opacity-0 translate-x-2 pointer-events-none"
|
||||
} ${
|
||||
isDefault
|
||||
? "bg-action-link-05 text-text-inverse font-medium scale-100"
|
||||
: "bg-background-neutral-02 text-text-03 hover:bg-background-neutral-03 scale-95 hover:scale-100"
|
||||
}`}
|
||||
>
|
||||
{isDefault ? "Default" : "Set as default"}
|
||||
</button>
|
||||
</div>
|
||||
modelName={modelConfiguration.name}
|
||||
modelDisplayName={modelConfiguration.display_name}
|
||||
isSelected={isSelected}
|
||||
isDefault={isDefault}
|
||||
onCheckChange={(checked) =>
|
||||
handleCheckboxChange(modelConfiguration.name, checked)
|
||||
}
|
||||
onSetDefault={
|
||||
onSetDefault
|
||||
? () => onSetDefault(modelConfiguration.name)
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{moreModels.length > 0 && (
|
||||
<Collapsible open={moreModelsOpen} onOpenChange={setMoreModelsOpen}>
|
||||
<CollapsibleTrigger asChild>
|
||||
<Card variant="borderless" padding={0}>
|
||||
<MoreModelsButton isOpen={moreModelsOpen} />
|
||||
</Card>
|
||||
</CollapsibleTrigger>
|
||||
<CollapsibleContent>
|
||||
<GeneralLayouts.Section gap={0.25} alignItems="start">
|
||||
{moreModels.map((modelConfiguration) => {
|
||||
const isSelected = selectedModels.includes(
|
||||
modelConfiguration.name
|
||||
);
|
||||
const isDefault =
|
||||
defaultModelName === modelConfiguration.name;
|
||||
|
||||
return (
|
||||
<ModelRow
|
||||
key={modelConfiguration.name}
|
||||
modelName={modelConfiguration.name}
|
||||
modelDisplayName={modelConfiguration.display_name}
|
||||
isSelected={isSelected}
|
||||
isDefault={isDefault}
|
||||
onCheckChange={(checked) =>
|
||||
handleCheckboxChange(modelConfiguration.name, checked)
|
||||
}
|
||||
onSetDefault={() =>
|
||||
onSetDefault(modelConfiguration.name)
|
||||
}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</GeneralLayouts.Section>
|
||||
</CollapsibleContent>
|
||||
</Collapsible>
|
||||
)}
|
||||
|
||||
{/* Auto update toggle */}
|
||||
{shouldShowAutoUpdateToggle && (
|
||||
<Card variant="borderless" padding={0.75} gap={0.75}>
|
||||
<Separator noPadding />
|
||||
<AutoModeToggle
|
||||
isAutoMode={isAutoMode}
|
||||
onToggle={handleToggleAutoMode}
|
||||
/>
|
||||
</Card>
|
||||
)}
|
||||
</Card>
|
||||
) : (
|
||||
<Card variant="tertiary">
|
||||
<GeneralLayouts.Section
|
||||
gap={0.5}
|
||||
flexDirection="row"
|
||||
alignItems="center"
|
||||
justifyContent="start"
|
||||
>
|
||||
<SvgEmpty className="w-4 h-4 line-item-icon-muted" />
|
||||
<Text text03 secondaryBody>
|
||||
{noModelConfigurationsMessage ?? "No models found"}
|
||||
</Text>
|
||||
</GeneralLayouts.Section>
|
||||
</Card>
|
||||
)}
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
import { TextFormField } from "@/components/Field";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import * as GeneralLayouts from "@/layouts/general-layouts";
|
||||
import InputWrapper from "./InputWrapper";
|
||||
|
||||
interface DisplayNameFieldProps {
|
||||
disabled?: boolean;
|
||||
@@ -6,12 +9,17 @@ interface DisplayNameFieldProps {
|
||||
|
||||
export function DisplayNameField({ disabled = false }: DisplayNameFieldProps) {
|
||||
return (
|
||||
<TextFormField
|
||||
name="name"
|
||||
<InputWrapper
|
||||
label="Display Name"
|
||||
subtext="A name which you can use to identify this provider when selecting it in the UI."
|
||||
placeholder="Display Name"
|
||||
disabled={disabled}
|
||||
/>
|
||||
optional
|
||||
description="Use to identify this provider in the app"
|
||||
>
|
||||
<TextFormField
|
||||
name="name"
|
||||
label=""
|
||||
placeholder="Display Name"
|
||||
disabled={disabled}
|
||||
/>
|
||||
</InputWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -2,13 +2,14 @@ import { LoadingAnimation } from "@/components/Loading";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { SvgTrash } from "@opal/icons";
|
||||
import { LLMProviderView } from "../../interfaces";
|
||||
import { DefaultModel, LLMProviderView } from "../../interfaces";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "../../constants";
|
||||
|
||||
interface FormActionButtonsProps {
|
||||
isTesting: boolean;
|
||||
testError: string;
|
||||
existingLlmProvider?: LLMProviderView;
|
||||
defaultModel?: DefaultModel;
|
||||
mutate: (key: string) => void;
|
||||
onClose: () => void;
|
||||
isFormValid: boolean;
|
||||
@@ -18,6 +19,7 @@ export function FormActionButtons({
|
||||
isTesting,
|
||||
testError,
|
||||
existingLlmProvider,
|
||||
defaultModel,
|
||||
mutate,
|
||||
onClose,
|
||||
isFormValid,
|
||||
@@ -39,16 +41,22 @@ export function FormActionButtons({
|
||||
}
|
||||
|
||||
// If the deleted provider was the default, set the first remaining provider as default
|
||||
if (existingLlmProvider.is_default_provider) {
|
||||
if (defaultModel?.provider_id === existingLlmProvider.id) {
|
||||
const remainingProvidersResponse = await fetch(LLM_PROVIDERS_ADMIN_URL);
|
||||
if (remainingProvidersResponse.ok) {
|
||||
const remainingProviders = await remainingProvidersResponse.json();
|
||||
const remainingProvidersJson = await remainingProvidersResponse.json();
|
||||
const remainingProviders = remainingProvidersJson.providers;
|
||||
|
||||
if (remainingProviders.length > 0) {
|
||||
const setDefaultResponse = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}/${remainingProviders[0].id}/default`,
|
||||
`${LLM_PROVIDERS_ADMIN_URL}/default`,
|
||||
{
|
||||
method: "POST",
|
||||
body: JSON.stringify({
|
||||
provider_id: remainingProviders[0].id,
|
||||
model_name:
|
||||
remainingProviders[0].model_configurations[0]?.name ?? "",
|
||||
}),
|
||||
}
|
||||
);
|
||||
if (!setDefaultResponse.ok) {
|
||||
|
||||
@@ -0,0 +1,184 @@
|
||||
"use client";
|
||||
|
||||
import React from "react";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import { LoadingAnimation } from "@/components/Loading";
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
/**
|
||||
* LLMFormLayout - A compound component for LLM provider configuration forms
|
||||
*
|
||||
* Provides a complete modal form structure with:
|
||||
* - Modal: The modal wrapper with header (icon, title, close button)
|
||||
* - Body: Form content area
|
||||
* - Footer: Cancel and Connect/Update buttons
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* <LLMFormLayout.Modal
|
||||
* icon={OpenAIIcon}
|
||||
* displayName="OpenAI"
|
||||
* providerName="openai"
|
||||
* onClose={handleClose}
|
||||
* isEditing={false}
|
||||
* >
|
||||
* <Formik ...>
|
||||
* {(formikProps) => (
|
||||
* <Form>
|
||||
* <LLMFormLayout.Body>
|
||||
* <InputField name="api_key" label="API Key" />
|
||||
* </LLMFormLayout.Body>
|
||||
* <LLMFormLayout.Footer
|
||||
* onCancel={handleClose}
|
||||
* isSubmitting={isSubmitting}
|
||||
* />
|
||||
* </Form>
|
||||
* )}
|
||||
* </Formik>
|
||||
* </LLMFormLayout.Modal>
|
||||
* ```
|
||||
*/
|
||||
|
||||
// ============================================================================
|
||||
// Modal (Root)
|
||||
// ============================================================================
|
||||
|
||||
interface LLMFormModalProps {
|
||||
children: React.ReactNode;
|
||||
/** Icon component for the provider */
|
||||
icon: React.FunctionComponent<IconProps>;
|
||||
/** Display name shown in the modal title */
|
||||
displayName: string;
|
||||
/** Whether editing an existing provider (changes title) */
|
||||
isEditing?: boolean;
|
||||
/** Custom name of the existing provider (shown in title when editing) */
|
||||
existingName?: string;
|
||||
/** Handler for closing the modal */
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
function LLMFormModal({
|
||||
children,
|
||||
icon: Icon,
|
||||
displayName,
|
||||
isEditing = false,
|
||||
existingName,
|
||||
onClose,
|
||||
}: LLMFormModalProps) {
|
||||
const title = isEditing
|
||||
? `Configure ${existingName ? `"${existingName}"` : displayName}`
|
||||
: `Setup up ${displayName}`;
|
||||
|
||||
const description = isEditing
|
||||
? ""
|
||||
: `Connect to ${displayName} to set up your ${displayName} models.`;
|
||||
|
||||
return (
|
||||
<Modal open onOpenChange={onClose}>
|
||||
<Modal.Content width="md-sm">
|
||||
<Modal.Header
|
||||
icon={Icon}
|
||||
title={title}
|
||||
description={description}
|
||||
onClose={onClose}
|
||||
/>
|
||||
<Modal.Body alignItems="stretch">{children}</Modal.Body>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Body
|
||||
// ============================================================================
|
||||
|
||||
interface LLMFormBodyProps {
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
function LLMFormBody({ children }: LLMFormBodyProps) {
|
||||
return <>{children}</>;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Footer
|
||||
// ============================================================================
|
||||
|
||||
interface LLMFormFooterProps {
|
||||
/** Handler for cancel button */
|
||||
onCancel: () => void;
|
||||
/** Label for the cancel button. Default: "Cancel" */
|
||||
cancelLabel?: string;
|
||||
/** Label for the submit button. Default: "Connect" */
|
||||
submitLabel?: string;
|
||||
/** Text to show while submitting. Default: "Testing" */
|
||||
submittingLabel?: string;
|
||||
/** Whether the form is currently submitting */
|
||||
isSubmitting?: boolean;
|
||||
/** Whether the submit button should be disabled */
|
||||
isSubmitDisabled?: boolean;
|
||||
/** Optional left-side content (e.g., delete button) */
|
||||
leftChildren?: React.ReactNode;
|
||||
/** Error message to display */
|
||||
error?: string;
|
||||
}
|
||||
|
||||
function LLMFormFooter({
|
||||
onCancel,
|
||||
cancelLabel = "Cancel",
|
||||
submitLabel = "Connect",
|
||||
submittingLabel = "Testing",
|
||||
isSubmitting = false,
|
||||
isSubmitDisabled = false,
|
||||
leftChildren,
|
||||
error,
|
||||
}: LLMFormFooterProps) {
|
||||
return (
|
||||
<Section alignItems="stretch" gap={0.5}>
|
||||
{error && (
|
||||
<Text as="p" className="text-error">
|
||||
{error}
|
||||
</Text>
|
||||
)}
|
||||
<Section flexDirection="row" justifyContent="between" gap={0.5}>
|
||||
<Section
|
||||
width="fit"
|
||||
flexDirection="row"
|
||||
justifyContent="start"
|
||||
gap={0.5}
|
||||
>
|
||||
{leftChildren}
|
||||
</Section>
|
||||
<Section width="fit" flexDirection="row" justifyContent="end" gap={0.5}>
|
||||
<Button secondary onClick={onCancel} disabled={isSubmitting}>
|
||||
{cancelLabel}
|
||||
</Button>
|
||||
<Button type="submit" disabled={isSubmitting || isSubmitDisabled}>
|
||||
{isSubmitting ? (
|
||||
<Text as="span" inverted>
|
||||
<LoadingAnimation text={submittingLabel} />
|
||||
</Text>
|
||||
) : (
|
||||
submitLabel
|
||||
)}
|
||||
</Button>
|
||||
</Section>
|
||||
</Section>
|
||||
</Section>
|
||||
);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Export
|
||||
// ============================================================================
|
||||
|
||||
const LLMFormLayout = {
|
||||
Modal: LLMFormModal,
|
||||
Body: LLMFormBody,
|
||||
Footer: LLMFormFooter,
|
||||
};
|
||||
|
||||
export default LLMFormLayout;
|
||||
@@ -8,13 +8,13 @@ import {
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "../../interfaces";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { SvgSettings } from "@opal/icons";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "../../constants";
|
||||
import { SvgArrowExchange, SvgSettings } from "@opal/icons";
|
||||
import { Card } from "@/refresh-components/cards";
|
||||
import * as GeneralLayouts from "@/layouts/general-layouts";
|
||||
import { ProviderIcon } from "../../ProviderIcon";
|
||||
import IconButton from "@/refresh-components/buttons/IconButton";
|
||||
import LineItem from "@/refresh-components/buttons/LineItem";
|
||||
import LLMFormLayout from "./FormLayout";
|
||||
|
||||
export interface ProviderFormContext {
|
||||
onClose: () => void;
|
||||
@@ -31,21 +31,19 @@ export interface ProviderFormContext {
|
||||
interface ProviderFormEntrypointWrapperProps {
|
||||
children: (context: ProviderFormContext) => ReactNode;
|
||||
providerName: string;
|
||||
providerDisplayName?: string;
|
||||
providerInternalName?: string;
|
||||
providerEndpoint?: string;
|
||||
existingLlmProvider?: LLMProviderView;
|
||||
/** When true, renders a simple button instead of a card-based UI */
|
||||
buttonMode?: boolean;
|
||||
/** Custom button text for buttonMode (defaults to "Add {providerName}") */
|
||||
buttonText?: string;
|
||||
}
|
||||
|
||||
export function ProviderFormEntrypointWrapper({
|
||||
children,
|
||||
providerName,
|
||||
providerDisplayName,
|
||||
providerInternalName,
|
||||
providerEndpoint,
|
||||
existingLlmProvider,
|
||||
buttonMode,
|
||||
buttonText,
|
||||
}: ProviderFormEntrypointWrapperProps) {
|
||||
const [formIsVisible, setFormIsVisible] = useState(false);
|
||||
|
||||
@@ -67,31 +65,6 @@ export function ProviderFormEntrypointWrapper({
|
||||
|
||||
const onClose = () => setFormIsVisible(false);
|
||||
|
||||
async function handleSetAsDefault(): Promise<void> {
|
||||
if (!existingLlmProvider) return;
|
||||
|
||||
const response = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}/default`,
|
||||
{
|
||||
method: "POST",
|
||||
}
|
||||
);
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
setPopup({
|
||||
type: "error",
|
||||
message: `Failed to set provider as default: ${errorMsg}`,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
await mutate(LLM_PROVIDERS_ADMIN_URL);
|
||||
setPopup({
|
||||
type: "success",
|
||||
message: "Provider set as default successfully!",
|
||||
});
|
||||
}
|
||||
|
||||
const context: ProviderFormContext = {
|
||||
onClose,
|
||||
mutate,
|
||||
@@ -104,113 +77,51 @@ export function ProviderFormEntrypointWrapper({
|
||||
wellKnownLLMProvider,
|
||||
};
|
||||
|
||||
// Button mode: simple button that opens a modal
|
||||
if (buttonMode && !existingLlmProvider) {
|
||||
return (
|
||||
<>
|
||||
{popup}
|
||||
<Button action onClick={() => setFormIsVisible(true)}>
|
||||
{buttonText ?? `Add ${providerName}`}
|
||||
</Button>
|
||||
const displayName = providerDisplayName ?? providerName;
|
||||
const internalName = providerInternalName ?? providerName;
|
||||
|
||||
{formIsVisible && (
|
||||
<Modal open onOpenChange={onClose}>
|
||||
<Modal.Content>
|
||||
<Modal.Header
|
||||
icon={SvgSettings}
|
||||
title={`Setup ${providerName}`}
|
||||
onClose={onClose}
|
||||
/>
|
||||
<Modal.Body>{children(context)}</Modal.Body>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
// Card mode: card-based UI with modal
|
||||
return (
|
||||
<div>
|
||||
<>
|
||||
{popup}
|
||||
<div className="border p-3 bg-background-neutral-01 rounded-16 w-96 flex shadow-md">
|
||||
<Card padding={0}>
|
||||
{existingLlmProvider ? (
|
||||
<>
|
||||
<div className="my-auto">
|
||||
<Text
|
||||
as="p"
|
||||
headingH3
|
||||
text04
|
||||
className="text-ellipsis overflow-hidden max-w-32"
|
||||
>
|
||||
{existingLlmProvider.name}
|
||||
</Text>
|
||||
<Text as="p" secondaryBody text03 className="italic">
|
||||
({providerName})
|
||||
</Text>
|
||||
{!existingLlmProvider.is_default_provider && (
|
||||
<Text
|
||||
as="p"
|
||||
className={cn("text-action-link-05", "cursor-pointer")}
|
||||
onClick={handleSetAsDefault}
|
||||
>
|
||||
Set as default
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{existingLlmProvider && (
|
||||
<div className="my-auto ml-3">
|
||||
{existingLlmProvider.is_default_provider ? (
|
||||
<Badge variant="agent">Default</Badge>
|
||||
) : (
|
||||
<Badge variant="success">Enabled</Badge>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="ml-auto my-auto">
|
||||
<Button
|
||||
action={!existingLlmProvider}
|
||||
secondary={!!existingLlmProvider}
|
||||
<GeneralLayouts.CardItemLayout
|
||||
icon={() => <ProviderIcon provider={internalName} size={24} />}
|
||||
title={displayName}
|
||||
description={providerName}
|
||||
rightChildren={
|
||||
<IconButton
|
||||
icon={SvgSettings}
|
||||
internal
|
||||
onClick={() => setFormIsVisible(true)}
|
||||
>
|
||||
Edit
|
||||
</Button>
|
||||
</div>
|
||||
</>
|
||||
/>
|
||||
}
|
||||
/>
|
||||
) : (
|
||||
<>
|
||||
<div className="my-auto">
|
||||
<Text as="p" headingH3>
|
||||
{providerName}
|
||||
</Text>
|
||||
</div>
|
||||
<div className="ml-auto my-auto">
|
||||
<Button action onClick={() => setFormIsVisible(true)}>
|
||||
Set up
|
||||
</Button>
|
||||
</div>
|
||||
</>
|
||||
<GeneralLayouts.CardItemLayout
|
||||
icon={() => <ProviderIcon provider={internalName} size={24} />}
|
||||
title={displayName}
|
||||
description={providerName}
|
||||
rightChildren={
|
||||
<LineItem
|
||||
children="Connect"
|
||||
icon={SvgArrowExchange}
|
||||
onClick={() => setFormIsVisible(true)}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
{formIsVisible && (
|
||||
<Modal open onOpenChange={onClose}>
|
||||
<Modal.Content>
|
||||
<Modal.Header
|
||||
icon={SvgSettings}
|
||||
title={`${existingLlmProvider ? "Configure" : "Setup"} ${
|
||||
existingLlmProvider?.name
|
||||
? `"${existingLlmProvider.name}"`
|
||||
: providerName
|
||||
}`}
|
||||
onClose={onClose}
|
||||
/>
|
||||
<Modal.Body>{children(context)}</Modal.Body>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
<LLMFormLayout.Modal
|
||||
icon={() => <ProviderIcon provider={internalName} size={24} />}
|
||||
displayName={displayName}
|
||||
onClose={onClose}
|
||||
>
|
||||
{children(context)}
|
||||
</LLMFormLayout.Modal>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
import * as GeneralLayouts from "@/layouts/general-layouts";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
|
||||
interface DescriptionLink {
|
||||
text: string;
|
||||
href: string;
|
||||
}
|
||||
|
||||
interface InputWrapperProps {
|
||||
label: string;
|
||||
description?: string;
|
||||
descriptionLink?: DescriptionLink;
|
||||
children: React.ReactNode;
|
||||
optional?: boolean;
|
||||
}
|
||||
|
||||
export default function InputWrapper({
|
||||
label,
|
||||
optional,
|
||||
description,
|
||||
descriptionLink,
|
||||
children,
|
||||
}: InputWrapperProps) {
|
||||
const renderDescription = () => {
|
||||
if (!description) return null;
|
||||
|
||||
if (descriptionLink && description.includes("{link}")) {
|
||||
const [before, after] = description.split("{link}");
|
||||
return (
|
||||
<Text as="p" secondaryBody text03>
|
||||
{before}
|
||||
<a
|
||||
href={descriptionLink.href}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="underline"
|
||||
>
|
||||
{descriptionLink.text}
|
||||
</a>
|
||||
{after}
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Text as="p" secondaryBody text03>
|
||||
{description}
|
||||
</Text>
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<GeneralLayouts.Section
|
||||
flexDirection="column"
|
||||
alignItems="start"
|
||||
gap={0.25}
|
||||
>
|
||||
<GeneralLayouts.Section
|
||||
flexDirection="row"
|
||||
gap={0.25}
|
||||
alignItems="center"
|
||||
justifyContent="start"
|
||||
>
|
||||
<Text as="p">{label}</Text>
|
||||
{optional && (
|
||||
<Text as="p" text03>
|
||||
(Optional)
|
||||
</Text>
|
||||
)}
|
||||
</GeneralLayouts.Section>
|
||||
{children}
|
||||
{renderDescription()}
|
||||
</GeneralLayouts.Section>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
"use client";
|
||||
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import InputSelectField from "@/refresh-components/form/InputSelectField";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import { useState } from "react";
|
||||
import { SvgOrganization } from "@opal/icons";
|
||||
import * as GeneralLayouts from "@/layouts/general-layouts";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Card } from "@/refresh-components/cards";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
|
||||
enum ModelAccessOption {
|
||||
ALL = "all", // All users and agents
|
||||
NAMED = "named", // Only named users and agents
|
||||
}
|
||||
|
||||
export function ModelAccessOptions() {
|
||||
const [modelAccessOption, setModelAccessOption] = useState<ModelAccessOption>(
|
||||
ModelAccessOption.ALL
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<InputLayouts.Horizontal
|
||||
title="Model Access"
|
||||
description="Who can access this provider."
|
||||
>
|
||||
<InputSelectField
|
||||
name="model_access_options"
|
||||
defaultValue={modelAccessOption}
|
||||
onValueChange={(value) =>
|
||||
setModelAccessOption(value as ModelAccessOption)
|
||||
}
|
||||
>
|
||||
<InputSelect.Trigger />
|
||||
<InputSelect.Content>
|
||||
<InputSelect.Item value={ModelAccessOption.ALL}>
|
||||
All users and agents
|
||||
</InputSelect.Item>
|
||||
<InputSelect.Item value={ModelAccessOption.NAMED}>
|
||||
Named users and agents
|
||||
</InputSelect.Item>
|
||||
</InputSelect.Content>
|
||||
</InputSelectField>
|
||||
</InputLayouts.Horizontal>
|
||||
|
||||
{modelAccessOption === ModelAccessOption.NAMED && (
|
||||
<NamedModelAccessOptions />
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
function NamedModelAccessOptions() {
|
||||
return (
|
||||
<Card>
|
||||
<InputTypeIn
|
||||
variant="primary"
|
||||
placeholder="Add users, groups, accounts, and agents"
|
||||
/>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
@@ -3,30 +3,26 @@ import {
|
||||
ModelConfiguration,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "../interfaces";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "../constants";
|
||||
import { LLM_PROVIDERS_ADMIN_URL, LLM_ADMIN_URL } from "../constants";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import * as Yup from "yup";
|
||||
import isEqual from "lodash/isEqual";
|
||||
import { setDefaultLlmModel } from "@/lib/admin/llm/svc";
|
||||
|
||||
// Common class names for the Form component across all LLM provider forms
|
||||
export const LLM_FORM_CLASS_NAME = "flex flex-col gap-y-4 items-stretch mt-6";
|
||||
|
||||
export const buildDefaultInitialValues = (
|
||||
existingLlmProvider?: LLMProviderView,
|
||||
modelConfigurations?: ModelConfiguration[]
|
||||
modelConfigurations?: ModelConfiguration[],
|
||||
defaultModelName?: string
|
||||
) => {
|
||||
const defaultModelName =
|
||||
existingLlmProvider?.default_model_name ??
|
||||
modelConfigurations?.[0]?.name ??
|
||||
"";
|
||||
|
||||
// Auto mode must be explicitly enabled by the user
|
||||
// Default to false for new providers, preserve existing value when editing
|
||||
const isAutoMode = existingLlmProvider?.is_auto_mode ?? false;
|
||||
|
||||
return {
|
||||
name: existingLlmProvider?.name || "",
|
||||
default_model_name: defaultModelName,
|
||||
is_public: existingLlmProvider?.is_public ?? true,
|
||||
is_auto_mode: isAutoMode,
|
||||
groups: existingLlmProvider?.groups ?? [],
|
||||
@@ -38,18 +34,19 @@ export const buildDefaultInitialValues = (
|
||||
: modelConfigurations
|
||||
?.filter((modelConfiguration) => modelConfiguration.is_visible)
|
||||
.map((modelConfiguration) => modelConfiguration.name) ?? [],
|
||||
default_model_name: defaultModelName,
|
||||
};
|
||||
};
|
||||
|
||||
export const buildDefaultValidationSchema = () => {
|
||||
return Yup.object({
|
||||
name: Yup.string().required("Display Name is required"),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
is_public: Yup.boolean().required(),
|
||||
is_auto_mode: Yup.boolean().required(),
|
||||
groups: Yup.array().of(Yup.number()),
|
||||
personas: Yup.array().of(Yup.number()),
|
||||
selected_model_names: Yup.array().of(Yup.string()),
|
||||
default_model_name: Yup.string().optional(),
|
||||
});
|
||||
};
|
||||
|
||||
@@ -81,13 +78,13 @@ export interface BaseLLMFormValues {
|
||||
name: string;
|
||||
api_key?: string;
|
||||
api_base?: string;
|
||||
default_model_name?: string;
|
||||
is_public: boolean;
|
||||
is_auto_mode: boolean;
|
||||
groups: number[];
|
||||
personas: number[];
|
||||
selected_model_names: string[];
|
||||
custom_config?: Record<string, string>;
|
||||
default_model_name?: string;
|
||||
}
|
||||
|
||||
export interface SubmitLLMProviderParams<
|
||||
@@ -110,8 +107,7 @@ export interface SubmitLLMProviderParams<
|
||||
|
||||
export const filterModelConfigurations = (
|
||||
currentModelConfigurations: ModelConfiguration[],
|
||||
visibleModels: string[],
|
||||
defaultModelName?: string
|
||||
visibleModels: string[]
|
||||
): ModelConfiguration[] => {
|
||||
return currentModelConfigurations
|
||||
.map(
|
||||
@@ -123,11 +119,15 @@ export const filterModelConfigurations = (
|
||||
display_name: modelConfiguration.display_name,
|
||||
})
|
||||
)
|
||||
.filter(
|
||||
(modelConfiguration) =>
|
||||
modelConfiguration.name === defaultModelName ||
|
||||
modelConfiguration.is_visible
|
||||
);
|
||||
.filter((modelConfiguration) => modelConfiguration.is_visible);
|
||||
};
|
||||
|
||||
const getFirstVisibleModelConfiguration = (
|
||||
modelConfigurations: ModelConfiguration[]
|
||||
): ModelConfiguration | undefined => {
|
||||
return modelConfigurations.find(
|
||||
(modelConfiguration) => modelConfiguration.is_visible
|
||||
);
|
||||
};
|
||||
|
||||
// Helper to get model configurations for auto mode
|
||||
@@ -169,27 +169,14 @@ export const submitLLMProvider = async <T extends BaseLLMFormValues>({
|
||||
// In auto mode, use recommended models from descriptor
|
||||
// In manual mode, use user's selection
|
||||
let filteredModelConfigurations: ModelConfiguration[];
|
||||
let finalDefaultModelName = rest.default_model_name;
|
||||
|
||||
if (values.is_auto_mode) {
|
||||
filteredModelConfigurations =
|
||||
getAutoModeModelConfigurations(modelConfigurations);
|
||||
|
||||
// In auto mode, use the first recommended model as default if current default isn't in the list
|
||||
const visibleModelNames = new Set(
|
||||
filteredModelConfigurations.map((m) => m.name)
|
||||
);
|
||||
if (
|
||||
finalDefaultModelName &&
|
||||
!visibleModelNames.has(finalDefaultModelName)
|
||||
) {
|
||||
finalDefaultModelName = filteredModelConfigurations[0]?.name ?? "";
|
||||
}
|
||||
} else {
|
||||
filteredModelConfigurations = filterModelConfigurations(
|
||||
modelConfigurations,
|
||||
visibleModels,
|
||||
rest.default_model_name as string | undefined
|
||||
visibleModels
|
||||
);
|
||||
}
|
||||
|
||||
@@ -200,7 +187,6 @@ export const submitLLMProvider = async <T extends BaseLLMFormValues>({
|
||||
|
||||
const finalValues = {
|
||||
...rest,
|
||||
default_model_name: finalDefaultModelName,
|
||||
api_key,
|
||||
api_key_changed: api_key !== (initialValues.api_key as string | undefined),
|
||||
custom_config_changed: customConfigChanged,
|
||||
@@ -211,6 +197,13 @@ export const submitLLMProvider = async <T extends BaseLLMFormValues>({
|
||||
if (!isEqual(finalValues, initialValues)) {
|
||||
setIsTesting(true);
|
||||
|
||||
const testModel =
|
||||
finalValues.default_model_name ??
|
||||
(filteredModelConfigurations.length > 0
|
||||
? getFirstVisibleModelConfiguration(filteredModelConfigurations)?.name
|
||||
: modelConfigurations[0]?.name) ??
|
||||
"";
|
||||
|
||||
const response = await fetch("/api/admin/llm/test", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
@@ -218,6 +211,7 @@ export const submitLLMProvider = async <T extends BaseLLMFormValues>({
|
||||
},
|
||||
body: JSON.stringify({
|
||||
provider: providerName,
|
||||
model: testModel,
|
||||
...finalValues,
|
||||
}),
|
||||
});
|
||||
@@ -263,13 +257,11 @@ export const submitLLMProvider = async <T extends BaseLLMFormValues>({
|
||||
return;
|
||||
}
|
||||
|
||||
if (shouldMarkAsDefault) {
|
||||
if (shouldMarkAsDefault && finalValues.default_model_name) {
|
||||
const newLlmProvider = (await response.json()) as LLMProviderView;
|
||||
const setDefaultResponse = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`,
|
||||
{
|
||||
method: "POST",
|
||||
}
|
||||
const setDefaultResponse = await setDefaultLlmModel(
|
||||
newLlmProvider.id,
|
||||
finalValues.default_model_name
|
||||
);
|
||||
if (!setDefaultResponse.ok) {
|
||||
const errorMsg = (await setDefaultResponse.json()).detail;
|
||||
|
||||
@@ -47,20 +47,16 @@ export interface LLMProvider {
|
||||
api_base: string | null;
|
||||
api_version: string | null;
|
||||
custom_config: { [key: string]: string } | null;
|
||||
default_model_name: string;
|
||||
is_public: boolean;
|
||||
is_auto_mode: boolean;
|
||||
groups: number[];
|
||||
personas: number[];
|
||||
deployment_name: string | null;
|
||||
default_vision_model: string | null;
|
||||
is_default_vision_provider: boolean | null;
|
||||
model_configurations: ModelConfiguration[];
|
||||
}
|
||||
|
||||
export interface LLMProviderView extends LLMProvider {
|
||||
id: number;
|
||||
is_default_provider: boolean | null;
|
||||
}
|
||||
|
||||
export interface VisionProvider extends LLMProviderView {
|
||||
@@ -68,19 +64,26 @@ export interface VisionProvider extends LLMProviderView {
|
||||
}
|
||||
|
||||
export interface LLMProviderDescriptor {
|
||||
id: number;
|
||||
name: string;
|
||||
provider: string;
|
||||
provider_display_name?: string;
|
||||
default_model_name: string;
|
||||
is_default_provider: boolean | null;
|
||||
is_default_vision_provider?: boolean | null;
|
||||
default_vision_model?: string | null;
|
||||
is_public?: boolean;
|
||||
groups?: number[];
|
||||
personas?: number[];
|
||||
model_configurations: ModelConfiguration[];
|
||||
}
|
||||
|
||||
export interface DefaultModel {
|
||||
provider_id: number;
|
||||
model_name: string;
|
||||
}
|
||||
|
||||
export interface LLMProviderResponse<
|
||||
T extends LLMProviderView | LLMProviderDescriptor,
|
||||
> {
|
||||
providers: T[];
|
||||
default_text: DefaultModel | null;
|
||||
default_vision: DefaultModel | null;
|
||||
}
|
||||
|
||||
export interface OllamaModelResponse {
|
||||
name: string;
|
||||
display_name: string;
|
||||
@@ -104,9 +107,16 @@ export interface BedrockModelResponse {
|
||||
|
||||
export interface LLMProviderFormProps {
|
||||
existingLlmProvider?: LLMProviderView;
|
||||
defaultLlmModel?: DefaultModel;
|
||||
shouldMarkAsDefault?: boolean;
|
||||
}
|
||||
|
||||
export interface DefaultModelSelectorProps {
|
||||
existingLlmProviders: LLMProviderView[];
|
||||
defaultLlmModel: DefaultModel | null;
|
||||
onModelChange: (provider_id: number, model_name: string) => void;
|
||||
}
|
||||
|
||||
// Param types for model fetching functions - use snake_case to match API structure
|
||||
export interface BedrockFetchParams {
|
||||
aws_region_name: string;
|
||||
|
||||
@@ -3,12 +3,15 @@
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { LLMConfiguration } from "./LLMConfiguration";
|
||||
import { SvgCpu } from "@opal/icons";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
|
||||
export default function Page() {
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle title="LLM Setup" icon={SvgCpu} />
|
||||
|
||||
<LLMConfiguration />
|
||||
</>
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header icon={SvgCpu} title="LLM Models" separator />
|
||||
<SettingsLayouts.Body>
|
||||
<LLMConfiguration />
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ import {
|
||||
OllamaFetchParams,
|
||||
OpenRouterFetchParams,
|
||||
} from "./interfaces";
|
||||
import { SvgAws, SvgOpenrouter } from "@opal/icons";
|
||||
import { SvgAws, SvgOpenrouter, SvgServer } from "@opal/icons";
|
||||
|
||||
// Aggregator providers that host models from multiple vendors
|
||||
export const AGGREGATOR_PROVIDERS = new Set([
|
||||
@@ -69,6 +69,9 @@ export const getProviderIcon = (
|
||||
bedrock_converse: SvgAws,
|
||||
openrouter: SvgOpenrouter,
|
||||
vertex_ai: GeminiIcon,
|
||||
|
||||
// Custom providers
|
||||
litellm: SvgServer,
|
||||
};
|
||||
|
||||
const lowerProviderName = providerName.toLowerCase();
|
||||
|
||||
@@ -16,6 +16,7 @@ import LLMSelector from "@/components/llm/LLMSelector";
|
||||
import { useVisionProviders } from "./hooks/useVisionProviders";
|
||||
import InputTextArea from "@/refresh-components/inputs/InputTextArea";
|
||||
import { SvgAlertTriangle } from "@opal/icons";
|
||||
import { useLlmManager } from "@/lib/hooks";
|
||||
|
||||
export function Checkbox({
|
||||
label,
|
||||
|
||||
@@ -24,17 +24,15 @@ export function useVisionProviders(setPopup: SetPopup) {
|
||||
setError(null);
|
||||
try {
|
||||
const data = await fetchVisionProviders();
|
||||
setVisionProviders(data);
|
||||
setVisionProviders(data.providers);
|
||||
|
||||
// Find the default vision provider and set it
|
||||
const defaultProvider = data.find(
|
||||
(provider) => provider.is_default_vision_provider
|
||||
const defaultProvider = data.providers.find(
|
||||
(provider) => provider.id === data.default_vision?.provider_id
|
||||
);
|
||||
|
||||
if (defaultProvider) {
|
||||
const modelToUse =
|
||||
defaultProvider.default_vision_model ||
|
||||
defaultProvider.default_model_name;
|
||||
const modelToUse = data.default_vision?.model_name;
|
||||
|
||||
if (modelToUse && defaultProvider.vision_models.includes(modelToUse)) {
|
||||
setVisionLLM(
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import { useMemo, useState, useCallback } from "react";
|
||||
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import {
|
||||
DefaultModel,
|
||||
LLMProviderDescriptor,
|
||||
} from "@/app/admin/configuration/llm/interfaces";
|
||||
import {
|
||||
BuildLlmSelection,
|
||||
getBuildLlmSelection,
|
||||
@@ -16,7 +19,8 @@ import {
|
||||
* 2. Smart default - via getDefaultLlmSelection()
|
||||
*/
|
||||
export function useBuildLlmSelection(
|
||||
llmProviders: LLMProviderDescriptor[] | undefined
|
||||
llmProviders: LLMProviderDescriptor[] | undefined,
|
||||
defaultLlmModel?: DefaultModel
|
||||
) {
|
||||
const [selection, setSelectionState] = useState<BuildLlmSelection | null>(
|
||||
() => getBuildLlmSelection()
|
||||
@@ -42,7 +46,19 @@ export function useBuildLlmSelection(
|
||||
}
|
||||
|
||||
// Fall back to smart default
|
||||
return getDefaultLlmSelection(llmProviders);
|
||||
return getDefaultLlmSelection(
|
||||
llmProviders?.map((p) => ({
|
||||
name: p.name,
|
||||
provider: p.provider,
|
||||
default_model_name: (() => {
|
||||
if (p.id === defaultLlmModel?.provider_id) {
|
||||
return defaultLlmModel?.model_name ?? "";
|
||||
}
|
||||
return p.model_configurations[0]?.name ?? "";
|
||||
})(),
|
||||
is_default_provider: p.id === defaultLlmModel?.provider_id,
|
||||
}))
|
||||
);
|
||||
}, [selection, llmProviders, isSelectionValid]);
|
||||
|
||||
// Update selection and persist to cookie
|
||||
|
||||
@@ -59,6 +59,7 @@ export function BuildOnboardingProvider({
|
||||
<BuildOnboardingModal
|
||||
mode={controller.mode}
|
||||
llmProviders={controller.llmProviders}
|
||||
defaultLlm={controller.defaultText}
|
||||
initialValues={controller.initialValues}
|
||||
isAdmin={controller.isAdmin}
|
||||
hasUserInfo={controller.hasUserInfo}
|
||||
|
||||
@@ -18,7 +18,10 @@ import {
|
||||
getBuildLlmSelection,
|
||||
getDefaultLlmSelection,
|
||||
} from "@/app/craft/onboarding/constants";
|
||||
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import {
|
||||
DefaultModel,
|
||||
LLMProviderDescriptor,
|
||||
} from "@/app/admin/configuration/llm/interfaces";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "@/app/admin/configuration/llm/constants";
|
||||
import {
|
||||
buildInitialValues,
|
||||
@@ -36,12 +39,25 @@ import OnboardingLlmSetup, {
|
||||
* Used when user completes onboarding without going through LLM setup step.
|
||||
*/
|
||||
function autoSelectBestLlm(
|
||||
llmProviders: LLMProviderDescriptor[] | undefined
|
||||
llmProviders: LLMProviderDescriptor[] | undefined,
|
||||
defaultLlmModel?: DefaultModel
|
||||
): void {
|
||||
// Don't override if user already has a selection
|
||||
if (getBuildLlmSelection()) return;
|
||||
|
||||
const selection = getDefaultLlmSelection(llmProviders);
|
||||
const selection = getDefaultLlmSelection(
|
||||
llmProviders?.map((p) => ({
|
||||
name: p.name,
|
||||
provider: p.provider,
|
||||
default_model_name: (() => {
|
||||
if (p.id === defaultLlmModel?.provider_id) {
|
||||
return defaultLlmModel?.model_name ?? "";
|
||||
}
|
||||
return p.model_configurations[0]?.name ?? "";
|
||||
})(),
|
||||
is_default_provider: p.id === defaultLlmModel?.provider_id,
|
||||
}))
|
||||
);
|
||||
if (selection) {
|
||||
setBuildLlmSelection(selection);
|
||||
}
|
||||
@@ -57,6 +73,7 @@ interface InitialValues {
|
||||
interface BuildOnboardingModalProps {
|
||||
mode: OnboardingModalMode;
|
||||
llmProviders?: LLMProviderDescriptor[];
|
||||
defaultLlm?: DefaultModel;
|
||||
initialValues: InitialValues;
|
||||
isAdmin: boolean;
|
||||
hasUserInfo: boolean;
|
||||
@@ -103,6 +120,7 @@ function getStepsForMode(
|
||||
export default function BuildOnboardingModal({
|
||||
mode,
|
||||
llmProviders,
|
||||
defaultLlm,
|
||||
initialValues,
|
||||
isAdmin,
|
||||
hasUserInfo,
|
||||
@@ -271,8 +289,12 @@ export default function BuildOnboardingModal({
|
||||
if (!llmProviders || llmProviders.length === 0) {
|
||||
const newProvider = await response.json();
|
||||
if (newProvider?.id) {
|
||||
await fetch(`${LLM_PROVIDERS_ADMIN_URL}/${newProvider.id}/default`, {
|
||||
await fetch(`${LLM_PROVIDERS_ADMIN_URL}/default`, {
|
||||
method: "POST",
|
||||
body: JSON.stringify({
|
||||
provider_id: newProvider.id,
|
||||
model_name: selectedModel,
|
||||
}),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -322,7 +344,7 @@ export default function BuildOnboardingModal({
|
||||
|
||||
// Auto-select best available LLM if user didn't go through LLM setup
|
||||
// (e.g., non-admin users or when all providers already configured)
|
||||
autoSelectBestLlm(llmProviders);
|
||||
autoSelectBestLlm(llmProviders, defaultLlm);
|
||||
|
||||
// Validate workArea is provided before submission
|
||||
if (!workArea) {
|
||||
|
||||
@@ -46,6 +46,7 @@ export function useOnboardingModal(): OnboardingModalController {
|
||||
const { user, isAdmin, refreshUser } = useUser();
|
||||
const {
|
||||
llmProviders,
|
||||
defaultText,
|
||||
isLoading: isLoadingLlm,
|
||||
refetch: refetchLlmProviders,
|
||||
} = useLLMProviders();
|
||||
@@ -168,6 +169,7 @@ export function useOnboardingModal(): OnboardingModalController {
|
||||
openLlmSetup,
|
||||
close,
|
||||
llmProviders,
|
||||
defaultText: defaultText ?? undefined,
|
||||
initialValues,
|
||||
completeUserInfo,
|
||||
completeLlmSetup,
|
||||
|
||||
@@ -36,6 +36,9 @@ export interface OnboardingModalController {
|
||||
llmProviders:
|
||||
| import("@/app/admin/configuration/llm/interfaces").LLMProviderDescriptor[]
|
||||
| undefined;
|
||||
defaultText:
|
||||
| import("@/app/admin/configuration/llm/interfaces").DefaultModel
|
||||
| undefined;
|
||||
initialValues: {
|
||||
firstName: string;
|
||||
lastName: string;
|
||||
@@ -54,7 +57,9 @@ export interface OnboardingModalController {
|
||||
completeUserInfo: (info: BuildUserInfo) => Promise<void>;
|
||||
completeLlmSetup: () => Promise<void>;
|
||||
refetchLlmProviders: () => Promise<
|
||||
| import("@/app/admin/configuration/llm/interfaces").LLMProviderDescriptor[]
|
||||
| import("@/app/admin/configuration/llm/interfaces").LLMProviderResponse<
|
||||
import("@/app/admin/configuration/llm/interfaces").LLMProviderDescriptor
|
||||
>
|
||||
| undefined
|
||||
>;
|
||||
}
|
||||
|
||||
@@ -78,7 +78,7 @@ interface SelectedConnectorState {
|
||||
*/
|
||||
export default function BuildConfigPage() {
|
||||
const { isAdmin, isCurator } = useUser();
|
||||
const { llmProviders } = useLLMProviders();
|
||||
const { llmProviders, defaultText } = useLLMProviders();
|
||||
const { openPersonaEditor, openLlmSetup } = useOnboarding();
|
||||
const [selectedConnector, setSelectedConnector] =
|
||||
useState<SelectedConnectorState | null>(null);
|
||||
@@ -109,7 +109,7 @@ export default function BuildConfigPage() {
|
||||
|
||||
// Build mode LLM selection (cookie-based)
|
||||
const { selection: llmSelection, updateSelection: updateLlmSelection } =
|
||||
useBuildLlmSelection(llmProviders);
|
||||
useBuildLlmSelection(llmProviders, defaultText ?? undefined);
|
||||
|
||||
// Read demo data from cookie (single source of truth)
|
||||
const [demoDataEnabled, setDemoDataEnabledLocal] = useState(() =>
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import {
|
||||
LLMProviderResponse,
|
||||
LLMProviderView,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "@/app/admin/configuration/llm/interfaces";
|
||||
@@ -36,9 +37,14 @@ export async function checkLlmProvider(user: User | null) {
|
||||
const [providerResponse, optionsResponse, defaultCheckResponse] =
|
||||
await Promise.all(tasks);
|
||||
|
||||
let providers: LLMProviderView[] = [];
|
||||
let providers: LLMProviderResponse<LLMProviderView> = {
|
||||
providers: [],
|
||||
default_text: null,
|
||||
default_vision: null,
|
||||
};
|
||||
if (providerResponse?.ok) {
|
||||
providers = await providerResponse.json();
|
||||
providers =
|
||||
(await providerResponse.json()) as LLMProviderResponse<LLMProviderView>;
|
||||
}
|
||||
|
||||
let options: WellKnownLLMProviderDescriptor[] = [];
|
||||
@@ -52,5 +58,5 @@ export async function checkLlmProvider(user: User | null) {
|
||||
setDefaultLLMProviderTestComplete();
|
||||
}
|
||||
|
||||
return { providers, options, defaultCheckSuccessful };
|
||||
return { providers: providers.providers, options, defaultCheckSuccessful };
|
||||
}
|
||||
|
||||
@@ -2,7 +2,10 @@
|
||||
|
||||
import { useMemo } from "react";
|
||||
import { parseLlmDescriptor, structureValue } from "@/lib/llm/utils";
|
||||
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import {
|
||||
DefaultModel,
|
||||
LLMProviderDescriptor,
|
||||
} from "@/app/admin/configuration/llm/interfaces";
|
||||
import { getProviderIcon } from "@/app/admin/configuration/llm/utils";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import { createIcon } from "@/components/icons/icons";
|
||||
@@ -139,17 +142,6 @@ export default function LLMSelector({
|
||||
});
|
||||
}, [llmOptions]);
|
||||
|
||||
const defaultProvider = llmProviders.find(
|
||||
(llmProvider) => llmProvider.is_default_provider
|
||||
);
|
||||
|
||||
const defaultModelName = defaultProvider?.default_model_name;
|
||||
const defaultModelConfig = defaultProvider?.model_configurations.find(
|
||||
(m) => m.name === defaultModelName
|
||||
);
|
||||
const defaultModelDisplayName = defaultModelConfig
|
||||
? defaultModelConfig.display_name || defaultModelConfig.name
|
||||
: defaultModelName || null;
|
||||
const defaultLabel = userSettings ? "System Default" : "User Default";
|
||||
|
||||
// Determine if we should show grouped view (only if we have multiple vendors)
|
||||
@@ -164,16 +156,7 @@ export default function LLMSelector({
|
||||
|
||||
<InputSelect.Content>
|
||||
{!excludePublicProviders && (
|
||||
<InputSelect.Item
|
||||
value="default"
|
||||
description={
|
||||
userSettings && defaultModelDisplayName
|
||||
? `(${defaultModelDisplayName})`
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
{defaultLabel}
|
||||
</InputSelect.Item>
|
||||
<InputSelect.Item value="default">{defaultLabel}</InputSelect.Item>
|
||||
)}
|
||||
{showGrouped
|
||||
? groupedOptions.map((group) => (
|
||||
|
||||
@@ -974,7 +974,8 @@ export default function useChatController({
|
||||
const [_, llmModel] = getFinalLLM(
|
||||
llmManager.llmProviders || [],
|
||||
liveAssistant || null,
|
||||
llmManager.currentLlm
|
||||
llmManager.currentLlm,
|
||||
llmManager.defaultLlmModel
|
||||
);
|
||||
const llmAcceptsImages = modelSupportsImageInput(
|
||||
llmManager.llmProviders || [],
|
||||
|
||||
16
web/src/lib/admin/llm/svc.ts
Normal file
16
web/src/lib/admin/llm/svc.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
import { LLM_ADMIN_URL } from "@/app/admin/configuration/llm/constants";
|
||||
|
||||
export function setDefaultLlmModel(providerId: number, modelName: string) {
|
||||
const response = fetch(`${LLM_ADMIN_URL}/default`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
provider_id: providerId,
|
||||
model_name: modelName,
|
||||
}),
|
||||
});
|
||||
|
||||
return response;
|
||||
}
|
||||
@@ -32,7 +32,10 @@ import {
|
||||
MinimalPersonaSnapshot,
|
||||
PersonaLabel,
|
||||
} from "@/app/admin/assistants/interfaces";
|
||||
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import {
|
||||
DefaultModel,
|
||||
LLMProviderDescriptor,
|
||||
} from "@/app/admin/configuration/llm/interfaces";
|
||||
import { isAnthropic } from "@/app/admin/configuration/llm/utils";
|
||||
import { getSourceMetadataForSources } from "./sources";
|
||||
import { AuthType, NEXT_PUBLIC_CLOUD_ENABLED } from "./constants";
|
||||
@@ -471,6 +474,8 @@ export interface LlmDescriptor {
|
||||
export interface LlmManager {
|
||||
currentLlm: LlmDescriptor;
|
||||
updateCurrentLlm: (newOverride: LlmDescriptor) => void;
|
||||
defaultLlmModel?: DefaultModel;
|
||||
updateDefaultLlmModel: (newDefaultLlmModel: DefaultModel) => void;
|
||||
temperature: number;
|
||||
updateTemperature: (temperature: number) => void;
|
||||
updateModelOverrideBasedOnChatSession: (chatSession?: ChatSession) => void;
|
||||
@@ -525,16 +530,20 @@ providing appropriate defaults for new conversations based on the available tool
|
||||
*/
|
||||
|
||||
function getDefaultLlmDescriptor(
|
||||
llmProviders: LLMProviderDescriptor[]
|
||||
llmProviders: LLMProviderDescriptor[],
|
||||
defaultLlmModel?: DefaultModel
|
||||
): LlmDescriptor | null {
|
||||
const defaultProvider = llmProviders.find(
|
||||
(provider) => provider.is_default_provider
|
||||
(provider) => provider.id === defaultLlmModel?.provider_id
|
||||
);
|
||||
if (defaultProvider) {
|
||||
return {
|
||||
name: defaultProvider.name,
|
||||
provider: defaultProvider.provider,
|
||||
modelName: defaultProvider.default_model_name,
|
||||
modelName:
|
||||
defaultLlmModel?.model_name ??
|
||||
defaultProvider.model_configurations[0]?.name ??
|
||||
"",
|
||||
};
|
||||
}
|
||||
const firstLlmProvider = llmProviders.find(
|
||||
@@ -544,7 +553,7 @@ function getDefaultLlmDescriptor(
|
||||
return {
|
||||
name: firstLlmProvider.name,
|
||||
provider: firstLlmProvider.provider,
|
||||
modelName: firstLlmProvider.default_model_name,
|
||||
modelName: firstLlmProvider.model_configurations[0]?.name ?? "",
|
||||
};
|
||||
}
|
||||
return null;
|
||||
@@ -558,8 +567,13 @@ export function useLlmManager(
|
||||
|
||||
// Get all user-accessible providers via SWR (general providers - no persona filter)
|
||||
// This includes public + all restricted providers user can access via groups
|
||||
const { llmProviders: allUserProviders, isLoading: isLoadingAllProviders } =
|
||||
useLLMProviders();
|
||||
const {
|
||||
llmProviders: allUserProviders,
|
||||
defaultText,
|
||||
isLoading: isLoadingAllProviders,
|
||||
} = useLLMProviders();
|
||||
|
||||
const defaultTextLlmModel = defaultText ?? undefined;
|
||||
// Fetch persona-specific providers to enforce RBAC restrictions per assistant
|
||||
// Only fetch if we have an assistant selected
|
||||
const personaId =
|
||||
@@ -581,6 +595,13 @@ export function useLlmManager(
|
||||
modelName: "",
|
||||
});
|
||||
|
||||
const [defaultLlmModel, setDefaultLlmModel] = useState<
|
||||
DefaultModel | undefined
|
||||
>(defaultTextLlmModel);
|
||||
const updateDefaultLlmModel = (newDefaultLlmModel: DefaultModel) => {
|
||||
setDefaultLlmModel(newDefaultLlmModel);
|
||||
};
|
||||
|
||||
// Track the previous assistant ID to detect when it changes
|
||||
const prevAssistantIdRef = useRef<number | undefined>(undefined);
|
||||
|
||||
@@ -629,7 +650,10 @@ export function useLlmManager(
|
||||
} else if (user?.preferences?.default_model) {
|
||||
setCurrentLlm(getValidLlmDescriptor(user.preferences.default_model));
|
||||
} else {
|
||||
const defaultLlm = getDefaultLlmDescriptor(llmProviders);
|
||||
const defaultLlm = getDefaultLlmDescriptor(
|
||||
llmProviders,
|
||||
defaultTextLlmModel
|
||||
);
|
||||
if (defaultLlm) {
|
||||
setCurrentLlm(defaultLlm);
|
||||
}
|
||||
@@ -679,7 +703,7 @@ export function useLlmManager(
|
||||
|
||||
// Model not found in available providers - fall back to default model
|
||||
return (
|
||||
getDefaultLlmDescriptor(llmProviders) ?? {
|
||||
getDefaultLlmDescriptor(llmProviders, defaultTextLlmModel) ?? {
|
||||
name: "",
|
||||
provider: "",
|
||||
modelName: "",
|
||||
@@ -787,6 +811,8 @@ export function useLlmManager(
|
||||
updateModelOverrideBasedOnChatSession,
|
||||
currentLlm,
|
||||
updateCurrentLlm,
|
||||
defaultLlmModel,
|
||||
updateDefaultLlmModel,
|
||||
temperature,
|
||||
updateTemperature,
|
||||
imageFilesPresent,
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import useSWR from "swr";
|
||||
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import {
|
||||
LLMProviderDescriptor,
|
||||
LLMProviderResponse,
|
||||
} from "@/app/admin/configuration/llm/interfaces";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
|
||||
export function useLLMProviders(personaId?: number) {
|
||||
@@ -12,17 +15,17 @@ export function useLLMProviders(personaId?: number) {
|
||||
? `/api/llm/persona/${personaId}/providers`
|
||||
: "/api/llm/provider";
|
||||
|
||||
const { data, error, mutate } = useSWR<LLMProviderDescriptor[] | undefined>(
|
||||
url,
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false, // Cache aggressively for performance
|
||||
dedupingInterval: 60000, // Dedupe requests within 1 minute
|
||||
}
|
||||
);
|
||||
const { data, error, mutate } = useSWR<
|
||||
LLMProviderResponse<LLMProviderDescriptor> | undefined
|
||||
>(url, errorHandlingFetcher, {
|
||||
revalidateOnFocus: false, // Cache aggressively for performance
|
||||
dedupingInterval: 60000, // Dedupe requests within 1 minute
|
||||
});
|
||||
|
||||
return {
|
||||
llmProviders: data,
|
||||
llmProviders: data?.providers,
|
||||
defaultText: data?.default_text,
|
||||
defaultVision: data?.default_vision,
|
||||
isLoading: !error && !data,
|
||||
error,
|
||||
refetch: mutate,
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { fetchSS } from "../utilsSS";
|
||||
|
||||
export async function fetchLLMProvidersSS() {
|
||||
// Test helper: allow Playwright runs to force an empty provider list so onboarding appears.
|
||||
if (process.env.PLAYWRIGHT_FORCE_EMPTY_LLM_PROVIDERS === "true") {
|
||||
return [];
|
||||
}
|
||||
const response = await fetchSS("/llm/provider");
|
||||
if (response.ok) {
|
||||
return (await response.json()) as LLMProviderDescriptor[];
|
||||
}
|
||||
return [];
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
import { MinimalPersonaSnapshot } from "@/app/admin/assistants/interfaces";
|
||||
import {
|
||||
DefaultModel,
|
||||
LLMProviderDescriptor,
|
||||
ModelConfiguration,
|
||||
} from "@/app/admin/configuration/llm/interfaces";
|
||||
@@ -8,14 +9,15 @@ import { LlmDescriptor } from "@/lib/hooks";
|
||||
export function getFinalLLM(
|
||||
llmProviders: LLMProviderDescriptor[],
|
||||
persona: MinimalPersonaSnapshot | null,
|
||||
currentLlm: LlmDescriptor | null
|
||||
currentLlm: LlmDescriptor | null,
|
||||
defaultLlmModel?: DefaultModel
|
||||
): [string, string] {
|
||||
const defaultProvider = llmProviders.find(
|
||||
(llmProvider) => llmProvider.is_default_provider
|
||||
(llmProvider) => llmProvider.id === defaultLlmModel?.provider_id
|
||||
);
|
||||
|
||||
let provider = defaultProvider?.provider || "";
|
||||
let model = defaultProvider?.default_model_name || "";
|
||||
let model = defaultLlmModel?.model_name || "";
|
||||
|
||||
if (persona) {
|
||||
// Map "provider override" to actual LLLMProvider
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
import { VisionProvider } from "@/app/admin/configuration/llm/interfaces";
|
||||
import {
|
||||
LLMProviderResponse,
|
||||
VisionProvider,
|
||||
} from "@/app/admin/configuration/llm/interfaces";
|
||||
|
||||
export async function fetchVisionProviders(): Promise<VisionProvider[]> {
|
||||
export async function fetchVisionProviders(): Promise<
|
||||
LLMProviderResponse<VisionProvider>
|
||||
> {
|
||||
const response = await fetch("/api/admin/llm/vision-providers", {
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
|
||||
@@ -93,15 +93,7 @@ export const testApiKeyHelper = async (
|
||||
...(formValues?.custom_config ?? {}),
|
||||
...(customConfigOverride ?? {}),
|
||||
},
|
||||
default_model_name: modelName ?? formValues?.default_model_name ?? "",
|
||||
model_configurations: [
|
||||
...(formValues.model_configurations || []).map(
|
||||
(model: ModelConfiguration) => ({
|
||||
name: model.name,
|
||||
is_visible: true,
|
||||
})
|
||||
),
|
||||
],
|
||||
model: modelName ?? formValues?.default_model_name ?? "",
|
||||
};
|
||||
|
||||
return await submitLlmTestRequest(
|
||||
|
||||
@@ -6,7 +6,10 @@ import {
|
||||
ModelConfiguration,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "@/app/admin/configuration/llm/interfaces";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "@/app/admin/configuration/llm/constants";
|
||||
import {
|
||||
LLM_ADMIN_URL,
|
||||
LLM_PROVIDERS_ADMIN_URL,
|
||||
} from "@/app/admin/configuration/llm/constants";
|
||||
import { OnboardingActions, OnboardingState } from "../types";
|
||||
import { APIFormFieldState } from "@/refresh-components/form/types";
|
||||
import {
|
||||
@@ -225,10 +228,19 @@ export function OnboardingFormWrapper<T extends Record<string, any>>({
|
||||
try {
|
||||
const newLlmProvider = await response.json();
|
||||
if (newLlmProvider?.id != null) {
|
||||
const setDefaultResponse = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`,
|
||||
{ method: "POST" }
|
||||
);
|
||||
const defaultLLMModel =
|
||||
payload.default_model_name ??
|
||||
newLlmProvider.model_configurations[0].name;
|
||||
const setDefaultResponse = await fetch(`${LLM_ADMIN_URL}/default`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
provider_id: newLlmProvider.id,
|
||||
model_name: defaultLLMModel,
|
||||
}),
|
||||
});
|
||||
if (!setDefaultResponse.ok) {
|
||||
const err = await setDefaultResponse.json().catch(() => ({}));
|
||||
console.error("Failed to set provider as default", err?.detail);
|
||||
|
||||
@@ -401,8 +401,14 @@ test("test, create, and set as default", async () => {
|
||||
expect.objectContaining({ method: "PUT" })
|
||||
);
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
"/api/llm/provider/5/default",
|
||||
expect.objectContaining({ method: "POST" })
|
||||
"/api/llm/provider/default",
|
||||
expect.objectContaining({
|
||||
method: "POST",
|
||||
body: JSON.stringify({
|
||||
provider_id: 5,
|
||||
model_name: "gpt-4"
|
||||
})
|
||||
})
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -67,7 +67,7 @@ test.describe("First user onboarding flow", () => {
|
||||
}
|
||||
);
|
||||
|
||||
await page.route("**/api/admin/llm/provider/*/default", async (route) => {
|
||||
await page.route("**/api/admin/llm/default", async (route) => {
|
||||
if (route.request().method() === "POST") {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
|
||||
Reference in New Issue
Block a user