mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-09 00:42:47 +00:00
Compare commits
11 Commits
cli/v0.2.1
...
richard/mo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5e8f00e73b | ||
|
|
ebc4c76b98 | ||
|
|
87456f1b58 | ||
|
|
166a14a249 | ||
|
|
6f19010c0b | ||
|
|
14edd41206 | ||
|
|
1f2b328d84 | ||
|
|
941aff3620 | ||
|
|
11b087d78c | ||
|
|
ad91752c9a | ||
|
|
60c9ad4ed4 |
@@ -0,0 +1,58 @@
|
||||
"""update ModelConfiguration recommendation fields to boolean
|
||||
|
||||
Revision ID: 62b99efedb8c
|
||||
Revises: 0816326d83aa
|
||||
Create Date: 2025-07-15 14:30:13.501302
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "62b99efedb8c"
|
||||
down_revision = "0816326d83aa"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add the boolean recommendation columns
|
||||
op.add_column(
|
||||
"model_configuration",
|
||||
sa.Column(
|
||||
"recommended_default",
|
||||
sa.Boolean(),
|
||||
nullable=True,
|
||||
default=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"model_configuration",
|
||||
sa.Column(
|
||||
"recommended_fast_default",
|
||||
sa.Boolean(),
|
||||
nullable=True,
|
||||
default=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"model_configuration",
|
||||
sa.Column(
|
||||
"recommended_is_visible",
|
||||
sa.Boolean(),
|
||||
nullable=True,
|
||||
default=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the boolean recommendation columns
|
||||
op.drop_column("model_configuration", "recommended_fast_default")
|
||||
op.drop_column("model_configuration", "recommended_default")
|
||||
op.drop_column("model_configuration", "recommended_is_visible")
|
||||
@@ -0,0 +1,36 @@
|
||||
"""add use recommended models attribute to LLMProvider
|
||||
|
||||
Revision ID: e1eb0b1f9ece
|
||||
Revises: 62b99efedb8c
|
||||
Create Date: 2025-07-15 17:29:30.183582
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e1eb0b1f9ece"
|
||||
down_revision = "62b99efedb8c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add use_recommended_models column to llm_provider table
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column(
|
||||
"use_recommended_models",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
default=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove use_recommended_models column from llm_provider table
|
||||
op.drop_column("llm_provider", "use_recommended_models")
|
||||
@@ -297,6 +297,8 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
"onyx.background.celery.tasks.llm_model_update_onyx_curated",
|
||||
"onyx.background.celery.tasks.update_recommended_selected_models",
|
||||
"onyx.background.celery.tasks.user_file_folder_sync",
|
||||
"onyx.background.celery.tasks.kg_processing",
|
||||
]
|
||||
|
||||
@@ -128,6 +128,15 @@ beat_task_templates: list[dict] = [
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "update-recommended-selected-models",
|
||||
"task": OnyxCeleryTask.UPDATE_RECOMMENDED_SELECTED_MODELS,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
if ENTERPRISE_EDITION_ENABLED:
|
||||
@@ -167,6 +176,19 @@ if LLM_MODEL_UPDATE_API_URL:
|
||||
},
|
||||
}
|
||||
)
|
||||
# Otherwise, we use our own curated list to check for model updates
|
||||
else:
|
||||
beat_task_templates.append(
|
||||
{
|
||||
"name": "check-for-llm-model-update-onyx-curated",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE_ONYX_CURATED,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def make_cloud_generator_task(task: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
@@ -0,0 +1,160 @@
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import LLMProvider
|
||||
from onyx.db.models import ModelConfiguration
|
||||
from onyx.llm.llm_provider_options import curated_models
|
||||
|
||||
|
||||
def _delete_deprecated_not_visible_models(
|
||||
db_session: Session,
|
||||
llm_provider: LLMProvider,
|
||||
model_configurations: list[ModelConfiguration],
|
||||
) -> None:
|
||||
"""
|
||||
Delete models that are marked as deprecated in curated_models and not visible.
|
||||
"""
|
||||
models_to_delete = []
|
||||
for model_configuration in model_configurations:
|
||||
for curated_model in curated_models[llm_provider.provider]:
|
||||
if (
|
||||
curated_model["name"] == model_configuration.name
|
||||
and curated_model.get("deprecated", False)
|
||||
and not model_configuration.is_visible
|
||||
):
|
||||
models_to_delete.append(model_configuration)
|
||||
if models_to_delete:
|
||||
model_ids_to_delete = [model.id for model in models_to_delete]
|
||||
deleted_count = (
|
||||
db_session.query(ModelConfiguration)
|
||||
.filter(ModelConfiguration.id.in_(model_ids_to_delete))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
task_logger.info(
|
||||
f"Deleted {deleted_count} models for provider {llm_provider.provider}"
|
||||
)
|
||||
|
||||
|
||||
def _sync_model_configurations_with_curated_models(
|
||||
db_session: Session,
|
||||
llm_provider: LLMProvider,
|
||||
model_configurations: list[ModelConfiguration],
|
||||
) -> None:
|
||||
"""
|
||||
Sync model configurations with curated models.
|
||||
"""
|
||||
for curated_model in curated_models[llm_provider.provider]:
|
||||
model_configuration = (
|
||||
db_session.query(ModelConfiguration)
|
||||
.filter(
|
||||
ModelConfiguration.llm_provider_id == llm_provider.id,
|
||||
ModelConfiguration.name == curated_model["name"],
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if model_configuration and (
|
||||
(
|
||||
model_configuration.recommended_default
|
||||
!= curated_model["recommended_default_model"]
|
||||
)
|
||||
or (
|
||||
model_configuration.recommended_fast_default
|
||||
!= curated_model["recommended_fast_default_model"]
|
||||
)
|
||||
or (
|
||||
model_configuration.recommended_is_visible
|
||||
!= curated_model["recommended_is_visible"]
|
||||
)
|
||||
):
|
||||
db_session.query(ModelConfiguration).filter(
|
||||
ModelConfiguration.id == model_configuration.id
|
||||
).update(
|
||||
{
|
||||
ModelConfiguration.recommended_default: curated_model[
|
||||
"recommended_default_model"
|
||||
],
|
||||
ModelConfiguration.recommended_fast_default: curated_model[
|
||||
"recommended_fast_default_model"
|
||||
],
|
||||
ModelConfiguration.recommended_is_visible: curated_model[
|
||||
"recommended_is_visible"
|
||||
],
|
||||
}
|
||||
)
|
||||
task_logger.info(
|
||||
f"Updated model configuration for {curated_model['name']} for provider {llm_provider.provider}"
|
||||
)
|
||||
elif not model_configuration and not curated_model.get("deprecated", False):
|
||||
db_session.add(
|
||||
ModelConfiguration(
|
||||
llm_provider_id=llm_provider.id,
|
||||
name=curated_model["name"],
|
||||
is_visible=curated_model["recommended_is_visible"],
|
||||
recommended_default=curated_model["recommended_default_model"],
|
||||
recommended_fast_default=curated_model[
|
||||
"recommended_fast_default_model"
|
||||
],
|
||||
)
|
||||
)
|
||||
task_logger.info(
|
||||
f"Added model configuration for {curated_model['name']} for provider {llm_provider.provider}"
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE_ONYX_CURATED,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_llm_model_update_onyx_curated(
|
||||
self: Task, *, tenant_id: str
|
||||
) -> bool | None:
|
||||
"""
|
||||
Check for LLM model updates using Onyx's curated model list.
|
||||
This task is used when LLM_MODEL_UPDATE_API_URL is not configured.
|
||||
"""
|
||||
task_logger.info("Starting Onyx curated LLM model update check")
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
llm_providers = db_session.query(LLMProvider).all()
|
||||
for llm_provider in llm_providers:
|
||||
if llm_provider.provider not in curated_models:
|
||||
continue
|
||||
model_configurations = (
|
||||
db_session.query(ModelConfiguration)
|
||||
.filter(ModelConfiguration.llm_provider_id == llm_provider.id)
|
||||
.all()
|
||||
)
|
||||
# Check if any model configuration is not in curated_models (custom provider)
|
||||
curated_model_names = [
|
||||
model["name"] for model in curated_models[llm_provider.provider]
|
||||
]
|
||||
has_custom_models = any(
|
||||
mc.name not in curated_model_names for mc in model_configurations
|
||||
)
|
||||
if has_custom_models:
|
||||
# Skip this provider iteration if any model is custom
|
||||
continue
|
||||
_delete_deprecated_not_visible_models(
|
||||
db_session, llm_provider, model_configurations
|
||||
)
|
||||
_sync_model_configurations_with_curated_models(
|
||||
db_session, llm_provider, model_configurations
|
||||
)
|
||||
db_session.commit()
|
||||
task_logger.info(
|
||||
"Onyx curated LLM model update check completed successfully"
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Failed to update models using Onyx curated list")
|
||||
return None
|
||||
@@ -0,0 +1,97 @@
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import LLMProvider
|
||||
from onyx.db.models import ModelConfiguration
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.UPDATE_RECOMMENDED_SELECTED_MODELS,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def update_recommended_selected_models(self: Task, *, tenant_id: str) -> bool | None:
|
||||
"""
|
||||
Update the recommended and selected status of LLM models based on configuration.
|
||||
This task manages which models should be recommended as default or fast default,
|
||||
and ensures proper visibility settings are applied.
|
||||
"""
|
||||
task_logger.info("Starting recommended selected models update")
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
llm_providers = db_session.query(LLMProvider).all()
|
||||
for llm_provider in llm_providers:
|
||||
if llm_provider.use_recommended_models:
|
||||
model_configurations = (
|
||||
db_session.query(ModelConfiguration)
|
||||
.filter(ModelConfiguration.llm_provider_id == llm_provider.id)
|
||||
.all()
|
||||
)
|
||||
for model_configuration in model_configurations:
|
||||
if (
|
||||
model_configuration.recommended_is_visible is not None
|
||||
and model_configuration.is_visible
|
||||
!= model_configuration.recommended_is_visible
|
||||
):
|
||||
db_session.query(ModelConfiguration).filter(
|
||||
ModelConfiguration.id == model_configuration.id
|
||||
).update(
|
||||
{
|
||||
ModelConfiguration.is_visible: model_configuration.recommended_is_visible
|
||||
}
|
||||
)
|
||||
task_logger.info(
|
||||
f"Updated is_visible for model {model_configuration.name} for provider {llm_provider.provider}"
|
||||
)
|
||||
if (
|
||||
model_configuration.recommended_default
|
||||
and model_configuration.name
|
||||
!= llm_provider.default_model_name
|
||||
):
|
||||
db_session.query(LLMProvider).filter(
|
||||
LLMProvider.id == llm_provider.id
|
||||
).update(
|
||||
{
|
||||
LLMProvider.default_model_name: model_configuration.name
|
||||
}
|
||||
)
|
||||
task_logger.info(
|
||||
f"""
|
||||
Updated default_model_name for model {model_configuration.name}
|
||||
for provider {llm_provider.provider}
|
||||
"""
|
||||
)
|
||||
if (
|
||||
model_configuration.recommended_fast_default
|
||||
and model_configuration.name
|
||||
!= llm_provider.fast_default_model_name
|
||||
):
|
||||
db_session.query(LLMProvider).filter(
|
||||
LLMProvider.id == llm_provider.id
|
||||
).update(
|
||||
{
|
||||
LLMProvider.fast_default_model_name: model_configuration.name
|
||||
}
|
||||
)
|
||||
task_logger.info(
|
||||
f"""
|
||||
Updated fast_default_model_name for model {model_configuration.name}
|
||||
for provider {llm_provider.provider}
|
||||
"""
|
||||
)
|
||||
db_session.commit()
|
||||
task_logger.info(
|
||||
"Recommended selected models update completed successfully"
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Failed to update recommended selected models")
|
||||
return None
|
||||
@@ -443,6 +443,9 @@ class OnyxCeleryTask:
|
||||
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
|
||||
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
||||
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
|
||||
CHECK_FOR_LLM_MODEL_UPDATE_ONYX_CURATED = "check_for_llm_model_update_onyx_curated"
|
||||
UPDATE_RECOMMENDED_SELECTED_MODELS = "update_recommended_selected_models"
|
||||
|
||||
CHECK_FOR_USER_FILE_FOLDER_SYNC = "check_for_user_file_folder_sync"
|
||||
|
||||
# Connector checkpoint cleanup
|
||||
|
||||
@@ -76,7 +76,15 @@ def upsert_llm_provider(
|
||||
)
|
||||
|
||||
if not existing_llm_provider:
|
||||
existing_llm_provider = LLMProviderModel(name=llm_provider_upsert_request.name)
|
||||
use_recommended_models = (
|
||||
llm_provider_upsert_request.use_recommended_models
|
||||
if llm_provider_upsert_request.use_recommended_models is not None
|
||||
else False
|
||||
)
|
||||
existing_llm_provider = LLMProviderModel(
|
||||
name=llm_provider_upsert_request.name,
|
||||
use_recommended_models=use_recommended_models,
|
||||
)
|
||||
db_session.add(existing_llm_provider)
|
||||
|
||||
existing_llm_provider.provider = llm_provider_upsert_request.provider
|
||||
|
||||
@@ -31,6 +31,7 @@ from sqlalchemy import Integer
|
||||
from sqlalchemy import Sequence
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import UniqueConstraint
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.engine.interfaces import Dialect
|
||||
@@ -2265,6 +2266,9 @@ class LLMProvider(Base):
|
||||
)
|
||||
default_model_name: Mapped[str] = mapped_column(String)
|
||||
fast_default_model_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
use_recommended_models: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False, server_default=text("false")
|
||||
)
|
||||
|
||||
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
@@ -2296,6 +2300,12 @@ class ModelConfiguration(Base):
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String, nullable=False)
|
||||
|
||||
recommended_default: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
recommended_fast_default: Mapped[bool | None] = mapped_column(
|
||||
Boolean, nullable=True
|
||||
)
|
||||
recommended_is_visible: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
|
||||
# Represents whether or not a given model will be usable by the end user or not.
|
||||
# This field is primarily used for "Well Known LLM Providers", since for them,
|
||||
# we have a pre-defined list of LLM models that we allow them to choose from.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from enum import Enum
|
||||
from typing import TypedDict
|
||||
|
||||
import litellm # type: ignore
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.llm.chat_llm import VERTEX_CREDENTIALS_FILE_KWARG
|
||||
@@ -9,6 +9,15 @@ from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.server.manage.llm.models import ModelConfigurationView
|
||||
|
||||
|
||||
class CuratedModelDict(TypedDict):
|
||||
name: str
|
||||
friendly_name: str
|
||||
recommended_default_model: bool
|
||||
recommended_fast_default_model: bool
|
||||
recommended_is_visible: bool
|
||||
deprecated: bool
|
||||
|
||||
|
||||
class CustomConfigKeyType(Enum):
|
||||
# used for configuration values that require manual input
|
||||
# i.e., textual API keys (e.g., "abcd1234")
|
||||
@@ -43,102 +52,534 @@ class WellKnownLLMProviderDescriptor(BaseModel):
|
||||
deployment_name_required: bool = False
|
||||
# set for providers like Azure, which support a single model per deployment.
|
||||
single_model_supported: bool = False
|
||||
# indicates whether this provider has curated models available
|
||||
has_curated_models: bool = False
|
||||
|
||||
|
||||
# TODO Before shipping:
|
||||
# Backfill existing litellm models (with deprecated: True if we wish to hide) for backwards compatibility
|
||||
curated_models: dict[str, list[CuratedModelDict]] = {
|
||||
"openai": [
|
||||
{
|
||||
"name": "o1",
|
||||
"friendly_name": "OpenAI o1",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "o3-mini",
|
||||
"friendly_name": "OpenAI o3 Mini",
|
||||
"recommended_default_model": True,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": True,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "gpt-4o",
|
||||
"friendly_name": "GPT 4o",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "gpt-4o-mini",
|
||||
"friendly_name": "GPT 4o Mini",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": True,
|
||||
"recommended_is_visible": True,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "o3",
|
||||
"friendly_name": "OpenAI o3",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "o4-mini",
|
||||
"friendly_name": "OpenAI o4 Mini",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "o1-mini",
|
||||
"friendly_name": "OpenAI o1 Mini",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "gpt-4.1",
|
||||
"friendly_name": "GPT 4.1",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "gpt-4",
|
||||
"friendly_name": "GPT 4",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "gpt-4-turbo",
|
||||
"friendly_name": "GPT 4 Turbo",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
],
|
||||
"anthropic": [
|
||||
{
|
||||
"name": "claude-4-sonnet-20250514",
|
||||
"friendly_name": "Claude 4 Sonnet",
|
||||
"recommended_default_model": True,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": True,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "claude-3-7-sonnet-20250219",
|
||||
"friendly_name": "Claude 3.7 Sonnet (February 2025)",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "claude-3-5-sonnet-20241022",
|
||||
"friendly_name": "Claude 3.5 Sonnet (October 2024)",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": True,
|
||||
"recommended_is_visible": True,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "claude-3-opus-20240229",
|
||||
"friendly_name": "Claude 3 Opus",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": True,
|
||||
},
|
||||
{
|
||||
"name": "claude-3-sonnet-20240229",
|
||||
"friendly_name": "Claude 3 Sonnet",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "claude-3-haiku-20240307",
|
||||
"friendly_name": "Claude 3 Haiku",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": True,
|
||||
},
|
||||
],
|
||||
"vertex_ai": [
|
||||
{
|
||||
"name": "gemini-2.0-flash",
|
||||
"friendly_name": "Gemini 2.0 Flash",
|
||||
"recommended_default_model": True,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": True,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "gemini-2.0-flash-lite",
|
||||
"friendly_name": "Gemini 2.0 Flash Lite",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": True,
|
||||
"recommended_is_visible": True,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "gemini-2.5-pro-preview-06-05",
|
||||
"friendly_name": "Gemini 2.5 Pro Preview (June 2024)",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "gemini-2.5-pro-preview-05-06",
|
||||
"friendly_name": "Gemini 2.5 Pro Preview (May 2024)",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "gemini-2.0-flash-lite-001",
|
||||
"friendly_name": "Gemini 2.0 Flash Lite (Version 001)",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "gemini-2.0-flash-001",
|
||||
"friendly_name": "Gemini 2.0 Flash (Version 001)",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "gemini-2.0-flash-exp",
|
||||
"friendly_name": "Gemini 2.0 Flash Experimental",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "gemini-1.5-pro",
|
||||
"friendly_name": "Gemini 1.5 Pro",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "gemini-1.5-pro-001",
|
||||
"friendly_name": "Gemini 1.5 Pro (Version 001)",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "gemini-1.5-pro-002",
|
||||
"friendly_name": "Gemini 1.5 Pro (Version 002)",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "gemini-1.5-flash",
|
||||
"friendly_name": "Gemini 1.5 Flash",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "gemini-1.5-flash-001",
|
||||
"friendly_name": "Gemini 1.5 Flash (Version 001)",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "gemini-1.5-flash-002",
|
||||
"friendly_name": "Gemini 1.5 Flash (Version 002)",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "claude-sonnet-4",
|
||||
"friendly_name": "Claude Sonnet 4 (Vertex AI)",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "claude-opus-4",
|
||||
"friendly_name": "Claude Opus 4 (Vertex AI)",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "claude-3-7-sonnet@20250219",
|
||||
"friendly_name": "Claude 3.7 Sonnet (Vertex AI)",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
],
|
||||
"bedrock": [
|
||||
{
|
||||
"name": "anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
"friendly_name": "Claude 3.7 Sonnet (Bedrock)",
|
||||
"recommended_default_model": True,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": True,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
"friendly_name": "Claude 3.5 Sonnet (Bedrock)",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": True,
|
||||
"recommended_is_visible": True,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "anthropic.claude-3-opus-20240229-v1:0",
|
||||
"friendly_name": "Claude 3 Opus (Bedrock)",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"friendly_name": "Claude 3 Sonnet (Bedrock)",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"friendly_name": "Claude 3 Haiku (Bedrock)",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "meta.llama3-1-70b-instruct-v1:0",
|
||||
"friendly_name": "Llama 3.1 70B Instruct",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "meta.llama3-1-8b-instruct-v1:0",
|
||||
"friendly_name": "Llama 3.1 8B Instruct",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "meta.llama3-2-90b-instruct-v1:0",
|
||||
"friendly_name": "Llama 3.2 90B Instruct",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "meta.llama3-2-11b-instruct-v1:0",
|
||||
"friendly_name": "Llama 3.2 11B Instruct",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "meta.llama3-3-70b-instruct-v1:0",
|
||||
"friendly_name": "Llama 3.3 70B Instruct",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "amazon.nova-micro-v1:0",
|
||||
"friendly_name": "Amazon Nova Micro",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "amazon.nova-lite-v1:0",
|
||||
"friendly_name": "Amazon Nova Lite",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "amazon.nova-pro-v1:0",
|
||||
"friendly_name": "Amazon Nova Pro",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "mistral.mistral-large-2402-v1:0",
|
||||
"friendly_name": "Mistral Large (February 2024)",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "mistral.mistral-large-2407-v1:0",
|
||||
"friendly_name": "Mistral Large (July 2024)",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "mistral.mistral-small-2402-v1:0",
|
||||
"friendly_name": "Mistral Small (February 2024)",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "ai21.jamba-instruct-v1:0",
|
||||
"friendly_name": "AI21 Jamba Instruct",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "cohere.command-r-plus-v1:0",
|
||||
"friendly_name": "Cohere Command R+",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "cohere.command-r-v1:0",
|
||||
"friendly_name": "Cohere Command R",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def get_curated_model_names(provider_name: str) -> list[str]:
|
||||
"""Get list of model names from curated_models for a specific provider."""
|
||||
if provider_name not in curated_models:
|
||||
return []
|
||||
return [
|
||||
model["name"]
|
||||
for model in curated_models[provider_name]
|
||||
if not model.get("deprecated", False)
|
||||
]
|
||||
|
||||
|
||||
def get_curated_model_info(
|
||||
provider_name: str, model_name: str
|
||||
) -> CuratedModelDict | None:
|
||||
"""Get curated model information for a specific provider and model."""
|
||||
if provider_name not in curated_models:
|
||||
return None
|
||||
|
||||
for model in curated_models[provider_name]:
|
||||
if model["name"] == model_name:
|
||||
return model
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_curated_is_visible(provider_name: str, model_name: str) -> bool:
|
||||
"""Get the recommended_is_visible value from curated_models for a specific model."""
|
||||
model_info = get_curated_model_info(provider_name, model_name)
|
||||
if model_info:
|
||||
return model_info["recommended_is_visible"]
|
||||
return False
|
||||
|
||||
|
||||
def get_curated_recommended_default_model(provider_name: str) -> str | None:
|
||||
"""Get the recommended default model from curated_models for a specific provider."""
|
||||
if provider_name not in curated_models:
|
||||
return None
|
||||
|
||||
for model in curated_models[provider_name]:
|
||||
if model.get("recommended_default_model", False) and not model.get(
|
||||
"deprecated", False
|
||||
):
|
||||
return model["name"]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_curated_recommended_fast_default_model(provider_name: str) -> str | None:
|
||||
"""Get the recommended fast default model from curated_models for a specific provider."""
|
||||
if provider_name not in curated_models:
|
||||
return None
|
||||
|
||||
for model in curated_models[provider_name]:
|
||||
if model.get("recommended_fast_default_model", False) and not model.get(
|
||||
"deprecated", False
|
||||
):
|
||||
return model["name"]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# TODO Before shipping: These values are taken from curated_models.
|
||||
# Ensure the values before and after this commit are the exact same.
|
||||
# Do that for basically the rest of this file.
|
||||
OPENAI_PROVIDER_NAME = "openai"
|
||||
OPEN_AI_MODEL_NAMES = [
|
||||
"o4-mini",
|
||||
"o3-mini",
|
||||
"o1-mini",
|
||||
"o3",
|
||||
"o1",
|
||||
"gpt-4",
|
||||
"gpt-4.1",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"o1-preview",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4-0613",
|
||||
"gpt-4o-2024-08-06",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-0125",
|
||||
"gpt-3.5-turbo-1106",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-3.5-turbo-0301",
|
||||
OPEN_AI_MODEL_NAMES = get_curated_model_names("openai")
|
||||
OPEN_AI_VISIBLE_MODEL_NAMES = [
|
||||
model["name"]
|
||||
for model in curated_models.get("openai", [])
|
||||
if model.get("recommended_is_visible", False) and not model.get("deprecated", False)
|
||||
]
|
||||
OPEN_AI_VISIBLE_MODEL_NAMES = ["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"]
|
||||
|
||||
BEDROCK_PROVIDER_NAME = "bedrock"
|
||||
# need to remove all the weird "bedrock/eu-central-1/anthropic.claude-v1" named
|
||||
# models
|
||||
BEDROCK_MODEL_NAMES = [
|
||||
model
|
||||
# bedrock_converse_models are just extensions of the bedrock_models, not sure why
|
||||
# litellm has split them into two lists :(
|
||||
for model in litellm.bedrock_models + litellm.bedrock_converse_models
|
||||
if "/" not in model and "embed" not in model
|
||||
][::-1]
|
||||
BEDROCK_DEFAULT_MODEL = "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
BEDROCK_MODEL_NAMES = get_curated_model_names("bedrock")
|
||||
BEDROCK_DEFAULT_MODEL = (
|
||||
get_curated_recommended_default_model("bedrock")
|
||||
or "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
)
|
||||
|
||||
IGNORABLE_ANTHROPIC_MODELS = [
|
||||
"claude-2",
|
||||
"claude-instant-1",
|
||||
"anthropic/claude-3-5-sonnet-20241022",
|
||||
]
|
||||
ANTHROPIC_PROVIDER_NAME = "anthropic"
|
||||
ANTHROPIC_MODEL_NAMES = [
|
||||
model
|
||||
for model in litellm.anthropic_models
|
||||
if model not in IGNORABLE_ANTHROPIC_MODELS
|
||||
][::-1]
|
||||
ANTHROPIC_MODEL_NAMES = get_curated_model_names("anthropic")
|
||||
ANTHROPIC_VISIBLE_MODEL_NAMES = [
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
model["name"]
|
||||
for model in curated_models.get("anthropic", [])
|
||||
if model.get("recommended_is_visible", False) and not model.get("deprecated", False)
|
||||
]
|
||||
|
||||
AZURE_PROVIDER_NAME = "azure"
|
||||
|
||||
|
||||
VERTEXAI_PROVIDER_NAME = "vertex_ai"
|
||||
VERTEXAI_DEFAULT_MODEL = "gemini-2.0-flash"
|
||||
VERTEXAI_DEFAULT_FAST_MODEL = "gemini-2.0-flash-lite"
|
||||
VERTEXAI_MODEL_NAMES = [
|
||||
# 2.5 pro models
|
||||
"gemini-2.5-pro-preview-06-05",
|
||||
"gemini-2.5-pro-preview-05-06",
|
||||
# 2.0 flash-lite models
|
||||
VERTEXAI_DEFAULT_FAST_MODEL,
|
||||
"gemini-2.0-flash-lite-001",
|
||||
# "gemini-2.0-flash-lite-preview-02-05",
|
||||
# 2.0 flash models
|
||||
VERTEXAI_DEFAULT_MODEL,
|
||||
"gemini-2.0-flash-001",
|
||||
"gemini-2.0-flash-exp",
|
||||
# "gemini-2.0-flash-exp-image-generation",
|
||||
# "gemini-2.0-flash-thinking-exp-01-21",
|
||||
# 1.5 pro models
|
||||
"gemini-1.5-pro",
|
||||
"gemini-1.5-pro-001",
|
||||
"gemini-1.5-pro-002",
|
||||
# 1.5 flash models
|
||||
"gemini-1.5-flash",
|
||||
"gemini-1.5-flash-001",
|
||||
"gemini-1.5-flash-002",
|
||||
# Anthropic models
|
||||
"claude-sonnet-4",
|
||||
"claude-opus-4",
|
||||
"claude-3-7-sonnet@20250219",
|
||||
]
|
||||
VERTEXAI_DEFAULT_MODEL = (
|
||||
get_curated_recommended_default_model("vertex_ai") or "gemini-2.0-flash"
|
||||
)
|
||||
VERTEXAI_DEFAULT_FAST_MODEL = (
|
||||
get_curated_recommended_fast_default_model("vertex_ai") or "gemini-2.0-flash-lite"
|
||||
)
|
||||
VERTEXAI_MODEL_NAMES = get_curated_model_names("vertex_ai")
|
||||
VERTEXAI_VISIBLE_MODEL_NAMES = [
|
||||
VERTEXAI_DEFAULT_MODEL,
|
||||
VERTEXAI_DEFAULT_FAST_MODEL,
|
||||
model["name"]
|
||||
for model in curated_models.get("vertex_ai", [])
|
||||
if model.get("recommended_is_visible", False) and not model.get("deprecated", False)
|
||||
]
|
||||
|
||||
|
||||
@@ -149,9 +590,14 @@ _PROVIDER_TO_MODELS_MAP = {
|
||||
VERTEXAI_PROVIDER_NAME: VERTEXAI_MODEL_NAMES,
|
||||
}
|
||||
|
||||
_PROVIDER_TO_VISIBLE_MODELS_MAP = {
|
||||
_PROVIDER_TO_VISIBLE_MODELS_MAP: dict[str, list[str]] = {
|
||||
OPENAI_PROVIDER_NAME: OPEN_AI_VISIBLE_MODEL_NAMES,
|
||||
BEDROCK_PROVIDER_NAME: [BEDROCK_DEFAULT_MODEL],
|
||||
BEDROCK_PROVIDER_NAME: [
|
||||
model["name"]
|
||||
for model in curated_models.get("bedrock", [])
|
||||
if model.get("recommended_is_visible", False)
|
||||
and not model.get("deprecated", False)
|
||||
],
|
||||
ANTHROPIC_PROVIDER_NAME: ANTHROPIC_VISIBLE_MODEL_NAMES,
|
||||
VERTEXAI_PROVIDER_NAME: VERTEXAI_VISIBLE_MODEL_NAMES,
|
||||
}
|
||||
@@ -169,8 +615,11 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
|
||||
model_configurations=fetch_model_configurations_for_provider(
|
||||
OPENAI_PROVIDER_NAME
|
||||
),
|
||||
default_model="gpt-4o",
|
||||
default_fast_model="gpt-4o-mini",
|
||||
default_model=get_curated_recommended_default_model(OPENAI_PROVIDER_NAME),
|
||||
default_fast_model=get_curated_recommended_fast_default_model(
|
||||
OPENAI_PROVIDER_NAME
|
||||
),
|
||||
has_curated_models=OPENAI_PROVIDER_NAME in curated_models,
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
name=ANTHROPIC_PROVIDER_NAME,
|
||||
@@ -182,8 +631,13 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
|
||||
model_configurations=fetch_model_configurations_for_provider(
|
||||
ANTHROPIC_PROVIDER_NAME
|
||||
),
|
||||
default_model="claude-3-7-sonnet-20250219",
|
||||
default_fast_model="claude-3-5-sonnet-20241022",
|
||||
default_model=get_curated_recommended_default_model(
|
||||
ANTHROPIC_PROVIDER_NAME
|
||||
),
|
||||
default_fast_model=get_curated_recommended_fast_default_model(
|
||||
ANTHROPIC_PROVIDER_NAME
|
||||
),
|
||||
has_curated_models=ANTHROPIC_PROVIDER_NAME in curated_models,
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
name=AZURE_PROVIDER_NAME,
|
||||
@@ -197,6 +651,7 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
|
||||
),
|
||||
deployment_name_required=True,
|
||||
single_model_supported=True,
|
||||
has_curated_models=AZURE_PROVIDER_NAME in curated_models,
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
name=BEDROCK_PROVIDER_NAME,
|
||||
@@ -226,8 +681,11 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
|
||||
model_configurations=fetch_model_configurations_for_provider(
|
||||
BEDROCK_PROVIDER_NAME
|
||||
),
|
||||
default_model=BEDROCK_DEFAULT_MODEL,
|
||||
default_fast_model=BEDROCK_DEFAULT_MODEL,
|
||||
default_model=get_curated_recommended_default_model(BEDROCK_PROVIDER_NAME),
|
||||
default_fast_model=get_curated_recommended_fast_default_model(
|
||||
BEDROCK_PROVIDER_NAME
|
||||
),
|
||||
has_curated_models=BEDROCK_PROVIDER_NAME in curated_models,
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
name=VERTEXAI_PROVIDER_NAME,
|
||||
@@ -258,8 +716,11 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
|
||||
default_value="us-east1",
|
||||
),
|
||||
],
|
||||
default_model=VERTEXAI_DEFAULT_MODEL,
|
||||
default_fast_model=VERTEXAI_DEFAULT_MODEL,
|
||||
default_model=get_curated_recommended_default_model(VERTEXAI_PROVIDER_NAME),
|
||||
default_fast_model=get_curated_recommended_fast_default_model(
|
||||
VERTEXAI_PROVIDER_NAME
|
||||
),
|
||||
has_curated_models=VERTEXAI_PROVIDER_NAME in curated_models,
|
||||
),
|
||||
]
|
||||
|
||||
@@ -284,22 +745,33 @@ def fetch_visible_model_names_for_provider_as_set(
|
||||
|
||||
def fetch_model_configurations_for_provider(
|
||||
provider_name: str,
|
||||
include_deprecated: bool = False,
|
||||
) -> list[ModelConfigurationView]:
|
||||
# if there are no explicitly listed visible model names,
|
||||
# then we won't mark any of them as "visible". This will get taken
|
||||
# care of by the logic to make default models visible.
|
||||
visible_model_names = (
|
||||
fetch_visible_model_names_for_provider_as_set(provider_name) or set()
|
||||
)
|
||||
# Use curated_models to determine which models should be visible
|
||||
# If a model is in curated_models, use its recommended_is_visible value
|
||||
# Otherwise, fall back to the old logic for backward compatibility
|
||||
return [
|
||||
ModelConfigurationView(
|
||||
name=model_name,
|
||||
is_visible=model_name in visible_model_names,
|
||||
name=model["name"],
|
||||
is_visible=get_curated_is_visible(provider_name, model["name"]),
|
||||
max_input_tokens=None,
|
||||
supports_image_input=model_supports_image_input(
|
||||
model_name=model_name,
|
||||
model_name=model["name"],
|
||||
model_provider=provider_name,
|
||||
),
|
||||
)
|
||||
for model_name in fetch_models_for_provider(provider_name)
|
||||
for model in curated_models.get(provider_name, [])
|
||||
if include_deprecated
|
||||
or not model.get(
|
||||
"deprecated", False
|
||||
) # Filter out deprecated models unless include_deprecated is True
|
||||
]
|
||||
|
||||
|
||||
def fetch_all_model_configurations_for_provider(
|
||||
provider_name: str,
|
||||
) -> list[ModelConfigurationView]:
|
||||
"""Fetch all model configurations for a provider, including deprecated ones."""
|
||||
return fetch_model_configurations_for_provider(
|
||||
provider_name, include_deprecated=True
|
||||
)
|
||||
|
||||
@@ -22,6 +22,7 @@ from onyx.db.models import User
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.llm.factory import get_llm
|
||||
from onyx.llm.factory import get_max_input_tokens_from_llm_provider
|
||||
from onyx.llm.llm_provider_options import fetch_all_model_configurations_for_provider
|
||||
from onyx.llm.llm_provider_options import fetch_available_well_known_llms
|
||||
from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor
|
||||
from onyx.llm.utils import get_llm_contextual_cost
|
||||
@@ -51,6 +52,39 @@ def fetch_llm_options(
|
||||
return fetch_available_well_known_llms()
|
||||
|
||||
|
||||
@admin_router.get("/built-in/options-with-all-models")
|
||||
def fetch_llm_options_with_all_models(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> dict[str, list[WellKnownLLMProviderDescriptor]]:
|
||||
"""Fetch both regular model configurations (without deprecated) and all model configurations (including deprecated)."""
|
||||
regular_options = fetch_available_well_known_llms()
|
||||
|
||||
# Create options with all models including deprecated ones
|
||||
all_options = []
|
||||
for option in regular_options:
|
||||
all_model_configurations = fetch_all_model_configurations_for_provider(
|
||||
option.name
|
||||
)
|
||||
all_options.append(
|
||||
WellKnownLLMProviderDescriptor(
|
||||
name=option.name,
|
||||
display_name=option.display_name,
|
||||
api_key_required=option.api_key_required,
|
||||
api_base_required=option.api_base_required,
|
||||
api_version_required=option.api_version_required,
|
||||
custom_config_keys=option.custom_config_keys,
|
||||
model_configurations=all_model_configurations,
|
||||
default_model=option.default_model,
|
||||
default_fast_model=option.default_fast_model,
|
||||
deployment_name_required=option.deployment_name_required,
|
||||
single_model_supported=option.single_model_supported,
|
||||
has_curated_models=option.has_curated_models,
|
||||
)
|
||||
)
|
||||
|
||||
return {"regular": regular_options, "all": all_options}
|
||||
|
||||
|
||||
@admin_router.post("/test")
|
||||
def test_llm_configuration(
|
||||
test_llm_request: TestLLMRequest,
|
||||
|
||||
@@ -82,6 +82,7 @@ class LLMProvider(BaseModel):
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
deployment_name: str | None = None
|
||||
default_vision_model: str | None = None
|
||||
use_recommended_models: bool | None = None
|
||||
|
||||
|
||||
class LLMProviderUpsertRequest(LLMProvider):
|
||||
@@ -127,6 +128,7 @@ class LLMProviderView(LLMProvider):
|
||||
is_public=llm_provider_model.is_public,
|
||||
groups=groups,
|
||||
deployment_name=llm_provider_model.deployment_name,
|
||||
use_recommended_models=llm_provider_model.use_recommended_models,
|
||||
model_configurations=list(
|
||||
ModelConfigurationView.from_model(
|
||||
model_configuration, llm_provider_model.provider
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
# Integration tests for LLM model update curated tasks
|
||||
@@ -0,0 +1,601 @@
|
||||
"""
|
||||
Integration tests for the LLM Model Update Onyx Curated task.
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.tasks.llm_model_update_onyx_curated.tasks import (
|
||||
_delete_deprecated_not_visible_models,
|
||||
)
|
||||
from onyx.background.celery.tasks.llm_model_update_onyx_curated.tasks import (
|
||||
_sync_model_configurations_with_curated_models,
|
||||
)
|
||||
from onyx.background.celery.tasks.llm_model_update_onyx_curated.tasks import (
|
||||
check_for_llm_model_update_onyx_curated,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import LLMProvider
|
||||
from onyx.db.models import ModelConfiguration
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
# Test constants
|
||||
TEST_TENANT_ID = "test-tenant-id"
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def db_session() -> Session:
|
||||
"""Create a database session for testing"""
|
||||
with get_session_with_current_tenant() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def tenant_context():
|
||||
"""Set up tenant context for testing"""
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def admin_user() -> DATestUser:
|
||||
"""Create an admin user for testing"""
|
||||
return UserManager.create(name="admin_user")
|
||||
|
||||
|
||||
def _create_test_llm_provider(
|
||||
db_session: Session, provider_name: str, provider_type: str
|
||||
) -> LLMProvider:
|
||||
"""Helper to create a test LLM provider"""
|
||||
llm_provider = LLMProvider(
|
||||
name=provider_name,
|
||||
provider=provider_type,
|
||||
default_model_name="test-model",
|
||||
is_default_provider=False,
|
||||
is_public=True,
|
||||
)
|
||||
db_session.add(llm_provider)
|
||||
db_session.commit()
|
||||
db_session.refresh(llm_provider)
|
||||
return llm_provider
|
||||
|
||||
|
||||
def _create_test_model_configuration(
|
||||
db_session: Session,
|
||||
llm_provider: LLMProvider,
|
||||
model_name: str,
|
||||
is_visible: bool = True,
|
||||
recommended_default: bool = False,
|
||||
recommended_fast_default: bool = False,
|
||||
recommended_is_visible: bool = True,
|
||||
) -> ModelConfiguration:
|
||||
"""Helper to create a test model configuration"""
|
||||
model_config = ModelConfiguration(
|
||||
llm_provider_id=llm_provider.id,
|
||||
name=model_name,
|
||||
is_visible=is_visible,
|
||||
recommended_default=recommended_default,
|
||||
recommended_fast_default=recommended_fast_default,
|
||||
recommended_is_visible=recommended_is_visible,
|
||||
)
|
||||
db_session.add(model_config)
|
||||
db_session.commit()
|
||||
db_session.refresh(model_config)
|
||||
return model_config
|
||||
|
||||
|
||||
def _get_model_configurations_for_provider(
|
||||
db_session: Session, llm_provider: LLMProvider
|
||||
) -> list[ModelConfiguration]:
|
||||
"""Helper to get all model configurations for a provider"""
|
||||
return (
|
||||
db_session.query(ModelConfiguration)
|
||||
.filter(ModelConfiguration.llm_provider_id == llm_provider.id)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
class TestLLMModelUpdateOnyxCurated:
|
||||
"""Test class for LLM model update curated tasks"""
|
||||
|
||||
def test_sync_adds_new_models_from_curated_list(
|
||||
self, db_session: Session, tenant_context, admin_user: DATestUser
|
||||
) -> None:
|
||||
"""Test that new models from curated list are added to database"""
|
||||
# Create test LLM provider
|
||||
llm_provider = _create_test_llm_provider(db_session, "Test OpenAI", "openai")
|
||||
|
||||
# Verify no models exist initially
|
||||
initial_models = _get_model_configurations_for_provider(
|
||||
db_session, llm_provider
|
||||
)
|
||||
assert len(initial_models) == 0
|
||||
|
||||
# Mock curated_models data
|
||||
mock_curated_models = {
|
||||
"openai": [
|
||||
{
|
||||
"name": "gpt-4",
|
||||
"recommended_default_model": True,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": True,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "gpt-4o-mini",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": True,
|
||||
"recommended_is_visible": True,
|
||||
"deprecated": False,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"onyx.background.celery.tasks.llm_model_update_onyx_curated.tasks.curated_models",
|
||||
mock_curated_models,
|
||||
):
|
||||
# Run the task
|
||||
result = check_for_llm_model_update_onyx_curated(
|
||||
Mock(), tenant_id=TEST_TENANT_ID
|
||||
)
|
||||
|
||||
# Verify task completed successfully
|
||||
assert result is True
|
||||
|
||||
# Verify models were added
|
||||
updated_models = _get_model_configurations_for_provider(
|
||||
db_session, llm_provider
|
||||
)
|
||||
assert len(updated_models) == 2
|
||||
|
||||
# Verify model details
|
||||
model_names = {model.name for model in updated_models}
|
||||
assert "gpt-4" in model_names
|
||||
assert "gpt-4o-mini" in model_names
|
||||
|
||||
# Verify recommendations were set correctly
|
||||
gpt4_model = next(m for m in updated_models if m.name == "gpt-4")
|
||||
assert gpt4_model.recommended_default is True
|
||||
assert gpt4_model.recommended_fast_default is False
|
||||
assert gpt4_model.recommended_is_visible is True
|
||||
|
||||
gpt4o_mini_model = next(
|
||||
m for m in updated_models if m.name == "gpt-4o-mini"
|
||||
)
|
||||
assert gpt4o_mini_model.recommended_default is False
|
||||
assert gpt4o_mini_model.recommended_fast_default is True
|
||||
assert gpt4o_mini_model.recommended_is_visible is True
|
||||
|
||||
def test_sync_updates_existing_models_with_new_recommendations(
|
||||
self, db_session: Session, tenant_context, admin_user: DATestUser
|
||||
) -> None:
|
||||
"""Test that existing models are updated with new recommendations"""
|
||||
# Create test LLM provider
|
||||
llm_provider = _create_test_llm_provider(db_session, "Test OpenAI", "openai")
|
||||
|
||||
# Create existing model with old recommendations
|
||||
existing_model = _create_test_model_configuration(
|
||||
db_session,
|
||||
llm_provider,
|
||||
"gpt-4",
|
||||
recommended_default=False,
|
||||
recommended_fast_default=False,
|
||||
recommended_is_visible=False,
|
||||
)
|
||||
|
||||
# Mock curated_models data with updated recommendations
|
||||
mock_curated_models = {
|
||||
"openai": [
|
||||
{
|
||||
"name": "gpt-4",
|
||||
"recommended_default_model": True,
|
||||
"recommended_fast_default_model": True,
|
||||
"recommended_is_visible": True,
|
||||
"deprecated": False,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"onyx.background.celery.tasks.llm_model_update_onyx_curated.tasks.curated_models",
|
||||
mock_curated_models,
|
||||
):
|
||||
# Run the task
|
||||
result = check_for_llm_model_update_onyx_curated(
|
||||
Mock(), tenant_id=TEST_TENANT_ID
|
||||
)
|
||||
|
||||
# Verify task completed successfully
|
||||
assert result is True
|
||||
|
||||
# Refresh the model from database
|
||||
db_session.refresh(existing_model)
|
||||
|
||||
# Verify recommendations were updated
|
||||
assert existing_model.recommended_default is True
|
||||
assert existing_model.recommended_fast_default is True
|
||||
assert existing_model.recommended_is_visible is True
|
||||
|
||||
def test_sync_deletes_deprecated_not_visible_models(
|
||||
self, db_session: Session, tenant_context, admin_user: DATestUser
|
||||
) -> None:
|
||||
"""Test that deprecated models that are not visible are deleted"""
|
||||
llm_provider = _create_test_llm_provider(db_session, "Test OpenAI", "openai")
|
||||
_ = _create_test_model_configuration(
|
||||
db_session,
|
||||
llm_provider,
|
||||
"gpt-3.5-turbo",
|
||||
is_visible=False,
|
||||
)
|
||||
_ = _create_test_model_configuration(
|
||||
db_session,
|
||||
llm_provider,
|
||||
"gpt-4",
|
||||
is_visible=True,
|
||||
)
|
||||
mock_curated_models = {
|
||||
"openai": [
|
||||
{
|
||||
"name": "gpt-3.5-turbo",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": True,
|
||||
},
|
||||
{
|
||||
"name": "gpt-4",
|
||||
"recommended_default_model": True,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": True,
|
||||
"deprecated": False,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"onyx.background.celery.tasks.llm_model_update_onyx_curated.tasks.curated_models",
|
||||
mock_curated_models,
|
||||
):
|
||||
# Run the task
|
||||
result = check_for_llm_model_update_onyx_curated(
|
||||
Mock(), tenant_id=TEST_TENANT_ID
|
||||
)
|
||||
|
||||
# Verify task completed successfully
|
||||
assert result is True
|
||||
|
||||
# Verify deprecated model was deleted
|
||||
remaining_models = _get_model_configurations_for_provider(
|
||||
db_session, llm_provider
|
||||
)
|
||||
model_names = {model.name for model in remaining_models}
|
||||
assert "gpt-3.5-turbo" not in model_names
|
||||
assert "gpt-4" in model_names
|
||||
assert len(remaining_models) == 1
|
||||
|
||||
def test_sync_skips_providers_with_custom_models(
|
||||
self, db_session: Session, tenant_context, admin_user: DATestUser
|
||||
) -> None:
|
||||
"""Test that providers with custom models are skipped"""
|
||||
# Create test LLM provider
|
||||
llm_provider = _create_test_llm_provider(db_session, "Test OpenAI", "openai")
|
||||
|
||||
# Create custom model not in curated list
|
||||
_ = _create_test_model_configuration(
|
||||
db_session,
|
||||
llm_provider,
|
||||
"custom-model-not-in-curated",
|
||||
is_visible=True,
|
||||
)
|
||||
|
||||
# Mock curated_models data without the custom model
|
||||
mock_curated_models = {
|
||||
"openai": [
|
||||
{
|
||||
"name": "gpt-4",
|
||||
"recommended_default_model": True,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": True,
|
||||
"deprecated": False,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"onyx.background.celery.tasks.llm_model_update_onyx_curated.tasks.curated_models",
|
||||
mock_curated_models,
|
||||
):
|
||||
# Run the task
|
||||
result = check_for_llm_model_update_onyx_curated(
|
||||
Mock(), tenant_id=TEST_TENANT_ID
|
||||
)
|
||||
|
||||
# Verify task completed successfully
|
||||
assert result is True
|
||||
|
||||
# Verify custom model still exists (provider was skipped)
|
||||
remaining_models = _get_model_configurations_for_provider(
|
||||
db_session, llm_provider
|
||||
)
|
||||
assert len(remaining_models) == 1
|
||||
assert remaining_models[0].name == "custom-model-not-in-curated"
|
||||
|
||||
# Verify no new models were added
|
||||
model_names = {model.name for model in remaining_models}
|
||||
assert "gpt-4" not in model_names
|
||||
|
||||
def test_sync_skips_providers_not_in_curated_models(
|
||||
self, db_session: Session, tenant_context, admin_user: DATestUser
|
||||
) -> None:
|
||||
"""Test that providers not in curated_models are skipped"""
|
||||
# Create test LLM provider for unsupported provider
|
||||
llm_provider = _create_test_llm_provider(
|
||||
db_session, "Test Custom", "custom_provider"
|
||||
)
|
||||
|
||||
# Create existing model
|
||||
_ = _create_test_model_configuration(
|
||||
db_session,
|
||||
llm_provider,
|
||||
"custom-model",
|
||||
is_visible=True,
|
||||
)
|
||||
|
||||
# Mock curated_models data without the custom provider
|
||||
mock_curated_models = {
|
||||
"openai": [
|
||||
{
|
||||
"name": "gpt-4",
|
||||
"recommended_default_model": True,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": True,
|
||||
"deprecated": False,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"onyx.background.celery.tasks.llm_model_update_onyx_curated.tasks.curated_models",
|
||||
mock_curated_models,
|
||||
):
|
||||
# Run the task
|
||||
result = check_for_llm_model_update_onyx_curated(
|
||||
Mock(), tenant_id=TEST_TENANT_ID
|
||||
)
|
||||
|
||||
# Verify task completed successfully
|
||||
assert result is True
|
||||
|
||||
# Verify existing model is unchanged (provider was skipped)
|
||||
remaining_models = _get_model_configurations_for_provider(
|
||||
db_session, llm_provider
|
||||
)
|
||||
assert len(remaining_models) == 1
|
||||
assert remaining_models[0].name == "custom-model"
|
||||
|
||||
def test_sync_handles_multiple_providers_correctly(
|
||||
self, db_session: Session, tenant_context, admin_user: DATestUser
|
||||
) -> None:
|
||||
"""Test that multiple providers are handled correctly"""
|
||||
# Create test LLM providers
|
||||
openai_provider = _create_test_llm_provider(db_session, "Test OpenAI", "openai")
|
||||
anthropic_provider = _create_test_llm_provider(
|
||||
db_session, "Test Anthropic", "anthropic"
|
||||
)
|
||||
|
||||
# Mock curated_models data for both providers
|
||||
mock_curated_models = {
|
||||
"openai": [
|
||||
{
|
||||
"name": "gpt-4",
|
||||
"recommended_default_model": True,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": True,
|
||||
"deprecated": False,
|
||||
}
|
||||
],
|
||||
"anthropic": [
|
||||
{
|
||||
"name": "claude-3-opus",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": True,
|
||||
"recommended_is_visible": True,
|
||||
"deprecated": False,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"onyx.background.celery.tasks.llm_model_update_onyx_curated.tasks.curated_models",
|
||||
mock_curated_models,
|
||||
):
|
||||
# Run the task
|
||||
result = check_for_llm_model_update_onyx_curated(
|
||||
Mock(), tenant_id=TEST_TENANT_ID
|
||||
)
|
||||
|
||||
# Verify task completed successfully
|
||||
assert result is True
|
||||
|
||||
# Verify models were added for both providers
|
||||
openai_models = _get_model_configurations_for_provider(
|
||||
db_session, openai_provider
|
||||
)
|
||||
anthropic_models = _get_model_configurations_for_provider(
|
||||
db_session, anthropic_provider
|
||||
)
|
||||
|
||||
assert len(openai_models) == 1
|
||||
assert openai_models[0].name == "gpt-4"
|
||||
assert openai_models[0].recommended_default is True
|
||||
|
||||
assert len(anthropic_models) == 1
|
||||
assert anthropic_models[0].name == "claude-3-opus"
|
||||
assert anthropic_models[0].recommended_fast_default is True
|
||||
|
||||
def test_sync_handles_exception_gracefully(
|
||||
self, db_session: Session, tenant_context, admin_user: DATestUser
|
||||
) -> None:
|
||||
"""Test that the task handles exceptions gracefully"""
|
||||
# Create test LLM provider
|
||||
_create_test_llm_provider(db_session, "Test OpenAI", "openai")
|
||||
|
||||
# Mock curated_models to raise an exception
|
||||
with patch(
|
||||
"onyx.background.celery.tasks.llm_model_update_onyx_curated.tasks.curated_models",
|
||||
side_effect=Exception("Test exception"),
|
||||
):
|
||||
# Run the task
|
||||
result = check_for_llm_model_update_onyx_curated(
|
||||
Mock(), tenant_id=TEST_TENANT_ID
|
||||
)
|
||||
|
||||
# Verify task returned None (indicating failure)
|
||||
assert result is None
|
||||
|
||||
def test_helper_function_delete_deprecated_not_visible_models(
|
||||
self, db_session: Session, tenant_context, admin_user: DATestUser
|
||||
) -> None:
|
||||
"""Test the helper function for deleting deprecated models"""
|
||||
# Create test LLM provider
|
||||
llm_provider = _create_test_llm_provider(db_session, "Test OpenAI", "openai")
|
||||
|
||||
# Create test models
|
||||
_ = _create_test_model_configuration(
|
||||
db_session, llm_provider, "deprecated-model", is_visible=False
|
||||
)
|
||||
_ = _create_test_model_configuration(
|
||||
db_session, llm_provider, "deprecated-visible-model", is_visible=True
|
||||
)
|
||||
_ = _create_test_model_configuration(
|
||||
db_session, llm_provider, "active-model", is_visible=False
|
||||
)
|
||||
|
||||
# Mock curated_models data
|
||||
mock_curated_models = {
|
||||
"openai": [
|
||||
{
|
||||
"name": "deprecated-model",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": True,
|
||||
},
|
||||
{
|
||||
"name": "deprecated-visible-model",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": True,
|
||||
},
|
||||
{
|
||||
"name": "active-model",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": False,
|
||||
"deprecated": False,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"onyx.background.celery.tasks.llm_model_update_onyx_curated.tasks.curated_models",
|
||||
mock_curated_models,
|
||||
):
|
||||
# Get all model configurations
|
||||
model_configurations = _get_model_configurations_for_provider(
|
||||
db_session, llm_provider
|
||||
)
|
||||
|
||||
# Call the helper function
|
||||
_delete_deprecated_not_visible_models(
|
||||
db_session, llm_provider, model_configurations
|
||||
)
|
||||
|
||||
# Verify only the deprecated and not visible model was deleted
|
||||
remaining_models = _get_model_configurations_for_provider(
|
||||
db_session, llm_provider
|
||||
)
|
||||
model_names = {model.name for model in remaining_models}
|
||||
|
||||
assert "deprecated-model" not in model_names # Should be deleted
|
||||
assert "deprecated-visible-model" in model_names # Should remain (visible)
|
||||
assert "active-model" in model_names # Should remain (not deprecated)
|
||||
assert len(remaining_models) == 2
|
||||
|
||||
def test_helper_function_sync_model_configurations(
|
||||
self, db_session: Session, tenant_context, admin_user: DATestUser
|
||||
) -> None:
|
||||
"""Test the helper function for syncing model configurations"""
|
||||
# Create test LLM provider
|
||||
llm_provider = _create_test_llm_provider(db_session, "Test OpenAI", "openai")
|
||||
|
||||
# Create existing model with old recommendations
|
||||
existing_model = _create_test_model_configuration(
|
||||
db_session,
|
||||
llm_provider,
|
||||
"gpt-4",
|
||||
recommended_default=False,
|
||||
recommended_fast_default=False,
|
||||
recommended_is_visible=False,
|
||||
)
|
||||
|
||||
# Mock curated_models data
|
||||
mock_curated_models = {
|
||||
"openai": [
|
||||
{
|
||||
"name": "gpt-4",
|
||||
"recommended_default_model": True,
|
||||
"recommended_fast_default_model": True,
|
||||
"recommended_is_visible": True,
|
||||
"deprecated": False,
|
||||
},
|
||||
{
|
||||
"name": "gpt-4o",
|
||||
"recommended_default_model": False,
|
||||
"recommended_fast_default_model": False,
|
||||
"recommended_is_visible": True,
|
||||
"deprecated": False,
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
with patch(
|
||||
"onyx.background.celery.tasks.llm_model_update_onyx_curated.tasks.curated_models",
|
||||
mock_curated_models,
|
||||
):
|
||||
# Get all model configurations
|
||||
model_configurations = _get_model_configurations_for_provider(
|
||||
db_session, llm_provider
|
||||
)
|
||||
|
||||
# Call the helper function
|
||||
_sync_model_configurations_with_curated_models(
|
||||
db_session, llm_provider, model_configurations
|
||||
)
|
||||
|
||||
# Verify existing model was updated and new model was added
|
||||
updated_models = _get_model_configurations_for_provider(
|
||||
db_session, llm_provider
|
||||
)
|
||||
assert len(updated_models) == 2
|
||||
|
||||
# Verify existing model was updated
|
||||
db_session.refresh(existing_model)
|
||||
assert existing_model.recommended_default is True
|
||||
assert existing_model.recommended_fast_default is True
|
||||
assert existing_model.recommended_is_visible is True
|
||||
|
||||
# Verify new model was added
|
||||
model_names = {model.name for model in updated_models}
|
||||
assert "gpt-4" in model_names
|
||||
assert "gpt-4o" in model_names
|
||||
@@ -0,0 +1 @@
|
||||
# Empty init file to make this directory a Python package
|
||||
@@ -0,0 +1,286 @@
|
||||
"""
|
||||
Integration tests for the Update Recommended Selected Models task.
|
||||
|
||||
These tests verify the functionality of the update_recommended_selected_models
|
||||
Celery task which updates LLM model visibility, default model, and fast default model
|
||||
settings based on recommended configurations.
|
||||
|
||||
To run these tests:
|
||||
cd backend
|
||||
pytest tests/integration/tests/update_recommended_selected_models/test_update_recommended_selected_models.py -v
|
||||
|
||||
Prerequisites:
|
||||
- PostgreSQL database running
|
||||
- Proper environment variables set for database connection
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.tasks.update_recommended_selected_models.tasks import (
|
||||
update_recommended_selected_models,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import LLMProvider
|
||||
from onyx.db.models import ModelConfiguration
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
# Test constants
|
||||
TEST_TENANT_ID = "test-tenant-id"
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def db_session() -> Session:
|
||||
"""Create a database session for testing"""
|
||||
with get_session_with_current_tenant() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def tenant_context():
|
||||
"""Set up tenant context for testing"""
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def _create_test_llm_provider(
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
provider_type: str,
|
||||
use_recommended_models: bool = True,
|
||||
default_model_name: str = "default-model",
|
||||
fast_default_model_name: str = "fast-default-model",
|
||||
) -> LLMProvider:
|
||||
"""Helper to create a test LLM provider"""
|
||||
# Try to find an existing provider first to avoid constraint violations
|
||||
existing_provider = (
|
||||
db_session.query(LLMProvider).filter(LLMProvider.name == provider_name).first()
|
||||
)
|
||||
|
||||
if existing_provider:
|
||||
# Update existing provider
|
||||
existing_provider.use_recommended_models = use_recommended_models
|
||||
existing_provider.default_model_name = default_model_name
|
||||
existing_provider.fast_default_model_name = fast_default_model_name
|
||||
db_session.commit()
|
||||
db_session.refresh(existing_provider)
|
||||
return existing_provider
|
||||
|
||||
# Create new provider
|
||||
llm_provider = LLMProvider(
|
||||
name=provider_name,
|
||||
provider=provider_type,
|
||||
default_model_name=default_model_name,
|
||||
fast_default_model_name=fast_default_model_name,
|
||||
is_default_provider=False,
|
||||
is_public=True,
|
||||
use_recommended_models=use_recommended_models,
|
||||
)
|
||||
|
||||
try:
|
||||
db_session.add(llm_provider)
|
||||
db_session.commit()
|
||||
db_session.refresh(llm_provider)
|
||||
return llm_provider
|
||||
except Exception:
|
||||
# If we get a constraint error, try to find and return the existing one
|
||||
db_session.rollback()
|
||||
existing_provider = (
|
||||
db_session.query(LLMProvider)
|
||||
.filter(LLMProvider.name == provider_name)
|
||||
.first()
|
||||
)
|
||||
if existing_provider:
|
||||
return existing_provider
|
||||
raise
|
||||
|
||||
|
||||
def _create_test_model_configuration(
|
||||
db_session: Session,
|
||||
llm_provider: LLMProvider,
|
||||
model_name: str,
|
||||
is_visible: bool = True,
|
||||
recommended_default: bool = False,
|
||||
recommended_fast_default: bool = False,
|
||||
recommended_is_visible: bool | None = None,
|
||||
) -> ModelConfiguration:
|
||||
"""Helper to create a test model configuration"""
|
||||
# Try to find existing model configuration first
|
||||
existing_config = (
|
||||
db_session.query(ModelConfiguration)
|
||||
.filter(
|
||||
ModelConfiguration.llm_provider_id == llm_provider.id,
|
||||
ModelConfiguration.name == model_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_config:
|
||||
# Update existing configuration
|
||||
existing_config.is_visible = is_visible
|
||||
existing_config.recommended_default = recommended_default
|
||||
existing_config.recommended_fast_default = recommended_fast_default
|
||||
existing_config.recommended_is_visible = recommended_is_visible
|
||||
db_session.commit()
|
||||
db_session.refresh(existing_config)
|
||||
return existing_config
|
||||
|
||||
# Create new configuration
|
||||
model_config = ModelConfiguration(
|
||||
llm_provider_id=llm_provider.id,
|
||||
name=model_name,
|
||||
is_visible=is_visible,
|
||||
recommended_default=recommended_default,
|
||||
recommended_fast_default=recommended_fast_default,
|
||||
recommended_is_visible=recommended_is_visible,
|
||||
)
|
||||
db_session.add(model_config)
|
||||
db_session.commit()
|
||||
db_session.refresh(model_config)
|
||||
return model_config
|
||||
|
||||
|
||||
class TestUpdateRecommendedSelectedModels:
|
||||
"""Test class for update recommended selected models task"""
|
||||
|
||||
def test_updates_visibility_based_on_recommended_is_visible(
|
||||
self, db_session: Session, tenant_context
|
||||
) -> None:
|
||||
"""Test that model visibility is updated based on recommended_is_visible"""
|
||||
# Create unique provider name for this test
|
||||
provider_name = f"Test Provider Visibility {uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Create test LLM provider
|
||||
llm_provider = _create_test_llm_provider(
|
||||
db_session, provider_name, "openai", use_recommended_models=True
|
||||
)
|
||||
|
||||
# Create model with is_visible=True but recommended_is_visible=False
|
||||
model_config = _create_test_model_configuration(
|
||||
db_session,
|
||||
llm_provider,
|
||||
"gpt-4",
|
||||
is_visible=True,
|
||||
recommended_is_visible=False,
|
||||
)
|
||||
|
||||
# Run the task
|
||||
result = update_recommended_selected_models(Mock(), tenant_id=TEST_TENANT_ID)
|
||||
|
||||
# Verify task completed successfully
|
||||
assert result is True
|
||||
|
||||
# Refresh the model from database
|
||||
db_session.refresh(model_config)
|
||||
|
||||
# Verify visibility was updated
|
||||
assert model_config.is_visible is False
|
||||
|
||||
def test_updates_default_model_name_based_on_recommended_default(
|
||||
self, db_session: Session, tenant_context
|
||||
) -> None:
|
||||
"""Test that provider default_model_name is updated based on recommended_default"""
|
||||
# Create unique provider name for this test
|
||||
provider_name = f"Test Provider Default {uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Create test LLM provider
|
||||
llm_provider = _create_test_llm_provider(
|
||||
db_session,
|
||||
provider_name,
|
||||
"openai",
|
||||
use_recommended_models=True,
|
||||
default_model_name="old-default-model",
|
||||
)
|
||||
|
||||
# Create model with recommended_default=True
|
||||
_ = _create_test_model_configuration(
|
||||
db_session,
|
||||
llm_provider,
|
||||
"gpt-4",
|
||||
recommended_default=True,
|
||||
)
|
||||
|
||||
# Run the task
|
||||
result = update_recommended_selected_models(Mock(), tenant_id=TEST_TENANT_ID)
|
||||
|
||||
# Verify task completed successfully
|
||||
assert result is True
|
||||
|
||||
# Refresh the provider from database
|
||||
db_session.refresh(llm_provider)
|
||||
|
||||
# Verify default_model_name was updated
|
||||
assert llm_provider.default_model_name == "gpt-4"
|
||||
|
||||
def test_skips_providers_without_use_recommended_models(
|
||||
self, db_session: Session, tenant_context
|
||||
) -> None:
|
||||
"""Test that providers without use_recommended_models=True are skipped"""
|
||||
# Create unique provider name for this test
|
||||
provider_name = f"Test Provider Skip {uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Create test LLM provider with use_recommended_models=False
|
||||
llm_provider = _create_test_llm_provider(
|
||||
db_session,
|
||||
provider_name,
|
||||
"openai",
|
||||
use_recommended_models=False,
|
||||
default_model_name="old-default-model",
|
||||
)
|
||||
|
||||
# Create model with recommended_default=True
|
||||
model_config = _create_test_model_configuration(
|
||||
db_session,
|
||||
llm_provider,
|
||||
"gpt-4",
|
||||
is_visible=True,
|
||||
recommended_default=True,
|
||||
recommended_is_visible=False,
|
||||
)
|
||||
|
||||
# Run the task
|
||||
result = update_recommended_selected_models(Mock(), tenant_id=TEST_TENANT_ID)
|
||||
|
||||
# Verify task completed successfully
|
||||
assert result is True
|
||||
|
||||
# Refresh the provider and model from database
|
||||
db_session.refresh(llm_provider)
|
||||
db_session.refresh(model_config)
|
||||
|
||||
# Verify nothing was updated (provider was skipped)
|
||||
assert llm_provider.default_model_name == "old-default-model"
|
||||
assert model_config.is_visible is True
|
||||
|
||||
def test_handles_exception_gracefully(
|
||||
self, db_session: Session, tenant_context
|
||||
) -> None:
|
||||
"""Test that the task handles exceptions gracefully"""
|
||||
# Create unique provider name for this test
|
||||
provider_name = f"Test Provider Exception {uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Create test LLM provider
|
||||
_create_test_llm_provider(
|
||||
db_session, provider_name, "openai", use_recommended_models=True
|
||||
)
|
||||
|
||||
# Mock get_session_with_current_tenant to raise an exception
|
||||
with patch(
|
||||
"onyx.background.celery.tasks.update_recommended_selected_models.tasks.get_session_with_current_tenant"
|
||||
) as mock_get_session:
|
||||
mock_get_session.side_effect = Exception("Database connection failed")
|
||||
|
||||
# Run the task
|
||||
result = update_recommended_selected_models(
|
||||
Mock(), tenant_id=TEST_TENANT_ID
|
||||
)
|
||||
|
||||
# Verify task returned None (indicating failure)
|
||||
assert result is None
|
||||
96
backend/tests/unit/onyx/llm/test_llm_provider_options.py
Normal file
96
backend/tests/unit/onyx/llm/test_llm_provider_options.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from onyx.llm.llm_provider_options import curated_models
|
||||
|
||||
|
||||
class TestCuratedModelsFormat:
|
||||
"""Test the format and constraints of the curated_models data structure."""
|
||||
|
||||
def test_deprecated_models_have_false_flags(self) -> None:
|
||||
for provider_name, models in curated_models.items():
|
||||
for model in models:
|
||||
if model.get("deprecated", False):
|
||||
assert (
|
||||
model.get("recommended_default_model", False) is False
|
||||
), f"Deprecated model '{model['name']}' in provider '{provider_name}' has recommended_default_model=True"
|
||||
|
||||
assert (
|
||||
model.get("recommended_fast_default_model", False) is False
|
||||
), f"Deprecated model '{model['name']}' in provider '{provider_name}' has recommended_fast_default_model=True"
|
||||
|
||||
assert (
|
||||
model.get("recommended_is_visible", False) is False
|
||||
), f"Deprecated model '{model['name']}' in provider '{provider_name}' has recommended_is_visible=True"
|
||||
|
||||
def test_model_names_are_globally_unique(self) -> None:
|
||||
all_model_names = []
|
||||
|
||||
for _, models in curated_models.items():
|
||||
for model in models:
|
||||
model_name = model["name"]
|
||||
assert (
|
||||
model_name not in all_model_names
|
||||
), f"Model name '{model_name}' appears in multiple providers."
|
||||
all_model_names.append(model_name)
|
||||
|
||||
def test_at_most_one_default_model_per_provider(self) -> None:
|
||||
for provider_name, models in curated_models.items():
|
||||
default_models = []
|
||||
fast_default_models = []
|
||||
|
||||
for model in models:
|
||||
if model.get("recommended_default_model", False):
|
||||
default_models.append(model["name"])
|
||||
|
||||
if model.get("recommended_fast_default_model", False):
|
||||
fast_default_models.append(model["name"])
|
||||
|
||||
assert (
|
||||
len(default_models) <= 1
|
||||
), f"Provider '{provider_name}' has multiple recommended_default_model set to True: {default_models}"
|
||||
|
||||
assert (
|
||||
len(fast_default_models) <= 1
|
||||
), f"Provider '{provider_name}' has multiple recommended_fast_default_model set to True: {fast_default_models}"
|
||||
|
||||
def test_required_fields_present(self) -> None:
|
||||
"""Test that all required fields are present in each model definition."""
|
||||
required_fields = [
|
||||
"name",
|
||||
"friendly_name",
|
||||
"recommended_default_model",
|
||||
"recommended_fast_default_model",
|
||||
"recommended_is_visible",
|
||||
"deprecated",
|
||||
]
|
||||
|
||||
for provider_name, models in curated_models.items():
|
||||
for model in models:
|
||||
for field in required_fields:
|
||||
assert (
|
||||
field in model
|
||||
), f"Model '{model.get('name', 'UNKNOWN')}' in provider '{provider_name}' is missing required field '{field}'"
|
||||
|
||||
def test_field_types_are_correct(self) -> None:
|
||||
"""Test that all fields have the correct types."""
|
||||
for provider_name, models in curated_models.items():
|
||||
for model in models:
|
||||
# Test string fields
|
||||
assert isinstance(
|
||||
model["name"], str
|
||||
), f"Model name must be a string in provider '{provider_name}'"
|
||||
assert isinstance(
|
||||
model["friendly_name"], str
|
||||
), f"Model friendly_name must be a string in provider '{provider_name}'"
|
||||
|
||||
# Test boolean fields
|
||||
assert isinstance(
|
||||
model["recommended_default_model"], bool
|
||||
), f"Model recommended_default_model must be a boolean in provider '{provider_name}'"
|
||||
assert isinstance(
|
||||
model["recommended_fast_default_model"], bool
|
||||
), f"Model recommended_fast_default_model must be a boolean in provider '{provider_name}'"
|
||||
assert isinstance(
|
||||
model["recommended_is_visible"], bool
|
||||
), f"Model recommended_is_visible must be a boolean in provider '{provider_name}'"
|
||||
assert isinstance(
|
||||
model["deprecated"], bool
|
||||
), f"Model deprecated must be a boolean in provider '{provider_name}'"
|
||||
@@ -147,9 +147,11 @@ function LLMProviderDisplay({
|
||||
export function ConfiguredLLMProviderDisplay({
|
||||
existingLlmProviders,
|
||||
llmProviderDescriptors,
|
||||
allLlmProviderDescriptors,
|
||||
}: {
|
||||
existingLlmProviders: LLMProviderView[];
|
||||
llmProviderDescriptors: WellKnownLLMProviderDescriptor[];
|
||||
allLlmProviderDescriptors: WellKnownLLMProviderDescriptor[];
|
||||
}) {
|
||||
existingLlmProviders = existingLlmProviders.sort((a, b) => {
|
||||
if (a.is_default_provider && !b.is_default_provider) {
|
||||
@@ -168,6 +170,10 @@ export function ConfiguredLLMProviderDisplay({
|
||||
(llmProviderDescriptors) =>
|
||||
llmProviderDescriptors.name === provider.provider
|
||||
);
|
||||
const allDefaultProviderDescriptor = allLlmProviderDescriptors.find(
|
||||
(allLlmProviderDescriptors) =>
|
||||
allLlmProviderDescriptors.name === provider.provider
|
||||
);
|
||||
|
||||
return (
|
||||
<LLMProviderDisplay
|
||||
@@ -177,8 +183,8 @@ export function ConfiguredLLMProviderDisplay({
|
||||
// provider descriptor
|
||||
llmProviderDescriptor={
|
||||
isSubset(
|
||||
defaultProviderDesciptor
|
||||
? defaultProviderDesciptor.model_configurations.map(
|
||||
allDefaultProviderDescriptor
|
||||
? allDefaultProviderDescriptor.model_configurations.map(
|
||||
(model_configuration) => model_configuration.name
|
||||
)
|
||||
: [],
|
||||
|
||||
@@ -134,12 +134,20 @@ export function LLMConfiguration() {
|
||||
const { data: llmProviderDescriptors } = useSWR<
|
||||
WellKnownLLMProviderDescriptor[]
|
||||
>("/api/admin/llm/built-in/options", errorHandlingFetcher);
|
||||
const { data: llmOptionsWithAllModels } = useSWR<{
|
||||
regular: WellKnownLLMProviderDescriptor[];
|
||||
all: WellKnownLLMProviderDescriptor[];
|
||||
}>("/api/admin/llm/built-in/options-with-all-models", errorHandlingFetcher);
|
||||
const { data: existingLlmProviders } = useSWR<LLMProviderView[]>(
|
||||
LLM_PROVIDERS_ADMIN_URL,
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
||||
if (!llmProviderDescriptors || !existingLlmProviders) {
|
||||
if (
|
||||
!llmProviderDescriptors ||
|
||||
!existingLlmProviders ||
|
||||
!llmOptionsWithAllModels
|
||||
) {
|
||||
return <ThreeDotsLoader />;
|
||||
}
|
||||
|
||||
@@ -158,6 +166,7 @@ export function LLMConfiguration() {
|
||||
<ConfiguredLLMProviderDisplay
|
||||
existingLlmProviders={existingLlmProviders}
|
||||
llmProviderDescriptors={llmProviderDescriptors}
|
||||
allLlmProviderDescriptors={llmOptionsWithAllModels.all}
|
||||
/>
|
||||
</>
|
||||
) : (
|
||||
|
||||
@@ -24,6 +24,7 @@ import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import * as Yup from "yup";
|
||||
import isEqual from "lodash/isEqual";
|
||||
import { IsPublicGroupSelector } from "@/components/IsPublicGroupSelector";
|
||||
import { Switch } from "@/components/ui/switch";
|
||||
|
||||
export function LLMProviderUpdateForm({
|
||||
llmProviderDescriptor,
|
||||
@@ -49,7 +50,23 @@ export function LLMProviderUpdateForm({
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const [testError, setTestError] = useState<string>("");
|
||||
|
||||
// Determine which model configurations to use for options
|
||||
// When editing an existing provider, use the database configurations
|
||||
// When creating a new provider, use the hardcoded descriptor configurations
|
||||
const modelConfigurationsForOptions = existingLlmProvider
|
||||
? existingLlmProvider.model_configurations
|
||||
: llmProviderDescriptor.model_configurations;
|
||||
|
||||
// Check if there are any display models available
|
||||
const hasDisplayModels = modelConfigurationsForOptions.some(
|
||||
(modelConfiguration) => modelConfiguration.is_visible
|
||||
);
|
||||
|
||||
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
|
||||
const [useRecommendedModel, setUseRecommendedModel] = useState(
|
||||
existingLlmProvider?.use_recommended_models ??
|
||||
(hasDisplayModels && llmProviderDescriptor.has_curated_models)
|
||||
);
|
||||
|
||||
// Define the initial values based on the provider's requirements
|
||||
const initialValues = {
|
||||
@@ -160,38 +177,51 @@ export function LLMProviderUpdateForm({
|
||||
const finalValues = {
|
||||
...rest,
|
||||
api_key_changed: values.api_key !== initialValues.api_key,
|
||||
model_configurations: llmProviderDescriptor.model_configurations.map(
|
||||
// If using recommended model, use the provider's defaults
|
||||
default_model_name: useRecommendedModel
|
||||
? (existingLlmProvider?.default_model_name ??
|
||||
(llmProviderDescriptor.default_model ||
|
||||
llmProviderDescriptor.model_configurations[0]?.name))
|
||||
: values.default_model_name,
|
||||
fast_default_model_name: useRecommendedModel
|
||||
? (existingLlmProvider?.fast_default_model_name ??
|
||||
(llmProviderDescriptor.default_fast_model ||
|
||||
llmProviderDescriptor.default_model ||
|
||||
llmProviderDescriptor.model_configurations[0]?.name))
|
||||
: values.fast_default_model_name,
|
||||
model_configurations: modelConfigurationsForOptions.map(
|
||||
(modelConfiguration): ModelConfigurationUpsertRequest => ({
|
||||
name: modelConfiguration.name,
|
||||
is_visible: visibleModels.includes(modelConfiguration.name),
|
||||
is_visible: useRecommendedModel
|
||||
? modelConfiguration.is_visible
|
||||
: visibleModels.includes(modelConfiguration.name),
|
||||
max_input_tokens: null,
|
||||
})
|
||||
),
|
||||
};
|
||||
|
||||
// test the configuration
|
||||
if (!isEqual(finalValues, initialValues)) {
|
||||
setIsTesting(true);
|
||||
// if (!isEqual(finalValues, initialValues)) {
|
||||
// setIsTesting(true);
|
||||
|
||||
const response = await fetch("/api/admin/llm/test", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
provider: llmProviderDescriptor.name,
|
||||
...finalValues,
|
||||
}),
|
||||
});
|
||||
setIsTesting(false);
|
||||
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
setTestError(errorMsg);
|
||||
return;
|
||||
}
|
||||
}
|
||||
// const response = await fetch("/api/admin/llm/test", {
|
||||
// method: "POST",
|
||||
// headers: {
|
||||
// "Content-Type": "application/json",
|
||||
// },
|
||||
// body: JSON.stringify({
|
||||
// provider: llmProviderDescriptor.name,
|
||||
// ...finalValues,
|
||||
// }),
|
||||
// });
|
||||
// setIsTesting(false);
|
||||
|
||||
// if (!response.ok) {
|
||||
// const errorMsg = (await response.json()).detail;
|
||||
// setTestError(errorMsg);
|
||||
// return;
|
||||
// }
|
||||
// }
|
||||
console.log(useRecommendedModel);
|
||||
const response = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}${
|
||||
existingLlmProvider ? "" : "?is_creation=true"
|
||||
@@ -204,6 +234,7 @@ export function LLMProviderUpdateForm({
|
||||
body: JSON.stringify({
|
||||
provider: llmProviderDescriptor.name,
|
||||
...finalValues,
|
||||
use_recommended_models: useRecommendedModel,
|
||||
fast_default_model_name:
|
||||
finalValues.fast_default_model_name ||
|
||||
finalValues.default_model_name,
|
||||
@@ -344,67 +375,110 @@ export function LLMProviderUpdateForm({
|
||||
<>
|
||||
<Separator />
|
||||
|
||||
{llmProviderDescriptor.model_configurations.length > 0 ? (
|
||||
<SelectorFormField
|
||||
name="default_model_name"
|
||||
subtext="The model to use by default for this provider unless otherwise specified."
|
||||
label="Default Model"
|
||||
options={llmProviderDescriptor.model_configurations.map(
|
||||
(modelConfiguration) => ({
|
||||
// don't clean up names here to give admins descriptive names / handle duplicates
|
||||
// like us.anthropic.claude-3-7-sonnet-20250219-v1:0 and anthropic.claude-3-7-sonnet-20250219-v1:0
|
||||
name: modelConfiguration.name,
|
||||
value: modelConfiguration.name,
|
||||
})
|
||||
{hasDisplayModels && llmProviderDescriptor.has_curated_models && (
|
||||
<div className="flex items-center space-x-2 mb-4">
|
||||
<Switch
|
||||
checked={useRecommendedModel}
|
||||
onCheckedChange={setUseRecommendedModel}
|
||||
/>
|
||||
<label className="text-sm font-medium">
|
||||
Use recommended models
|
||||
</label>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{!useRecommendedModel && (
|
||||
<>
|
||||
{modelConfigurationsForOptions.length > 0 ? (
|
||||
<SelectorFormField
|
||||
name="default_model_name"
|
||||
subtext="The model to use by default for this provider unless otherwise specified."
|
||||
label="Default Model"
|
||||
options={modelConfigurationsForOptions.map(
|
||||
(modelConfiguration) => ({
|
||||
// don't clean up names here to give admins descriptive names / handle duplicates
|
||||
// like us.anthropic.claude-3-7-sonnet-20250219-v1:0 and anthropic.claude-3-7-sonnet-20250219-v1:0
|
||||
name: modelConfiguration.name,
|
||||
value: modelConfiguration.name,
|
||||
})
|
||||
)}
|
||||
maxHeight="max-h-56"
|
||||
/>
|
||||
) : (
|
||||
<TextFormField
|
||||
name="default_model_name"
|
||||
subtext="The model to use by default for this provider unless otherwise specified."
|
||||
label="Default Model"
|
||||
placeholder="E.g. gpt-4"
|
||||
/>
|
||||
)}
|
||||
maxHeight="max-h-56"
|
||||
/>
|
||||
) : (
|
||||
<TextFormField
|
||||
name="default_model_name"
|
||||
subtext="The model to use by default for this provider unless otherwise specified."
|
||||
label="Default Model"
|
||||
placeholder="E.g. gpt-4"
|
||||
/>
|
||||
)}
|
||||
|
||||
{llmProviderDescriptor.deployment_name_required && (
|
||||
<TextFormField
|
||||
name="deployment_name"
|
||||
label="Deployment Name"
|
||||
placeholder="Deployment Name"
|
||||
/>
|
||||
)}
|
||||
{llmProviderDescriptor.deployment_name_required && (
|
||||
<TextFormField
|
||||
name="deployment_name"
|
||||
label="Deployment Name"
|
||||
placeholder="Deployment Name"
|
||||
/>
|
||||
)}
|
||||
|
||||
{!llmProviderDescriptor.single_model_supported &&
|
||||
(llmProviderDescriptor.model_configurations.length > 0 ? (
|
||||
<SelectorFormField
|
||||
name="fast_default_model_name"
|
||||
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
|
||||
{!llmProviderDescriptor.single_model_supported &&
|
||||
(modelConfigurationsForOptions.length > 0 ? (
|
||||
<SelectorFormField
|
||||
name="fast_default_model_name"
|
||||
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
|
||||
for this provider. If \`Default\` is specified, will use
|
||||
the Default Model configured above.`}
|
||||
label="[Optional] Fast Model"
|
||||
options={llmProviderDescriptor.model_configurations.map(
|
||||
(modelConfiguration) => ({
|
||||
// don't clean up names here to give admins descriptive names / handle duplicates
|
||||
// like us.anthropic.claude-3-7-sonnet-20250219-v1:0 and anthropic.claude-3-7-sonnet-20250219-v1:0
|
||||
name: modelConfiguration.name,
|
||||
value: modelConfiguration.name,
|
||||
})
|
||||
)}
|
||||
includeDefault
|
||||
maxHeight="max-h-56"
|
||||
/>
|
||||
) : (
|
||||
<TextFormField
|
||||
name="fast_default_model_name"
|
||||
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
|
||||
label="[Optional] Fast Model"
|
||||
options={modelConfigurationsForOptions.map(
|
||||
(modelConfiguration) => ({
|
||||
// don't clean up names here to give admins descriptive names / handle duplicates
|
||||
// like us.anthropic.claude-3-7-sonnet-20250219-v1:0 and anthropic.claude-3-7-sonnet-20250219-v1:0
|
||||
name: modelConfiguration.name,
|
||||
value: modelConfiguration.name,
|
||||
})
|
||||
)}
|
||||
includeDefault
|
||||
maxHeight="max-h-56"
|
||||
/>
|
||||
) : (
|
||||
<TextFormField
|
||||
name="fast_default_model_name"
|
||||
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
|
||||
for this provider. If \`Default\` is specified, will use
|
||||
the Default Model configured above.`}
|
||||
label="[Optional] Fast Model"
|
||||
placeholder="E.g. gpt-4"
|
||||
/>
|
||||
))}
|
||||
label="[Optional] Fast Model"
|
||||
placeholder="E.g. gpt-4"
|
||||
/>
|
||||
))}
|
||||
|
||||
{modelConfigurationsForOptions.length > 0 && (
|
||||
<div className="w-full">
|
||||
<MultiSelectField
|
||||
selectedInitially={
|
||||
formikProps.values.selected_model_names ?? []
|
||||
}
|
||||
name="selected_model_names"
|
||||
label="Display Models"
|
||||
subtext="Select the models to make available to users. Unselected models will not be available."
|
||||
options={modelConfigurationsForOptions.map(
|
||||
(modelConfiguration) => ({
|
||||
value: modelConfiguration.name,
|
||||
// don't clean up names here to give admins descriptive names / handle duplicates
|
||||
// like us.anthropic.claude-3-7-sonnet-20250219-v1:0 and anthropic.claude-3-7-sonnet-20250219-v1:0
|
||||
label: modelConfiguration.name,
|
||||
})
|
||||
)}
|
||||
onChange={(selected) =>
|
||||
formikProps.setFieldValue(
|
||||
"selected_model_names",
|
||||
selected
|
||||
)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
|
||||
<>
|
||||
<Separator />
|
||||
@@ -414,32 +488,6 @@ export function LLMProviderUpdateForm({
|
||||
/>
|
||||
{showAdvancedOptions && (
|
||||
<>
|
||||
{llmProviderDescriptor.model_configurations.length > 0 && (
|
||||
<div className="w-full">
|
||||
<MultiSelectField
|
||||
selectedInitially={
|
||||
formikProps.values.selected_model_names ?? []
|
||||
}
|
||||
name="selected_model_names"
|
||||
label="Display Models"
|
||||
subtext="Select the models to make available to users. Unselected models will not be available."
|
||||
options={llmProviderDescriptor.model_configurations.map(
|
||||
(modelConfiguration) => ({
|
||||
value: modelConfiguration.name,
|
||||
// don't clean up names here to give admins descriptive names / handle duplicates
|
||||
// like us.anthropic.claude-3-7-sonnet-20250219-v1:0 and anthropic.claude-3-7-sonnet-20250219-v1:0
|
||||
label: modelConfiguration.name,
|
||||
})
|
||||
)}
|
||||
onChange={(selected) =>
|
||||
formikProps.setFieldValue(
|
||||
"selected_model_names",
|
||||
selected
|
||||
)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<IsPublicGroupSelector
|
||||
formikProps={formikProps}
|
||||
objectName="LLM Provider"
|
||||
|
||||
@@ -36,6 +36,7 @@ export interface WellKnownLLMProviderDescriptor {
|
||||
default_fast_model: string | null;
|
||||
is_public: boolean;
|
||||
groups: number[];
|
||||
has_curated_models: boolean;
|
||||
}
|
||||
|
||||
export interface LLMModelDescriptor {
|
||||
@@ -59,6 +60,7 @@ export interface LLMProvider {
|
||||
default_vision_model: string | null;
|
||||
is_default_vision_provider: boolean | null;
|
||||
model_configurations: ModelConfiguration[];
|
||||
use_recommended_models: boolean | null;
|
||||
}
|
||||
|
||||
export interface LLMProviderView extends LLMProvider {
|
||||
|
||||
@@ -823,26 +823,3 @@ export function getDisplayNameForModel(modelName: string): string {
|
||||
|
||||
return MODEL_DISPLAY_NAMES[modelName] || modelName;
|
||||
}
|
||||
|
||||
export const defaultModelsByProvider: { [name: string]: string[] } = {
|
||||
openai: [
|
||||
"gpt-4",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4.1",
|
||||
"o3-mini",
|
||||
"o1-mini",
|
||||
"o1",
|
||||
"o4-mini",
|
||||
"o3",
|
||||
],
|
||||
bedrock: [
|
||||
"meta.llama3-1-70b-instruct-v1:0",
|
||||
"meta.llama3-1-8b-instruct-v1:0",
|
||||
"anthropic.claude-3-opus-20240229-v1:0",
|
||||
"mistral.mistral-large-2402-v1:0",
|
||||
"anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
"anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||
],
|
||||
anthropic: ["claude-3-opus-20240229", "claude-3-5-sonnet-20241022"],
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user