Compare commits

...

70 Commits

Author SHA1 Message Date
Dane Urban
376a2aee52 . 2026-01-30 21:22:42 -08:00
Dane Urban
8cac462353 nit 2026-01-30 21:10:18 -08:00
Dane Urban
1b849f6df6 . 2026-01-30 19:31:00 -08:00
Dane Urban
d105b96c97 . 2026-01-30 19:02:52 -08:00
Dane Urban
7cd555653f . 2026-01-30 18:57:38 -08:00
Dane Urban
38fb3f96ec resolve mcs 2026-01-30 18:51:35 -08:00
Dane Urban
339a111a8f . 2026-01-30 18:19:03 -08:00
Dane Urban
89696a7f57 Merge branch 'main' into persona_port 2026-01-30 18:08:08 -08:00
Dane Urban
09b7e6fc9b fix revision id 2026-01-30 17:39:02 -08:00
Dane Urban
135238014f Merge branch 'main' into flow_mapping_table 2026-01-30 17:38:20 -08:00
Dane Urban
303e37bf53 migrate 2026-01-30 17:38:15 -08:00
Dane Urban
6a888e9900 nit 2026-01-30 17:01:22 -08:00
Dane Urban
e90a7767c6 nit 2026-01-30 15:35:31 -08:00
Dane Urban
108e913017 resolve mc 2026-01-30 14:27:37 -08:00
Dane Urban
1ded3af63c nit 2026-01-30 14:22:27 -08:00
Dane Urban
c53546c000 nit 2026-01-30 13:03:05 -08:00
Dane Urban
9afa12edda nit 2026-01-30 13:02:48 -08:00
Dane Urban
32046de962 nit 2026-01-30 13:01:36 -08:00
Dane Urban
bdf93704db . 2026-01-30 11:57:24 -08:00
Dane Urban
81b4dbb5b0 resolve mcs 2026-01-30 10:52:51 -08:00
Dane Urban
1124a220cd nit 2026-01-29 17:31:44 -08:00
Dane Urban
f4e14bfa06 . 2026-01-29 16:48:03 -08:00
Dane Urban
02d29a1a83 . 2026-01-29 16:02:55 -08:00
Dane Urban
709261aec9 nit 2026-01-29 14:12:42 -08:00
Dane Urban
ec9c494a15 nit 2026-01-29 13:58:28 -08:00
Dane Urban
f160692d19 fix bugs 2026-01-29 13:56:22 -08:00
Dane Urban
631dcf8222 . 2026-01-29 13:47:01 -08:00
Dane Urban
334ae3ab17 . 2026-01-29 13:19:18 -08:00
Dane Urban
6904ee5de0 . 2026-01-29 13:16:25 -08:00
Dane Urban
4d913d7c3d nit 2026-01-29 13:02:07 -08:00
Dane Urban
d7f67d92b8 nit 2026-01-29 10:07:24 -08:00
Dane Urban
bc52fcc4ce nit 2026-01-28 21:44:35 -08:00
Dane Urban
6ec04dc446 nit 2026-01-28 21:07:56 -08:00
Dane Urban
2b90453748 nit 2026-01-28 20:51:49 -08:00
Dane Urban
db04ade260 line removal 2026-01-28 20:24:37 -08:00
Dane Urban
da31f91631 nits 2026-01-28 20:18:30 -08:00
Dane Urban
dbaefae827 nits 2026-01-28 19:51:50 -08:00
Dane Urban
fd407e5657 . 2026-01-28 17:16:34 -08:00
Dane Urban
65a18a720e Merge branch 'persona_port' into config_change 2026-01-28 16:12:02 -08:00
Dane Urban
a1ed112972 wqMerge branch 'main' into persona_port 2026-01-28 16:11:42 -08:00
Dane Urban
40221968a7 nit 2026-01-28 16:07:01 -08:00
Dane Urban
ae325a8b01 nit 2026-01-28 13:39:30 -08:00
Dane Urban
9579dd4be4 resolve mc 2026-01-28 13:36:00 -08:00
Dane Urban
a461ca9710 nit 2026-01-28 13:27:20 -08:00
Dane Urban
767dba1096 nit 2026-01-28 12:40:32 -08:00
Dane Urban
44ed8632a5 Merge branch 'main' into persona_port 2026-01-28 12:35:49 -08:00
Dane Urban
8d12302378 fixes 2026-01-28 12:35:09 -08:00
Dane Urban
0b00d56e18 undo 2026-01-28 12:22:09 -08:00
Dane Urban
590d844ea7 update revises 2026-01-28 12:12:55 -08:00
Dane Urban
cfa0f6e4fa . 2026-01-28 12:07:38 -08:00
Dane Urban
81829025b7 Resolve mc 2026-01-28 12:05:33 -08:00
Dane Urban
a24270c13b . 2026-01-28 11:54:33 -08:00
Dane Urban
cd62834902 nit 2026-01-28 11:50:47 -08:00
Dane Urban
0e93bd2d6f . 2026-01-28 10:46:45 -08:00
Dane Urban
5fde172464 n 2026-01-28 09:57:32 -08:00
Dane Urban
0593107528 nit 2026-01-28 09:49:14 -08:00
Dane Urban
4e2b9e3a32 update downgrade 2026-01-28 09:15:56 -08:00
Dane Urban
6b9870375c Update interface 2026-01-28 09:10:56 -08:00
Dane Urban
62ee239a1a nit 2026-01-27 23:23:22 -08:00
Dane Urban
2932fe948d . 2026-01-27 20:14:00 -08:00
Dane Urban
07b2b4e0ff . 2026-01-27 20:10:54 -08:00
Dane Urban
6b1005240c undo 2026-01-27 20:03:36 -08:00
Dane Urban
cd54511605 . 2026-01-27 20:01:28 -08:00
Dane Urban
e556a13f46 . 2026-01-27 17:34:59 -08:00
Dane Urban
acd81c99d4 nit 2026-01-27 16:04:33 -08:00
Dane Urban
d2eccbb0e0 nit 2026-01-27 16:01:15 -08:00
Dane Urban
2383caaccd nits 2026-01-27 15:40:24 -08:00
Dane Urban
3559e7f2bf nit 2026-01-27 11:40:55 -08:00
Dane Urban
d10895336c nit 2026-01-27 10:13:53 -08:00
Dane Urban
94873c8a8e nit 2026-01-26 16:50:32 -08:00
46 changed files with 1299 additions and 590 deletions

View File

@@ -0,0 +1,107 @@
"""Generalise model config and differentiate via flow
Revision ID: 0c5fbcd15bdd
Revises: f220515df7b4
Create Date: 2026-01-26 14:43:16.932376
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "0c5fbcd15bdd"
down_revision = "f220515df7b4"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add each model config to the text flow, setting the global default if it exists
# Exclude models that are part of ImageGenerationConfig
op.execute(
"""
INSERT INTO flow_mapping (flow_type, is_default, model_configuration_id)
SELECT
'text' AS flow_type,
COALESCE(
(lp.is_default_provider IS TRUE AND lp.default_model_name = mc.name),
FALSE
) AS is_default,
mc.id AS model_configuration_id
FROM model_configuration mc
LEFT JOIN llm_provider lp
ON lp.id = mc.llm_provider_id
WHERE NOT EXISTS (
SELECT 1 FROM image_generation_config igc
WHERE igc.model_configuration_id = mc.id
);
"""
)
# Add vision models to the vision flow
op.execute(
"""
INSERT INTO flow_mapping (flow_type, is_default, model_configuration_id)
SELECT
'vision' AS flow_type,
COALESCE(
(lp.is_default_vision_provider IS TRUE AND lp.default_vision_model = mc.name),
FALSE
) AS is_default,
mc.id AS model_configuration_id
FROM model_configuration mc
LEFT JOIN llm_provider lp
ON lp.id = mc.llm_provider_id;
"""
)
def downgrade() -> None:
# Populate vision defaults from flow_mapping
op.execute(
"""
UPDATE llm_provider AS lp
SET
is_default_vision_provider = TRUE,
default_vision_model = mc.name
FROM flow_mapping fm
JOIN model_configuration mc ON mc.id = fm.model_configuration_id
WHERE fm.flow_type = 'vision'
AND fm.is_default = TRUE
AND mc.llm_provider_id = lp.id;
"""
)
# Populate text defaults from flow_mapping
op.execute(
"""
UPDATE llm_provider AS lp
SET
is_default_provider = TRUE,
default_model_name = mc.name
FROM flow_mapping fm
JOIN model_configuration mc ON mc.id = fm.model_configuration_id
WHERE fm.flow_type = 'text'
AND fm.is_default = TRUE
AND mc.llm_provider_id = lp.id;
"""
)
# For providers that have text flow mappings but aren't the default,
# we still need a default_model_name (it was NOT NULL originally)
# Pick the first visible model or any model for that provider
op.execute(
"""
UPDATE llm_provider AS lp
SET default_model_name = (
SELECT mc.name
FROM model_configuration mc
JOIN flow_mapping fm ON fm.model_configuration_id = mc.id
WHERE mc.llm_provider_id = lp.id
AND fm.flow_type = 'text'
ORDER BY mc.is_visible DESC, mc.id ASC
LIMIT 1
)
WHERE lp.default_model_name IS NULL;
"""
)

View File

@@ -0,0 +1,97 @@
"""Persona uses model configuration id
Revision ID: 1b5455f99105
Revises: e7f8a9b0c1d2
Create Date: 2026-01-27 19:11:34.510574
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "1b5455f99105"
down_revision = "e7f8a9b0c1d2"
branch_labels = None
depends_on = None
def upgrade() -> None:
# We need to migrate from llm_model_provider_override and llm_model_version_override to default_model_configuration_id
# Migration Strategy:
# 1. Port over where provider_override + model_version are both present
# 2. Port over where only provider is present (Provider default)
# 3. Where only the model is provided, or neither provider or model are provided, we go for global default
conn = op.get_bind()
# Strategy 1: Both provider_override and model_version are present
# Match against model_configuration using provider name and model name
conn.execute(
sa.text(
"""
UPDATE persona
SET default_model_configuration_id = mc.id
FROM model_configuration mc
JOIN llm_provider lp ON mc.llm_provider_id = lp.id
WHERE persona.llm_model_provider_override IS NOT NULL
AND persona.llm_model_version_override IS NOT NULL
AND lp.name = persona.llm_model_provider_override
AND mc.name = persona.llm_model_version_override
"""
)
)
# Strategy 2: Only provider is present (use Provider's default model)
conn.execute(
sa.text(
"""
UPDATE persona
SET default_model_configuration_id = mc.id
FROM model_configuration mc
JOIN llm_provider lp ON mc.llm_provider_id = lp.id
WHERE persona.llm_model_provider_override IS NOT NULL
AND persona.llm_model_version_override IS NULL
AND persona.default_model_configuration_id IS NULL
AND lp.name = persona.llm_model_provider_override
AND mc.name = lp.default_model_name
"""
)
)
# For remaining personas with only model set but no match found,
# fall back to global default (provider's default model where is_default_provider = true)
conn.execute(
sa.text(
"""
UPDATE persona
SET default_model_configuration_id = mc.id
FROM model_configuration mc
JOIN llm_provider lp ON mc.llm_provider_id = lp.id
WHERE persona.llm_model_provider_override IS NULL
AND persona.llm_model_version_override IS NOT NULL
AND persona.default_model_configuration_id IS NULL
AND lp.is_default_provider = true
AND mc.name = lp.default_model_name
"""
)
)
def downgrade() -> None:
# Migrate data back from default_model_configuration_id to old columns
conn = op.get_bind()
conn.execute(
sa.text(
"""
UPDATE persona
SET llm_model_provider_override = lp.name,
llm_model_version_override = mc.name
FROM model_configuration mc
JOIN llm_provider lp ON mc.llm_provider_id = lp.id
WHERE persona.default_model_configuration_id IS NOT NULL
AND persona.default_model_configuration_id = mc.id
"""
)
)

View File

@@ -0,0 +1,57 @@
"""Add flow mapping table
Revision ID: f220515df7b4
Revises: be87a654d5af
Create Date: 2026-01-30 12:21:24.955922
"""
from onyx.db.enums import ModelFlowType
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "f220515df7b4"
down_revision = "be87a654d5af"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"model_flow",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"model_flow_type",
sa.Enum(ModelFlowType, name="modelflowtype", native_enum=False),
nullable=False,
),
sa.Column(
"is_default", sa.Boolean(), nullable=False, server_default=sa.text("false")
),
sa.Column("model_configuration_id", sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
["model_configuration_id"], ["model_configuration.id"], ondelete="CASCADE"
),
sa.UniqueConstraint(
"model_flow_type",
"model_configuration_id",
name="uq_model_config_per_flow_type",
),
)
# Partial unique index so that there is at most one default for each flow type
op.create_index(
"ix_one_default_per_model_flow",
"model_flow",
["model_flow_type"],
unique=True,
postgresql_where=sa.text("is_default IS TRUE"),
)
def downgrade() -> None:
# Drop the model_flow table (index is dropped automatically with table)
op.drop_table("model_flow")

View File

@@ -20,7 +20,7 @@ from ee.onyx.server.enterprise_settings.store import (
from ee.onyx.server.enterprise_settings.store import upload_logo
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.llm import update_default_provider
from onyx.db.llm import update_default_text_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import Tool
from onyx.db.persona import upsert_persona
@@ -123,9 +123,12 @@ def _seed_llms(
upsert_llm_provider(llm_upsert_request, db_session)
for llm_upsert_request in llm_upsert_requests
]
update_default_provider(
provider_id=seeded_providers[0].id, db_session=db_session
)
if len(seeded_providers[0].model_configurations) > 0:
update_default_text_provider(
provider_id=seeded_providers[0].id,
model=seeded_providers[0].model_configurations[0].name,
db_session=db_session,
)
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:
@@ -144,8 +147,7 @@ def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) ->
llm_filter_extraction=persona.llm_filter_extraction,
recency_bias=RecencyBiasSetting.AUTO,
document_set_ids=persona.document_set_ids,
llm_model_provider_override=persona.llm_model_provider_override,
llm_model_version_override=persona.llm_model_version_override,
default_model_configuration_id=persona.default_model_configuration_id,
starter_messages=persona.starter_messages,
is_public=persona.is_public,
db_session=db_session,

View File

@@ -33,7 +33,7 @@ from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.image_generation import create_default_image_gen_config_from_api_key
from onyx.db.llm import update_default_provider
from onyx.db.llm import update_default_text_provider
from onyx.db.llm import upsert_cloud_embedding_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import AvailableTenant
@@ -300,12 +300,14 @@ def configure_default_api_keys(db_session: Session) -> None:
has_set_default_provider = False
def _upsert(request: LLMProviderUpsertRequest) -> None:
def _upsert(request: LLMProviderUpsertRequest, default_model: str) -> None:
nonlocal has_set_default_provider
try:
provider = upsert_llm_provider(request, db_session)
if not has_set_default_provider:
update_default_provider(provider.id, db_session)
update_default_text_provider(
provider_id=provider.id, model=default_model, db_session=db_session
)
has_set_default_provider = True
except Exception as e:
logger.error(f"Failed to configure {request.provider} provider: {e}")
@@ -330,7 +332,7 @@ def configure_default_api_keys(db_session: Session) -> None:
api_key_changed=True,
is_auto_mode=True,
)
_upsert(openai_provider)
_upsert(openai_provider, default_model_name)
# Create default image generation config using the OpenAI API key
try:
@@ -366,7 +368,7 @@ def configure_default_api_keys(db_session: Session) -> None:
api_key_changed=True,
is_auto_mode=True,
)
_upsert(anthropic_provider)
_upsert(anthropic_provider, default_model_name)
else:
logger.info(
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
@@ -398,7 +400,7 @@ def configure_default_api_keys(db_session: Session) -> None:
api_key_changed=True,
is_auto_mode=True,
)
_upsert(vertexai_provider)
_upsert(vertexai_provider, default_model_name)
else:
logger.info(
"VERTEXAI_DEFAULT_CREDENTIALS not set, skipping Vertex AI provider configuration"
@@ -435,7 +437,7 @@ def configure_default_api_keys(db_session: Session) -> None:
api_key_changed=True,
is_auto_mode=True,
)
_upsert(openrouter_provider)
_upsert(openrouter_provider, default_model_name)
else:
logger.info(
"OPENROUTER_DEFAULT_API_KEY not set, skipping OpenRouter provider configuration"

View File

@@ -293,8 +293,7 @@ def create_temporary_persona(
llm_relevance_filter=persona_config.llm_relevance_filter,
llm_filter_extraction=persona_config.llm_filter_extraction,
recency_bias=RecencyBiasSetting.BASE_DECAY,
llm_model_provider_override=persona_config.llm_model_provider_override,
llm_model_version_override=persona_config.llm_model_version_override,
default_model_configuration_id=persona_config.default_model_configuration_id,
)
if persona_config.prompts:

View File

@@ -118,8 +118,7 @@ class PersonaOverrideConfig(BaseModel):
num_chunks: float | None = None
llm_relevance_filter: bool = False
llm_filter_extraction: bool = False
llm_model_provider_override: str | None = None
llm_model_version_override: str | None = None
default_model_configuration_id: int | None = None
prompts: list[PromptOverrideConfig] = Field(default_factory=list)
# Note: prompt_ids removed - prompts are now embedded in personas

View File

@@ -269,3 +269,9 @@ class HierarchyNodeType(str, PyEnum):
# Slack
CHANNEL = "channel"
class ModelFlowType(str, PyEnum):
CONVERSATION = "conversation"
VISION = "vision"
EMBEDDINGS = "embeddings"

View File

@@ -1,9 +1,13 @@
from collections.abc import Callable
from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.db.enums import ModelFlowType
from onyx.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
from onyx.db.models import DocumentSet
from onyx.db.models import ImageGenerationConfig
@@ -11,6 +15,7 @@ from onyx.db.models import LLMProvider as LLMProviderModel
from onyx.db.models import LLMProvider__Persona
from onyx.db.models import LLMProvider__UserGroup
from onyx.db.models import ModelConfiguration
from onyx.db.models import ModelFlow
from onyx.db.models import Persona
from onyx.db.models import SearchSettings
from onyx.db.models import Tool as ToolModel
@@ -20,8 +25,10 @@ from onyx.llm.utils import model_supports_image_input
from onyx.llm.well_known_providers.auto_update_models import LLMRecommendations
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from onyx.server.manage.llm.models import DefaultModel
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationView
from shared_configs.enums import EmbeddingProvider
@@ -163,15 +170,16 @@ def validate_persona_ids_exist(
def get_personas_using_provider(
db_session: Session, provider_name: str
db_session: Session, provider: LLMProviderModel
) -> list[Persona]:
"""Get all non-deleted personas that use a specific LLM provider."""
return list(
db_session.scalars(
select(Persona).where(
Persona.llm_model_provider_override == provider_name,
Persona.deleted == False, # noqa: E712
select(Persona)
.join(
ModelConfiguration,
Persona.default_model_configuration_id == ModelConfiguration.id,
)
.where(ModelConfiguration.llm_provider_id == provider.id)
).all()
)
@@ -233,9 +241,6 @@ def upsert_llm_provider(
existing_llm_provider.api_base = llm_provider_upsert_request.api_base
existing_llm_provider.api_version = llm_provider_upsert_request.api_version
existing_llm_provider.custom_config = custom_config
existing_llm_provider.default_model_name = (
llm_provider_upsert_request.default_model_name
)
existing_llm_provider.is_public = llm_provider_upsert_request.is_public
existing_llm_provider.is_auto_mode = llm_provider_upsert_request.is_auto_mode
existing_llm_provider.deployment_name = llm_provider_upsert_request.deployment_name
@@ -263,17 +268,14 @@ def upsert_llm_provider(
model_provider=llm_provider_upsert_request.provider,
)
db_session.execute(
insert(ModelConfiguration)
.values(
llm_provider_id=existing_llm_provider.id,
name=model_configuration.name,
is_visible=model_configuration.is_visible,
max_input_tokens=max_input_tokens,
supports_image_input=model_configuration.supports_image_input,
display_name=model_configuration.display_name,
)
.on_conflict_do_nothing()
insert_new_model_configuration__no_commit(
db_session=db_session,
llm_provider_id=existing_llm_provider.id,
model_name=model_configuration.name,
is_visible=model_configuration.is_visible,
max_input_tokens=max_input_tokens,
supports_image_input=model_configuration.supports_image_input or False,
display_name=model_configuration.display_name,
)
# Make sure the relationship table stays up to date
@@ -331,23 +333,19 @@ def sync_model_configurations(
model_name = model["name"]
if model_name not in existing_names:
# Insert new model with is_visible=False (user must explicitly enable)
db_session.execute(
insert(ModelConfiguration)
.values(
llm_provider_id=provider.id,
name=model_name,
is_visible=False,
max_input_tokens=model.get("max_input_tokens"),
supports_image_input=model.get("supports_image_input", False),
display_name=model.get("display_name"),
)
.on_conflict_do_nothing()
insert_new_model_configuration__no_commit(
db_session=db_session,
llm_provider_id=provider.id,
model_name=model_name,
is_visible=False,
max_input_tokens=model.get("max_input_tokens"),
supports_image_input=model.get("supports_image_input", False),
display_name=model.get("display_name"),
)
new_count += 1
if new_count > 0:
db_session.commit()
return new_count
@@ -371,38 +369,85 @@ def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolM
)
def fetch_existing_llm_providers(
def fetch_existing_llm_providers_supporting_flows(
db_session: Session,
flows: list[ModelFlowType],
only_public: bool = False,
exclude_image_generation_providers: bool = True,
) -> list[LLMProviderModel]:
"""Fetch all LLM providers with optional filtering.
"""Fetch LLM providers that have at least one model supporting one of the specified flows.
Args:
db_session: Database session
flows: List of flow types to filter by (providers must have at least one model
with a ModelFlow matching any of these flow types)
only_public: If True, only return public providers
exclude_image_generation_providers: If True, exclude providers that are
used for image generation configs
exclude_image_generation_providers: If True, only match on flows.
If False, include image generation providers in addition to
flow-filtered providers.
Returns:
List of LLM providers matching the criteria
"""
stmt = select(LLMProviderModel).options(
# Subquery to find provider IDs that have at least one model supporting any of the flows
providers_with_flows = (
select(ModelConfiguration.llm_provider_id)
.join(ModelFlow)
.where(ModelFlow.model_flow_type.in_(flows))
.distinct()
)
if exclude_image_generation_providers:
# Only include providers with matching flows
stmt = select(LLMProviderModel).where(
LLMProviderModel.id.in_(providers_with_flows)
)
else:
# Include providers with matching flows OR image generation providers
image_gen_provider_ids = select(ModelConfiguration.llm_provider_id).join(
ImageGenerationConfig
)
stmt = select(LLMProviderModel).where(
LLMProviderModel.id.in_(providers_with_flows)
| LLMProviderModel.id.in_(image_gen_provider_ids)
)
stmt = stmt.options(
selectinload(LLMProviderModel.model_configurations),
selectinload(LLMProviderModel.groups),
selectinload(LLMProviderModel.personas),
)
if exclude_image_generation_providers:
# Get LLM provider IDs used by ImageGenerationConfig
image_gen_provider_ids = select(ModelConfiguration.llm_provider_id).join(
ImageGenerationConfig
)
stmt = stmt.where(LLMProviderModel.id.not_in(image_gen_provider_ids))
providers = list(db_session.scalars(stmt).all())
if only_public:
return [provider for provider in providers if provider.is_public]
return providers
def fetch_existing_model_configs_for_flow(
db_session: Session,
flows: list[ModelFlowType],
) -> list[ModelConfiguration]:
"""Fetch all model configurations that support any of the specified flow types.
Args:
db_session: Database session
flows: List of flow types to filter by
Returns:
List of model configurations that have a ModelFlow matching any of the flows
"""
stmt = (
select(ModelConfiguration)
.join(ModelFlow)
.where(ModelFlow.model_flow_type.in_(flows))
.options(selectinload(ModelConfiguration.llm_provider))
)
return list(db_session.scalars(stmt).all())
def fetch_existing_llm_provider(
name: str, db_session: Session
) -> LLMProviderModel | None:
@@ -419,6 +464,12 @@ def fetch_existing_llm_provider(
return provider_model
def fetch_existing_llm_provider_by_id(
id: int, db_session: Session
) -> LLMProviderModel | None:
return db_session.scalar(select(LLMProviderModel).where(LLMProviderModel.id == id))
def fetch_embedding_provider(
db_session: Session, provider_type: EmbeddingProvider
) -> CloudEmbeddingProviderModel | None:
@@ -429,26 +480,23 @@ def fetch_embedding_provider(
)
def fetch_default_provider(db_session: Session) -> LLMProviderView | None:
provider_model = db_session.scalar(
select(LLMProviderModel)
.where(LLMProviderModel.is_default_provider == True) # noqa: E712
.options(selectinload(LLMProviderModel.model_configurations))
def fetch_default_model(
db_session: Session, flow_type: ModelFlowType
) -> DefaultModel | None:
flow_mapping = db_session.scalar(
select(ModelFlow).where(
ModelFlow.model_flow_type == flow_type,
ModelFlow.is_default == True, # noqa: E712
)
)
if not provider_model:
if not flow_mapping:
return None
return LLMProviderView.from_model(provider_model)
def fetch_default_vision_provider(db_session: Session) -> LLMProviderView | None:
provider_model = db_session.scalar(
select(LLMProviderModel)
.where(LLMProviderModel.is_default_vision_provider == True) # noqa: E712
.options(selectinload(LLMProviderModel.model_configurations))
return DefaultModel(
provider_id=flow_mapping.model_configuration.llm_provider.id,
model_name=flow_mapping.model_configuration.name,
)
if not provider_model:
return None
return LLMProviderView.from_model(provider_model)
def fetch_llm_provider_view(
@@ -462,6 +510,48 @@ def fetch_llm_provider_view(
return LLMProviderView.from_model(provider_model)
def fetch_llm_provider_view_from_id(
db_session: Session, provider_id: int
) -> LLMProviderView | None:
provider_model = fetch_existing_llm_provider_by_id(
id=provider_id, db_session=db_session
)
if not provider_model:
return None
return LLMProviderView.from_model(provider_model)
def fetch_model_configuration_view(
db_session: Session, model_configuration_id: int
) -> ModelConfigurationView | None:
model_configuration_model = db_session.scalar(
select(ModelConfiguration).where(
ModelConfiguration.id == model_configuration_id
)
)
if not model_configuration_model:
return None
return ModelConfigurationView.from_model(
model_configuration_model=model_configuration_model,
provider=model_configuration_model.llm_provider.provider,
)
def fetch_llm_provider_view_from_model_id(
db_session: Session, model_configuration_id: int
) -> LLMProviderView | None:
model_configuration_model = db_session.scalar(
select(ModelConfiguration).where(
ModelConfiguration.id == model_configuration_id
)
)
if not model_configuration_model:
return None
return LLMProviderView.from_model(model_configuration_model.llm_provider)
def remove_embedding_provider(
db_session: Session, provider_type: EmbeddingProvider
) -> None:
@@ -479,16 +569,18 @@ def remove_embedding_provider(
db_session.commit()
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
def remove_llm_provider(
db_session: Session, provider_id: int, commit: bool = True
) -> None:
provider = db_session.get(LLMProviderModel, provider_id)
if not provider:
raise ValueError("LLM Provider not found")
# Clear the provider override from any personas using it
# This causes them to fall back to the default provider
personas_using_provider = get_personas_using_provider(db_session, provider.name)
personas_using_provider = get_personas_using_provider(db_session, provider)
for persona in personas_using_provider:
persona.llm_model_provider_override = None
persona.default_model_configuration_id = None
db_session.execute(
delete(LLMProvider__UserGroup).where(
@@ -499,88 +591,86 @@ def remove_llm_provider(db_session: Session, provider_id: int) -> None:
db_session.execute(
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)
)
db_session.commit()
def remove_llm_provider__no_commit(db_session: Session, provider_id: int) -> None:
"""Remove LLM provider."""
provider = db_session.get(LLMProviderModel, provider_id)
if not provider:
raise ValueError("LLM Provider not found")
# Clear the provider override from any personas using it
# This causes them to fall back to the default provider
personas_using_provider = get_personas_using_provider(db_session, provider.name)
for persona in personas_using_provider:
persona.llm_model_provider_override = None
db_session.execute(
delete(LLMProvider__UserGroup).where(
LLMProvider__UserGroup.llm_provider_id == provider_id
)
)
# Remove LLMProvider
db_session.execute(
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)
)
db_session.flush()
def update_default_provider(provider_id: int, db_session: Session) -> None:
new_default = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.id == provider_id)
)
if not new_default:
raise ValueError(f"LLM Provider with id {provider_id} does not exist")
existing_default = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.is_default_provider == True # noqa: E712
)
)
if existing_default:
existing_default.is_default_provider = None
# required to ensure that the below does not cause a unique constraint violation
if commit:
db_session.commit()
else:
db_session.flush()
new_default.is_default_provider = True
def _update_default_provider(
provider_id: int,
model: str,
db_session: Session,
flow_type: ModelFlowType,
validation_func: (
Callable[[ModelConfiguration], str] | None
) = None, # Returns error message
) -> None:
# Single query with join to get both model_config and its flow mapping
result = db_session.execute(
select(ModelConfiguration, ModelFlow)
.join(ModelFlow, ModelFlow.model_configuration_id == ModelConfiguration.id)
.where(
ModelConfiguration.llm_provider_id == provider_id,
ModelConfiguration.name == model,
ModelFlow.model_flow_type == flow_type,
)
).first()
if not result:
raise ValueError(
f"Model '{model}' is not a valid TEXT model for provider '{provider_id}'"
)
model_config, new_default = result
if validation_func and (error_message := validation_func(model_config)):
raise ValueError(error_message)
# Clear existing default and set now in a single atomic operation
db_session.execute(
update(ModelFlow)
.where(
ModelFlow.model_flow_type == flow_type,
ModelFlow.is_default == True, # noqa: E712
)
.values(is_default=False)
)
new_default.is_default = True
model_config.is_visible = True
db_session.commit()
def update_default_text_provider(
provider_id: int, model: str, db_session: Session
) -> None:
_update_default_provider(
provider_id=provider_id,
model=model,
db_session=db_session,
flow_type=ModelFlowType.CONVERSATION,
)
def update_default_vision_provider(
provider_id: int, vision_model: str | None, db_session: Session
provider_id: int, vision_model: str, db_session: Session
) -> None:
new_default = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.id == provider_id)
def validate_vision_model(model_config: ModelConfiguration) -> str:
if not model_supports_image_input(
vision_model, model_config.llm_provider.provider
):
return f"Model '{vision_model}' for provider '{model_config.llm_provider.provider}' does not support image input"
return ""
_update_default_provider(
provider_id=provider_id,
model=vision_model,
db_session=db_session,
flow_type=ModelFlowType.VISION,
validation_func=validate_vision_model,
)
if not new_default:
raise ValueError(f"LLM Provider with id {provider_id} does not exist")
# Validate that the specified vision model supports image input
model_to_validate = vision_model or new_default.default_model_name
if model_to_validate:
if not model_supports_image_input(model_to_validate, new_default.provider):
raise ValueError(
f"Model '{model_to_validate}' for provider '{new_default.provider}' does not support image input"
)
else:
raise ValueError(
f"Model '{vision_model}' is not a valid model for provider '{new_default.provider}'"
)
existing_default = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.is_default_vision_provider == True # noqa: E712
)
)
if existing_default:
existing_default.is_default_vision_provider = None
# required to ensure that the below does not cause a unique constraint violation
db_session.flush()
new_default.is_default_vision_provider = True
new_default.default_vision_model = vision_model
db_session.commit()
def fetch_auto_mode_providers(db_session: Session) -> list[LLMProviderModel]:
@@ -671,10 +761,84 @@ def sync_auto_mode_models(
changes += 1
# In Auto mode, default model is always set from GitHub config
default_model = llm_recommendations.get_default_model(provider.provider)
if default_model and provider.default_model_name != default_model.name:
provider.default_model_name = default_model.name
# Check if this provider is the set default provider
default_model = fetch_default_model(db_session, ModelFlowType.CONVERSATION)
recommended = llm_recommendations.get_default_model(provider.provider)
if (
default_model
and default_model.provider_id == provider.id
and recommended
and recommended.name != default_model.model_name
):
update_default_text_provider(provider.id, recommended.name, db_session)
changes += 1
db_session.commit()
return changes
def create_new_flow_mapping__no_commit(
db_session: Session,
model_configuration_id: int,
flow_type: ModelFlowType,
is_default: bool = False,
) -> ModelFlow:
"""Create a new flow mapping without committing.
Uses flush() to get the ID while keeping the transaction open,
allowing callers like upsert_llm_provider to maintain atomicity.
"""
flow_mapping = ModelFlow(
model_configuration_id=model_configuration_id,
flow_type=flow_type,
is_default=is_default,
)
db_session.add(flow_mapping)
db_session.flush()
db_session.refresh(flow_mapping)
return flow_mapping
def insert_new_model_configuration__no_commit(
db_session: Session,
llm_provider_id: int,
model_name: str,
is_visible: bool,
max_input_tokens: int | None,
supports_image_input: bool,
display_name: str | None,
) -> int | None:
result = db_session.execute(
insert(ModelConfiguration)
.values(
llm_provider_id=llm_provider_id,
name=model_name,
is_visible=is_visible,
max_input_tokens=max_input_tokens,
display_name=display_name,
)
.on_conflict_do_nothing()
.returning(ModelConfiguration.id)
)
model_config_id = result.scalar_one_or_none()
if not model_config_id:
return None
create_new_flow_mapping__no_commit(
db_session=db_session,
model_configuration_id=model_config_id,
flow_type=ModelFlowType.CONVERSATION,
is_default=False,
)
if supports_image_input:
create_new_flow_mapping__no_commit(
db_session=db_session,
model_configuration_id=model_config_id,
flow_type=ModelFlowType.VISION,
is_default=False,
)
return model_config_id

View File

@@ -7,6 +7,7 @@ from uuid import uuid4
from pydantic import BaseModel
from sqlalchemy.orm import validates
from typing_extensions import TypedDict # noreorder
from uuid import UUID
from pydantic import ValidationError
@@ -70,6 +71,7 @@ from onyx.db.enums import (
MCPAuthenticationPerformer,
MCPTransport,
MCPServerStatus,
ModelFlowType,
ThemePreference,
SwitchoverType,
)
@@ -2613,14 +2615,16 @@ class LLMProvider(Base):
custom_config: Mapped[dict[str, str] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
# Deprecated: use ModelFlow.is_default
default_model_name: Mapped[str] = mapped_column(String)
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
# should only be set for a single provider
# Deprecated: Use ModelFlows instead
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
is_default_vision_provider: Mapped[bool | None] = mapped_column(Boolean)
default_vision_model: Mapped[str | None] = mapped_column(String, nullable=True)
# EE only
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
# Auto mode: models, visibility, and defaults are managed by GitHub config
@@ -2670,6 +2674,7 @@ class ModelConfiguration(Base):
# - The end-user is configuring a model and chooses not to set a max-input-tokens limit.
max_input_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Deprecated: Use ModelFlows instead
supports_image_input: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
# Human-readable display name for the model.
@@ -2682,6 +2687,51 @@ class ModelConfiguration(Base):
back_populates="model_configurations",
)
model_flows: Mapped[list["ModelFlow"]] = relationship(
"ModelFlow",
back_populates="model_configuration",
cascade="all, delete-orphan",
passive_deletes=True,
)
@property
def model_flow_types(self) -> list[ModelFlowType]:
return [flow.model_flow_type for flow in self.model_flows]
class ModelFlow(Base):
__tablename__ = "model_flow"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
model_flow_type: Mapped[ModelFlowType] = mapped_column(
Enum(ModelFlowType), nullable=False
)
model_configuration_id: Mapped[int] = mapped_column(
ForeignKey("model_configuration.id", ondelete="CASCADE"),
nullable=False,
)
is_default: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
model_configuration: Mapped["ModelConfiguration"] = relationship(
"ModelConfiguration",
back_populates="model_flows",
)
__table_args__ = (
UniqueConstraint(
"model_flow_type",
"model_configuration_id",
name="uq_model_config_per_flow_type",
),
Index(
"ix_one_default_per_model_flow",
"model_flow_type",
unique=True,
postgresql_where=(is_default == True), # noqa: E712
),
)
class ImageGenerationConfig(Base):
__tablename__ = "image_generation_config"
@@ -3013,6 +3063,12 @@ class Persona(Base):
# Allows the persona to specify a specific default LLM model
# NOTE: only is applied on the actual response generation - is not used for things like
# auto-detected time filters, relevance filters, etc.
default_model_configuration_id: Mapped[int | None] = mapped_column(
Integer,
ForeignKey("model_configuration.id", ondelete="SET NULL"),
nullable=True,
)
# llm_model_provider_override and llm_model_version_override are deprecated and will be removed in a future release
llm_model_provider_override: Mapped[str | None] = mapped_column(
String, nullable=True
)

View File

@@ -284,8 +284,7 @@ def create_update_persona(
tool_ids=create_persona_request.tool_ids,
is_public=create_persona_request.is_public,
recency_bias=create_persona_request.recency_bias,
llm_model_provider_override=create_persona_request.llm_model_provider_override,
llm_model_version_override=create_persona_request.llm_model_version_override,
default_model_configuration_id=create_persona_request.default_model_configuration_id,
starter_messages=create_persona_request.starter_messages,
system_prompt=create_persona_request.system_prompt,
task_prompt=create_persona_request.task_prompt,
@@ -817,8 +816,7 @@ def upsert_persona(
llm_relevance_filter: bool,
llm_filter_extraction: bool,
recency_bias: RecencyBiasSetting,
llm_model_provider_override: str | None,
llm_model_version_override: str | None,
default_model_configuration_id: int | None,
starter_messages: list[StarterMessage] | None,
# Embedded prompt fields
system_prompt: str | None,
@@ -958,8 +956,7 @@ def upsert_persona(
existing_persona.llm_relevance_filter = llm_relevance_filter
existing_persona.llm_filter_extraction = llm_filter_extraction
existing_persona.recency_bias = recency_bias
existing_persona.llm_model_provider_override = llm_model_provider_override
existing_persona.llm_model_version_override = llm_model_version_override
existing_persona.default_model_configuration_id = default_model_configuration_id
existing_persona.starter_messages = starter_messages
existing_persona.deleted = False # Un-delete if previously deleted
existing_persona.is_public = is_public
@@ -1032,8 +1029,7 @@ def upsert_persona(
datetime_aware=(datetime_aware if datetime_aware is not None else True),
replace_base_system_prompt=replace_base_system_prompt,
document_sets=document_sets or [],
llm_model_provider_override=llm_model_provider_override,
llm_model_version_override=llm_model_version_override,
default_model_configuration_id=default_model_configuration_id,
starter_messages=starter_messages,
tools=tools or [],
uploaded_image_id=uploaded_image_id,

View File

@@ -79,8 +79,7 @@ def create_slack_channel_persona(
recency_bias=RecencyBiasSetting.AUTO,
tool_ids=[search_tool.id],
document_set_ids=document_set_ids,
llm_model_provider_override=None,
llm_model_version_override=None,
default_model_configuration_id=None,
starter_messages=None,
is_public=True,
is_default_persona=False,

View File

@@ -5,22 +5,24 @@ from sqlalchemy.orm import Session
from onyx.chat.models import PersonaOverrideConfig
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import ModelFlowType
from onyx.db.llm import can_user_access_llm_provider
from onyx.db.llm import fetch_default_provider
from onyx.db.llm import fetch_default_vision_provider
from onyx.db.llm import fetch_default_model
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_existing_model_configs_for_flow
from onyx.db.llm import fetch_llm_provider_view
from onyx.db.llm import fetch_llm_provider_view_from_id
from onyx.db.llm import fetch_llm_provider_view_from_model_id
from onyx.db.llm import fetch_model_configuration_view
from onyx.db.llm import fetch_user_group_ids
from onyx.db.models import LLMProvider
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.llm.constants import LlmProviderNames
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMConfig
from onyx.llm.multi_llm import LitellmLLM
from onyx.llm.override_models import LLMOverride
from onyx.llm.utils import get_max_input_tokens_from_llm_provider
from onyx.llm.utils import model_supports_image_input
from onyx.llm.well_known_providers.constants import OLLAMA_API_KEY_CONFIG_KEY
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.manage.llm.models import LLMProviderView
@@ -53,54 +55,6 @@ def _build_provider_extra_headers(
return {}
def get_llm_config_for_persona(
persona: Persona,
db_session: Session,
llm_override: LLMOverride | None = None,
) -> LLMConfig:
"""Get LLM config from persona without access checks.
This function assumes access to the persona has already been verified.
Use this when you need the LLM config but don't need to create the full LLM object.
"""
provider_name_override = llm_override.model_provider if llm_override else None
model_version_override = llm_override.model_version if llm_override else None
temperature_override = llm_override.temperature if llm_override else None
provider_name = provider_name_override or persona.llm_model_provider_override
if not provider_name:
llm_provider = fetch_default_provider(db_session)
if not llm_provider:
raise ValueError("No default LLM provider found")
model_name: str | None = llm_provider.default_model_name
else:
llm_provider = fetch_llm_provider_view(db_session, provider_name)
if not llm_provider:
raise ValueError(f"No LLM provider found with name: {provider_name}")
model_name = model_version_override or persona.llm_model_version_override
if not model_name:
model_name = llm_provider.default_model_name
if not model_name:
raise ValueError("No model name found")
max_input_tokens = get_max_input_tokens_from_llm_provider(
llm_provider=llm_provider, model_name=model_name
)
return LLMConfig(
model_provider=llm_provider.provider,
model_name=model_name,
temperature=temperature_override or GEN_AI_TEMPERATURE,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
deployment_name=llm_provider.deployment_name,
custom_config=llm_provider.custom_config,
max_input_tokens=max_input_tokens,
)
def get_llm_for_persona(
persona: Persona | PersonaOverrideConfig | None,
user: User,
@@ -108,6 +62,11 @@ def get_llm_for_persona(
additional_headers: dict[str, str] | None = None,
long_term_logger: LongTermLogger | None = None,
) -> LLM:
"""Get the appropriate LLM for a persona, with the following priority:
1. LLM override (provider + model version)
2. Persona's model configuration override
3. Default LLM
"""
if persona is None:
logger.warning("No persona provided, using default LLM")
return get_default_llm()
@@ -116,8 +75,7 @@ def get_llm_for_persona(
model_version_override = llm_override.model_version if llm_override else None
temperature_override = llm_override.temperature if llm_override else None
provider_name = provider_name_override or persona.llm_model_provider_override
if not provider_name:
if not provider_name_override and not persona.default_model_configuration_id:
return get_default_llm(
temperature=temperature_override or GEN_AI_TEMPERATURE,
additional_headers=additional_headers,
@@ -125,7 +83,11 @@ def get_llm_for_persona(
)
with get_session_with_current_tenant() as db_session:
provider_model = fetch_existing_llm_provider(provider_name, db_session)
# Resolve the provider model
# Atleast one of the vars in non-None due to above check, so we should get something
provider_model = _resolve_provider_model(
db_session, provider_name_override, persona.default_model_configuration_id
)
if not provider_model:
raise ValueError("No LLM provider found")
@@ -155,7 +117,14 @@ def get_llm_for_persona(
llm_provider = LLMProviderView.from_model(provider_model)
model = model_version_override or persona.llm_model_version_override
# Resolve the model name
model = model_version_override
if not model and persona.default_model_configuration_id:
model_config = fetch_model_configuration_view(
db_session, persona.default_model_configuration_id
)
model = model_config.name if model_config else None
if not model:
raise ValueError("No model name found")
@@ -176,6 +145,25 @@ def get_llm_for_persona(
)
def _resolve_provider_model(
db_session: Session,
provider_name_override: str | None,
model_config_id_override: int | None,
) -> LLMProvider | None:
"""Resolve the LLM provider model from overrides."""
if provider_name_override:
return fetch_existing_llm_provider(provider_name_override, db_session)
if model_config_id_override:
llm_view = fetch_llm_provider_view_from_model_id(
db_session, model_config_id_override
)
if llm_view:
return fetch_existing_llm_provider(llm_view.name, db_session)
return None
def get_default_llm_with_vision(
timeout: int | None = None,
temperature: float | None = None,
@@ -210,48 +198,40 @@ def get_default_llm_with_vision(
with get_session_with_current_tenant() as db_session:
# Try the default vision provider first
default_provider = fetch_default_vision_provider(db_session)
if default_provider and default_provider.default_vision_model:
if model_supports_image_input(
default_provider.default_vision_model, default_provider.provider
):
return create_vision_llm(
default_provider, default_provider.default_vision_model
)
default_model = fetch_default_model(
db_session=db_session, flow_type=ModelFlowType.VISION
)
if default_model:
provider_view = fetch_llm_provider_view_from_id(
db_session, default_model.provider_id
)
if provider_view:
return create_vision_llm(provider_view, default_model.model_name)
# Fall back to searching all vision models
models = fetch_existing_model_configs_for_flow(
db_session=db_session,
flows=[ModelFlowType.VISION],
)
# Fall back to searching all providers
providers = fetch_existing_llm_providers(db_session)
if not providers:
if not models:
return None
# Check all providers for viable vision models
for provider in providers:
provider_view = LLMProviderView.from_model(provider)
# Check for viable vision models
non_public_vision_llm: LLM | None = None
# First priority: Check if provider has a default_vision_model
if provider.default_vision_model and model_supports_image_input(
provider.default_vision_model, provider.provider
):
return create_vision_llm(provider_view, provider.default_vision_model)
for model in models:
if model.is_visible:
return create_vision_llm(
provider=LLMProviderView.from_model(model.llm_provider),
model=model.name,
)
elif not non_public_vision_llm:
non_public_vision_llm = create_vision_llm(
provider=LLMProviderView.from_model(model.llm_provider),
model=model.name,
)
# If no model-configurations are specified, try default model
if not provider.model_configurations:
# Try default_model_name
if provider.default_model_name and model_supports_image_input(
provider.default_model_name, provider.provider
):
return create_vision_llm(provider_view, provider.default_model_name)
# Otherwise, if model-configurations are specified, check each model
else:
for model_configuration in provider.model_configurations:
if model_supports_image_input(
model_configuration.name, provider.provider
):
return create_vision_llm(provider_view, model_configuration.name)
return None
return non_public_vision_llm
def llm_from_provider(
@@ -298,18 +278,25 @@ def get_default_llm(
long_term_logger: LongTermLogger | None = None,
) -> LLM:
with get_session_with_current_tenant() as db_session:
llm_provider = fetch_default_provider(db_session)
default_model = fetch_default_model(
db_session=db_session, flow_type=ModelFlowType.CONVERSATION
)
if not llm_provider:
raise ValueError("No default LLM provider found")
if not default_model:
raise ValueError("No default LLM provider found")
model_name = llm_provider.default_model_name
if not model_name:
raise ValueError("No default model name found")
llm_provider_view = fetch_llm_provider_view_from_id(
db_session, default_model.provider_id
)
if not llm_provider_view:
raise ValueError(
"No LLM provider found with id {}".format(default_model.provider_id)
)
return llm_from_provider(
model_name=model_name,
llm_provider=llm_provider,
model_name=default_model.model_name,
llm_provider=llm_provider_view,
timeout=timeout,
temperature=temperature,
additional_headers=additional_headers,

View File

@@ -17,6 +17,7 @@ from onyx.configs.model_configs import GEN_AI_MAX_TOKENS
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from onyx.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import ModelFlowType
from onyx.db.models import LLMProvider
from onyx.db.models import ModelConfiguration
from onyx.llm.constants import LlmProviderNames
@@ -689,8 +690,8 @@ def model_supports_image_input(model_name: str, model_provider: str) -> bool:
LLMProvider.provider == model_provider,
)
)
if model_config and model_config.supports_image_input is not None:
return model_config.supports_image_input
if model_config:
return ModelFlowType.VISION in model_config.model_flow_types
except Exception as e:
logger.warning(
f"Failed to query database for {model_provider} model {model_name} image support: {e}"

View File

@@ -18,7 +18,7 @@ from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OPENROUTER_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import VERTEXAI_PROVIDER_NAME
from onyx.llm.well_known_providers.models import WellKnownLLMProviderDescriptor
from onyx.server.manage.llm.models import ModelConfigurationView
from onyx.server.manage.llm.models import ModelConfiguration
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -212,11 +212,11 @@ def get_vertexai_model_names() -> list[str]:
def model_configurations_for_provider(
provider_name: str, llm_recommendations: LLMRecommendations
) -> list[ModelConfigurationView]:
) -> list[ModelConfiguration]:
recommended_visible_models = llm_recommendations.get_visible_models(provider_name)
recommended_visible_models_names = [m.name for m in recommended_visible_models]
return [
ModelConfigurationView(
ModelConfiguration(
name=model_name,
is_visible=model_name in recommended_visible_models_names,
max_input_tokens=get_max_input_tokens(model_name, provider_name),

View File

@@ -3,7 +3,7 @@ from enum import Enum
from pydantic import BaseModel
from pydantic import Field
from onyx.server.manage.llm.models import ModelConfigurationView
from onyx.server.manage.llm.models import ModelConfiguration
class CustomConfigKeyType(str, Enum):
@@ -28,5 +28,5 @@ class WellKnownLLMProviderDescriptor(BaseModel):
name: str
# NOTE: the recommended visible models are encoded in the known_models list
known_models: list[ModelConfigurationView] = Field(default_factory=list)
known_models: list[ModelConfiguration] = Field(default_factory=list)
recommended_default_model: SimpleKnownModel | None = None

View File

@@ -27,8 +27,10 @@ from sqlalchemy.orm import Session as DBSession
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import MessageType
from onyx.db.enums import ModelFlowType
from onyx.db.enums import SandboxStatus
from onyx.db.llm import fetch_default_provider
from onyx.db.llm import fetch_default_model
from onyx.db.llm import fetch_existing_llm_provider_by_id
from onyx.db.models import BuildMessage
from onyx.db.models import BuildSession
from onyx.db.models import User
@@ -338,15 +340,25 @@ class SessionManager:
)
# Fallback to system default
default_provider = fetch_default_provider(self._db_session)
if not default_provider:
default_model = fetch_default_model(
db_session=self._db_session, flow_type=ModelFlowType.CONVERSATION
)
if not default_model:
raise ValueError("No LLM provider configured")
provider = fetch_existing_llm_provider_by_id(
id=default_model.provider_id,
db_session=self._db_session,
)
if not provider:
raise ValueError("No LLM provider found")
return LLMProviderConfig(
provider=default_provider.provider,
model_name=default_provider.default_model_name,
api_key=default_provider.api_key,
api_base=default_provider.api_base,
provider=provider.provider,
model_name=default_model.model_name,
api_key=provider.api_key,
api_base=provider.api_base,
)
# =========================================================================

View File

@@ -18,6 +18,7 @@ from onyx.configs.constants import FileOrigin
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import PUBLIC_API_TAGS
from onyx.db.engine.sql_engine import get_session
from onyx.db.llm import fetch_model_configuration_view
from onyx.db.models import User
from onyx.db.persona import create_assistant_label
from onyx.db.persona import create_update_persona
@@ -48,7 +49,7 @@ from onyx.server.features.persona.models import PersonaLabelCreate
from onyx.server.features.persona.models import PersonaLabelResponse
from onyx.server.features.persona.models import PersonaSnapshot
from onyx.server.features.persona.models import PersonaUpsertRequest
from onyx.server.manage.llm.api import get_valid_model_names_for_persona
from onyx.server.manage.llm.api import get_valid_models_for_persona
from onyx.server.models import DisplayPriorityRequest
from onyx.server.settings.store import load_settings
from onyx.utils.logger import setup_logger
@@ -479,13 +480,21 @@ def get_persona(
)
# Validate and fix default model if it's no longer valid for this persona's restrictions
if persona.llm_model_version_override:
valid_models = get_valid_model_names_for_persona(persona_id, user, db_session)
if persona.default_model_configuration_id:
valid_models = get_valid_models_for_persona(persona_id, user, db_session)
model_configuration = fetch_model_configuration_view(
db_session=db_session,
model_configuration_id=persona.default_model_configuration_id,
)
# If current default model is not in the valid list, update to first valid or None
if persona.llm_model_version_override not in valid_models:
persona.llm_model_version_override = (
valid_models[0] if valid_models else None
if not (
model_configuration
and model_configuration.id in [m.id for m in valid_models]
):
persona.default_model_configuration_id = (
valid_models[0].id if valid_models else None
)
db_session.commit()

View File

@@ -109,8 +109,7 @@ class PersonaUpsertRequest(BaseModel):
recency_bias: RecencyBiasSetting
llm_filter_extraction: bool
llm_relevance_filter: bool
llm_model_provider_override: str | None = None
llm_model_version_override: str | None = None
default_model_configuration_id: int | None = None
starter_messages: list[StarterMessage] | None = None
# For Private Personas, who should be able to access these
users: list[UUID] = Field(default_factory=list)
@@ -162,8 +161,7 @@ class MinimalPersonaSnapshot(BaseModel):
# Unique sources from all knowledge (document sets + hierarchy nodes)
# Used to populate source filters in chat
knowledge_sources: list[DocumentSource]
llm_model_version_override: str | None
llm_model_provider_override: str | None
default_model_configuration_id: int | None
uploaded_image_id: str | None
icon_name: str | None
@@ -219,8 +217,7 @@ class MinimalPersonaSnapshot(BaseModel):
hierarchy_node_count=len(persona.hierarchy_nodes),
attached_document_count=len(persona.attached_documents),
knowledge_sources=list(sources),
llm_model_version_override=persona.llm_model_version_override,
llm_model_provider_override=persona.llm_model_provider_override,
default_model_configuration_id=persona.default_model_configuration_id,
uploaded_image_id=persona.uploaded_image_id,
icon_name=persona.icon_name,
is_public=persona.is_public,
@@ -259,8 +256,7 @@ class PersonaSnapshot(BaseModel):
users: list[MinimalUserSnapshot]
groups: list[int]
document_sets: list[DocumentSetSummary]
llm_model_provider_override: str | None
llm_model_version_override: str | None
default_model_configuration_id: int | None
num_chunks: float | None
# Hierarchy nodes attached for scoped search
hierarchy_nodes: list[HierarchyNodeSnapshot] = Field(default_factory=list)
@@ -318,8 +314,7 @@ class PersonaSnapshot(BaseModel):
DocumentSetSummary.from_model(document_set_model)
for document_set_model in persona.document_sets
],
llm_model_provider_override=persona.llm_model_provider_override,
llm_model_version_override=persona.llm_model_version_override,
default_model_configuration_id=persona.default_model_configuration_id,
num_chunks=persona.num_chunks,
system_prompt=persona.system_prompt,
replace_base_system_prompt=persona.replace_base_system_prompt,
@@ -391,8 +386,7 @@ class FullPersonaSnapshot(PersonaSnapshot):
search_start_date=persona.search_start_date,
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
llm_model_provider_override=persona.llm_model_provider_override,
llm_model_version_override=persona.llm_model_version_override,
default_model_configuration_id=persona.default_model_configuration_id,
system_prompt=persona.system_prompt,
replace_base_system_prompt=persona.replace_base_system_prompt,
task_prompt=persona.task_prompt,

View File

@@ -141,8 +141,7 @@ def enable_or_disable_kg(
recency_bias=RecencyBiasSetting.NO_DECAY,
document_set_ids=[],
tool_ids=[search_tool.id, kg_tool.id],
llm_model_provider_override=None,
llm_model_version_override=None,
default_model_configuration_id=None,
starter_messages=None,
users=[user.id],
groups=[],

View File

@@ -11,7 +11,7 @@ from onyx.db.image_generation import get_all_image_generation_configs
from onyx.db.image_generation import get_image_generation_config
from onyx.db.image_generation import set_default_image_generation_config
from onyx.db.image_generation import unset_default_image_generation_config
from onyx.db.llm import remove_llm_provider__no_commit
from onyx.db.llm import remove_llm_provider
from onyx.db.models import LLMProvider as LLMProviderModel
from onyx.db.models import ModelConfiguration
from onyx.db.models import User
@@ -93,7 +93,6 @@ def _build_llm_provider_request(
api_key=source_provider.api_key, # Only this from source
api_base=api_base, # From request
api_version=api_version, # From request
default_model_name=model_name,
deployment_name=deployment_name, # From request
is_public=True,
groups=[],
@@ -132,7 +131,6 @@ def _build_llm_provider_request(
api_key=api_key,
api_base=api_base,
api_version=api_version,
default_model_name=model_name,
deployment_name=deployment_name,
is_public=True,
groups=[],
@@ -164,7 +162,6 @@ def _create_image_gen_llm_provider__no_commit(
api_key=provider_request.api_key,
api_base=provider_request.api_base,
api_version=provider_request.api_version,
default_model_name=provider_request.default_model_name,
deployment_name=provider_request.deployment_name,
is_public=provider_request.is_public,
custom_config=provider_request.custom_config,
@@ -457,7 +454,9 @@ def update_config(
existing_config.model_configuration_id = new_model_config_id
# 5. Delete old LLM provider (safe now - nothing references it)
remove_llm_provider__no_commit(db_session, old_llm_provider_id)
remove_llm_provider(
db_session=db_session, provider_id=old_llm_provider_id, commit=False
)
db_session.commit()
db_session.refresh(existing_config)
@@ -491,7 +490,9 @@ def delete_config(
delete_image_generation_config__no_commit(db_session, image_provider_id)
# Clean up the orphaned LLM provider (it was exclusively for image gen)
remove_llm_provider__no_commit(db_session, llm_provider_id)
remove_llm_provider(
db_session=db_session, provider_id=llm_provider_id, commit=False
)
db_session.commit()
except HTTPException:

View File

@@ -19,14 +19,17 @@ from onyx.auth.schemas import UserRole
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accessible_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import ModelFlowType
from onyx.db.llm import can_user_access_llm_provider
from onyx.db.llm import fetch_default_model
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_existing_llm_providers_supporting_flows
from onyx.db.llm import fetch_existing_model_configs_for_flow
from onyx.db.llm import fetch_persona_with_groups
from onyx.db.llm import fetch_user_group_ids
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import sync_model_configurations
from onyx.db.llm import update_default_provider
from onyx.db.llm import update_default_text_provider
from onyx.db.llm import update_default_vision_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.llm import validate_persona_ids_exist
@@ -37,7 +40,6 @@ from onyx.llm.factory import get_llm
from onyx.llm.factory import get_max_input_tokens_from_llm_provider
from onyx.llm.utils import get_bedrock_token_limit
from onyx.llm.utils import get_llm_contextual_cost
from onyx.llm.utils import model_supports_image_input
from onyx.llm.utils import test_llm
from onyx.llm.well_known_providers.auto_update_service import (
fetch_llm_recommendations_from_github,
@@ -50,11 +52,13 @@ from onyx.llm.well_known_providers.llm_provider_options import (
)
from onyx.server.manage.llm.models import BedrockFinalModelResponse
from onyx.server.manage.llm.models import BedrockModelsRequest
from onyx.server.manage.llm.models import DefaultModel
from onyx.server.manage.llm.models import LLMCost
from onyx.server.manage.llm.models import LLMProviderDescriptor
from onyx.server.manage.llm.models import LLMProviderResponse
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.server.manage.llm.models import ModelConfigurationView
from onyx.server.manage.llm.models import OllamaFinalModelResponse
from onyx.server.manage.llm.models import OllamaModelDetails
from onyx.server.manage.llm.models import OllamaModelsRequest
@@ -211,7 +215,7 @@ def test_llm_configuration(
llm = get_llm(
provider=test_llm_request.provider,
model=test_llm_request.default_model_name,
model=test_llm_request.model,
api_key=test_api_key,
api_base=test_llm_request.api_base,
api_version=test_llm_request.api_version,
@@ -246,13 +250,15 @@ def list_llm_providers(
include_image_gen: bool = Query(False),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderView]:
) -> LLMProviderResponse[LLMProviderView]:
start_time = datetime.now(timezone.utc)
logger.debug("Starting to fetch LLM providers")
llm_provider_list: list[LLMProviderView] = []
for llm_provider_model in fetch_existing_llm_providers(
db_session, exclude_image_generation_providers=not include_image_gen
for llm_provider_model in fetch_existing_llm_providers_supporting_flows(
db_session=db_session,
flows=[ModelFlowType.CONVERSATION, ModelFlowType.VISION],
exclude_image_generation_providers=not include_image_gen,
):
from_model_start = datetime.now(timezone.utc)
full_llm_provider = LLMProviderView.from_model(llm_provider_model)
@@ -269,7 +275,18 @@ def list_llm_providers(
duration = (end_time - start_time).total_seconds()
logger.debug(f"Completed fetching LLM providers in {duration:.2f} seconds")
return llm_provider_list
default_text = fetch_default_model(
db_session=db_session, flow_type=ModelFlowType.CONVERSATION
)
default_vision = fetch_default_model(
db_session=db_session, flow_type=ModelFlowType.VISION
)
return LLMProviderResponse[LLMProviderView].from_models(
default_text=default_text,
default_vision=default_vision,
providers=llm_provider_list,
)
@admin_router.put("/provider")
@@ -328,20 +345,6 @@ def put_llm_provider(
deduplicated_personas.append(persona_id)
llm_provider_upsert_request.personas = deduplicated_personas
default_model_found = False
for model_configuration in llm_provider_upsert_request.model_configurations:
if model_configuration.name == llm_provider_upsert_request.default_model_name:
model_configuration.is_visible = True
default_model_found = True
if not default_model_found:
llm_provider_upsert_request.model_configurations.append(
ModelConfigurationUpsertRequest(
name=llm_provider_upsert_request.default_model_name, is_visible=True
)
)
# the llm api key is sanitized when returned to clients, so the only time we
# should get a real key is when it is explicitly changed
if existing_provider and not llm_provider_upsert_request.api_key_changed:
@@ -354,6 +357,28 @@ def put_llm_provider(
not existing_provider or not existing_provider.is_auto_mode
)
default_vision_model = fetch_default_model(
db_session=db_session, flow_type=ModelFlowType.VISION
)
# Check that we're not disabling vision on the default vision model
if (
default_vision_model
and existing_provider
and default_vision_model.provider_id == existing_provider.id
and any(
m.name == default_vision_model.model_name and not m.supports_image_input
for m in llm_provider_upsert_request.model_configurations
)
):
raise HTTPException(
status_code=409,
detail=(
f"Cannot disable vision for '{default_vision_model.model_name}' because it is the default vision model."
"Select another model as default first."
),
)
try:
result = upsert_llm_provider(
llm_provider_upsert_request=llm_provider_upsert_request,
@@ -393,31 +418,51 @@ def delete_llm_provider(
db_session: Session = Depends(get_session),
) -> None:
try:
default_model = fetch_default_model(
db_session=db_session, flow_type=ModelFlowType.CONVERSATION
)
if default_model and default_model.provider_id == provider_id:
raise HTTPException(
status_code=409,
detail="Cannot delete default provider. Select another provider as default",
)
default_vision_model = fetch_default_model(
db_session=db_session, flow_type=ModelFlowType.VISION
)
if default_vision_model and default_vision_model.provider_id == provider_id:
raise HTTPException(
status_code=409,
detail="Cannot delete default vision provider. Select another provider as default",
)
remove_llm_provider(db_session, provider_id)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@admin_router.post("/provider/{provider_id}/default")
def set_provider_as_default(
provider_id: int,
_: User = Depends(current_admin_user),
@admin_router.post("/default")
def set_default_model(
default_model_request: DefaultModel,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_default_provider(provider_id=provider_id, db_session=db_session)
update_default_text_provider(
provider_id=default_model_request.provider_id,
model=default_model_request.model_name,
db_session=db_session,
)
@admin_router.post("/provider/{provider_id}/default-vision")
def set_provider_as_default_vision(
provider_id: int,
vision_model: str | None = Query(
None, description="The default vision model to use"
),
_: User = Depends(current_admin_user),
@admin_router.post("/default-vision")
def set_default_vision_model(
default_vision_request: DefaultModel,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_default_vision_provider(
provider_id=provider_id, vision_model=vision_model, db_session=db_session
provider_id=default_vision_request.provider_id,
vision_model=default_vision_request.model_name,
db_session=db_session,
)
@@ -446,40 +491,34 @@ def get_vision_capable_providers(
) -> list[VisionProviderResponse]:
"""Return a list of LLM providers and their models that support image input"""
providers = fetch_existing_llm_providers(db_session)
vision_providers = []
vision_providers = fetch_existing_llm_providers_supporting_flows(
db_session=db_session,
flows=[ModelFlowType.VISION],
)
logger.info("Fetching vision-capable providers")
response: list[VisionProviderResponse] = []
for provider in providers:
vision_models = []
for provider in vision_providers:
provider_view = LLMProviderView.from_model(provider)
_mask_provider_credentials(provider_view)
# Check each model for vision capability
for model_configuration in provider.model_configurations:
if model_supports_image_input(model_configuration.name, provider.provider):
vision_models.append(model_configuration.name)
logger.debug(
f"Vision model found: {provider.provider}/{model_configuration.name}"
)
# Only include providers with at least one vision-capable model
if vision_models:
provider_view = LLMProviderView.from_model(provider)
_mask_provider_credentials(provider_view)
vision_providers.append(
VisionProviderResponse(
**provider_view.model_dump(),
vision_models=vision_models,
)
response.append(
VisionProviderResponse(
**provider_view.model_dump(),
vision_models=[
model.name
for model in provider.model_configurations
if model.supports_image_input
],
)
)
logger.info(
f"Vision provider: {provider.provider} with models: {vision_models}"
)
logger.info(
f"Vision provider: {provider.provider} with models: {response[-1].vision_models}"
)
logger.info(f"Found {len(vision_providers)} vision-capable providers")
return vision_providers
return response
"""Endpoints for all"""
@@ -489,7 +528,7 @@ def get_vision_capable_providers(
def list_llm_provider_basics(
user: User = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderDescriptor]:
) -> LLMProviderResponse[LLMProviderDescriptor]:
"""Get LLM providers accessible to the current user.
Returns:
@@ -502,7 +541,10 @@ def list_llm_provider_basics(
start_time = datetime.now(timezone.utc)
logger.debug("Starting to fetch user-accessible LLM providers")
all_providers = fetch_existing_llm_providers(db_session)
all_providers = fetch_existing_llm_providers_supporting_flows(
db_session=db_session,
flows=[ModelFlowType.CONVERSATION, ModelFlowType.VISION],
)
user_group_ids = fetch_user_group_ids(db_session, user)
is_admin = user.role == UserRole.ADMIN
@@ -526,14 +568,25 @@ def list_llm_provider_basics(
f"Completed fetching {len(accessible_providers)} user-accessible providers in {duration:.2f} seconds"
)
return accessible_providers
default_text = fetch_default_model(
db_session=db_session, flow_type=ModelFlowType.CONVERSATION
)
default_vision = fetch_default_model(
db_session=db_session, flow_type=ModelFlowType.VISION
)
return LLMProviderResponse[LLMProviderDescriptor].from_models(
providers=accessible_providers,
default_text=default_text,
default_vision=default_vision,
)
def get_valid_model_names_for_persona(
def get_valid_models_for_persona(
persona_id: int,
user: User,
db_session: Session,
) -> list[str]:
) -> list[ModelConfigurationView]:
"""Get all valid model names that a user can access for this persona.
Returns a list of model names (e.g., ["gpt-4o", "claude-3-5-sonnet"]) that are
@@ -545,7 +598,10 @@ def get_valid_model_names_for_persona(
return []
is_admin = user.role == UserRole.ADMIN
all_providers = fetch_existing_llm_providers(db_session)
all_providers = fetch_existing_llm_providers_supporting_flows(
db_session=db_session,
flows=[ModelFlowType.CONVERSATION],
)
user_group_ids = set() if is_admin else fetch_user_group_ids(db_session, user)
valid_models = []
@@ -557,7 +613,12 @@ def get_valid_model_names_for_persona(
# Collect all model names from this provider
for model_config in llm_provider_model.model_configurations:
if model_config.is_visible:
valid_models.append(model_config.name)
valid_models.append(
ModelConfigurationView.from_model(
model_configuration_model=model_config,
provider=llm_provider_model.provider,
)
)
return valid_models
@@ -592,7 +653,10 @@ def list_llm_providers_for_persona(
)
is_admin = user.role == UserRole.ADMIN
all_providers = fetch_existing_llm_providers(db_session)
all_providers = fetch_existing_llm_providers_supporting_flows(
db_session=db_session,
flows=[ModelFlowType.CONVERSATION],
)
user_group_ids = set() if is_admin else fetch_user_group_ids(db_session, user)
llm_provider_list: list[LLMProviderDescriptor] = []
@@ -630,32 +694,34 @@ def get_provider_contextual_cost(
- the chunk_context
- The per-token cost of the LLM used to generate the doc_summary and chunk_context
"""
providers = fetch_existing_llm_providers(db_session)
models = fetch_existing_model_configs_for_flow(
db_session=db_session,
flows=[ModelFlowType.CONVERSATION],
)
costs = []
for provider in providers:
for model_configuration in provider.model_configurations:
llm_provider = LLMProviderView.from_model(provider)
llm = get_llm(
provider=provider.provider,
model=model_configuration.name,
deployment_name=provider.deployment_name,
api_key=provider.api_key,
api_base=provider.api_base,
api_version=provider.api_version,
custom_config=provider.custom_config,
max_input_tokens=get_max_input_tokens_from_llm_provider(
llm_provider=llm_provider, model_name=model_configuration.name
),
)
cost = get_llm_contextual_cost(llm)
costs.append(
LLMCost(
provider=provider.name,
model_name=model_configuration.name,
cost=cost,
)
)
for model in models:
llm_provider = LLMProviderView.from_model(model.llm_provider)
llm = get_llm(
provider=model.llm_provider.provider,
model=model.name,
deployment_name=model.llm_provider.deployment_name,
api_key=model.llm_provider.api_key,
api_base=model.llm_provider.api_base,
api_version=model.llm_provider.api_version,
custom_config=model.llm_provider.custom_config,
max_input_tokens=get_max_input_tokens_from_llm_provider(
llm_provider=llm_provider, model_name=model.name
),
)
cost = get_llm_contextual_cost(llm)
costs.append(
LLMCost(
provider=model.llm_provider.provider,
model_name=model.name,
cost=cost,
)
)
return costs

View File

@@ -1,10 +1,15 @@
from __future__ import annotations
from typing import Any
from typing import Generic
from typing import TYPE_CHECKING
from typing import TypeVar
from pydantic import BaseModel
from pydantic import Field
from pydantic import field_validator
from onyx.db.enums import ModelFlowType
from onyx.llm.utils import get_max_input_tokens
from onyx.llm.utils import litellm_thinks_model_supports_image_input
from onyx.llm.utils import model_is_reasoning_model
@@ -21,22 +26,23 @@ if TYPE_CHECKING:
)
T = TypeVar("T", bound="LLMProviderDescriptor | LLMProviderView")
class TestLLMRequest(BaseModel):
# provider level
name: str | None = None
name: str
provider: str
model: str
api_key: str | None = None
api_base: str | None = None
api_version: str | None = None
custom_config: dict[str, str] | None = None
# model level
default_model_name: str
deployment_name: str | None = None
model_configurations: list["ModelConfigurationUpsertRequest"]
# if try and use the existing API/custom config key
# if try and use the existing API key
api_key_changed: bool
custom_config_changed: bool
@@ -54,10 +60,6 @@ class LLMProviderDescriptor(BaseModel):
name: str
provider: str
provider_display_name: str # Human-friendly name like "Claude (Anthropic)"
default_model_name: str
is_default_provider: bool | None
is_default_vision_provider: bool | None
default_vision_model: str | None
model_configurations: list["ModelConfigurationView"]
@classmethod
@@ -75,10 +77,6 @@ class LLMProviderDescriptor(BaseModel):
name=llm_provider_model.name,
provider=provider,
provider_display_name=get_provider_display_name(provider),
default_model_name=llm_provider_model.default_model_name,
is_default_provider=llm_provider_model.is_default_provider,
is_default_vision_provider=llm_provider_model.is_default_vision_provider,
default_vision_model=llm_provider_model.default_vision_model,
model_configurations=filter_model_configurations(
llm_provider_model.model_configurations, provider
),
@@ -92,13 +90,11 @@ class LLMProvider(BaseModel):
api_base: str | None = None
api_version: str | None = None
custom_config: dict[str, str] | None = None
default_model_name: str
is_public: bool = True
is_auto_mode: bool = False
groups: list[int] = Field(default_factory=list)
personas: list[int] = Field(default_factory=list)
deployment_name: str | None = None
default_vision_model: str | None = None
class LLMProviderUpsertRequest(LLMProvider):
@@ -119,8 +115,6 @@ class LLMProviderView(LLMProvider):
"""Stripped down representation of LLMProvider for display / limited access info only"""
id: int
is_default_provider: bool | None = None
is_default_vision_provider: bool | None = None
model_configurations: list["ModelConfigurationView"]
@classmethod
@@ -150,10 +144,6 @@ class LLMProviderView(LLMProvider):
api_base=llm_provider_model.api_base,
api_version=llm_provider_model.api_version,
custom_config=llm_provider_model.custom_config,
default_model_name=llm_provider_model.default_model_name,
is_default_provider=llm_provider_model.is_default_provider,
is_default_vision_provider=llm_provider_model.is_default_vision_provider,
default_vision_model=llm_provider_model.default_vision_model,
is_public=llm_provider_model.is_public,
is_auto_mode=llm_provider_model.is_auto_mode,
groups=groups,
@@ -165,13 +155,15 @@ class LLMProviderView(LLMProvider):
)
class ModelConfigurationUpsertRequest(BaseModel):
class ModelConfiguration(BaseModel):
name: str
is_visible: bool
max_input_tokens: int | None = None
supports_image_input: bool | None = None
display_name: str | None = None # For dynamic providers, from source API
class ModelConfigurationUpsertRequest(ModelConfiguration):
@classmethod
def from_model(
cls, model_configuration_model: "ModelConfigurationModel"
@@ -180,18 +172,15 @@ class ModelConfigurationUpsertRequest(BaseModel):
name=model_configuration_model.name,
is_visible=model_configuration_model.is_visible,
max_input_tokens=model_configuration_model.max_input_tokens,
supports_image_input=model_configuration_model.supports_image_input,
supports_image_input=ModelFlowType.VISION
in (flow for flow in model_configuration_model.model_flow_types),
display_name=model_configuration_model.display_name,
)
class ModelConfigurationView(BaseModel):
name: str
is_visible: bool
max_input_tokens: int | None = None
supports_image_input: bool
class ModelConfigurationView(ModelConfiguration):
id: int
supports_reasoning: bool = False
display_name: str | None = None
provider_display_name: str | None = None
vendor: str | None = None
version: str | None = None
@@ -201,26 +190,23 @@ class ModelConfigurationView(BaseModel):
def from_model(
cls,
model_configuration_model: "ModelConfigurationModel",
provider_name: str,
provider: str,
) -> "ModelConfigurationView":
# For dynamic providers (OpenRouter, Bedrock, Ollama), use the display_name
# stored in DB from the source API. Skip LiteLLM parsing entirely.
if (
provider_name in DYNAMIC_LLM_PROVIDERS
and model_configuration_model.display_name
):
if provider in DYNAMIC_LLM_PROVIDERS and model_configuration_model.display_name:
# Extract vendor from model name for grouping (e.g., "Anthropic", "OpenAI")
vendor = extract_vendor_from_model_name(
model_configuration_model.name, provider_name
model_configuration_model.name, provider
)
return cls(
id=model_configuration_model.id,
name=model_configuration_model.name,
is_visible=model_configuration_model.is_visible,
max_input_tokens=model_configuration_model.max_input_tokens,
supports_image_input=(
model_configuration_model.supports_image_input or False
),
supports_image_input=ModelFlowType.VISION
in (flow for flow in model_configuration_model.model_flow_types),
# Infer reasoning support from model name/display name
supports_reasoning=is_reasoning_model(
model_configuration_model.name,
@@ -239,8 +225,8 @@ class ModelConfigurationView(BaseModel):
# Parse the model name to get display information
# Include provider prefix if not already present (enrichments use full keys like "vertex_ai/...")
model_name = model_configuration_model.name
if provider_name and not model_name.startswith(f"{provider_name}/"):
model_name = f"{provider_name}/{model_name}"
if provider and not model_name.startswith(f"{provider}/"):
model_name = f"{provider}/{model_name}"
parsed = parse_litellm_model_name(model_name)
# Include region in display name for Bedrock cross-region models
@@ -251,24 +237,29 @@ class ModelConfigurationView(BaseModel):
)
return cls(
id=model_configuration_model.id,
name=model_configuration_model.name,
is_visible=model_configuration_model.is_visible,
max_input_tokens=(
model_configuration_model.max_input_tokens
or get_max_input_tokens(
model_name=model_configuration_model.name,
model_provider=provider_name,
model_provider=provider,
)
),
supports_image_input=(
val
if (val := model_configuration_model.supports_image_input) is not None
if (
val := ModelFlowType.VISION
in (flow for flow in model_configuration_model.model_flow_types)
)
is not None
else litellm_thinks_model_supports_image_input(
model_configuration_model.name, provider_name
model_configuration_model.name, provider
)
),
supports_reasoning=model_is_reasoning_model(
model_configuration_model.name, provider_name
model_configuration_model.name, provider
),
# Populate display fields from parsed model name
display_name=display_name,
@@ -371,3 +362,27 @@ class OpenRouterFinalModelResponse(BaseModel):
int | None
) # From OpenRouter API context_length (may be missing for some models)
supports_image_input: bool
class DefaultModel(BaseModel):
provider_id: int
model_name: str
class LLMProviderResponse(BaseModel, Generic[T]):
providers: list[T]
default_text: DefaultModel | None = None
default_vision: DefaultModel | None = None
@classmethod
def from_models(
cls,
providers: list[T],
default_text: DefaultModel | None = None,
default_vision: DefaultModel | None = None,
) -> LLMProviderResponse[T]:
return cls(
default_text=default_text,
default_vision=default_vision,
providers=providers,
)

View File

@@ -21,10 +21,11 @@ from onyx.db.connector_credential_pair import resync_cc_pair
from onyx.db.credentials import create_initial_public_credential
from onyx.db.document import check_docs_exist
from onyx.db.enums import EmbeddingPrecision
from onyx.db.enums import ModelFlowType
from onyx.db.index_attempt import cancel_indexing_attempts_past_model
from onyx.db.index_attempt import expire_index_attempts
from onyx.db.llm import fetch_default_provider
from onyx.db.llm import update_default_provider
from onyx.db.llm import fetch_default_model
from onyx.db.llm import update_default_text_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.search_settings import get_active_search_settings
from onyx.db.search_settings import get_current_search_settings
@@ -286,7 +287,7 @@ def setup_postgres(db_session: Session) -> None:
create_initial_default_connector(db_session)
associate_default_cc_pair(db_session)
if GEN_AI_API_KEY and fetch_default_provider(db_session) is None:
if GEN_AI_API_KEY and fetch_default_model(db_session, ModelFlowType.TEXT) is None:
# Only for dev flows
logger.notice("Setting up default OpenAI LLM for dev.")
@@ -298,7 +299,6 @@ def setup_postgres(db_session: Session) -> None:
api_base=None,
api_version=None,
custom_config=None,
default_model_name=llm_model,
is_public=True,
groups=[],
model_configurations=[
@@ -310,7 +310,9 @@ def setup_postgres(db_session: Session) -> None:
new_llm_provider = upsert_llm_provider(
llm_provider_upsert_request=model_req, db_session=db_session
)
update_default_provider(provider_id=new_llm_provider.id, db_session=db_session)
update_default_text_provider(
provider_id=new_llm_provider.id, model=llm_model, db_session=db_session
)
def update_default_multipass_indexing(db_session: Session) -> None:

View File

@@ -17,7 +17,6 @@ def test_bedrock_llm_configuration(client: TestClient) -> None:
# Prepare the test request payload
test_request: dict[str, Any] = {
"provider": LlmProviderNames.BEDROCK,
"default_model_name": _DEFAULT_BEDROCK_MODEL,
"api_key": None,
"api_base": None,
"api_version": None,
@@ -43,7 +42,6 @@ def test_bedrock_llm_configuration_invalid_key(client: TestClient) -> None:
# Prepare the test request payload with invalid credentials
test_request: dict[str, Any] = {
"provider": LlmProviderNames.BEDROCK,
"default_model_name": _DEFAULT_BEDROCK_MODEL,
"api_key": None,
"api_base": None,
"api_version": None,

View File

@@ -8,7 +8,7 @@ from unittest.mock import patch
import pytest
from sqlalchemy.orm import Session
from onyx.db.llm import update_default_provider
from onyx.db.llm import update_default_text_provider
from onyx.db.llm import upsert_llm_provider
from onyx.llm.constants import LlmProviderNames
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
@@ -27,14 +27,15 @@ def ensure_default_llm_provider(db_session: Session) -> None:
provider=LlmProviderNames.OPENAI,
api_key=os.environ.get("OPENAI_API_KEY", "test"),
is_public=True,
default_model_name="gpt-4o-mini",
groups=[],
)
provider = upsert_llm_provider(
llm_provider_upsert_request=llm_provider_request,
db_session=db_session,
)
update_default_provider(provider.id, db_session)
update_default_text_provider(
provider_id=provider.id, model="gpt-4o-mini", db_session=db_session
)
except Exception as exc: # pragma: no cover - only hits on duplicate setup issues
print(f"Note: Could not create LLM provider: {exc}")

View File

@@ -10,9 +10,10 @@ from onyx.chat.models import MessageResponseIDInfo
from onyx.chat.models import StreamingError
from onyx.chat.process_message import stream_chat_message_objects
from onyx.db.chat import create_chat_session
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.enums import ModelFlowType
from onyx.db.llm import fetch_existing_llm_providers_supporting_flows
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import update_default_provider
from onyx.db.llm import update_default_text_provider
from onyx.db.llm import upsert_llm_provider
from onyx.llm.constants import LlmProviderNames
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
@@ -36,7 +37,10 @@ def test_answer_with_only_anthropic_provider(
assert anthropic_api_key, "ANTHROPIC_API_KEY environment variable must be set"
# Drop any existing providers so that only Anthropic is available.
for provider in fetch_existing_llm_providers(db_session):
for provider in fetch_existing_llm_providers_supporting_flows(
db_session=db_session,
flows=[ModelFlowType.CONVERSATION],
):
remove_llm_provider(db_session, provider.id)
anthropic_model = "claude-haiku-4-5-20251001"
@@ -47,7 +51,6 @@ def test_answer_with_only_anthropic_provider(
name=provider_name,
provider=LlmProviderNames.ANTHROPIC,
api_key=anthropic_api_key,
default_model_name=anthropic_model,
is_public=True,
groups=[],
model_configurations=[
@@ -59,7 +62,11 @@ def test_answer_with_only_anthropic_provider(
)
try:
update_default_provider(anthropic_provider.id, db_session)
update_default_text_provider(
provider_id=anthropic_provider.id,
model=anthropic_model,
db_session=db_session,
)
test_user = create_test_user(db_session, email_prefix="anthropic_only")
chat_session = create_chat_session(

View File

@@ -79,8 +79,7 @@ def test_stream_chat_message_objects_without_web_search(
llm_relevance_filter=True,
llm_filter_extraction=True,
recency_bias=RecencyBiasSetting.BASE_DECAY,
llm_model_provider_override=None,
llm_model_version_override=None,
default_model_configuration_id=None,
starter_messages=None,
system_prompt=None,
task_prompt=None,

View File

@@ -5,6 +5,8 @@ from unittest.mock import Mock
from unittest.mock import patch
from uuid import uuid4
from onyx.db.llm import update_default_text_provider
# Set environment variables to disable model server for testing
os.environ["DISABLE_MODEL_SERVER"] = "true"
os.environ["MODEL_SERVER_HOST"] = "disabled"
@@ -418,13 +420,15 @@ class TestSlackBotFederatedSearch:
name=f"test-llm-provider-{uuid4().hex[:8]}",
provider=LlmProviderNames.OPENAI,
api_key=api_key,
default_model_name="gpt-4o",
is_default_provider=True,
is_public=True,
)
db_session.add(llm_provider)
db_session.commit()
update_default_text_provider(
provider_id=llm_provider.id, model="gpt-4o", db_session=db_session
)
def _teardown_common_mocks(self, patches: list) -> None:
"""Stop all patches"""
for p in patches:

View File

@@ -6,6 +6,7 @@ import requests
from onyx.llm.constants import LlmProviderNames
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestLLMProvider
@@ -25,6 +26,7 @@ class LLMProviderManager:
personas: list[int] | None = None,
is_public: bool | None = None,
set_as_default: bool = True,
model_configurations: list[ModelConfigurationUpsertRequest] | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestLLMProvider:
email = "Unknown"
@@ -36,7 +38,6 @@ class LLMProviderManager:
llm_provider = LLMProviderUpsertRequest(
name=name or f"test-provider-{uuid4()}",
provider=provider or LlmProviderNames.OPENAI,
default_model_name=default_model_name or "gpt-4o-mini",
api_key=api_key or os.environ["OPENAI_API_KEY"],
api_base=api_base,
api_version=api_version,
@@ -44,7 +45,7 @@ class LLMProviderManager:
is_public=True if is_public is None else is_public,
groups=groups or [],
personas=personas or [],
model_configurations=[],
model_configurations=model_configurations or [],
api_key_changed=True,
)
@@ -138,11 +139,25 @@ class LLMProviderManager:
if (
fetched_llm_groups == llm_provider_groups
and llm_provider.provider == fetched_llm_provider.provider
and llm_provider.default_model_name
== fetched_llm_provider.default_model_name
and llm_provider.is_public == fetched_llm_provider.is_public
and set(fetched_llm_provider.personas) == set(llm_provider.personas)
):
return
if not verify_deleted:
raise ValueError(f"LLM Provider {llm_provider.id} not found")
@staticmethod
def build_model_configuration_upsert_request(
name: str,
is_visible: bool,
max_input_tokens: int | None,
supports_image_input: bool,
display_name: str | None,
) -> ModelConfigurationUpsertRequest:
return ModelConfigurationUpsertRequest(
name=name,
is_visible=is_visible,
max_input_tokens=max_input_tokens,
supports_image_input=supports_image_input,
display_name=display_name,
)

View File

@@ -28,8 +28,7 @@ class PersonaManager:
datetime_aware: bool = False,
document_set_ids: list[int] | None = None,
tool_ids: list[int] | None = None,
llm_model_provider_override: str | None = None,
llm_model_version_override: str | None = None,
default_model_configuration_id: int | None = None,
users: list[str] | None = None,
groups: list[int] | None = None,
label_ids: list[int] | None = None,
@@ -55,8 +54,7 @@ class PersonaManager:
recency_bias=recency_bias,
document_set_ids=document_set_ids or [],
tool_ids=tool_ids or [],
llm_model_provider_override=llm_model_provider_override,
llm_model_version_override=llm_model_version_override,
default_model_configuration_id=default_model_configuration_id,
users=[UUID(user) for user in (users or [])],
groups=groups or [],
label_ids=label_ids or [],
@@ -90,8 +88,7 @@ class PersonaManager:
datetime_aware=datetime_aware,
document_set_ids=document_set_ids or [],
tool_ids=tool_ids or [],
llm_model_provider_override=llm_model_provider_override,
llm_model_version_override=llm_model_version_override,
default_model_configuration_id=default_model_configuration_id,
users=users or [],
groups=groups or [],
label_ids=label_ids or [],
@@ -112,8 +109,7 @@ class PersonaManager:
datetime_aware: bool = False,
document_set_ids: list[int] | None = None,
tool_ids: list[int] | None = None,
llm_model_provider_override: str | None = None,
llm_model_version_override: str | None = None,
default_model_configuration_id: int | None = None,
users: list[str] | None = None,
groups: list[int] | None = None,
label_ids: list[int] | None = None,
@@ -137,11 +133,8 @@ class PersonaManager:
recency_bias=recency_bias or persona.recency_bias,
document_set_ids=document_set_ids or persona.document_set_ids,
tool_ids=tool_ids or persona.tool_ids,
llm_model_provider_override=(
llm_model_provider_override or persona.llm_model_provider_override
),
llm_model_version_override=(
llm_model_version_override or persona.llm_model_version_override
default_model_configuration_id=(
default_model_configuration_id or persona.default_model_configuration_id
),
users=[UUID(user) for user in (users or persona.users)],
groups=groups or persona.groups,
@@ -174,11 +167,8 @@ class PersonaManager:
datetime_aware=datetime_aware,
document_set_ids=updated_persona_data["document_sets"],
tool_ids=updated_persona_data["tools"],
llm_model_provider_override=updated_persona_data[
"llm_model_provider_override"
],
llm_model_version_override=updated_persona_data[
"llm_model_version_override"
default_model_configuration_id=updated_persona_data[
"default_model_configuration_id"
],
users=[user["email"] for user in updated_persona_data["users"]],
groups=updated_persona_data["groups"],
@@ -267,25 +257,14 @@ class PersonaManager:
)
)
if (
fetched_persona.llm_model_provider_override
!= persona.llm_model_provider_override
fetched_persona.default_model_configuration_id
!= persona.default_model_configuration_id
):
mismatches.append(
(
"llm_model_provider_override",
persona.llm_model_provider_override,
fetched_persona.llm_model_provider_override,
)
)
if (
fetched_persona.llm_model_version_override
!= persona.llm_model_version_override
):
mismatches.append(
(
"llm_model_version_override",
persona.llm_model_version_override,
fetched_persona.llm_model_version_override,
"default_model_configuration_id",
persona.default_model_configuration_id,
fetched_persona.default_model_configuration_id,
)
)
if fetched_persona.system_prompt != persona.system_prompt:

View File

@@ -111,6 +111,15 @@ class DATestUserGroup(BaseModel):
cc_pair_ids: list[int]
class DATestModelConfiguration(BaseModel):
id: int
name: str
is_visible: bool
max_input_tokens: int | None
supports_image_input: bool | None
display_name: str | None
class DATestLLMProvider(BaseModel):
id: int
name: str
@@ -123,6 +132,7 @@ class DATestLLMProvider(BaseModel):
personas: list[int]
api_base: str | None = None
api_version: str | None = None
model_configurations: list[DATestModelConfiguration] = Field(default_factory=list)
class DATestImageGenerationConfig(BaseModel):
@@ -157,8 +167,7 @@ class DATestPersona(BaseModel):
recency_bias: RecencyBiasSetting
document_set_ids: list[int]
tool_ids: list[int]
llm_model_provider_override: str | None
llm_model_version_override: str | None
default_model_configuration_id: int | None
users: list[str]
groups: list[int]
label_ids: list[int]

View File

@@ -42,7 +42,6 @@ def _create_provider_with_api(
llm_provider_data = {
"name": name,
"provider": provider_type,
"default_model_name": default_model,
"api_key": "test-api-key-for-auto-mode-testing",
"api_base": None,
"api_version": None,

View File

@@ -96,6 +96,8 @@ def assert_response_is_equivalent(
for name in actual_by_name:
actual_config = actual_by_name[name]
expected_config = expected_by_name[name]
assert "id" in actual_config, f"Config {name} is missing id"
del actual_config["id"] # This will not be in the expected config
assert actual_config == expected_config, (
f"Config mismatch for {name}:\n"
f"Actual: {actual_config}\n"

View File

@@ -11,6 +11,7 @@ from onyx.db.llm import fetch_user_group_ids
from onyx.db.models import LLMProvider as LLMProviderModel
from onyx.db.models import LLMProvider__Persona
from onyx.db.models import LLMProvider__UserGroup
from onyx.db.models import ModelConfiguration
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.models import User__UserGroup
@@ -48,22 +49,40 @@ def _create_llm_provider(
default_model_name=default_model_name,
deployment_name=None,
is_public=is_public,
# Use None instead of False to avoid unique constraint violation
# The is_default_provider column has unique=True, so only one True and one False allowed
is_default_provider=is_default if is_default else None,
is_default_vision_provider=False,
default_vision_model=None,
)
db_session.add(provider)
db_session.flush()
return provider
def _create_model_configuration(
db_session: Session,
*,
llm_provider_id: int,
name: str,
is_visible: bool,
max_input_tokens: int | None,
supports_image_input: bool,
display_name: str | None,
) -> ModelConfiguration:
model_configuration = ModelConfiguration(
llm_provider_id=llm_provider_id,
name=name,
is_visible=is_visible,
max_input_tokens=max_input_tokens,
supports_image_input=supports_image_input,
display_name=display_name,
)
db_session.add(model_configuration)
db_session.flush()
return model_configuration
def _create_persona(
db_session: Session,
*,
name: str,
provider_name: str,
default_model_configuration_id: int,
) -> Persona:
persona = Persona(
name=name,
@@ -74,8 +93,7 @@ def _create_persona(
llm_relevance_filter=True,
llm_filter_extraction=True,
recency_bias=RecencyBiasSetting.AUTO,
llm_model_provider_override=provider_name,
llm_model_version_override="gpt-4o-mini",
default_model_configuration_id=default_model_configuration_id,
system_prompt="System prompt",
task_prompt="Task prompt",
datetime_aware=True,
@@ -114,6 +132,15 @@ def test_can_user_access_llm_provider_or_logic(
is_public=True,
is_default=True,
)
_create_model_configuration(
db_session,
llm_provider_id=default_provider.id,
name="default-model",
is_visible=True,
max_input_tokens=None,
supports_image_input=False,
display_name=None,
)
# Locked provider - is_public=False with no restrictions
locked_provider = _create_llm_provider(
db_session,
@@ -122,6 +149,15 @@ def test_can_user_access_llm_provider_or_logic(
is_public=False,
is_default=False,
)
_create_model_configuration(
db_session,
llm_provider_id=locked_provider.id,
name="locked-model",
is_visible=True,
max_input_tokens=None,
supports_image_input=False,
display_name=None,
)
# Restricted provider - has both group AND persona restrictions (AND logic)
restricted_provider = _create_llm_provider(
db_session,
@@ -130,16 +166,34 @@ def test_can_user_access_llm_provider_or_logic(
is_public=False,
is_default=False,
)
restricted_model_1 = _create_model_configuration(
db_session,
llm_provider_id=restricted_provider.id,
name="restricted-model-1",
is_visible=True,
max_input_tokens=None,
supports_image_input=False,
display_name=None,
)
restricted_model_2 = _create_model_configuration(
db_session,
llm_provider_id=restricted_provider.id,
name="restricted-model-2",
is_visible=True,
max_input_tokens=None,
supports_image_input=False,
display_name=None,
)
allowed_persona = _create_persona(
db_session,
name="allowed-persona",
provider_name=restricted_provider.name,
default_model_configuration_id=restricted_model_1.id,
)
blocked_persona = _create_persona(
db_session,
name="blocked-persona",
provider_name=restricted_provider.name,
default_model_configuration_id=restricted_model_2.id,
)
access_group = UserGroup(name="access-group")
@@ -253,6 +307,15 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
is_public=True,
is_default=True,
)
_create_model_configuration(
db_session,
llm_provider_id=default_provider.id,
name="default-model",
is_visible=True,
max_input_tokens=None,
supports_image_input=False,
display_name=None,
)
restricted_provider = _create_llm_provider(
db_session,
name="restricted-provider",
@@ -260,11 +323,19 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
is_public=False,
is_default=False,
)
restricted_model_configuration = _create_model_configuration(
db_session,
llm_provider_id=restricted_provider.id,
name="restricted-model",
is_visible=True,
max_input_tokens=None,
supports_image_input=False,
display_name=None,
)
persona = _create_persona(
db_session,
name="fallback-persona",
provider_name=restricted_provider.name,
default_model_configuration_id=restricted_model_configuration.id,
)
access_group = UserGroup(name="persona-group")
@@ -300,13 +371,13 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
persona=persona,
user=admin_model,
)
assert allowed_llm.config.model_name == restricted_provider.default_model_name
assert allowed_llm.config.model_name == "gpt-4o-mini"
fallback_llm = get_llm_for_persona(
persona=persona,
user=basic_model,
)
assert fallback_llm.config.model_name == default_provider.default_model_name
assert fallback_llm.config.model_name == "gpt-4o"
def test_list_llm_provider_basics_excludes_non_public_unrestricted(
@@ -382,9 +453,19 @@ def test_provider_delete_clears_persona_references(reset: None) -> None:
is_public=False,
set_as_default=False,
user_performing_action=admin_user,
model_configurations=[
LLMProviderManager.build_model_configuration_upsert_request(
name="model-1",
is_visible=True,
max_input_tokens=None,
supports_image_input=False,
display_name=None,
)
],
)
persona = PersonaManager.create(
llm_model_provider_override=provider.name,
default_model_configuration_id=provider.model_configurations[0].id,
user_performing_action=admin_user,
)
@@ -394,11 +475,11 @@ def test_provider_delete_clears_persona_references(reset: None) -> None:
user_performing_action=admin_user,
)
# Verify the persona now falls back to default (llm_model_provider_override cleared)
# Verify the persona now falls back to default (default_model_configuration_id cleared)
persona_response = requests.get(
f"{API_SERVER_URL}/persona/{persona.id}",
headers=admin_user.headers,
)
assert persona_response.status_code == 200
updated_persona = persona_response.json()
assert updated_persona["llm_model_provider_override"] is None
assert updated_persona["default_model_configuration_id"] is None

View File

@@ -47,7 +47,8 @@ class TestSyncModelConfigurations:
assert result == 2 # Two new models
assert mock_session.execute.call_count == 2
mock_session.commit.assert_called_once()
# Single commit at end for atomicity (flow mappings use flush)
assert mock_session.commit.call_count == 1
def test_skips_existing_models(self) -> None:
"""Test that existing models are not overwritten."""

View File

@@ -109,7 +109,9 @@ class TestGetOllamaAvailableModels:
# Verify DB operations were called
assert mock_session.execute.call_count == 3 # 3 models inserted
mock_session.commit.assert_called_once()
assert (
mock_session.commit.call_count == 6
) # 3 for model config + 3 for flow mapping
def test_no_sync_when_provider_name_not_specified(
self, mock_ollama_tags_response: dict, mock_ollama_show_response: dict
@@ -251,7 +253,7 @@ class TestGetOpenRouterAvailableModels:
# Verify DB operations were called
assert mock_session.execute.call_count == 3 # 3 models inserted
mock_session.commit.assert_called_once()
mock_session.commit.call_count == 6 # 3 for model config + 3 for flow mapping
def test_preserves_existing_models_on_sync(
self, mock_openrouter_response: dict

View File

@@ -43,8 +43,7 @@ export interface MinimalPersonaSnapshot {
// Unique sources from all knowledge (document sets + hierarchy nodes)
// Used to populate source filters in chat
knowledge_sources?: ValidSources[];
llm_model_version_override?: string;
llm_model_provider_override?: string;
default_model_configuration_id?: number;
uploaded_image_id?: string;
icon_name?: string;

View File

@@ -16,8 +16,7 @@ interface PersonaUpsertRequest {
recency_bias: string;
llm_filter_extraction: boolean;
llm_relevance_filter: boolean | null;
llm_model_provider_override: string | null;
llm_model_version_override: string | null;
default_model_configuration_id: number | null;
starter_messages: StarterMessage[] | null;
users?: string[];
groups: number[];
@@ -48,8 +47,7 @@ export interface PersonaUpsertParameters {
num_chunks: number | null;
is_public: boolean;
llm_relevance_filter: boolean | null;
llm_model_provider_override: string | null;
llm_model_version_override: string | null;
default_model_configuration_id: number | null;
starter_messages: StarterMessage[] | null;
users?: string[];
groups: number[];
@@ -88,8 +86,7 @@ function buildPersonaUpsertRequest({
uploaded_image_id,
is_default_persona,
llm_relevance_filter,
llm_model_provider_override,
llm_model_version_override,
default_model_configuration_id,
starter_messages,
label_ids,
replace_base_system_prompt,
@@ -114,8 +111,7 @@ function buildPersonaUpsertRequest({
recency_bias: "base_decay",
llm_filter_extraction: false,
llm_relevance_filter: llm_relevance_filter ?? null,
llm_model_provider_override: llm_model_provider_override ?? null,
llm_model_version_override: llm_model_version_override ?? null,
default_model_configuration_id: default_model_configuration_id ?? null,
starter_messages: starter_messages ?? null,
display_priority: null,
label_ids: label_ids ?? null,

View File

@@ -144,6 +144,7 @@ export function ModelConfigurationField({
<CreateButton
onClick={() => {
arrayHelpers.push({
id: -1,
name: "",
is_visible: true,
// Use null so Yup.number().nullable() accepts empty inputs

View File

@@ -214,6 +214,7 @@ describe("Custom LLM Provider Configuration Workflow", () => {
default_model_name: "claude-3-opus",
model_configurations: [
{
id: 1,
name: "claude-3-opus",
is_visible: true,
max_input_tokens: null,

View File

@@ -7,7 +7,7 @@ import {
Formik,
ErrorMessage,
} from "formik";
import { LLMProviderFormProps, LLMProviderView } from "../interfaces";
import { LLMProviderFormProps } from "../interfaces";
import * as Yup from "yup";
import { ProviderFormEntrypointWrapper } from "./components/FormWrapper";
import { DisplayNameField } from "./components/DisplayNameField";
@@ -69,7 +69,7 @@ export function CustomForm({
...modelConfiguration,
max_input_tokens: modelConfiguration.max_input_tokens ?? null,
})
) ?? [{ name: "", is_visible: true, max_input_tokens: null }],
) ?? [{ id: -1, name: "", is_visible: true, max_input_tokens: null }],
custom_config_list: existingLlmProvider?.custom_config
? Object.entries(existingLlmProvider.custom_config)
: [],

View File

@@ -14,8 +14,11 @@ export interface ModelConfiguration {
is_visible: boolean;
max_input_tokens: number | null;
supports_image_input: boolean | null;
supports_reasoning?: boolean;
display_name?: string;
}
export interface ModelConfigurationView extends ModelConfiguration {
id: number;
supports_reasoning?: boolean;
provider_display_name?: string;
vendor?: string;
version?: string;
@@ -55,7 +58,7 @@ export interface LLMProvider {
deployment_name: string | null;
default_vision_model: string | null;
is_default_vision_provider: boolean | null;
model_configurations: ModelConfiguration[];
model_configurations: ModelConfigurationView[];
}
export interface LLMProviderView extends LLMProvider {
@@ -78,7 +81,7 @@ export interface LLMProviderDescriptor {
is_public?: boolean;
groups?: number[];
personas?: number[];
model_configurations: ModelConfiguration[];
model_configurations: ModelConfigurationView[];
}
export interface OllamaModelResponse {

View File

@@ -618,9 +618,11 @@ export function useLlmManager(
setCurrentLlm(
getValidLlmDescriptor(currentChatSession.current_alternate_model)
);
} else if (liveAssistant?.llm_model_version_override) {
} else if (liveAssistant?.default_model_configuration_id) {
setCurrentLlm(
getValidLlmDescriptor(liveAssistant.llm_model_version_override)
getValidLlmDescriptorFromModelId(
liveAssistant.default_model_configuration_id
)
);
} else if (userHasManuallyOverriddenLLM) {
// if the user has an override and there's nothing special about the
@@ -640,6 +642,17 @@ export function useLlmManager(
setChatSession(currentChatSession || null);
};
function getValidLlmDescriptorFromModelId(
modelId: number | null | undefined
): LlmDescriptor {
// Get the name of the model
const model = llmProviders
?.find((p) => p.model_configurations.find((m) => m.id === modelId))
?.model_configurations.find((m) => m.id === modelId);
return getValidLlmDescriptor(model?.name);
}
function getValidLlmDescriptor(
modelName: string | null | undefined
): LlmDescriptor {

View File

@@ -18,15 +18,20 @@ export function getFinalLLM(
let model = defaultProvider?.default_model_name || "";
if (persona) {
// Map "provider override" to actual LLLMProvider
if (persona.llm_model_provider_override) {
// Map "model override" to actual model
if (persona.default_model_configuration_id) {
const underlyingProvider = llmProviders.find(
(item: LLMProviderDescriptor) =>
item.name === persona.llm_model_provider_override
item.model_configurations.find(
(m) => m.id === persona.default_model_configuration_id
)
);
const underlyingModel = underlyingProvider?.model_configurations.find(
(m) => m.id === persona.default_model_configuration_id
);
provider = underlyingProvider?.provider || provider;
model = underlyingModel?.name || model;
}
model = persona.llm_model_version_override || model;
}
if (currentLlm) {
@@ -41,26 +46,27 @@ export function getLLMProviderOverrideForPersona(
liveAssistant: MinimalPersonaSnapshot,
llmProviders: LLMProviderDescriptor[]
): LlmDescriptor | null {
const overrideProvider = liveAssistant.llm_model_provider_override;
const overrideModel = liveAssistant.llm_model_version_override;
const defaultModelConfigurationId =
liveAssistant.default_model_configuration_id;
if (!overrideModel) {
if (!defaultModelConfigurationId) {
return null;
}
const matchingProvider = llmProviders.find(
(provider) =>
(overrideProvider ? provider.name === overrideProvider : true) &&
provider.model_configurations
.map((modelConfiguration) => modelConfiguration.name)
.includes(overrideModel)
const matchingProvider = llmProviders.find((provider) =>
provider.model_configurations.find(
(m) => m.id === defaultModelConfigurationId
)
);
const underlyingModel = matchingProvider?.model_configurations.find(
(m) => m.id === defaultModelConfigurationId
);
if (matchingProvider) {
if (matchingProvider && underlyingModel) {
return {
name: matchingProvider.name,
provider: matchingProvider.provider,
modelName: overrideModel,
modelName: underlyingModel.name,
};
}

View File

@@ -451,37 +451,59 @@ export default function AgentEditorPage({
const deleteAgentModal = useCreateModal();
// LLM Model Selection
const getCurrentLlm = useCallback(
(values: any, llmProviders: any) =>
values.llm_model_version_override && values.llm_model_provider_override
? (() => {
const provider = llmProviders?.find(
(p: any) => p.name === values.llm_model_provider_override
);
return structureValue(
values.llm_model_provider_override,
provider?.provider || "",
values.llm_model_version_override
);
})()
: null,
const findProviderAndModel = useCallback(
(
providers: any[] | undefined,
modelPredicate: (m: any) => boolean,
providerPredicate?: (p: any) => boolean
) => {
const candidateProviders = providerPredicate
? providers?.filter(providerPredicate)
: providers;
const provider = candidateProviders?.find(
(p: any) => p.model_configurations?.find(modelPredicate)
);
const model = provider?.model_configurations?.find(modelPredicate);
return { provider, model };
},
[]
);
const getCurrentLlm = useCallback(
(values: any, llmProviders: any) =>
values.default_model_configuration_id
? (() => {
const { provider, model } = findProviderAndModel(
llmProviders,
(m: any) => m.id === values.default_model_configuration_id
);
return structureValue(
provider?.name || "",
provider?.provider || "",
model?.name || ""
);
})()
: null,
[findProviderAndModel]
);
const onLlmSelect = useCallback(
(selected: string | null, setFieldValue: any) => {
(selected: string | null, setFieldValue: any, llmProviders: any) => {
if (selected === null) {
setFieldValue("llm_model_version_override", null);
setFieldValue("llm_model_provider_override", null);
setFieldValue("default_model_configuration_id", null);
} else {
const { modelName, name } = parseLlmDescriptor(selected);
if (modelName && name) {
setFieldValue("llm_model_version_override", modelName);
setFieldValue("llm_model_provider_override", name);
const { model } = findProviderAndModel(
llmProviders,
(m: any) => m.name === modelName,
(p: any) => p.name === name
);
setFieldValue("default_model_configuration_id", model?.id || null);
}
}
},
[]
[findProviderAndModel]
);
// Hooks for Knowledge section
@@ -576,10 +598,8 @@ export default function AgentEditorPage({
selected_sources: [] as ValidSources[],
// Advanced
llm_model_provider_override:
existingAgent?.llm_model_provider_override ?? null,
llm_model_version_override:
existingAgent?.llm_model_version_override ?? null,
default_model_configuration_id:
existingAgent?.default_model_configuration_id ?? null,
knowledge_cutoff_date: existingAgent?.search_start_date
? new Date(existingAgent.search_start_date)
: null,
@@ -701,8 +721,7 @@ export default function AgentEditorPage({
),
// Advanced
llm_model_provider_override: Yup.string().nullable().optional(),
llm_model_version_override: Yup.string().nullable().optional(),
default_model_configuration_id: Yup.number().nullable().optional(),
knowledge_cutoff_date: Yup.date().nullable().optional(),
replace_base_system_prompt: Yup.boolean(),
reminders: Yup.string().optional(),
@@ -805,8 +824,8 @@ export default function AgentEditorPage({
// recency_bias: ...,
// llm_filter_extraction: ...,
llm_relevance_filter: false,
llm_model_provider_override: values.llm_model_provider_override || null,
llm_model_version_override: values.llm_model_version_override || null,
default_model_configuration_id:
values.default_model_configuration_id || null,
starter_messages: finalStarterMessages,
users: values.shared_user_ids,
groups: values.shared_group_ids,
@@ -1407,7 +1426,11 @@ export default function AgentEditorPage({
llmProviders
)}
onSelect={(selected) =>
onLlmSelect(selected, setFieldValue)
onLlmSelect(
selected,
setFieldValue,
llmProviders
)
}
/>
</InputLayouts.Horizontal>